'Can PySpark ML models be run on only parts of a dataframe, depending on a condition?

I have trained a logistic regression algorithm to match job titles and descriptions to a set of 4 digit numeric codes. This it does very well. It will form part of a pipeline that first attempts to match these data by joining to a reference database, which leaves some entries of the dataframe matched to a 4 digit code, and some left with a dummy code indicating they are still to be matched. Therefore, the state of my dataframe just prior to my running the ML algorithm is

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

data = spark.createDataFrame(
    [
        ('1', 'BUTCHER', 'MEAT PERSON', '-1'),
        ('2', 'BAKER', 'BREAD AND PASTRY AND CAKE', '1468'),
        ('3', 'CANDLESTICK MAKER', 'LET THERE BE LIGHT', '-1')
    ],
    [
        'ID',
        'COLUMN_TO_VECTORIZE_1',
        'COLUMN_TO_VECTORIZE_2',
        'RESULTS_COLUMN'
]
)

where '-1' is the dummy code for 'as yet unmatched' and 'BAKER' has been matched to '1468' already.

If I wanted to match the whole dataframe, I would go on to write

data = pretrained_model_pipeline.transform(data) # vectorizes/assembles feature column, runs ML algorithm

# other code to perform index to string conversion on labels, and place labels into RESULTS_COLUMN

But that runs the whole dataframe through the algorithm, and anything previously matched does not need to be matched again.

I COULD create a new dataframe from the old one, selecting just those rows with '-1' in the 'RESULTS_COLUMN', but my limited understanding of Spark says that this would essentially double my memory usage.

Is there a way for the pretrained model to be given the whole dataframe to transform, but with some sort of mask telling it to skip rows with '-1' in the results?



Solution 1:[1]

Spark doesn't always load data into memory. It also doesn't necessarily load everything into memory, unless we tell it to. So, just this is enough

data = pretrained_model_pipeline.transform(data.where(F.col('RESULTS_COLUMN') != -1))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 pltc