'Window function with PySpark

I have a PySpark Dataframe and my goal is to create a Flag column whose value depends on the value of the Amount column. Basically, for each Group, I want to know if in any of the first three months, there is an amount greater than 0 and if that is the case, the value of the Flag column will be 1 for all the group, otherwise the value will be 0.

I will include an example to clarify a bit better.

Initial PySpark Dataframe:

Group Month Amount
A 1 0
A 2 0
A 3 35
A 4 0
A 5 0
B 1 0
B 2 0
C 1 0
C 2 0
C 3 0
C 4 13
D 1 0
D 2 24
D 3 0

Final PySpark Dataframe:

Group Month Amount Flag
A 1 0 1
A 2 0 1
A 3 35 1
A 4 0 1
A 5 0 1
B 1 0 0
B 2 0 0
C 1 0 0
C 2 0 0
C 3 0 0
C 4 13 0
D 1 0 1
D 2 24 1
D 3 0 1

Basically, what I want is for each group, to sum the amount of the first 3 months. If that sum is greater than 0, the flag is 1 for all the elements of the group, and otherwise is 0.



Solution 1:[1]

You can create the flag column by applying a Window function. Create a psuedo-column which becomes 1 if the criteria is met and then finally sum over the psuedo-column and if it's greater than 0, then there was atleast once row that met the criteria and set the flag to 1.

from pyspark.sql import functions as F
from pyspark.sql import Window as W

data = [("A", 1, 0, ), 
("A", 2, 0, ), 
("A", 3, 35, ), 
("A", 4, 0, ), 
("A", 5, 0, ), 
("B", 1, 0, ), 
("B", 2, 0, ), 
("C", 1, 0, ), 
("C", 2, 0, ), 
("C", 3, 0, ), 
("C", 4, 13, ), 
("D", 1, 0, ), 
("D", 2, 24, ), 
("D", 3, 0, ), ]

df = spark.createDataFrame(data, ("Group", "Month", "Amount", ))

ws = W.partitionBy("Group").orderBy("Month").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

criteria = F.when((F.col("Month") < 4) & (F.col("Amount") > 0), F.lit(1)).otherwise(F.lit(0))

(df.withColumn("flag", F.when(F.sum(criteria).over(ws) > 0, F.lit(1)).otherwise(F.lit(0)))
).show()

"""
+-----+-----+------+----+
|Group|Month|Amount|flag|
+-----+-----+------+----+
|    A|    1|     0|   1|
|    A|    2|     0|   1|
|    A|    3|    35|   1|
|    A|    4|     0|   1|
|    A|    5|     0|   1|
|    B|    1|     0|   0|
|    B|    2|     0|   0|
|    C|    1|     0|   0|
|    C|    2|     0|   0|
|    C|    3|     0|   0|
|    C|    4|    13|   0|
|    D|    1|     0|   1|
|    D|    2|    24|   1|
|    D|    3|     0|   1|
+-----+-----+------+----+
"""

Solution 2:[2]

You can use Window function with count and when.

w = Window.partitionBy('Group')
df = df.withColumn('Flag', F.count(
        F.when((F.col('Month') < 4) & (F.col('Amount') > 0), True)).over(w))
     .withColumn('Flag', F.when(F.col('Flag') > 0, 1).otherwise(0))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Nithish
Solution 2 Emma