Skip to content

Commit d0131bb

Browse files
harshmotw-dbyhuang-db
authored andcommitted
[SPARK-50815][PYTHON] Fix Variant Local Data to Arrow Conversion
### What changes were proposed in this pull request? This PR removes unnecessary code for converting Variants in PySpark from local to arrow representation. This allows createDataFrame and Python Datasources to work seamlessly with Variants. ### Why are the changes needed? [This PR](apache#45826) introduced code to convert Variants from internal representation to representation in Arrow (LocalDataToArrowConversion). However, the internal representation is assumed to be `dict` and the arrow representation is assumed to be `VariantVal` even though it should be the other way around. It appears this code written in the PR is not actually encountered in any tests. This caused `createDataFrame` to not work with Variants and the [attempted fix](apache#49487) added a special case (`variants_as_dicts`) for this code, even though the special case was actually the only use case. This PR removes the old unnecessary code and only keeps the "special case" code as the main code for converting Variant from local (`VariantVal`) to Arrow (`dict`). ### Does this PR introduce _any_ user-facing change? This will allow users to use Python datasources with Variants. ### How was this patch tested? Existing tests should pass, and a new unit test for Python Datasources was added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51082 from harshmotw-db/harsh-motwani_data/experimental_variant_fix. Authored-by: Harsh Motwani <harsh.motwani@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent ad8a991 commit d0131bb

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

python/pyspark/sql/conversion.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def _need_converter(
9595
def _create_converter(
9696
dataType: DataType,
9797
nullable: bool = True,
98-
variants_as_dicts: bool = False, # some code paths may require python internal types
9998
) -> Callable:
10099
assert dataType is not None and isinstance(dataType, DataType)
101100
assert isinstance(nullable, bool)
@@ -117,9 +116,7 @@ def convert_null(value: Any) -> Any:
117116
dedup_field_names = _dedup_names(dataType.names)
118117

119118
field_convs = [
120-
LocalDataToArrowConversion._create_converter(
121-
field.dataType, field.nullable, variants_as_dicts
122-
)
119+
LocalDataToArrowConversion._create_converter(field.dataType, field.nullable)
123120
for field in dataType.fields
124121
]
125122

@@ -161,7 +158,7 @@ def convert_struct(value: Any) -> Any:
161158

162159
elif isinstance(dataType, ArrayType):
163160
element_conv = LocalDataToArrowConversion._create_converter(
164-
dataType.elementType, dataType.containsNull, variants_as_dicts
161+
dataType.elementType, dataType.containsNull
165162
)
166163

167164
def convert_array(value: Any) -> Any:
@@ -178,7 +175,7 @@ def convert_array(value: Any) -> Any:
178175
elif isinstance(dataType, MapType):
179176
key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType)
180177
value_conv = LocalDataToArrowConversion._create_converter(
181-
dataType.valueType, dataType.valueContainsNull, variants_as_dicts
178+
dataType.valueType, dataType.valueContainsNull
182179
)
183180

184181
def convert_map(value: Any) -> Any:
@@ -288,14 +285,7 @@ def convert_variant(value: Any) -> Any:
288285
if not nullable:
289286
raise PySparkValueError(f"input for {dataType} must not be None")
290287
return None
291-
elif (
292-
isinstance(value, dict)
293-
and all(key in value for key in ["value", "metadata"])
294-
and all(isinstance(value[key], bytes) for key in ["value", "metadata"])
295-
and not variants_as_dicts
296-
):
297-
return VariantVal(value["value"], value["metadata"])
298-
elif isinstance(value, VariantVal) and variants_as_dicts:
288+
elif isinstance(value, VariantVal):
299289
return VariantType().toInternal(value)
300290
else:
301291
raise PySparkValueError(errorClass="MALFORMED_VARIANT")
@@ -325,9 +315,7 @@ def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool)
325315
column_names = schema.fieldNames()
326316

327317
column_convs = [
328-
LocalDataToArrowConversion._create_converter(
329-
field.dataType, field.nullable, variants_as_dicts=True
330-
)
318+
LocalDataToArrowConversion._create_converter(field.dataType, field.nullable)
331319
for field in schema.fields
332320
]
333321

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
from pyspark.sql.functions import spark_partition_id
5050
from pyspark.sql.session import SparkSession
51-
from pyspark.sql.types import Row, StructType
51+
from pyspark.sql.types import Row, StructType, VariantVal
5252
from pyspark.testing import assertDataFrameEqual
5353
from pyspark.testing.sqlutils import (
5454
SPARK_HOME,
@@ -88,32 +88,55 @@ def read(self, partition):
8888
def test_data_source_register(self):
8989
class TestReader(DataSourceReader):
9090
def read(self, partition):
91-
yield (0, 1)
91+
yield (
92+
0,
93+
1,
94+
VariantVal.parseJson('{"c":1}'),
95+
{"v": VariantVal.parseJson('{"d":2}')},
96+
[VariantVal.parseJson('{"e":3}')],
97+
{"v1": VariantVal.parseJson('{"f":4}'), "v2": VariantVal.parseJson('{"g":5}')},
98+
)
9299

93100
class TestDataSource(DataSource):
94101
def schema(self):
95-
return "a INT, b INT"
102+
return (
103+
"a INT, b INT, c VARIANT, d STRUCT<v VARIANT>, e ARRAY<VARIANT>,"
104+
"f MAP<STRING, VARIANT>"
105+
)
96106

97107
def reader(self, schema):
98108
return TestReader()
99109

100110
self.spark.dataSource.register(TestDataSource)
101111
df = self.spark.read.format("TestDataSource").load()
102-
assertDataFrameEqual(df, [Row(a=0, b=1)])
112+
assertDataFrameEqual(
113+
df.selectExpr(
114+
"a", "b", "to_json(c) c", "to_json(d.v) d", "to_json(e[0]) e", "to_json(f['v2']) f"
115+
),
116+
[Row(a=0, b=1, c='{"c":1}', d='{"d":2}', e='{"e":3}', f='{"g":5}')],
117+
)
103118

104119
class MyDataSource(TestDataSource):
105120
@classmethod
106121
def name(cls):
107122
return "TestDataSource"
108123

109124
def schema(self):
110-
return "c INT, d INT"
125+
return (
126+
"c INT, d INT, e VARIANT, f STRUCT<v VARIANT>, g ARRAY<VARIANT>,"
127+
"h MAP<STRING, VARIANT>"
128+
)
111129

112130
# Should be able to register the data source with the same name.
113131
self.spark.dataSource.register(MyDataSource)
114132

115133
df = self.spark.read.format("TestDataSource").load()
116-
assertDataFrameEqual(df, [Row(c=0, d=1)])
134+
assertDataFrameEqual(
135+
df.selectExpr(
136+
"c", "d", "to_json(e) e", "to_json(f.v) f", "to_json(g[0]) g", "to_json(h['v2']) h"
137+
),
138+
[Row(c=0, d=1, e='{"c":1}', f='{"d":2}', g='{"e":3}', h='{"g":5}')],
139+
)
117140

118141
def register_data_source(
119142
self,

0 commit comments

Comments
 (0)