'Recursive computation with pypsark

I want to be able, within a window, to calculate a sum between two columns and modify the value of a column if this sum becomes odd. Thus, the modification of this column will de facto modify the sum and so on. However, I don't know how to scan "row by row" my data in an efficient way. Would you have a tip for that?

I am attaching a few example rows and what I would like to achieve for clarity: My window will be based on the ID column

my_data = spark.createDataFrame([
    (1, 1, 0), 
    (1, 0, 0), 
    (1, 0, 1), 
    (1, 0, 0), 
    (1, 0, 1), 
    (1, 0, 0), 
    (1, 1, 0),
    (1, 0, 0),
    (1, 0, 1),
    (1, 0, 0), 
    (1, 1, 0),
    (1, 0, 0),
    (1, 1, 0),
    (1, 0, 1),
],
    ['ID','flag_1','flag_2'])

Thus my issue is to derive the sum and at the same time to modify the flag_2 if the sum is becoming odd. sum is here the expected results and flag_2_results the "cleaned" version of flag_2 as explained,

my_data = spark.createDataFrame([
    (1, 1, 0, 0, 1), 
    (1, 0, 0, 0, 1), 
    (1, 0, 1, 1, 2), 
    (1, 0, 0, 0, 2), 
    (1, 0, 1, 0, 2), 
    (1, 0, 0, 0, 2), 
    (1, 1, 0, 0, 3),
    (1, 0, 0, 0, 3),
    (1, 0, 1, 1, 4),
    (1, 0, 0, 0, 4),
    (1, 1, 0, 0, 5),
    (1, 0, 0, 0, 5),
    (1, 1, 0, 0, 6),
    (1, 0, 1, 0, 6),],
    ['ID','flag_1','flag_2', 'flag_2_results', 'sum'])
  • Raw n°3 : we keep the flag_2 = 1 as the sum was odd.
  • Raw n°5 : we do not keep the flag_2 = 1 as the sum was even, thus the sum is not changing until flag_1 = 1.
  • Last raw : we do not keep the flag_2 = 1 (even if it's the first after a flag_1 = 1) because it would lead to an odd cumulative sum

Thank you for your help,



Solution 1:[1]

According to your last comment, you do not have that much lines to process. Then, I'd advice you to use an UDF only on the lines where "flag1+flag2>0" :

from pyspark.sql import functions as F, Window as W, types as T


df = my_data.groupBy("ID").agg(
    F.collect_list(F.struct(F.col("posTime"), F.col("flag_1"), F.col("flag_2"))).alias(
        "data"
    )
)

schm = T.ArrayType(
    T.StructType(
        [
            T.StructField("posTime", T.IntegerType()),
            T.StructField("flag_1", T.IntegerType()),
            T.StructField("flag_2", T.IntegerType()),
            T.StructField("flag_2_result", T.IntegerType()),
            T.StructField("sum", T.IntegerType()),
        ]
    )
)


@F.udf(schm)
def process(data):
    accumulator = 0
    out = []
    data.sort(key=lambda x: x["posTime"])
    for l in data:
        flag_2_result = 0
        accumulator += l["flag_1"]
        if l["flag_2"] and accumulator % 2 == 1:
            accumulator += l["flag_2"]
            flag_2_result = 1
        out.append((l["posTime"], l["flag_1"], l["flag_2"], flag_2_result, accumulator))
    return out


df.select("ID", F.explode(process(F.col("data"))).alias("data")).select(
    "ID", "data.*"
).show()

and the result :

+---+-------+------+------+-------------+---+                                   
| ID|posTime|flag_1|flag_2|flag_2_result|sum|
+---+-------+------+------+-------------+---+
|  1|      1|     1|     0|            0|  1|
|  1|      2|     0|     0|            0|  1|
|  1|      3|     0|     1|            1|  2|
|  1|      4|     0|     0|            0|  2|
|  1|      5|     0|     1|            0|  2|
|  1|      6|     0|     0|            0|  2|
|  1|      7|     1|     0|            0|  3|
|  1|      8|     0|     0|            0|  3|
|  1|      9|     0|     1|            1|  4|
|  1|     10|     0|     0|            0|  4|
|  1|     11|     1|     0|            0|  5|
|  1|     12|     0|     0|            0|  5|
|  1|     13|     1|     0|            0|  6|
|  1|     14|     0|     1|            0|  6|
+---+-------+------+------+-------------+---+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Steven