'Using F.lit() in parametrize or as a default value throws a none type error

The following code works fine from the pyspark interpreter

spark_utils.py
--------------
from typing import List, Optional
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

def add_new_cols_from_list(
    df: DataFrame, columns: List[str], default_value: Optional[F.column] = F.lit(0.0)
) -> DataFrame:
    return df.select(*df.columns, *[default_value.alias(column) for column in columns])

test.py
--------------
from spark_utils import add_new_cols_from_list
from pyspark.sql import Row
from pyspark.sql import functions as F

spark = spark_session
sdf = spark.createDataFrame(
    [{"id": 1, "a": 1}, {"id": 2, "a": 2}],
    schema=StructType(
        [
            StructField("id", IntegerType()),
            StructField("a", IntegerType()),
        ]
    ),
)

result = add_new_cols_from_list(sdf, ["b", "c"])
[Row(b=0.0, c=0.0), Row(b=0.0, c=0.0)] == result.select("b", "c").collect()

but if you run pytest test.py::test_add_new_cols_from_list

def test_add_new_cols_from_list(spark_session, column_value, expected):
    spark = spark_session
    sdf = spark.createDataFrame(
        [{"id": 1, "a": 1}, {"id": 2, "a": 2}],
        schema=StructType(
            [
                StructField("id", IntegerType()),
                StructField("a", IntegerType()),
            ]
        ),
    )

    result = add_new_cols_from_list(sdf, ["b", "c"])
    assert [Row(b=0.0, c=0.0), Row(b=0.0, c=0.0)] == result.select("b", "c").collect()

you get an error.

test_spark_utils.py:6: in <module>
    from spark_utils import add_new_cols_from_list
nht/io/spark/spark_utils.py:74: in <module>
    df: DataFrame, columns: List[str], default_value: Optional[F.column] = F.lit(0.0)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:98: in lit
    return col if isinstance(col, Column) else _invoke_function("lit", col)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:57: in _invoke_function
    jf = _get_get_jvm_function(name, SparkContext._active_spark_context)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:49: in _get_get_jvm_function
    return getattr(sc._jvm.functions, name)
E   AttributeError: 'NoneType' object has no attribute '_jvm'

you get the same error if you take out the default but run the test parameterized (running the test with no default and no parameterization does not throw an error):

def add_new_cols_from_list(
    df: DataFrame, columns: List[str], default_value:F.column) -> DataFrame:
    return df.select(*df.columns, *[default_value.alias(column) for column in columns])
@pytest.mark.parameterize(
    "column_value, expected",
    [
        (F.lit(1.0), [Row(b=1.0, c=1.0), Row(b=1.0, c=1.0)]),
        (F.when(F.col("a") == 2, 0.5).otherwise(F.lit(6.0)), [Row(b=6.0, c=6.0), Row(b=0.5, c=0.5)]),
    ],
)
def test_add_new_cols_from_list(spark_session, column_value, expected):
    spark = spark_session
    sdf = spark.createDataFrame(
        [{"id": 1, "a": 1}, {"id": 2, "a": 2}],
        schema=StructType(
            [
                StructField("id", IntegerType()),
                StructField("a", IntegerType()),
            ]
        ),
    )

    result = add_new_cols_from_list(sdf, ["b", "c"], column_value)
    assert expected == result.select("b", "c").collect()

I don't understand why the code doesn't work in pytest and what NoneType the error is refererring to.



Solution 1:[1]

From this post https://stackoverflow.com/a/48250353/7186374 I figured out that you can't use pyspark functions before there is a spark context which happens when you want to use a default or in a parameterization (both are evaluated on load). I decided to just not have a default but if you wanted one you can do

def add_new_cols_from_list(
    df: DataFrame, columns: List[str], default_value: Optional[F.column] = None
) -> DataFrame:
    if default_value is None:
        default_value = F.lit(0.0)
    return df.select(*df.columns, *[default_value.alias(column) for column in columns])

to parameterize you can use your spark_session fixture as an parameter in another fixture so you can do this:

@pytest.fixture
def result(spark_session, request):
    spark = spark_session
    sdf = spark.createDataFrame(
        [{"id": 1, "a": 1}, {"id": 2, "a": 2}],
        schema=StructType(
            [
                StructField("id", IntegerType()),
                StructField("a", IntegerType()),
            ]
        ),
    )
    if request.param == "literal":
        result = add_new_cols_from_list(sdf, ["b", "c"], F.lit(1.0))
    elif request.param == "when":
        result = add_new_cols_from_list(sdf, ["b", "c"], F.when(F.col("a") == 2, 0.5).otherwise(F.lit(6.0)))
    else:
        result = None
    return result


@pytest.mark.parametrize(
    "result, expected",
    [
        ("literal", [Row(b=1.0, c=1.0), Row(b=1.0, c=1.0)]),
        ("when", [Row(b=6.0, c=6.0), Row(b=0.5, c=0.5)]),
    ],
    indirect=["result"],
)
def test_add_new_cols_from_list(result, expected):
    assert expected == result.select("b", "c").collect()

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 CLedbetter