Skip to content

Commit 4325d3a

Browse files
chenhao-dbcloud-fan
authored andcommitted
[SPARK-52153][SQL] Fix from_json and to_json with variant
### What changes were proposed in this pull request? It fixes two minor issues with `from_json(variant)` and `to_json(variant)`. - `from_json(variant)` currently ignores any JSON options. This is inconsistent with `from_json(nested type containing variant)`, which respects JSON options. - `to_json(variant)`, when the variant contains special floating-point values (Infinity, NaN), the output is currently not wrapped in quotes. This is inconsistent with `to_json(nested type containing floating points)`. - For example, the result of `to_json(named_struct('a', cast('NaN' as double)))` is `{"a":"NaN"}`, while the result of `to_json(to_variant_object(named_struct('a', cast('NaN' as double))))` is `{"a":NaN}` - Although `{"a":NaN}` can be parsed by `from_json` when the `allowNonNumericNumbers` option is true, it is still not a valid JSON according to the spec. `to_json` should produce valid JSON. ### Why are the changes needed? This makes variant-related JSON handling more consistent with non-variant JSON handling. ### Does this PR introduce _any_ user-facing change? Yes, as stated above. ### How was this patch tested? Unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50901 from chenhao-db/fix_variant_from_to_json. Authored-by: Chenhao Li <chenhao.li@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 451a7a6 commit 4325d3a

File tree

4 files changed

+62
-30
lines changed

4 files changed

+62
-30
lines changed

common/variant/src/main/java/org/apache/spark/types/variant/Variant.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb,
316316
case STRING:
317317
sb.append(escapeJson(VariantUtil.getString(value, pos)));
318318
break;
319-
case DOUBLE:
320-
sb.append(VariantUtil.getDouble(value, pos));
319+
case DOUBLE: {
320+
double d = VariantUtil.getDouble(value, pos);
321+
if (Double.isFinite(d)) {
322+
sb.append(d);
323+
} else {
324+
appendQuoted(sb, Double.toString(d));
325+
}
321326
break;
327+
}
322328
case DECIMAL:
323329
sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
324330
break;
@@ -333,9 +339,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb,
333339
appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format(
334340
microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC)));
335341
break;
336-
case FLOAT:
337-
sb.append(VariantUtil.getFloat(value, pos));
342+
case FLOAT: {
343+
float f = VariantUtil.getFloat(value, pos);
344+
if (Float.isFinite(f)) {
345+
sb.append(f);
346+
} else {
347+
appendQuoted(sb, Float.toString(f));
348+
}
338349
break;
350+
}
339351
case BINARY:
340352
appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos)));
341353
break;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import com.fasterxml.jackson.core.json.JsonReadFeature
2626
import org.apache.spark.SparkException
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
29-
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
3029
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions}
3130
import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode}
3231
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -123,6 +122,8 @@ case class JsonToStructsEvaluator(
123122
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
124123
case _: MapType =>
125124
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
125+
case _: VariantType =>
126+
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getVariant(0) else null
126127
}
127128

128129
@transient
@@ -152,13 +153,7 @@ case class JsonToStructsEvaluator(
152153

153154
final def evaluate(json: UTF8String): Any = {
154155
if (json == null) return null
155-
nullableSchema match {
156-
case _: VariantType =>
157-
VariantExpressionEvalUtils.parseJson(json,
158-
allowDuplicateKeys = variantAllowDuplicateKeys)
159-
case _ =>
160-
converter(parser.parse(json))
161-
}
156+
converter(parser.parse(json))
162157
}
163158
}
164159

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class JacksonParser(
108108
*/
109109
private def makeRootConverter(dt: DataType): JsonParser => Iterable[InternalRow] = {
110110
dt match {
111+
case _: VariantType => (parser: JsonParser) => {
112+
Some(InternalRow(parseVariant(parser)))
113+
}
111114
case _: StructType if options.singleVariantColumn.isDefined => (parser: JsonParser) => {
112115
Some(InternalRow(parseVariant(parser)))
113116
}

sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
3030
import org.apache.spark.sql.vectorized.ColumnarArray
3131
import org.apache.spark.types.variant.VariantBuilder
3232
import org.apache.spark.types.variant.VariantUtil._
33-
import org.apache.spark.unsafe.types.VariantVal
33+
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
3434

3535
class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
3636
import testImplicits._
@@ -101,6 +101,26 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
101101
)
102102
// scalastyle:on nonascii
103103
check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
104+
105+
// Validate options work.
106+
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
107+
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'false'))"), Row(null))
108+
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
109+
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"),
110+
Row(
111+
VariantExpressionEvalUtils.castToVariant(InternalRow(Double.NaN),
112+
StructType.fromDDL("a double"))))
113+
// String input "NaN" will remain a string instead of double.
114+
checkAnswer(Seq("""{"a": "NaN"}""").toDF("v")
115+
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"),
116+
Row(
117+
VariantExpressionEvalUtils.castToVariant(InternalRow(UTF8String.fromString("NaN")),
118+
StructType.fromDDL("a string"))))
119+
// to_json should put special floating point values in quotes.
120+
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
121+
.selectExpr("to_json(from_json(v, 'variant', map('allowNonNumericNumbers', 'true')))"),
122+
Row("""{"a":"NaN"}"""))
123+
104124
}
105125

106126
test("try_parse_json/to_json round-trip") {
@@ -346,6 +366,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
346366

347367
test("from_json(_, 'variant') with duplicate keys") {
348368
val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}"""
369+
349370
withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") {
350371
val df = Seq(json).toDF("j")
351372
.selectExpr("from_json(j,'variant')")
@@ -359,24 +380,25 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
359380
val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c')
360381
assert(actual === new VariantVal(expectedValue, expectedMetadata))
361382
}
362-
// Check whether the parse_json and from_json expressions throw the correct exception.
363-
Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr =>
364-
withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") {
365-
val df = Seq(json).toDF("j").selectExpr(expr)
366-
val exception = intercept[SparkException] {
367-
df.collect()
368-
}
369-
checkError(
370-
exception = exception,
371-
condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
372-
parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST")
373-
)
374-
checkError(
375-
exception = exception.getCause.asInstanceOf[SparkRuntimeException],
376-
condition = "VARIANT_DUPLICATE_KEY",
377-
parameters = Map("key" -> "a")
378-
)
383+
384+
withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") {
385+
// In default mode (PERMISSIVE), JSON with duplicate keys is still invalid, but no error will
386+
// be thrown.
387+
checkAnswer(Seq(json).toDF("j").selectExpr("from_json(j, 'variant')"), Row(null))
388+
389+
val exception = intercept[SparkException] {
390+
Seq(json).toDF("j").selectExpr("from_json(j, 'variant', map('mode', 'FAILFAST'))").collect()
379391
}
392+
checkError(
393+
exception = exception,
394+
condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
395+
parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST")
396+
)
397+
checkError(
398+
exception = exception.getCause.asInstanceOf[SparkRuntimeException],
399+
condition = "VARIANT_DUPLICATE_KEY",
400+
parameters = Map("key" -> "a")
401+
)
380402
}
381403
}
382404

0 commit comments

Comments
 (0)