'Group by id and create a column based on priority in Pyspark

Can someone help me with the below. I have an input dataframe.

ID process_type STP_stagewise
1 loan_creation Manual
1 loan creation NSTP
1 reimbursement STP
2 loan_creation STP
2 reimbursement NSTP
3 loan_creation Manual
3 loan_creation STP
4 loan_creation Manual
4 loan_creation NSTP

Output dataframe required:

ID process_type STP_stagewise STP_type
1 loan_creation Manual Manual
1 loan creation NSTP Manual
1 reimbursement STP STP
2 loan_creation STP STP
2 reimbursement NSTP NSTP
3 loan_creation Manual Manual
3 loan_creation STP Manual
4 loan_creation NSTP NSTP
4 loan_creation NSTP NSTP

I need to groupby id and process_type column and prioritize, Manual >> NSTP >> STP and create a different column.

Can someone provide an approach to solve this. Thanks in Advance.

Slight change along with ID, group by should be done on process type also.



Solution 1:[1]

One way you can solve this is by aggregating at id and collecting all the distinct STP_stagewise into a list and sorting it with a custom_sort_map to get the first index element and finally joining it back to your main DataFrame

Data Preparation

s = StringIO("""
ID  STP_stagewise
1   Manual
1   NSTP
1   STP
2   STP
2   NSTP
3   Manual
3   STP
4   Manual
4   NSTP
""")

df = pd.read_csv(s,delimiter='\t')

sparkDF = sql.createDataFrame(df)

sparkDF.show()

+---+-------------+
| ID|STP_stagewise|
+---+-------------+
|  1|       Manual|
|  1|         NSTP|
|  1|          STP|
|  2|          STP|
|  2|         NSTP|
|  3|       Manual|
|  3|          STP|
|  4|       Manual|
|  4|         NSTP|
+---+-------------+

Aggregation - Collect Set & Sort

custom_sort_map = {'Manual':0,'NSTP':1,'STP':2}

udf_custom_sort = F.udf(lambda x: sorted(x,key=lambda x:custom_sort_map[x]), ArrayType(StringType()))

stpAgg = sparkDF.groupBy(F.col('ID')).agg(F.collect_set(F.col('STP_stagewise')).alias('STP_stagewise_set'))\
                .withColumn('sorted_STP_stagewise_set',udf_custom_sort('STP_stagewise_set'))\
                .withColumn('STP_type',F.col('sorted_STP_stagewise_set').getItem(0))

stpAgg.show()

+---+-------------------+------------------------+--------+
| ID|  STP_stagewise_set|sorted_STP_stagewise_set|STP_type|
+---+-------------------+------------------------+--------+
|  1|[STP, NSTP, Manual]|     [Manual, NSTP, STP]|  Manual|
|  3|      [STP, Manual]|           [Manual, STP]|  Manual|
|  2|        [STP, NSTP]|             [NSTP, STP]|    NSTP|
|  4|     [NSTP, Manual]|          [Manual, NSTP]|  Manual|
+---+-------------------+------------------------+--------+

Join

sparkDF = sparkDF.join(stpAgg
                       ,sparkDF['ID'] == stpAgg['ID']
                       ,'inner'
                      ).select(sparkDF['*'],stpAgg['STP_type'])

sparkDF.show()

+---+-------------+--------+
| ID|STP_stagewise|STP_type|
+---+-------------+--------+
|  1|       Manual|  Manual|
|  1|         NSTP|  Manual|
|  1|          STP|  Manual|
|  3|       Manual|  Manual|
|  3|          STP|  Manual|
|  2|          STP|    NSTP|
|  2|         NSTP|    NSTP|
|  4|       Manual|  Manual|
|  4|         NSTP|  Manual|
+---+-------------+--------+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Vaebhav