'Exception when trying to use saved Spark ML model for distributed computations with SHAP

I am having this error

Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

and

PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

I have a saved LightGBM spark model which I want to use with SHAP package to get explanations of my predictions.

# loading LightGBM model
from synapse.ml.lightgbm import LightGBMClassificationModel
import shap

loaded_model = LightGBMClassificationModel.loadNativeModelFromFile(model_path) 

# Custom predict_proba method

assembler = VectorAssembler(handleInvalid="keep",
                                        inputCols=features,
                                        outputCol="features")

def spark_to_pandas(X):
        return spark.createDataFrame(X)
        
def predict_proba(X):
    
    sdf = assembler.transform(spark_to_pandas(X).select(*features))
    getNegative = F.udf(lambda x: float(x[0]), FloatType())
    getPositive = F.udf(lambda x: float(x[1]), FloatType())

    predictions = (
        loaded_model.transform(sdf)
        .select("probability")
        .withColumn("0", getNegative(F.col("probability")))
        .withColumn("1", getPositive(F.col("probability")))
        .select("0", "1")
    )
    return predictions.toPandas()

# Creating SHAP explainer

explainer = shap.KernelExplainer(model=predict_proba, 
                                 data=data_explainer.select(*features).limit(100).toPandas())

# Trying to get explanations

def explain_df(explainer, df):
    return [e.tolist() for e in explainer.shap_values(df)[1]]

explain = F.udf(lambda x: list(explain_df(explainer, x)), ArrayType(ArrayType(DoubleType())))

gr = datap.groupBy('_partition').agg(F.collect_list('features').alias('features'))

The last action doesn't work apparently because I am trying to use a spark model. This method works with sklearn models.



Solution 1:[1]

I ran into the same issue and discovered that Synapse ML has its own explainers built in. The documentation is severely lacking, however. The alternative if you really need to use the shap package is to re-build the model with the lightgbm package—you'd just need to convert the Synapse ML LightGBMClassifier hyperparameters to the format accepted by lightgbm.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Will Wright