'Window function with PySpark
I have a PySpark Dataframe and my goal is to create a Flag column whose value depends on the value of the Amount column.
Basically, for each Group, I want to know if in any of the first three months, there is an amount greater than 0 and if that is the case, the value of the Flag column will be 1 for all the group, otherwise the value will be 0.
I will include an example to clarify a bit better.
Initial PySpark Dataframe:
| Group | Month | Amount |
|---|---|---|
| A | 1 | 0 |
| A | 2 | 0 |
| A | 3 | 35 |
| A | 4 | 0 |
| A | 5 | 0 |
| B | 1 | 0 |
| B | 2 | 0 |
| C | 1 | 0 |
| C | 2 | 0 |
| C | 3 | 0 |
| C | 4 | 13 |
| D | 1 | 0 |
| D | 2 | 24 |
| D | 3 | 0 |
Final PySpark Dataframe:
| Group | Month | Amount | Flag |
|---|---|---|---|
| A | 1 | 0 | 1 |
| A | 2 | 0 | 1 |
| A | 3 | 35 | 1 |
| A | 4 | 0 | 1 |
| A | 5 | 0 | 1 |
| B | 1 | 0 | 0 |
| B | 2 | 0 | 0 |
| C | 1 | 0 | 0 |
| C | 2 | 0 | 0 |
| C | 3 | 0 | 0 |
| C | 4 | 13 | 0 |
| D | 1 | 0 | 1 |
| D | 2 | 24 | 1 |
| D | 3 | 0 | 1 |
Basically, what I want is for each group, to sum the amount of the first 3 months. If that sum is greater than 0, the flag is 1 for all the elements of the group, and otherwise is 0.
Solution 1:[1]
You can create the flag column by applying a Window function. Create a psuedo-column which becomes 1 if the criteria is met and then finally sum over the psuedo-column and if it's greater than 0, then there was atleast once row that met the criteria and set the flag to 1.
from pyspark.sql import functions as F
from pyspark.sql import Window as W
data = [("A", 1, 0, ),
("A", 2, 0, ),
("A", 3, 35, ),
("A", 4, 0, ),
("A", 5, 0, ),
("B", 1, 0, ),
("B", 2, 0, ),
("C", 1, 0, ),
("C", 2, 0, ),
("C", 3, 0, ),
("C", 4, 13, ),
("D", 1, 0, ),
("D", 2, 24, ),
("D", 3, 0, ), ]
df = spark.createDataFrame(data, ("Group", "Month", "Amount", ))
ws = W.partitionBy("Group").orderBy("Month").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
criteria = F.when((F.col("Month") < 4) & (F.col("Amount") > 0), F.lit(1)).otherwise(F.lit(0))
(df.withColumn("flag", F.when(F.sum(criteria).over(ws) > 0, F.lit(1)).otherwise(F.lit(0)))
).show()
"""
+-----+-----+------+----+
|Group|Month|Amount|flag|
+-----+-----+------+----+
| A| 1| 0| 1|
| A| 2| 0| 1|
| A| 3| 35| 1|
| A| 4| 0| 1|
| A| 5| 0| 1|
| B| 1| 0| 0|
| B| 2| 0| 0|
| C| 1| 0| 0|
| C| 2| 0| 0|
| C| 3| 0| 0|
| C| 4| 13| 0|
| D| 1| 0| 1|
| D| 2| 24| 1|
| D| 3| 0| 1|
+-----+-----+------+----+
"""
Solution 2:[2]
You can use Window function with count and when.
w = Window.partitionBy('Group')
df = df.withColumn('Flag', F.count(
F.when((F.col('Month') < 4) & (F.col('Amount') > 0), True)).over(w))
.withColumn('Flag', F.when(F.col('Flag') > 0, 1).otherwise(0))
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Nithish |
| Solution 2 | Emma |
