'PySpark dataframe How to filter data? [closed]
I have a dataframe with department, item id and count of those ids. There are 127 departments and I want to get the top 10 items for each department and list them. That means based on the item count, I want to list the top 10 items for each each department separately. I have been trying to do this using groupBy and agg.max but was not able to. Example of the dataframe is listed below.
| Department | Item id | count |
|---|---|---|
| A | 101 | 10 |
| B | 102 | 5 |
| A | 104 | 12 |
| C | 101 | 5 |
| D | 104 | 14 |
| C | 108 | 10 |
Solution 1:[1]
The solution is based on the row_number() windows function.
- In this demo I returned the top 3 records per department. Feel free to change it to 10.
qualifyis new to Spark SQL. If your Spark version doesn't support it, then wrapping query is needed and the filter will done using WHERE clause on the outer query.- I added the
Item idto the ORDER BY in order to breakcountties in a deterministic way.
Data Sample Creation
df = spark.sql('''select char(ascii('A') + d.i) as Department, 100 + i.i as `Item id`, int(rand()*100) as count from range(3) as d(i), range(7) as i(i) order by 1,3 desc''')
df.show(999)
+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| A| 100| 52|
| A| 105| 50|
| A| 102| 40|
| A| 101| 30|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| B| 105| 66|
| B| 102| 48|
| B| 100| 32|
| B| 103| 14|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
| C| 104| 82|
| C| 100| 9|
| C| 101| 6|
| C| 106| 3|
+----------+-------+-----+
Spark SQL Solution
df.createOrReplaceTempView('t')
sql_query = '''
select *
from t
qualify row_number() over (partition by Department order by count desc, `Item id`) <= 3
'''
spark.sql(sql_query).show(999)
+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
+----------+-------+-----+
pyspark Solution
import pyspark.sql.functions as F
import pyspark.sql.window as W
(df.withColumn('rn', F.row_number().over(W.Window.partitionBy('Department').orderBy(df['count'].desc(),df['Item id'])))
.where('rn <= 3')
.drop('rn')
.show(999)
)
+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
+----------+-------+-----+
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | David דודו Markovitz |
