Skip to content

Commit b923243

Browse files
austinrwarneryhuang-db
authored andcommitted
[SPARK-52355][PYTHON] Infer VariantVal object type as VariantType when creating a DataFrame
### What changes were proposed in this pull request? When creating a `DataFrame` from Python using `spark.createDataFrame`, infer the type of any `VariantVal` objects as `VariantType`. This is implemented by adding a case mapping `VariantVal` to `VariantType` in the `pyspark.sql.types._infer_type` function. ### Why are the changes needed? Currently, when creating a `DataFrame` that includes locally-instantiated `VariantVal` objects in Python, the type is inferred as `struct<metadata:binary,value:binary>` rather than `VariantType`. This leads to unintended behavior when creating a `DataFrame` locally, or in certain situations like `df.rdd.map(...).toDF` which call `createDataFrame` under the hood. The bug only occurs when the schema of the `DataFrame` is not passed explicitly. ### Does this PR introduce _any_ user-facing change? Yes, fixes the bug described above. ### How was this patch tested? Added a test in `python/pyspark/sql/tests/test_types.py` that checks the inferred type is `VariantType`, as well as ensuring the `VariantVal` has the correct `value` and `metadata` after inference. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51065 from austinrwarner/SPARK-52355. Authored-by: Austin Warner <austin.richard.warner@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent ecb181e commit b923243

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python/pyspark/sql/tests/test_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,25 @@ def test_infer_map_pair_type_with_nested_maps(self):
477477
df.first(),
478478
)
479479

480+
def test_infer_variant_type(self):
481+
# SPARK-52355: Test inferring variant type
482+
value = VariantVal.parseJson('{"a": 1}')
483+
484+
data = [Row(f1=value)]
485+
df = self.spark.createDataFrame(data)
486+
actual = df.first()["f1"]
487+
488+
self.assertEqual(type(df.schema["f1"].dataType), VariantType)
489+
# As of writing VariantVal can also include bytearray
490+
self.assertEqual(
491+
bytes(actual.value),
492+
bytes(value.value),
493+
)
494+
self.assertEqual(
495+
bytes(actual.metadata),
496+
bytes(value.metadata),
497+
)
498+
480499
def test_create_dataframe_from_dict_respects_schema(self):
481500
df = self.spark.createDataFrame([{"a": 1}], ["b"])
482501
self.assertEqual(df.columns, ["b"])

python/pyspark/sql/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,6 +2307,8 @@ def _infer_type(
23072307
errorClass="UNSUPPORTED_DATA_TYPE",
23082308
messageParameters={"data_type": f"array({obj.typecode})"},
23092309
)
2310+
elif isinstance(obj, VariantVal):
2311+
return VariantType()
23102312
else:
23112313
try:
23122314
return _infer_schema(

0 commit comments

Comments
 (0)