'Get value from first lead row that has a different value

I have a list of ids, a sequence number of messages (seq) and a value (e.g. timestamps). Multiple rows can have the same sequence number. There are some other columns with different values in every row, but I excluded them as they are not important.

Within all messages from a deviceId (=partitionBy), I need to sort by sequence_number (=orderBy) and add the 'ts'-value of the next message with a different sequence_number to all messages of the current sequence_number.

I got so far as to retrieve the value of the next row if that row has a different sequence number. But since the "next row with a different sequence number" could potentially be x rows far away, I would have to add specific .when(condition, ...) blocks for x rows ahead.

I was wondering if there was a better solution which works no matter how "far away" the next row with a different sequence number is. I tried a .otherwise(lead(col("next_value"), 1), but since I am just building the column, it doesn't work.

My Code & reproducible example:

data = [
    (1, 1, "A"),
    (2, 1, "G"),
    (2, 2, "F"),
    (3, 1, "A"),
    (4, 1, "A"),
    (4, 2, "B"),
    (4, 3, "C"),
    (4, 3, "C"),
    (4, 3, "C"),
    (4, 4, "D")
]

df = spark.createDataFrame(data=data, schema=["id", "seq", "ts"])

df.printSchema()
df.show(10, False)


window = Window \
    .orderBy("id", "seq") \
    .partitionBy("id")
# I could potentially do this 100x if the next lead-value is 100 rows away, but I wonder if there isn't a better solution.
is_different_seq1 = lead(col("seq"), 1).over(window) != col("seq")
is_different_seq2 = lead(col("seq"), 2).over(window) != col("seq")

df = df.withColumn("lead_value",
                   when(is_different_seq1,
                        lead(col("ts"), 1).over(window)
                        )
                   .when(is_different_seq2,
                        lead(col("ts"), 2).over(window)
                    )

                   )

df.printSchema()
df.show(10, False)

Ideal output in column "next_value" for id=4:

id seq ts next_value
4 1 A B
4 2 B C
4 3 C D
4 3 C D
4 3 C D
4 4 D Null


Solution 1:[1]

I haven't tried the more complicated case, so this might still need more adjustment but I think you can combine with last function.

With just the lead function, it results in like this.

id seq ts lead_value
4 1 A B
4 2 B C
4 3 C C
4 3 C C
4 3 C D
4 4 D Null

You want to overwrite the lead_value of 3rd and 4th rows to be "D" which is the last value of the lead_value in the same id&seq group.

lead_window = (Window
    .partitionBy("deviceId")
    .orderBy("seq"))

last_window = (Window
    .partitionBy('deviceId', 'seq')
    .rowsBetween(0, Window.unboundedFollowing)) 

df = df.withColumn('next_value', F.last(
        F.lead(F.col('ts')).over(lead_window)
    ).over(last_window))

Result.

id seq ts next_value
4 1 A B
4 2 B C
4 3 C D
4 3 C D
4 3 C D
4 4 D Null

Solution 2:[2]

I found a solution (horribly slow however), so if someone comes up with a better solution, please add your answer!

I get one row per "message" with a distinct, execute the lead(1) there, and join it back to the dataframe to the rest of the columns.

df_filtered = df.select("id", "seq", "ts").distinct()
df_filtered = df_filtered.withColumn("lead_value", lead(col("ts"), 1).over(window))
df = df.join(df_filtered, on=["id", "seq", "ts"])

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Emma
Solution 2 Cribber