'How to apply a function on a sequential data within group in spark?
I have a custom function which is depended on the order of the data. I want to apply this function for each group in spark in parallel (parallel groups). How can I do?
For example,
public ArrayList<Integer> my_logic(ArrayList<Integer> glist) {
Boolean b = true;
ArrayList<Integer> result = new ArrayList<>();
for (int i=1; i<glist.size();I++) { // Size is around 30000
If b && glist[i-1] > glist[i] {
// some logic then set b to false
result.add(glist[i]);
} else {
// some logic then set b to true
}
}
return result;
}
My data,
Col1 Col2
a 1
b 2
a 3
c 4
c 3
…. ….
I want something similar to below
df.group_by(col(“Col1”)).apply(my_logic(col(“Col2”)));
// output
a [1,3,5…]
b [2,5,8…]
…. ….
Solution 1:[1]
In Spark, you can use Window Aggregate Functions directly, I will show that here in Scala.
Here is your input data (my preparation):
import scala.collection.JavaConversions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
val schema = StructType(
StructField("Col1", StringType, false) ::
StructField("Col2", IntegerType, false) :: Nil
)
val row = Seq(Row("a", 1),Row("b", 8),Row("b", 2),Row("a", 5),Row("b", 5),Row("a", 3))
val df = spark.createDataFrame(row, schema)
df.show(false)
//input:
// +----+----+
// |Col1|Col2|
// +----+----+
// |a |1 |
// |b |8 |
// |b |2 |
// |a |5 |
// |b |5 |
// |a |3 |
// +----+----+
Here is the code to obtain desired logic :
import org.apache.spark.sql.expressions.Window
df
// NEWCOLUMN: EVALUATE/CREATE LIST OF VALUES FOR EACH RECORD OVER THE WINDOW AS FRAME MOVES
.withColumn(
"collected_list",
collect_list(col("Col2")) over Window
.partitionBy(col("Col1"))
.orderBy(col("Col2"))
)
// NEWCOLUMN: MAX SIZE OF COLLECTED LIST IN EACH WINDOW
.withColumn(
"max_size",
max(size(col("collected_list"))) over Window.partitionBy(col("Col1"))
)
// FILTER TO GET ONLY HIGHEST SIZED ARRAY ROW
.where(col("max_size") - size(col("collected_list")) === 0)
.orderBy(col("Col1"))
.drop("Col2", "max_size")
.show(false)
// output:
// +----+--------------+
// |Col1|collected_list|
// +----+--------------+
// |a |[1, 3, 5] |
// |b |[2, 5, 8] |
// +----+--------------+
Note:
- you can just use collect_list() Aggregate function with groupBy directly but, you can not get the collection list ordered.
- collect_set() Aggregate function you can explore if you want to eliminate duplicates (with some changes to the above query).
EDIT 2 : You can write your custom collect_list() as a UDAF (UserDefinedAggregateFunction) like this in Scala Spark for DataFrames
Online Docs
Below Code Spark Version == 2.3.0
object Your_Collect_Array extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("yourInputToAggFunction", LongType, false) :: Nil
)
override def dataType: ArrayType = ArrayType(LongType, false)
override def deterministic: Boolean = true
override def bufferSchema: StructType = {
StructType(
StructField("yourCollectedArray", ArrayType(LongType, false), false) :: Nil
)
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = new Array[Long](0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(
0,
buffer.getAs[mutable.WrappedArray[Long]](0) :+ input.getLong(0)
)
}
override def merge(
buffer1: MutableAggregationBuffer,
buffer2: Row
): Unit = {
buffer1.update(
0,
buffer1.getAs[mutable.WrappedArray[Long]](0) ++ buffer2
.getAs[mutable.WrappedArray[Long]](0)
)
}
override def evaluate(buffer: Row): Any =
buffer.getAs[mutable.WrappedArray[Long]](0)
}
//Below is the query with just one line change i.e., calling above written custom udf
df
// NEWCOLUMN : USING OUR CUSTOM UDF
.withColumn(
"your_collected_list",
Your_Collect_Array(col("Col2")) over Window
.partitionBy(col("Col1"))
.orderBy(col("Col2"))
)
// NEWCOLUMN: MAX SIZE OF COLLECTED LIST IN EACH WINDOW
.withColumn(
"max_size",
max(size(col("your_collected_list"))) over Window.partitionBy(col("Col1"))
)
// FILTER TO GET ONLY HIGHEST SIZED ARRAY ROW
.where(col("max_size") - size(col("your_collected_list")) === 0)
.orderBy(col("Col1"))
.drop("Col2", "max_size")
.show(false)
//Output:
// +----+-------------------+
// |Col1|your_collected_list|
// +----+-------------------+
// |a |[1, 3, 5] |
// |b |[2, 5, 8] |
// +----+-------------------+
Note:
- UDFs are not that efficient in spark hence, use them only when you absolutely need them. They are mainly focused for data analytics.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
