'How to group by a count based on a condition over an aggregated function in Pyspark?
Suppose I build the following example dataset:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime
spark = SparkSession.builder\
.config("spark.driver.memory", "10g")\
.config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
.getOrCreate()
df = spark.createDataFrame(
[
("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
("US", "US_SL_A", datetime(2022, 1, 5), 1.),
("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
("US", "US_SL_B", datetime(2022, 1, 3), 9.),
("US", "US_SL_C", datetime(2022, 1, 1), 1.),
("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
],
schema = ("country", "platform", "timestamp", "size")
)
>> df.show()
+-------+--------+-------------------+----+
|country|platform| timestamp|size|
+-------+--------+-------------------+----+
| US| US_SL_A|2022-01-01 00:00:00| 3.8|
| US| US_SL_A|2022-01-02 00:00:00| 4.3|
| US| US_SL_A|2022-01-03 00:00:00| 4.3|
| US| US_SL_A|2022-01-04 00:00:00|3.95|
| US| US_SL_A|2022-01-05 00:00:00| 1.0|
| US| US_SL_B|2022-01-01 00:00:00| 4.3|
| US| US_SL_B|2022-01-02 00:00:00| 3.8|
| US| US_SL_B|2022-01-03 00:00:00| 9.0|
| US| US_SL_C|2022-01-01 00:00:00| 1.0|
| ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
| ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
| ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
| FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+
My goal is to detect the number of outliers in the size column, but previously grouping by country and platform. For this I want to use the interquartile range as a criterion; that is, I want to count all those sizes whose value is less than 1.5 times the quantile 0.25 minus the interquartile range.
I can get the different quantile parameters and desired threshold per group by doing:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.round(1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)), 2)
).alias("1.5xInterquartile"),
F.percentile_approx("size", 0.25).alias("q1"),
F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
| US| US_SL_A| 0.75|3.8|4.3| 3.05|
| US| US_SL_B| 7.8|3.8|9.0| -4.0|
| US| US_SL_C| 0.0|1.0|1.0| 1.0|
| ES| ES_SL_A| 4.8|1.0|4.2| -3.8|
| FR| FR_SL_A| 0.0|2.0|2.0| 2.0|
| ES| ES_SL_B| 0.0|2.0|2.0| 2.0|
+-------+--------+-----------------+---+---+---------+
But this is not exactly what I want to get. What I would want is, instead of aggregating by interquartiles, to aggregate by a count of the number of rows per group that satisfy the condition of being below the outlier threshold.
Desired output would be something like this:
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| US| US_SL_A| 1 |
| US| US_SL_B| 0 |
| US| US_SL_C| 0 |
| ES| ES_SL_A| 0 |
| FR| FR_SL_A| 0 |
| ES| ES_SL_B| 0 |
+-------+--------+----------+
This is because only (US, US_SL_A) group has one value (1.) below the outlier threshold for such a group
Here's my attempt to achieve that:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.count(
F.when(
F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)),
True
)
)
).alias("n_outliers"),
)
But I get an error, which states:
AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false
Solution 1:[1]
The key here is the use of windows functions previous to the aggregation
import pyspark.sql.window as W
w = W.Window.partitionBy(["country", "platform"])
(df
.withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w)), 2))
.withColumn("q1",F.percentile_approx("size", 0.25).over(w))
.withColumn("q3",F.percentile_approx("size", 0.75).over(w))
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
.groupBy(["country", "platform"])
.agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
.show()
)
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| ES| ES_SL_A| 0|
| ES| ES_SL_B| 0|
| FR| FR_SL_A| 0|
| US| US_SL_A| 1|
| US| US_SL_B| 0|
| US| US_SL_C| 0|
+-------+--------+----------+
Solution 2:[2]
Your count and percentile_approx both need aggregation but looks like the agg on top doesn't take care of those.
You can try using window functions for all of aggregations which will add n_outliers count for each records. Then, later you can use distinct to get only the 1 record per group.
w = Window.partitionBy("country", "platform")
df = (df.withColumn('n_outliers',
F.count(F.when(
F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w))),
1
)).over(w))
.select('country', 'platform', 'n_outliers')
.distinct())
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | David דודו Markovitz |
| Solution 2 |
