'Function to select max per row

I have a dataset of videogames that includes all their sales per region (NA, EU, JP, Other) in each column.

Game NA_Sales EU sales JP Sales Other Sales
Wii Sports 10 5 8 2
Mario Kart 5 3 8 1

I want to create a function that will iterate each row and return the max value for each game. So when I run the UDF function to create a new column, it will return me 10 for Wii Sports and 8 for Mario Kart.

Any comment or help is highly appreciated.



Solution 1:[1]

For a UDF free solution, you can find the maximum sales across columns using greatest and then apply when to find the column containing this value.

from pyspark.sql import functions as F
from pyspark.sql import Column
from typing import List

data = [("Wii Sports", 10, 5, 8, 2,),
        ("Mario Kart", 5, 3, 8, 1,), ]

df = spark.createDataFrame(data, ("Game", "NA_Sales", "EU sales", "JP Sales", "Other Sales",))

def find_region_max_sales(cols: List[str]) -> Column:
    max_sales = F.greatest(*[F.col(c) for c in cols])
    max_col_expr = F
    for c in cols:
        max_col_expr = max_col_expr.when(F.col(c) == max_sales, c)
    return max_col_expr

df.withColumn("region_with_maximum_sales", find_region_max_sales(metric_cols)).show()

"""
+----------+--------+--------+--------+-----------+-------------------------+
|      Game|NA_Sales|EU sales|JP Sales|Other Sales|region_with_maximum_sales|
+----------+--------+--------+--------+-----------+-------------------------+
|Wii Sports|      10|       5|       8|          2|                 NA_Sales|
|Mario Kart|       5|       3|       8|          1|                 JP Sales|
+----------+--------+--------+--------+-----------+-------------------------+
"""

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Nithish