'Using F.lit() in parametrize or as a default value throws a none type error
The following code works fine from the pyspark interpreter
spark_utils.py
--------------
from typing import List, Optional
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
def add_new_cols_from_list(
df: DataFrame, columns: List[str], default_value: Optional[F.column] = F.lit(0.0)
) -> DataFrame:
return df.select(*df.columns, *[default_value.alias(column) for column in columns])
test.py
--------------
from spark_utils import add_new_cols_from_list
from pyspark.sql import Row
from pyspark.sql import functions as F
spark = spark_session
sdf = spark.createDataFrame(
[{"id": 1, "a": 1}, {"id": 2, "a": 2}],
schema=StructType(
[
StructField("id", IntegerType()),
StructField("a", IntegerType()),
]
),
)
result = add_new_cols_from_list(sdf, ["b", "c"])
[Row(b=0.0, c=0.0), Row(b=0.0, c=0.0)] == result.select("b", "c").collect()
but if you run pytest test.py::test_add_new_cols_from_list
def test_add_new_cols_from_list(spark_session, column_value, expected):
spark = spark_session
sdf = spark.createDataFrame(
[{"id": 1, "a": 1}, {"id": 2, "a": 2}],
schema=StructType(
[
StructField("id", IntegerType()),
StructField("a", IntegerType()),
]
),
)
result = add_new_cols_from_list(sdf, ["b", "c"])
assert [Row(b=0.0, c=0.0), Row(b=0.0, c=0.0)] == result.select("b", "c").collect()
you get an error.
test_spark_utils.py:6: in <module>
from spark_utils import add_new_cols_from_list
nht/io/spark/spark_utils.py:74: in <module>
df: DataFrame, columns: List[str], default_value: Optional[F.column] = F.lit(0.0)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:98: in lit
return col if isinstance(col, Column) else _invoke_function("lit", col)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:57: in _invoke_function
jf = _get_get_jvm_function(name, SparkContext._active_spark_context)
/opt/conda/lib/python3.8/site-packages/pyspark/sql/functions.py:49: in _get_get_jvm_function
return getattr(sc._jvm.functions, name)
E AttributeError: 'NoneType' object has no attribute '_jvm'
you get the same error if you take out the default but run the test parameterized (running the test with no default and no parameterization does not throw an error):
def add_new_cols_from_list(
df: DataFrame, columns: List[str], default_value:F.column) -> DataFrame:
return df.select(*df.columns, *[default_value.alias(column) for column in columns])
@pytest.mark.parameterize(
"column_value, expected",
[
(F.lit(1.0), [Row(b=1.0, c=1.0), Row(b=1.0, c=1.0)]),
(F.when(F.col("a") == 2, 0.5).otherwise(F.lit(6.0)), [Row(b=6.0, c=6.0), Row(b=0.5, c=0.5)]),
],
)
def test_add_new_cols_from_list(spark_session, column_value, expected):
spark = spark_session
sdf = spark.createDataFrame(
[{"id": 1, "a": 1}, {"id": 2, "a": 2}],
schema=StructType(
[
StructField("id", IntegerType()),
StructField("a", IntegerType()),
]
),
)
result = add_new_cols_from_list(sdf, ["b", "c"], column_value)
assert expected == result.select("b", "c").collect()
I don't understand why the code doesn't work in pytest and what NoneType the error is refererring to.
Solution 1:[1]
From this post https://stackoverflow.com/a/48250353/7186374 I figured out that you can't use pyspark functions before there is a spark context which happens when you want to use a default or in a parameterization (both are evaluated on load). I decided to just not have a default but if you wanted one you can do
def add_new_cols_from_list(
df: DataFrame, columns: List[str], default_value: Optional[F.column] = None
) -> DataFrame:
if default_value is None:
default_value = F.lit(0.0)
return df.select(*df.columns, *[default_value.alias(column) for column in columns])
to parameterize you can use your spark_session fixture as an parameter in another fixture so you can do this:
@pytest.fixture
def result(spark_session, request):
spark = spark_session
sdf = spark.createDataFrame(
[{"id": 1, "a": 1}, {"id": 2, "a": 2}],
schema=StructType(
[
StructField("id", IntegerType()),
StructField("a", IntegerType()),
]
),
)
if request.param == "literal":
result = add_new_cols_from_list(sdf, ["b", "c"], F.lit(1.0))
elif request.param == "when":
result = add_new_cols_from_list(sdf, ["b", "c"], F.when(F.col("a") == 2, 0.5).otherwise(F.lit(6.0)))
else:
result = None
return result
@pytest.mark.parametrize(
"result, expected",
[
("literal", [Row(b=1.0, c=1.0), Row(b=1.0, c=1.0)]),
("when", [Row(b=6.0, c=6.0), Row(b=0.5, c=0.5)]),
],
indirect=["result"],
)
def test_add_new_cols_from_list(result, expected):
assert expected == result.select("b", "c").collect()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | CLedbetter |
