'Recursive computation with pypsark
I want to be able, within a window, to calculate a sum between two columns and modify the value of a column if this sum becomes odd. Thus, the modification of this column will de facto modify the sum and so on. However, I don't know how to scan "row by row" my data in an efficient way. Would you have a tip for that?
I am attaching a few example rows and what I would like to achieve for clarity:
My window will be based on the ID column
my_data = spark.createDataFrame([
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 1, 0),
(1, 0, 1),
],
['ID','flag_1','flag_2'])
Thus my issue is to derive the sum and at the same time to modify the flag_2 if the sum is becoming odd. sum is here the expected results and flag_2_results the "cleaned" version of flag_2 as explained,
my_data = spark.createDataFrame([
(1, 1, 0, 0, 1),
(1, 0, 0, 0, 1),
(1, 0, 1, 1, 2),
(1, 0, 0, 0, 2),
(1, 0, 1, 0, 2),
(1, 0, 0, 0, 2),
(1, 1, 0, 0, 3),
(1, 0, 0, 0, 3),
(1, 0, 1, 1, 4),
(1, 0, 0, 0, 4),
(1, 1, 0, 0, 5),
(1, 0, 0, 0, 5),
(1, 1, 0, 0, 6),
(1, 0, 1, 0, 6),],
['ID','flag_1','flag_2', 'flag_2_results', 'sum'])
- Raw n°3 : we keep the
flag_2 = 1as thesumwas odd. - Raw n°5 : we do not keep the
flag_2 = 1as thesumwas even, thus thesumis not changing untilflag_1 = 1. - Last raw : we do not keep the
flag_2 = 1(even if it's the first after aflag_1 = 1) because it would lead to an odd cumulative sum
Thank you for your help,
Solution 1:[1]
According to your last comment, you do not have that much lines to process. Then, I'd advice you to use an UDF only on the lines where "flag1+flag2>0" :
from pyspark.sql import functions as F, Window as W, types as T
df = my_data.groupBy("ID").agg(
F.collect_list(F.struct(F.col("posTime"), F.col("flag_1"), F.col("flag_2"))).alias(
"data"
)
)
schm = T.ArrayType(
T.StructType(
[
T.StructField("posTime", T.IntegerType()),
T.StructField("flag_1", T.IntegerType()),
T.StructField("flag_2", T.IntegerType()),
T.StructField("flag_2_result", T.IntegerType()),
T.StructField("sum", T.IntegerType()),
]
)
)
@F.udf(schm)
def process(data):
accumulator = 0
out = []
data.sort(key=lambda x: x["posTime"])
for l in data:
flag_2_result = 0
accumulator += l["flag_1"]
if l["flag_2"] and accumulator % 2 == 1:
accumulator += l["flag_2"]
flag_2_result = 1
out.append((l["posTime"], l["flag_1"], l["flag_2"], flag_2_result, accumulator))
return out
df.select("ID", F.explode(process(F.col("data"))).alias("data")).select(
"ID", "data.*"
).show()
and the result :
+---+-------+------+------+-------------+---+
| ID|posTime|flag_1|flag_2|flag_2_result|sum|
+---+-------+------+------+-------------+---+
| 1| 1| 1| 0| 0| 1|
| 1| 2| 0| 0| 0| 1|
| 1| 3| 0| 1| 1| 2|
| 1| 4| 0| 0| 0| 2|
| 1| 5| 0| 1| 0| 2|
| 1| 6| 0| 0| 0| 2|
| 1| 7| 1| 0| 0| 3|
| 1| 8| 0| 0| 0| 3|
| 1| 9| 0| 1| 1| 4|
| 1| 10| 0| 0| 0| 4|
| 1| 11| 1| 0| 0| 5|
| 1| 12| 0| 0| 0| 5|
| 1| 13| 1| 0| 0| 6|
| 1| 14| 0| 1| 0| 6|
+---+-------+------+------+-------------+---+
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Steven |
