'spark sql Find the number of extensions for a record

I have a dataset as below

col1 extension_col1
2345 2246
2246 2134
2134 2091
2091 Null
1234 1111
1111 Null

I need to find the number of extensions available for each record in col1 whereby records are sorted already and contiguously in terms of sets which are terminated by a null.

the final result as below

col1 extension_col1 No_Of_Extensions
2345 2246 3
2246 2134 2
2134 2091 1
2091 Null 0
1234 1111 1
1111 Null 0

value 2345 extends as 2345>2246>2134>2091>null and hence it has 3 extension relations excluding null.

How to get the 3rd column(No_Of_Extensions) using spark sql/scala?



Solution 1:[1]

You can achieve that using some Window functions. First, using a cumulative conditional sum on extension_col1, create a group column grp. Then, using row_number function on a window partitioned by grp and ordered by col1 but this time on ascending you get the desired result:

import org.apache.spark.sql.expressions.Window

val df = Seq(
  (Some(99985), Some(94904)), (Some(94904), Some(89884)),
  (Some(89884), Some(88592)), (Some(88592), Some(86367)),
  (Some(86367), Some(84121)), (Some(84121), None)
).toDF("col1", "extension_col1")

val w1 = Window.orderBy(desc("col1"))
val w2 = Window.partitionBy("grp").orderBy("col1")

val result = df.withColumn(
    "grp",
    sum(when(col("extension_col1").isNull, 1).otherwise(0)).over(w1)
).withColumn(
    "No_Of_Extensions",
    when(col("extension_col1").isNull, 0).otherwise(row_number().over(w2))
).drop("grp").orderBy(desc("col1"))

result.show
                        
//+-----+--------------+----------------+
//| col1|extension_col1|No_Of_Extensions|
//+-----+--------------+----------------+
//|99985|         94904|               5|
//|94904|         89884|               4|
//|89884|         88592|               3|
//|88592|         86367|               2|
//|86367|         84121|               1|
//|84121|          null|               0|
//+-----+--------------+----------------+

Note that the first sum is using a non partitioned window, so all the data will be moved into one partition and thus could affect performances.


Spark-SQL equivalent query:

SELECT col1, 
       extension_col1, 
       case when extension_col1 is null then 0 else row_number() over(partition by grp order by col1) end as No_Of_Extensions
FROM  (
      SELECT *, 
             sum(case when extension_col1 is null then 1 else 0 end) over(order by col1 desc) as grp
      FROM df
)
ORDER BY col1 desc

Solution 2:[2]

An alternative to the blackbishop in that I assume the data may not always be ordered, and hence do some alternative processing. I like the conditional summing, but not applicable here.

In all honesty a bad use case for Spark at scale as I can also not get around single partition aspect either as the other answer states. But partitioning size is increased on newer Spark versions and may the ' lists' are long in this example.

Part 1 - Generate data

// 1. Generate data.
val df = Seq(( Some(2345), Some(22246) ), ( Some(22246), Some(2134) ), ( Some(2134), Some(2091) ), (Some(2091), None) ,
              ( Some(1234), Some(1111) ), ( Some(1111), None )
             ).toDF("col1" ,"extCol1")

Part 2 - Actual processing

//2. Narrow transform, add position in dataset as values nay not awlays be desc or asc.
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField,StructType,IntegerType, ArrayType, LongType}
val newSchema = StructType(df.schema.fields ++ Array(StructField("rowid", LongType, false)))
val rdd = df.rdd.zipWithIndex
val df2 = spark.createDataFrame(rdd.map{ case (row, index) => Row.fromSeq(row.toSeq ++ Array(index))}, newSchema)  // Some cost


//3. Make groupings in record ranges. Cannot avoid the single partition aspects, so this only works if we can do it with data that can fit into a single partition. At scale one would 
//   not be able to do this really unless some grouping characteristic. 
val dfg = df2.filter(df2("extCol1").isNull)

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val winSpec1 = Window.orderBy(asc("rowid"))

val dfg2 = dfg.withColumn("prev_rowid_tmp", lag("rowid", 1, -1).over(winSpec1))
              .withColumn("rowidFrom", $"prev_rowid_tmp" + 1)
              .drop("prev_rowid_tmp")
              .drop("extCol1")
              .withColumnRenamed("rowid","rowidTo")

//4. Apply grouping of ranges of rows to data.
val df3 = df2.as("df2").join(dfg2.as("dfg2"), 
          $"df2.rowid" >= $"dfg2.rowidFrom" && $"df2.rowid" <= $"dfg2.rowidTo", "inner")             

//5. Do the calcs.
val res = df3.withColumn("numExtensions", $"rowidTo" - $"rowid") 
res.select("df2.col1", "extCol1", "numExtensions").show(false)

returns:

+-----+-------+-------------+
|col1 |extCol1|numExtensions|
+-----+-------+-------------+
|2345 |22246  |3            |
|22246|2134   |2            |
|2134 |2091   |1            |
|2091 |null   |0            |
|1234 |1111   |1            |
|1111 |null   |0            |
+-----+-------+-------------+

Solution 3:[3]

For this scenario where the data table’s first column is already ordered, create a new group each time when value of the second column in the previous record is null, and in each group, add a number column according to the specific requirements. It is a great hassle to try to achieve the process in SQL. You need to first create row numbers and the marker column as needed, and perform the grouping according to the marker column and row numbers. A common alternative is to fetch the original data out of the database and process it in Python or SPL. SPL, the open-source Java package, is easier to be integrated into a Java program and generate much simpler code. It expresses the algorithm with only two lines of code:

A
1 =MYSQL.query("select * from t4")
2 =A1.group@i(#2[-1]==null).run(len=~.len(),~=~.derive(len-#:No_Of_Extensions)).conj()

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2
Solution 3