Skip to content

Commit 9fdb4a8

Browse files
davidm-dbcloud-fan
authored andcommitted
[SPARK-48356][FOLLOW UP][SQL] Improve FOR statement's column schema inference
### What changes were proposed in this pull request? This pull request changes `FOR` statement to infer column schemas from the query DataFrame, and no longer implicitly infer column schema in SetVariable. This is necessary due to type mismatch errors with complex nested types, e.g. `ARRAY<STRUCT<..>>`. ### Why are the changes needed? Bug fix for FOR statement. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test that specifically targets problematic case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51053 from davidm-db/for_schema_inference. Lead-authored-by: David Milicevic <david.milicevic@databricks.com> Co-authored-by: David Milicevic <163021185+davidm-db@users.noreply.github.com> Co-authored-by: Wenchen Fan <cloud0fan@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 4953f89 commit 9fdb4a8

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.Exceptio
3030
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
3131
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
3232
import org.apache.spark.sql.errors.SqlScriptingErrors
33-
import org.apache.spark.sql.types.BooleanType
33+
import org.apache.spark.sql.types.{BooleanType, DataType}
3434

3535
/**
3636
* Trait for all SQL scripting execution nodes used during interpretation phase.
@@ -997,10 +997,14 @@ class ForStatementExec(
997997
private var state = ForState.VariableAssignment
998998

999999
private var queryResult: util.Iterator[Row] = _
1000+
private var queryColumnNameToDataType: Map[String, DataType] = _
10001001
private var isResultCacheValid = false
10011002
private def cachedQueryResult(): util.Iterator[Row] = {
10021003
if (!isResultCacheValid) {
1003-
queryResult = query.buildDataFrame(session).toLocalIterator()
1004+
val df = query.buildDataFrame(session)
1005+
queryResult = df.toLocalIterator()
1006+
queryColumnNameToDataType = df.schema.fields.map(f => f.name -> f.dataType).toMap
1007+
10041008
query.isExecuted = true
10051009
isResultCacheValid = true
10061010
}
@@ -1063,7 +1067,7 @@ class ForStatementExec(
10631067
val variableInitStatements = row.schema.names.toSeq
10641068
.map { colName => (colName, createExpressionFromValue(row.getAs(colName))) }
10651069
.flatMap { case (colName, expr) => Seq(
1066-
createDeclareVarExec(colName, expr),
1070+
createDeclareVarExec(colName),
10671071
createSetVarExec(colName, expr)
10681072
) }
10691073

@@ -1166,8 +1170,9 @@ class ForStatementExec(
11661170
case _ => Literal(value)
11671171
}
11681172

1169-
private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = {
1170-
val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null")
1173+
private def createDeclareVarExec(varName: String): SingleStatementExec = {
1174+
val defaultExpression = DefaultValueExpression(
1175+
Literal(null, queryColumnNameToDataType(varName)), "null")
11711176
val declareVariable = CreateVariable(
11721177
UnresolvedIdentifier(Seq(varName)),
11731178
defaultExpression,

sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2720,7 +2720,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
27202720
| SELECT varL3;
27212721
| SELECT 1/0;
27222722
| END;
2723-
27242723
| SELECT 5;
27252724
| SELECT 1/0;
27262725
| SELECT 6;

sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,4 +3450,32 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
34503450
verifySqlScriptResult(commands, expected)
34513451
}
34523452
}
3453+
3454+
test("for statement - structs in array have different values") {
3455+
withTable("t") {
3456+
val sqlScript =
3457+
"""
3458+
|BEGIN
3459+
| CREATE TABLE t(
3460+
| array_column ARRAY<STRUCT<id: INT, strCol: STRING, intArrayCol: ARRAY<INT>>>
3461+
| );
3462+
| INSERT INTO t VALUES
3463+
| Array(Struct(1, null, Array(10)),
3464+
| Struct(2, "name", Array()));
3465+
| FOR SELECT * FROM t DO
3466+
| SELECT array_column;
3467+
| END FOR;
3468+
|END
3469+
|""".stripMargin
3470+
3471+
val expected = Seq(
3472+
Seq.empty[Row], // create table
3473+
Seq.empty[Row], // insert
3474+
Seq.empty[Row], // declare array_column
3475+
Seq.empty[Row], // set array_column
3476+
Seq(Row(Seq(Row(1, null, Seq(10)), Row(2, "name", Seq.empty))))
3477+
)
3478+
verifySqlScriptResult(sqlScript, expected)
3479+
}
3480+
}
34533481
}

0 commit comments

Comments
 (0)