'Group by id and create a column based on priority in Pyspark
Can someone help me with the below. I have an input dataframe.
ID | process_type | STP_stagewise |
---|---|---|
1 | loan_creation | Manual |
1 | loan creation | NSTP |
1 | reimbursement | STP |
2 | loan_creation | STP |
2 | reimbursement | NSTP |
3 | loan_creation | Manual |
3 | loan_creation | STP |
4 | loan_creation | Manual |
4 | loan_creation | NSTP |
Output dataframe required:
ID | process_type | STP_stagewise | STP_type |
---|---|---|---|
1 | loan_creation | Manual | Manual |
1 | loan creation | NSTP | Manual |
1 | reimbursement | STP | STP |
2 | loan_creation | STP | STP |
2 | reimbursement | NSTP | NSTP |
3 | loan_creation | Manual | Manual |
3 | loan_creation | STP | Manual |
4 | loan_creation | NSTP | NSTP |
4 | loan_creation | NSTP | NSTP |
I need to groupby id and process_type column and prioritize, Manual >> NSTP >> STP and create a different column.
Can someone provide an approach to solve this. Thanks in Advance.
Slight change along with ID, group by should be done on process type also.
Solution 1:[1]
One way you can solve this is by aggregating at id
and collecting all the distinct STP_stagewise
into a list and sorting it with a custom_sort_map
to get the first index element and finally joining it back to your main DataFrame
Data Preparation
s = StringIO("""
ID STP_stagewise
1 Manual
1 NSTP
1 STP
2 STP
2 NSTP
3 Manual
3 STP
4 Manual
4 NSTP
""")
df = pd.read_csv(s,delimiter='\t')
sparkDF = sql.createDataFrame(df)
sparkDF.show()
+---+-------------+
| ID|STP_stagewise|
+---+-------------+
| 1| Manual|
| 1| NSTP|
| 1| STP|
| 2| STP|
| 2| NSTP|
| 3| Manual|
| 3| STP|
| 4| Manual|
| 4| NSTP|
+---+-------------+
Aggregation - Collect Set & Sort
custom_sort_map = {'Manual':0,'NSTP':1,'STP':2}
udf_custom_sort = F.udf(lambda x: sorted(x,key=lambda x:custom_sort_map[x]), ArrayType(StringType()))
stpAgg = sparkDF.groupBy(F.col('ID')).agg(F.collect_set(F.col('STP_stagewise')).alias('STP_stagewise_set'))\
.withColumn('sorted_STP_stagewise_set',udf_custom_sort('STP_stagewise_set'))\
.withColumn('STP_type',F.col('sorted_STP_stagewise_set').getItem(0))
stpAgg.show()
+---+-------------------+------------------------+--------+
| ID| STP_stagewise_set|sorted_STP_stagewise_set|STP_type|
+---+-------------------+------------------------+--------+
| 1|[STP, NSTP, Manual]| [Manual, NSTP, STP]| Manual|
| 3| [STP, Manual]| [Manual, STP]| Manual|
| 2| [STP, NSTP]| [NSTP, STP]| NSTP|
| 4| [NSTP, Manual]| [Manual, NSTP]| Manual|
+---+-------------------+------------------------+--------+
Join
sparkDF = sparkDF.join(stpAgg
,sparkDF['ID'] == stpAgg['ID']
,'inner'
).select(sparkDF['*'],stpAgg['STP_type'])
sparkDF.show()
+---+-------------+--------+
| ID|STP_stagewise|STP_type|
+---+-------------+--------+
| 1| Manual| Manual|
| 1| NSTP| Manual|
| 1| STP| Manual|
| 3| Manual| Manual|
| 3| STP| Manual|
| 2| STP| NSTP|
| 2| NSTP| NSTP|
| 4| Manual| Manual|
| 4| NSTP| Manual|
+---+-------------+--------+
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Vaebhav |