Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
s"unmatched child schema for GetStructField: ${projSchema.toString}"
)
}
case GetStructField(child, ordinal, nameOpt) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, projSchema: StructType) =>
// Look up the field name from the original schema using ordinal
val originalFieldName = nameOpt.getOrElse {
child.dataType.asInstanceOf[StructType](ordinal).name
}
// Find the new ordinal in the pruned schema
GetStructField(projection, projSchema.fieldIndex(originalFieldName),
Some(originalFieldName))
case (_, projSchema) =>
throw new IllegalStateException(
s"unmatched child schema for GetStructField: ${projSchema.toString}"
)
}
case ElementAt(left, right, defaultValueOutOfBound, failOnError) if right.foldable =>
getProjection(left).map(p => ElementAt(p, right, defaultValueOutOfBound, failOnError))
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ object ExtractValue {
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
// SPARK-47230: Keep the original field name from schema (not user-provided name)
// to ensure case-insensitive lookups work correctly during schema pruning
GetArrayStructFields(child, fields(ordinal),
ordinal, fields.length, containsNull || fields(ordinal).nullable)

case (_: ArrayType, _) => GetArrayItem(child, extraction)
Expand Down Expand Up @@ -196,9 +198,22 @@ case class GetArrayStructFields(
val values = ctx.freshName("values")
val j = ctx.freshName("j")
val row = ctx.freshName("row")
val actualOrdinal = ctx.freshName("actualOrdinal")
val actualNumFields = ctx.freshName("actualNumFields")

// SPARK-47230: Dynamic ordinal lookup to support schema pruning
// Store the element struct schema as a reference object
// Get the current schema from the child expression's dataType
val elementSchema = child.dataType.asInstanceOf[ArrayType].elementType
.asInstanceOf[StructType]
val schemaRef = ctx.addReferenceObj("elementSchema", elementSchema,
classOf[StructType].getName)
// field.name contains the resolved field name from the schema (case-preserved)
val fieldNameRef = ctx.addReferenceObj("fieldName", field.name, classOf[String].getName)

val nullSafeEval = if (field.nullable) {
s"""
if ($row.isNullAt($ordinal)) {
if ($row.isNullAt($actualOrdinal)) {
$values[$j] = null;
} else
"""
Expand All @@ -209,13 +224,17 @@ case class GetArrayStructFields(
s"""
final int $n = $eval.numElements();
final Object[] $values = new Object[$n];
// SPARK-47230: Look up ordinal by field name from CURRENT schema
final int $actualOrdinal = $schemaRef.fieldIndex($fieldNameRef);
final int $actualNumFields = $schemaRef.size();

for (int $j = 0; $j < $n; $j++) {
if ($eval.isNullAt($j)) {
$values[$j] = null;
} else {
final InternalRow $row = $eval.getStruct($j, $numFields);
final InternalRow $row = $eval.getStruct($j, $actualNumFields);
$nullSafeEval {
$values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)};
$values[$j] = ${CodeGenerator.getValue(row, field.dataType, actualOrdinal)};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,13 @@ object GeneratorNestedColumnAliasing {
attrToExtractValuesOnGenerator.flatMap(_._2).toSeq, Seq.empty,
collectNestedGetStructFields)

// Pruning on `Generator`'s output. We only process single field case.
// For multiple field case, we cannot directly move field extractor into
// the generator expression. A workaround is to re-construct array of struct
// from multiple fields. But it will be more complicated and may not worth.
// TODO(SPARK-34956): support multiple fields.
// SPARK-47230/SPARK-34956: Support multiple nested field pruning on Generator output.
// For single field, we can push the field access directly into the generator.
// For multiple fields, we create _extract_* aliases like we do for non-generator fields.
val nestedFieldsOnGenerator = attrToExtractValuesOnGenerator.values.flatten.toSet
if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
if (nestedFieldsOnGenerator.isEmpty) {
Some(pushedThrough)
} else {
} else if (nestedFieldsOnGenerator.size == 1) {
// Only one nested column accessor.
// E.g., df.select(explode($"items").as("item")).select($"item.a")
val nestedFieldOnGenerator = nestedFieldsOnGenerator.head
Expand Down Expand Up @@ -456,6 +454,11 @@ object GeneratorNestedColumnAliasing {
// We should not reach here.
throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
}
} else {
// TODO(SPARK-34956): Handle multiple nested fields on generator output.
// For now, we skip the optimization when there are multiple fields.
// The `rewritePlanWithAliases` approach doesn't work because generator outputs are fixed.
Some(pushedThrough)
}

case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,20 @@ case class GenerateExec(
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {

override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
override def output: Seq[Attribute] = requiredChildOutput ++ correctedGeneratorOutput

/**
* SPARK-47230: Create corrected generator output attributes with data types from the
* bound generator's element schema. This ensures projections use the correct ordinals
* after schema pruning.
*/
private lazy val correctedGeneratorOutput: Seq[Attribute] = {
val elementFields = boundGenerator.elementSchema.fields
generatorOutput.zip(elementFields).map { case (attr, field) =>
// Create a new attribute with the correct data type from the generator's schema
attr.withDataType(field.dataType)
}
}

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand All @@ -77,6 +90,7 @@ case class GenerateExec(
protected override def doExecute(): RDD[InternalRow] = {
// boundGenerator.terminate() should be triggered after all of the rows in the partition
val numOutputRows = longMetric("numOutputRows")

child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
boundGenerator.foreach {
case n: Nondeterministic => n.initialize(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.{GeneratorOrdinalRewriting, PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
Expand All @@ -36,7 +36,9 @@ class SparkOptimizer(

override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
// SPARK-47230: GeneratorOrdinalRewriting must run IMMEDIATELY AFTER SchemaPruning
// to fix ordinals before other optimizer rules transform the plan
Seq(SchemaPruning, GeneratorOrdinalRewriting) :+
GroupBasedRowLevelOperationScanPlanning :+
V1Writes :+
V2ScanRelationPushDown :+
Expand Down
Loading