'Pyspark : Return all column names of max values

I have a DataFrame like this :

from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType, StringType

#import numpy as np

data = [(("ID1", 3, 5,5)), (("ID2", 4, 5,6)), (("ID3", 3, 3,3))]
df = sqlContext.createDataFrame(data, ["ID", "colA", "colB","colC"])
df.show()

cols = df.columns
maxcol = f.udf(lambda row: cols[row.index(max(row)) +1], StringType())


maxDF = df.withColumn("Max_col", maxcol(f.struct([df[x] for x in df.columns[1:]])))
maxDF.show(truncate=False)

+---+----+----+----+
| ID|colA|colB|colC|
+---+----+----+----+
|ID1|   3|   5|   5|
|ID2|   4|   5|   6|
|ID3|   3|   3|   3|
+---+----+----+----+

+---+----+----+----+-------+
|ID |colA|colB|colC|Max_col|
+---+----+----+----+-------+
|ID1|3   |5   |5   |colB   |
|ID2|4   |5   |6   |colC   |
|ID3|3   |3   |3   |colA   |
+---+----+----+----+-------+

I want to return all column names of max values in case there are ties, how can I achieve this in pyspark like this :

+---+----+----+----+--------------+
|ID |colA|colB|colC|Max_col       |
+---+----+----+----+--------------+
|ID1|3   |5   |5   |colB,colC     |
|ID2|4   |5   |6   |colC          |
|ID3|3   |3   |3   |colA,ColB,ColC|
+---+----+----+----+--------------+

Thank you



Solution 1:[1]

Seems like a udf solution. iterate over the columns you have (pass them as an input to the class) and perform a python operations to get the max and check who has the same value. return a list (aka array) of the column names.

@udf(returnType=ArrayType(StringType()))
def collect_same_max():
...

Or, maybe if it doable you can try use the transform function from https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.transform.html

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Benny Elgazar