Skip to content

Commit 791f7ce

Browse files
fix(spark): enable aliased expressions to round-trip (#348)
A number of the TPC-DS tests were failing because the query contains multiple aliases to the same expression, causing a potential mismatch in the reference index. Although the plans were equivalent, the substrait POJO comparison failed. This commit uses the `hint` field of the Rel message to store the alias names, and restore them back to the Spark plan to match the original.
1 parent 2ce3501 commit 791f7ce

File tree

4 files changed

+39
-10
lines changed

4 files changed

+39
-10
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ protected Project newProject(ProjectRel rel) {
423423

424424
builder
425425
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
426-
.remap(optionalRelmap(rel.getCommon()));
426+
.remap(optionalRelmap(rel.getCommon()))
427+
.hint(optionalHint(rel.getCommon()));
427428
if (rel.hasAdvancedExtension()) {
428429
builder.extension(advancedExtension(rel.getAdvancedExtension()));
429430
}

spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package io.substrait.spark.logical
1818

19-
import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSparkType}
19+
import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSparkType, ToSubstraitType}
2020
import io.substrait.spark.expression._
2121

2222
import org.apache.spark.sql.SparkSession
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
3333
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
3434
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
3535
import org.apache.spark.sql.internal.SQLConf
36-
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, IntegerType, MapType, StructField, StructType}
36+
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
3737

3838
import io.substrait.`type`.{NamedStruct, StringTypeVisitor, Type}
3939
import io.substrait.{expression => exp}
@@ -244,18 +244,43 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
244244
}
245245
}
246246

247+
/**
248+
* Returns the top level field (column) names for the given relation, if they have been specified
249+
* in the optional `hint` message. Does not include the field names of any inner structs.
250+
* @param rel
251+
* @return
252+
* Optional list of names.
253+
*/
254+
private def fieldNames(rel: relation.Rel): Option[Seq[String]] = {
255+
if (rel.getHint.isPresent && !rel.getHint.get().getOutputNames.isEmpty) {
256+
Some(
257+
ToSubstraitType
258+
.toNamedStruct(ToSparkType.toStructType(
259+
NamedStruct.of(rel.getHint.get.getOutputNames, rel.getRecordType)))
260+
.names
261+
.asScala)
262+
} else {
263+
None
264+
}
265+
}
266+
247267
override def visit(project: relation.Project): LogicalPlan = {
248268
val child = project.getInput.accept(this)
249269
val (output, createProject) = child match {
250270
case a: Aggregate => (a.aggregateExpressions, false)
251271
case other => (other.output, true)
252272
}
273+
val names = fieldNames(project).getOrElse(List.empty)
253274

254275
withOutput(output) {
255-
val projectList =
276+
val projectExprs =
256277
project.getExpressions.asScala
257278
.map(expr => expr.accept(expressionConverter))
258-
.map(toNamedExpression)
279+
val projectList = if (names.size == projectExprs.size) {
280+
projectExprs.zip(names).map { case (expr, name) => Alias(expr, name)() }
281+
} else {
282+
projectExprs.map(toNamedExpression)
283+
}
259284
if (createProject) {
260285
Project(projectList, child)
261286
} else {
@@ -267,7 +292,9 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
267292

268293
override def visit(expand: relation.Expand): LogicalPlan = {
269294
val child = expand.getInput.accept(this)
270-
val names = expand.getHint.get().getOutputNames.asScala
295+
val names = fieldNames(expand).getOrElse(
296+
expand.getFields.asScala.zipWithIndex.map { case (_, i) => s"col$i" }
297+
)
271298

272299
withChild(child) {
273300
val projections = expand.getFields.asScala

spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
173173
}
174174
}
175175
val projects = projectExpressions.map(toExpression(newOutput))
176+
val names = ToSubstraitType.toNamedStruct(agg.schema).names()
176177

177178
relation.Project.builder
178179
.remap(relation.Rel.Remap.offset(newOutput.size, projects.size))
179180
.expressions(projects.asJava)
181+
.hint(Hint.builder.addAllOutputNames(names).build())
180182
.input(substraitAgg)
181183
.build()
182184
}
@@ -340,9 +342,11 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
340342
p.child.output.count(o => !existenceJoins.contains(o.exprId.id)),
341343
expressions.size
342344
)
345+
343346
relation.Project.builder
344347
.remap(remap)
345348
.expressions(expressions.asJava)
349+
.hint(Hint.builder.addAllOutputNames(ToSubstraitType.toNamedStruct(p.schema).names()).build())
346350
.input(child)
347351
.build()
348352
}
@@ -357,11 +361,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
357361
.build()
358362
})
359363

360-
val names = p.output.map(_.name)
361-
362364
relation.Expand.builder
363365
.fields(fields.asJava)
364-
.hint(Hint.builder.addAllOutputNames(names.asJava).build())
366+
.hint(Hint.builder.addAllOutputNames(ToSubstraitType.toNamedStruct(p.schema).names()).build())
365367
.input(visit(p.child))
366368
.build()
367369
}

spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
3333

3434
// spotless:off
3535
val failingSQL: Set[String] = Set(
36-
"q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait)
3736
"q72" //requires implementation of date_add()
3837
)
3938
// spotless:on

0 commit comments

Comments
 (0)