'ValueError when applying pandas_udf to grouped spark Dataframe

Sample PySpark Dataframe: join_df

+----------+----------+-------+---------+----------+----------+
|        ID|        ds|      y|     yhat|yhat_upper|yhat_lower|
+----------+----------+-------+---------+----------+----------+
|    Ax849b|2021-07-01|1165.59| 1298.809| 1939.1261| 687.48206|
|    Ax849b|2021-07-02|1120.69| 1295.552| 1892.4929|   693.786|
|    Ax849b|2021-07-03|1120.69| 1294.079| 1923.0253|  664.1514|
|    Ax849b|2021-07-04|1120.69|1295.0399| 1947.6392|  639.4879|
|    Bz383J|2021-07-03|1108.71|1159.4934| 1917.6515| 652.76624|
|    Bz383J|2021-07-04|1062.77|1191.2385| 1891.9268|  665.9529|
+----------+----------+-------+---------+----------+----------+

  • y - real value
  • yhat - predicted value
final_schema =StructType([
  StructField('ID',IntegerType()),
  StructField('ds',DateType()),
  StructField('y',FloatType()),
  StructField('yhat',FloatType()),
  StructField('yhat_upper',FloatType()),
  StructField('yhat_lower',FloatType()),
  StructField('mape',FloatType())
  ])

I have created an udf and applied it on IDs using applyInPandas function.

from sklearn.metrics import mean_absolute_percentage_error
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf(final_schema, PandasUDFType.GROUPED_MAP)
def gr_mape_val(join_df):
  
  mape = mean_absolute_percentage_error(join_df["y"], join_df["yhat"]) 
  join_df['mape'] = mape
  
  return join_df

df_apply = join_df.groupby('ID').applyInPandas(gr_mape_val, final_schema)
df_apply.show()

I made multiple efforts but still getting this error: ValueError

ValueError: Invalid function: pandas_udf with function type GROUPED_MAP or the function in groupby.applyInPandas must take either one argument (data) or two arguments (key, data).

I wonder, do I need to have pandas df as an argument in gr_mape_val() or it will spark df (as it is). Couldn't figure out what am I doing wrong in this.



Solution 1:[1]

You don't need @pandas_udf when you use applyInPandas. Just comment that line and you'll be fine.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1