'Using sum in withColumn pyspark / python

I need your help in a small issue, which I actually solved in a manual way, but I would like to get it "right"

Data:

customer_id gender
abc m
def w

etc.

Now it gets aggreaget the following way:

gender_count = data.groupBy('gender').agg(F.countDistinct('customer_id').alias('amount'))

gender_count:

gender amount
m 4
w 6

Now, I would like to creat a new column with the total number of customers in order to compute the share of the individual genders. However, I could not find a function that works with "withColumn". So what I do is to sum up the number of customers before and insert it as a literal value:

gender_count = gender_count.withColumn('total', F.lit(10)).withColumn('share', (F.col('amount') / F.col('total')))

Result:

gender amount total share
m 4 10 0.4
w 6 10 0.6

Do you have any idea how i could replace the F.lit(10)? That would save me one manual step and a potential source of error.

Thank you!



Solution 1:[1]

You can use count as a window function before aggregation to count all records.

from pyspark.sql import functions as F
from pyspark.sql import Window as W

(df
    .withColumn('count', F.count('*').over(W.partitionBy(F.lit(1))))
    .groupBy('gender', 'count')
    .agg(F.countDistinct('customer_id').alias('amount'))
    .withColumn('share', F.col('amount') / F.col('count'))
    .show()
)

+------+-----+------+-----+
|gender|count|amount|share|
+------+-----+------+-----+
|     m|   10|     4|  0.4|
|     w|   10|     6|  0.6|
+------+-----+------+-----+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 pltc