From 5e83f0dce9a6c0f32e354082dac56bcdd9b8c57c Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 13:12:34 -0700 Subject: [PATCH 01/32] 1 --- .../resources/error/error-conditions.json | 32 + project/SparkBuild.scala | 2 + sql/pipelines/pom.xml | 196 +++++- .../spark/sql/pipelines/AnalysisWarning.scala | 33 + .../apache/spark/sql/pipelines/Language.scala | 24 + .../graph/CoreDataflowNodeProcessor.scala | 232 +++++++ .../sql/pipelines/graph/DataflowGraph.scala | 251 +++++++ .../graph/DataflowGraphTransformer.scala | 374 ++++++++++ .../spark/sql/pipelines/graph/Flow.scala | 226 ++++++ .../sql/pipelines/graph/FlowAnalysis.scala | 316 +++++++++ .../pipelines/graph/FlowAnalysisContext.scala | 74 ++ .../graph/GraphElementTypeUtils.scala | 37 + .../sql/pipelines/graph/GraphErrors.scala | 123 ++++ .../graph/GraphIdentifierManager.scala | 345 ++++++++++ .../sql/pipelines/graph/GraphOperations.scala | 177 +++++ .../graph/GraphRegistrationContext.scala | 211 ++++++ .../pipelines/graph/GraphValidations.scala | 275 ++++++++ .../sql/pipelines/graph/PipelinesErrors.scala | 153 +++++ .../graph/PipelinesTableProperties.scala | 112 +++ .../sql/pipelines/graph/QueryOrigin.scala | 153 +++++ .../sql/pipelines/graph/ViewHelpers.scala | 39 ++ .../spark/sql/pipelines/graph/elements.scala | 277 ++++++++ .../sql/pipelines/util/InputReadInfo.scala | 56 ++ .../pipelines/util/SchemaInferenceUtils.scala | 167 +++++ .../pipelines/util/SchemaMergingUtils.scala | 26 + .../graph/ConnectInvalidPipelineSuite.scala | 466 +++++++++++++ .../graph/ConnectValidPipelineSuite.scala | 589 ++++++++++++++++ .../sql/pipelines/utils/PipelineTest.scala | 643 ++++++++++++++++++ .../pipelines/utils/SparkErrorTestMixin.scala | 105 +++ .../utils/TestGraphRegistrationContext.scala | 225 ++++++ 30 files changed, 5938 insertions(+), 1 deletion(-) create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8dddca5077a67..4b38ac8954f67 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1322,6 +1322,12 @@ ], "sqlState" : "42713" }, + "DUPLICATE_FLOW_SQL_CONF": { + "message": [ + "Found duplicate sql conf for dataset '': '' is defined by both '' and ''" + ], + "sqlState": "42710" + }, "DUPLICATED_MAP_KEY" : { "message" : [ "Duplicate map key was found, please check the input data.", @@ -2019,6 +2025,18 @@ ], "sqlState" : "42613" }, + "INCOMPATIBLE_BATCH_VIEW_READ": { + "message": [ + "View is not a streaming view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + ], + "sqlState": "42000" + }, + "INCOMPATIBLE_STREAMING_VIEW_READ": { + "message": [ + "View is a streaming view and must be referenced using readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + ], + "sqlState": "42000" + }, "INCOMPATIBLE_VIEW_SCHEMA_CHANGE" : { "message" : [ "The SQL query of view has an incompatible schema change and column cannot be resolved. Expected columns named but got .", @@ -6571,6 +6589,20 @@ ], "sqlState" : "P0001" }, + "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE": { + "message": [ + "Table '' has a user-specified schema that is incompatible with the schema", + " inferred from its query.", + "", + "", + "Declared schema:", + "", + "", + "Inferred schema:", + "" + ], + "sqlState": "42000" + }, "VARIABLE_ALREADY_EXISTS" : { "message" : [ "Cannot create the variable because it already exists.", diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 77001e6bdf227..cdf87dcc142e4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -418,6 +418,8 @@ object SparkBuild extends PomBuild { enable(HiveThriftServer.settings)(hiveThriftServer) + enable(SparkConnect.settings)(pipelines) + enable(SparkConnectCommon.settings)(connectCommon) enable(SparkConnect.settings)(connect) enable(SparkConnectClient.settings)(connectClient) diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 7d796a83af69d..7d791e57bd000 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -16,7 +16,8 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark @@ -49,9 +50,202 @@ test-jar test + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + com.google.protobuf + protobuf-java-util + compile + + + + com.google.guava + guava + ${connect.guava.version} + compile + + + com.google.guava + failureaccess + ${guava.failureaccess.version} + compile + + + + + io.grpc + grpc-stub + ${io.grpc.version} + + + io.grpc + grpc-netty + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + io.grpc + grpc-services + ${io.grpc.version} + + + + org.scala-lang.modules + scala-parser-combinators_${scala.binary.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + + + org.scala-lang.modules + scala-parallel-collections_${scala.binary.version} + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + kr.motd.maven + os-maven-plugin + 1.7.0 + true + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + + src/main/protobuf + grpc-java + io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier} + + + + + generate-sources + + compile + compile-custom + test-compile + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.9.2 + + + + compile + testCompile + + + + + + diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala new file mode 100644 index 0000000000000..35b8185c255e1 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines + +/** Represents a warning generated as part of graph analysis. */ +sealed trait AnalysisWarning + +object AnalysisWarning { + + /** + * Warning that some streaming reader options are being dropped + * + * @param sourceName Source for which reader options are being dropped. + * @param droppedOptions Set of reader options that are being dropped for a specific source. + */ + case class StreamingReaderOptionsDropped(sourceName: String, droppedOptions: Seq[String]) + extends AnalysisWarning +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala new file mode 100644 index 0000000000000..77ca6f84b7fbc --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines + +sealed trait Language {} + +case class Python() extends Language {} + +case class Sql() extends Language {} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala new file mode 100644 index 0000000000000..e1400331f4148 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ + TransformNodeFailedException, + TransformNodeRetryableException +} + +/** + * Core processor that is responsible for analyzing each flow and sort the nodes in + * topological order + */ +class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { + + private val flowResolver = new FlowResolver(rawGraph) + + // Map of input identifier to resolved [[Input]]. + private val resolvedInputs = new ConcurrentHashMap[TableIdentifier, Input]() + // Map & queue of resolved flows identifiers + // queue is there to track the topological order while map is used to store the id -> flow + // mapping + private val resolvedFlowNodesMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() + private val resolvedFlowNodesQueue = new ConcurrentLinkedQueue[ResolvedFlow]() + + private def processUnresolvedFlow(flow: UnresolvedFlow): ResolvedFlow = { + val resolvedFlow = flowResolver.attemptResolveFlow( + flow, + rawGraph.inputIdentifiers, + resolvedInputs.asScala.toMap + ) + resolvedFlowNodesQueue.add(resolvedFlow) + resolvedFlowNodesMap.put(flow.identifier, resolvedFlow) + resolvedFlow + } + + /** + * Processes the node of the graph, re-arranging them if they are not topologically sorted. + * Takes care of resolving the flows and virtualization if needed for the nodes. + * @param node The node to process + * @param upstreamNodes Upstream nodes for the node + * @return + */ + def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = { + node match { + case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow)) + case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow)) + case table: Table => + // Ensure all upstreamNodes for a table are flows + val flowsToTable = upstreamNodes.map { + case flow: Flow => flow + case _ => + throw new IllegalArgumentException( + s"Unsupported upstream node type for table ${table.displayName}: " + + s"${upstreamNodes.getClass}" + ) + } + val resolvedFlowsToTable = flowsToTable.map { flow => + resolvedFlowNodesMap.get(flow.identifier) + } + + // Assign isStreamingTable (MV or ST) to the table based on the resolvedFlowsToTable + val tableWithType = table.copy( + isStreamingTableOpt = Option(resolvedFlowsToTable.exists(f => f.df.isStreaming)) + ) + + // Table will be virtual in either of the following scenarios: + // 1. If table is present in context.fullRefreshTables + // 2. If table has any virtual inputs (flows or tables) + // 3. If the table pre-existing metadata is different from current metadata + val virtualTableInput = VirtualTableInput( + identifier = table.identifier, + specifiedSchema = table.specifiedSchema, + incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet, + availableFlows = resolvedFlowsToTable + ) + resolvedInputs.put(table.identifier, virtualTableInput) + Seq(tableWithType) + case view: View => + // For view, add the flow to resolvedInputs and return empty. + require(upstreamNodes.size == 1, "Found multiple flows to view") + upstreamNodes.head match { + case f: Flow => + resolvedInputs.put(view.identifier, resolvedFlowNodesMap.get(f.destinationIdentifier)) + Seq(view) + case _ => + throw new IllegalArgumentException( + s"Unsupported upstream node type for view ${view.displayName}: " + + s"${upstreamNodes.getClass}" + ) + } + case _ => + throw new IllegalArgumentException(s"Unsupported node type: ${node.getClass}") + } + } +} + +private class FlowResolver(rawGraph: DataflowGraph) extends Logging { + + /** Helper used to track which confs were set by which flows. */ + private case class FlowConf(key: String, value: String, flowIdentifier: TableIdentifier) + + /** Attempts resolving a single flow using the map of resolved inputs. */ + def attemptResolveFlow( + flowToResolve: UnresolvedFlow, + allInputs: Set[TableIdentifier], + availableResolvedInputs: Map[TableIdentifier, Input]): ResolvedFlow = { + val flowFunctionResult = flowToResolve.func.call( + allInputs = allInputs, + availableInputs = availableResolvedInputs.values.toList, + configuration = flowToResolve.sqlConf, + currentCatalog = flowToResolve.currentCatalog, + currentDatabase = flowToResolve.currentDatabase + ) + val result = + flowFunctionResult match { + case f if f.dataFrame.isSuccess => + // Merge the flow's inputs' confs into confs for this flow, throwing if any conflict + // But do not merge confs from inputs that are tables; we don't want to propagate confs + // past materialization points. + val allFConfs = + (flowToResolve +: + f.inputs.toSeq + .map(availableResolvedInputs(_)) + .filter { + // Input is a flow implies that the upstream table is a View. + case _: Flow => true + // We stop in all other cases. + case _ => false + }).collect { + case g: Flow => + g.sqlConf.toSeq.map { case (k, v) => FlowConf(k, v, g.identifier) } + }.flatten + + allFConfs + .groupBy(_.key) // Key name -> Seq[FlowConf] + .filter(_._2.length > 1) // Entries where key was set more than once + .find(_._2.map(_.value).toSet.size > 1) // Entry where key was set with diff vals + .foreach { + case (key, confs) => + val sortedByVal = confs.sortBy(_.value) + throw new AnalysisException( + "DUPLICATE_FLOW_SQL_CONF", + Map( + "key" -> key, + "datasetName" -> flowToResolve.displayName, + "flowName1" -> sortedByVal.head.flowIdentifier.unquotedString, + "flowName2" -> sortedByVal.last.flowIdentifier.unquotedString + ) + ) + } + + val newSqlConf = allFConfs.map(fc => fc.key -> fc.value).toMap + // if the new sql confs are different from the original sql confs the flow was resolved + // with, resolve again. + val maybeNewFuncResult = if (newSqlConf != flowToResolve.sqlConf) { + flowToResolve.func.call( + allInputs = allInputs, + availableInputs = availableResolvedInputs.values.toList, + configuration = newSqlConf, + currentCatalog = flowToResolve.currentCatalog, + currentDatabase = flowToResolve.currentDatabase + ) + } else { + f + } + convertResolvedToTypedFlow(flowToResolve, maybeNewFuncResult) + + // If flow failed due to unresolved dataset, throw a retryable exception, otherwise just + // return the failed flow. + case f => + f.dataFrame.failed.toOption.collectFirst { + case e: UnresolvedDatasetException => e + case _ => None + } match { + case Some(e: UnresolvedDatasetException) => + throw TransformNodeRetryableException( + e.identifier, + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + case _ => + throw TransformNodeFailedException( + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + } + } + result + } + + private def convertResolvedToTypedFlow( + flow: UnresolvedFlow, + funcResult: FlowFunctionResult): ResolvedFlow = { + val typedFlow = flow match { + case f: UnresolvedFlow if f.once => new AppendOnceFlow(flow, funcResult) + case f: UnresolvedFlow if funcResult.dataFrame.get.isStreaming => + // If there's more than 1 flow to this flow's destination, we should not allow it + // to be planned with an output mode other than Append, as the other flows will + // then get their results overwritten. + val mustBeAppend = rawGraph.flowsTo(f.destinationIdentifier).size > 1 + new StreamingFlow(flow, funcResult, mustBeAppend = mustBeAppend) + case _: UnresolvedFlow => new CompleteFlow(flow, funcResult) + } + if (!funcResult.resolved) { + logError(s"Failed to resolve ${flow.displayName}: ${funcResult.failure.mkString("\n\n\n")}") + } else { + logInfo(s"Successfully resolved ${flow.displayName}") + } + typedFlow + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala new file mode 100644 index 0000000000000..958dab2b647a8 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique +import org.apache.spark.sql.pipelines.util.SchemaMergingUtils +import org.apache.spark.sql.types.StructType + +/** + * DataflowGraph represents the core graph structure for Spark declarative pipelines. + * It manages the relationships between logical flows, tables, and views, providing + * operations for graph traversal, validation, and transformation. + */ +class DataflowGraph(val flows: Seq[Flow], val tables: Seq[Table], val views: Seq[View]) + extends GraphOperations + with GraphValidations { + + /** Returns a [[Output]] given its identifier */ + lazy val output: Map[TableIdentifier, Output] = mapUnique(tables, "output")(_.identifier) + + /** + * Returns a [[TableInput]], if one is available, that can be read from by downstream flows. + */ + def tableInput(identifier: TableIdentifier): Option[TableInput] = table.get(identifier) + + /** + * Returns [[Flow]]s in this graph that need to get planned and potentially executed when + * executing the graph. Flows that write to logical views are excluded. + */ + lazy val materializedFlows: Seq[ResolvedFlow] = { + resolvedFlows.filter( + f => output.contains(f.destinationIdentifier) + ) + } + + /** Returns the identifiers of [[materializedFlows]]. */ + val materializedFlowIdentifiers: Set[TableIdentifier] = materializedFlows.map(_.identifier).toSet + + /** Returns a [[Table]] given its identifier */ + lazy val table: Map[TableIdentifier, Table] = + mapUnique(tables, "table")(_.identifier) + + /** Returns a [[Flow]] given its identifier */ + lazy val flow: Map[TableIdentifier, Flow] = { + // Better error message than using mapUnique. + val flowsByIdentifier = flows.groupBy(_.identifier) + flowsByIdentifier + .find(_._2.size > 1) + .foreach { + case (flowIdentifier, flows) => + // We don't expect this to ever actually be hit, graph registration should validate for + // unique flow names. + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW", + messageParameters = Map( + "flowName" -> flowIdentifier.unquotedString, + "datasetNames" -> flows.map(_.destinationIdentifier).mkString(",") + ) + ) + } + // Flows with non-default names shouldn't conflict with table names + flows + .filterNot(f => f.identifier == f.destinationIdentifier) + .filter(f => table.contains(f.identifier)) + .foreach { f => + throw new AnalysisException( + "FLOW_NAME_CONFLICTS_WITH_TABLE", + Map( + "flowName" -> f.identifier.toString(), + "target" -> f.destinationIdentifier.toString(), + "tableName" -> f.identifier.toString() + ) + ) + } + flowsByIdentifier.view.mapValues(_.head).toMap + } + + /** Returns a [[View]] given its identifier */ + lazy val view: Map[TableIdentifier, View] = mapUnique(views, "view")(_.identifier) + + /** Returns the [[PersistedView]]s of the graph */ + lazy val persistedViews: Seq[PersistedView] = views.collect { + case v: PersistedView => v + } + + /** Returns all the [[Input]]s in the current DataflowGraph. */ + lazy val inputIdentifiers: Set[TableIdentifier] = { + (flows ++ tables).map(_.identifier).toSet + } + + /** Returns the [[Flow]]s that write to a given destination. */ + lazy val flowsTo: Map[TableIdentifier, Seq[Flow]] = flows.groupBy(_.destinationIdentifier) + + lazy val resolvedFlows: Seq[ResolvedFlow] = { + flows.collect { case f: ResolvedFlow => f } + } + + lazy val resolvedFlow: Map[TableIdentifier, ResolvedFlow] = { + resolvedFlows.map { f => + f.identifier -> f + }.toMap + } + + lazy val resolutionFailedFlows: Seq[ResolutionFailedFlow] = { + flows.collect { case f: ResolutionFailedFlow => f } + } + + lazy val resolutionFailedFlow: Map[TableIdentifier, ResolutionFailedFlow] = { + resolutionFailedFlows.map { f => + f.identifier -> f + }.toMap + } + + /** Returns a copy of this [[DataflowGraph]] with optionally replaced components. */ + def copy( + flows: Seq[Flow] = flows, + tables: Seq[Table] = tables, + views: Seq[View] = views): DataflowGraph = { + new DataflowGraph(flows, tables, views) + } + + /** + * Used to reanalyze the flow's DF for a given table. This is done by finding all upstream + * flows (until a table is reached) for the specified source and reanalyzing all upstream + * flows. + * + * @param srcFlow The flow that writes into the table that we will start from when finding + * upstream flows + * @return The reanalyzed flow + */ + protected[graph] def reanalyzeFlow(srcFlow: Flow): ResolvedFlow = { + val upstreamDatasetIdentifiers = dfsInternal( + flowNodes(srcFlow.identifier).output, + downstream = false, + stopAtMaterializationPoints = true + ) + val upstreamFlows = + resolvedFlows + .filter(f => upstreamDatasetIdentifiers.contains(f.destinationIdentifier)) + .map(_.flow) + val upstreamViews = upstreamDatasetIdentifiers.flatMap(identifier => view.get(identifier)).toSeq + + val subgraph = new DataflowGraph( + flows = upstreamFlows, + views = upstreamViews, + tables = Seq(table(srcFlow.destinationIdentifier)) + ) + subgraph.resolve().resolvedFlow(srcFlow.identifier) + } + + /** + * Returns a map of the inferred schema of each table, computed by merging the analyzed schemas + * of all flows writing to that table. + */ + lazy val inferredSchema: Map[TableIdentifier, StructType] = { + flowsTo.view.mapValues { flows => + flows + .map { flow => + resolvedFlow(flow.identifier).schema + } + .reduce(SchemaMergingUtils.mergeSchemas) + }.toMap + } + + /** Ensure that the [[DataflowGraph]] is valid and throws errors if not. */ + def validate(): DataflowGraph = { + validationFailure.toOption match { + case Some(exception) => throw exception + case None => this + } + } + + /** + * Validate the current [[DataflowGraph]] and cache the validation failure. + * + * To add more validations, add them in a helper function that throws an exception if the + * validation fails, and invoke the helper function here. + */ + private lazy val validationFailure: Try[Throwable] = Try { + validateSuccessfulFlowAnalysis() + validateUserSpecifiedSchemas() + // Connecting the graph sorts it topologically + validateGraphIsTopologicallySorted() + validateMultiQueryTables() + validatePersistedViewSources() + validateEveryDatasetHasFlow() + validateTablesAreResettable() + validateAppendOnceFlows() + inferredSchema + }.failed + + /** Enforce every dataset has at least once input flow. For example its possible to define + * streaming tables without a query; such tables should still have at least one flow + * writing to it. */ + def validateEveryDatasetHasFlow(): Unit = { + (tables.map(_.identifier) ++ views.map(_.identifier)).foreach { identifier => + if (!flows.exists(_.destinationIdentifier == identifier)) { + throw new AnalysisException( + "PIPELINE_DATASET_WITHOUT_FLOW", + Map("identifier" -> identifier.quotedString) + ) + } + } + } + + /** Returns true iff all [[Flow]]s are successfully analyzed. */ + def resolved: Boolean = + flows.forall(f => resolvedFlow.contains(f.identifier)) + + def resolve(): DataflowGraph = + DataflowGraphTransformer.withDataflowGraphTransformer(this) { transformer => + val coreDataflowNodeProcessor = + new CoreDataflowNodeProcessor(rawGraph = this) + transformer + .transformDownNodes(coreDataflowNodeProcessor.processNode) + .getDataflowGraph + } +} + +object DataflowGraph { + protected[graph] def mapUnique[K, A](input: Seq[A], tpe: String)(f: A => K): Map[K, A] = { + val grouped = input.groupBy(f) + grouped.filter(_._2.length > 1).foreach { + case (name, _) => + throw new AnalysisException( + errorClass = "DUPLICATE_GRAPH_ELEMENT", + messageParameters = Map("graphElementType" -> tpe, "graphElementName" -> name.toString) + ) + } + grouped.view.mapValues(_.head).toMap + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala new file mode 100644 index 0000000000000..5ae3b4e9a389c --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ + ConcurrentHashMap, + ConcurrentLinkedDeque, + ConcurrentLinkedQueue, + ExecutionException, + Future +} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.control.NoStackTrace + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.util.ThreadUtils + +/** + * Resolves the [[DataflowGraph]] by processing each node in the graph. This class exposes visitor + * functionality to resolve/analyze graph nodes. + * We only expose simple visitor abilities to transform different entities of the + * graph. + * For advanced transformations we also expose a mechanism to walk the graph over entity by entity. + * + * Assumptions: + * 1. Each output will have at-least 1 flow to it. + * 2. Each flow may or may not have a destination table. If a flow does not have a destination + * table, the destination is a view. + * + * The way graph is structured is that flows, tables and sinks all are graph elements or nodes. + * While we expose transformation functions for each of these entities, we also expose a way to + * process to walk over the graph. + * + * Constructor is private as all usages should be via + * DataflowGraphTransformer.withDataflowGraphTransformer. + * @param graph: Any Dataflow Graph + */ +class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { + import DataflowGraphTransformer._ + + private var tables: Seq[Table] = graph.tables + private var tableMap: Map[TableIdentifier, Table] = computeTableMap() + private var flows: Seq[Flow] = graph.flows + private var flowsTo: Map[TableIdentifier, Seq[Flow]] = computeFlowsTo() + private var views: Seq[View] = graph.views + private var viewMap: Map[TableIdentifier, View] = computeViewMap() + + // Fail analysis nodes + // Failed flows are flows that are failed to resolve or its inputs are not available or its + // destination failed to resolve. + private var failedFlows: Seq[ResolutionCompletedFlow] = Seq.empty + // We define a dataset is failed to resolve if: + // 1. It is a destination of a flow that is unresolved. + private var failedTables: Seq[Table] = Seq.empty + + private val parallelism = 10 + + // Executor used to resolve nodes in parallel. It is lazily initialized to avoid creating it + // for scenarios its not required. To track if the lazy val was evaluated or not we use a + // separate variable so we know if we need to shutdown the executor or not. + private var fixedPoolExecutorInitialized = false + lazy private val fixedPoolExecutor = { + fixedPoolExecutorInitialized = true + ThreadUtils.newDaemonFixedThreadPool( + parallelism, + prefix = "data-flow-graph-transformer-" + ) + } + private val selfExecutor = ThreadUtils.sameThreadExecutorService() + + private def computeTableMap(): Map[TableIdentifier, Table] = synchronized { + tables.map(table => table.identifier -> table).toMap + } + + private def computeViewMap(): Map[TableIdentifier, View] = synchronized { + views.map(view => view.identifier -> view).toMap + } + + private def computeFlowsTo(): Map[TableIdentifier, Seq[Flow]] = synchronized { + flows.groupBy(_.destinationIdentifier) + } + + def transformTables(transformer: Table => Table): DataflowGraphTransformer = { + tables = tables.map(transformer) + tableMap = computeTableMap() + this + } + + private def defaultOnFailedDependentTables( + failedTableDependencies: Map[TableIdentifier, Seq[Table]]): Unit = { + require( + failedTableDependencies.isEmpty, + "Dependency failure happened and some tables were not resolved" + ) + } + + /** + * Example graph: [Flow1, Flow 2] -> ST -> Flow3 -> MV + * Order of processing: Flow1, Flow2, ST, Flow3, MV. + * @param transformer function that transforms any graph entity. + * transformer( + * nodeToTransform: GraphElement, upstreamNodes: Seq[GraphElement] + * ) => transformedNodes: Seq[GraphElement] + * @return this + */ + def transformDownNodes( + transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + disableParallelism: Boolean = false): DataflowGraphTransformer = { + val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor + val batchSize = if (disableParallelism) 1 else parallelism + // List of resolved tables, sinks and flows + val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]() + val resolvedTables = new ConcurrentLinkedQueue[Table]() + val resolvedViews = new ConcurrentLinkedQueue[View]() + // Flow identifier to a list of transformed flows mapping to track resolved flows + val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]() + val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() + val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() + val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() + + var futures = ArrayBuffer[Future[Unit]]() + val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() + toBeResolvedFlows.addAll(flows.asJava) + + while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) { + val (done, notDone) = futures.partition(_.isDone) + // Explicitly call future.get() to propagate exceptions one by one if any + try { + done.foreach(_.get()) + } catch { + case exn: ExecutionException => + // Computation threw the exception that is the cause of exn + throw exn.getCause + } + futures = notDone + val flowOpt = { + // We only schedule [[batchSize]] number of flows in parallel. + if (futures.size < batchSize) { + Option(toBeResolvedFlows.pollFirst()) + } else { + None + } + } + if (flowOpt.isDefined) { + val flow = flowOpt.get + futures.append( + executor.submit( + () => + try { + try { + // Note: Flow don't need their inputs passed, so for now we send empty Seq. + val result = transformer(flow, Seq.empty) + require( + result.forall(_.isInstanceOf[ResolvedFlow]), + "transformer must return a Seq[Flow]" + ) + + val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow]) + resolvedFlowsMap.put(flow.identifier, transformedFlows) + resolvedFlows.addAll(transformedFlows.asJava) + } catch { + case e: TransformNodeRetryableException => + val datasetIdentifier = e.datasetIdentifier + failedDependentFlows.compute( + datasetIdentifier, + (_, flows) => { + // Don't add the input flow back but the failed flow object + // back which has relevant failure information. + val failedFlow = e.failedNode + if (flows == null) { + Seq(failedFlow) + } else { + flows :+ failedFlow + } + } + ) + // Between the time the flow started and finished resolving, perhaps the + // dependent dataset was resolved + resolvedFlowDestinationsMap.computeIfPresent( + datasetIdentifier, + (_, resolved) => { + if (resolved) { + // Check if the dataset that the flow is dependent on has been resolved + // and if so, remove all dependent flows from the failedDependentFlows and + // add them to the toBeResolvedFlows queue for retry. + failedDependentFlows.computeIfPresent( + datasetIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) + null + } + ) + } + resolved + } + ) + case other: Throwable => throw other + } + // If all flows to this particular destination are resolved, move to the destination + // node transformer + if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => + resolvedFlowsMap.containsKey(flowToDestination.identifier) + })) { + // If multiple flows completed in parallel, ensure we resolve the destination only + // once by electing a leader via computeIfAbsent + var isCurrentThreadLeader = false + resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { + isCurrentThreadLeader = true + // Set initial value as false as flow destination is not resolved yet. + false + }) + if (isCurrentThreadLeader) { + if (tableMap.contains(flow.destinationIdentifier)) { + val transformed = + transformer( + tableMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + resolvedTables.addAll( + transformed.collect { case t: Table => t }.asJava + ) + resolvedFlows.addAll( + transformed.collect { case f: ResolvedFlow => f }.asJava + ) + } else { + if (viewMap.contains(flow.destinationIdentifier)) { + resolvedViews.addAll { + val transformed = + transformer( + viewMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + transformed.map(_.asInstanceOf[View]).asJava + } + } else { + throw new IllegalArgumentException( + s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + + s" in flow: ${flow.displayName} at transformDownNodes" + ) + } + } + // Set flow destination as resolved now. + resolvedFlowDestinationsMap.computeIfPresent( + flow.destinationIdentifier, + (_, _) => { + // If there are any other node failures dependent on this destination, retry + // them + failedDependentFlows.computeIfPresent( + flow.destinationIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) + null + } + ) + true + } + ) + } + } + } catch { + case ex: TransformNodeFailedException => failedFlowsQueue.add(ex.failedNode) + } + ) + ) + } + } + + // Mutate the fail analysis entities + // A table is failed to analyze if: + // - It does not exist in the resolvedFlowDestinationsMap + failedTables = tables.filterNot { table => + resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) + } + + // We maintain the topological sort order of successful flows always + val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) = + resolvedFlows.asScala.toSeq.partition(flow => { + resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) + }) + + // A flow is failed to analyze if: + // - It is non-retryable + // - It is retryable but could not be retried, i.e. the dependent dataset is still unresolved + // - It might be resolvable but it writes into a destination that is failed to analyze + // To note: because we are transform down nodes, all downstream nodes of any pruned nodes + // will also be pruned + failedFlows = + // All transformed flows that write to a destination that is failed to analyze. + resolvedFlowsWithFailedDest ++ + // All failed flows thrown by TransformNodeFailedException + failedFlowsQueue.asScala.toSeq ++ + // All flows that have not been transformed and resolved yet + failedDependentFlows.values().asScala.flatten.toSeq + + // Mutate the resolved entities + flows = resolvedFlowsWithResolvedDest + flowsTo = computeFlowsTo() + tables = resolvedTables.asScala.toSeq + views = resolvedViews.asScala.toSeq + tableMap = computeTableMap() + viewMap = computeViewMap() + this + } + + def getDataflowGraph: DataflowGraph = { + graph.copy( + // Returns all flows (resolved and failed) in topological order. + // The relative order between flows and failed flows doesn't matter here. + // For failed flows that were resolved but were marked failed due to destination failure, + // they will be front of the list in failedFlows and thus by definition topologically sorted + // in the combined sequence too. + flows = flows ++ failedFlows, + tables = tables ++ failedTables + ) + } + + override def close(): Unit = { + if (fixedPoolExecutorInitialized) { + fixedPoolExecutor.shutdown() + } + } +} + +object DataflowGraphTransformer { + + /** + * Exception thrown when a node in the graph fails to be transformed because at least one of its + * dependencies weren't yet transformed. + * + * @param datasetIdentifier The identifier for an untransformed dependency table identifier in the + * dataflow graph. + */ + case class TransformNodeRetryableException( + datasetIdentifier: TableIdentifier, + failedNode: ResolutionFailedFlow) + extends Exception + with NoStackTrace + + case class TransformNodeFailedException(failedNode: ResolutionFailedFlow) + extends Exception + with NoStackTrace + + /** + * Autocloseable wrapper around DataflowGraphTransformer to ensure that the transformer is closed + * without clients needing to remember to close it. It takes in the same arguments as + * [[DataflowGraphTransformer]] constructor. It exposes the DataflowGraphTransformer instance + * within the callable scope. + */ + def withDataflowGraphTransformer[T](graph: DataflowGraph)(f: DataflowGraphTransformer => T): T = { + val dataflowGraphTransformer = new DataflowGraphTransformer(graph) + try { + f(dataflowGraphTransformer) + } finally { + dataflowGraphTransformer.close() + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala new file mode 100644 index 0000000000000..f0a80962f0994 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.classic.DataFrame +import org.apache.spark.sql.pipelines.AnalysisWarning +import org.apache.spark.sql.pipelines.util.InputReadOptions +import org.apache.spark.sql.types.StructType + +trait Flow extends GraphElement with Logging { + + /** The [[FlowFunction]] containing the user's query. */ + def func: FlowFunction + + val identifier: TableIdentifier + + /** + * The dataset that this Flow represents a write to. Since the DataflowGraph doesn't have a first- + * class concept of views, writing to a destination that isn't a Table or a Sink represents a + * view. + */ + val destinationIdentifier: TableIdentifier + + /** + * Whether this is a ONCE flow. ONCE flows should run only once per full refresh. + */ + def once: Boolean = false + + /** The current catalog in the execution context when the query is defined. */ + def currentCatalog: Option[String] + + /** The current database in the execution context when the query is defined. */ + def currentDatabase: Option[String] + + /** The comment associated with this flow */ + def comment: Option[String] + + def sqlConf: Map[String, String] +} + +/** A wrapper for a resolved internal input that includes the identifier used in SubqueryAlias */ +case class ResolvedInput(input: Input, aliasIdentifier: AliasIdentifier) + +/** A wrapper for the lambda function that defines a [[Flow]]. */ +trait FlowFunction extends Logging { + + /** + * This function defines the transformations performed by a flow, expressed as a [[DataFrame]]. + * + * @param allInputs the set of identifiers for all the [[Input]]s defined in the + * [[DataflowGraph]]. + * @param availableInputs the list of all [[Input]]s available to this flow + * @param configuration the spark configurations that apply to this flow. + * @param currentCatalog The current catalog in execution context when the query is defined. + * @param currentDatabase The current database in execution context when the query is defined. + * @return the inputs actually used, and the [[DataFrame]] expression for the flow + */ + def call( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + configuration: Map[String, String], + currentCatalog: Option[String], + currentDatabase: Option[String] + ): FlowFunctionResult +} + +/** + * Holds the [[DataFrame]] returned by a [[FlowFunction]] along with the inputs used to + * construct it. + * @param usedBatchInputs the identifiers of the complete inputs read by the flow + * @param usedStreamingInputs the identifiers of the incremental inputs read by the flow + * @param usedExternalInputs the identifiers of the external inputs read by the flow + * @param dataFrame the [[DataFrame]] expression executed by the flow if the flow can be resolved + */ +case class FlowFunctionResult( + requestedInputs: Set[TableIdentifier], + usedBatchInputs: Set[ResolvedInput], + usedStreamingInputs: Set[ResolvedInput], + usedExternalInputs: Set[String], + dataFrame: Try[DataFrame], + sqlConf: Map[String, String], + analysisWarnings: Seq[AnalysisWarning] = Nil) { + + /** + * Returns the names of all of the [[Input]]s used when resolving this [[Flow]]. If the + * flow failed to resolve, we return the all the datasets that were requested when evaluating the + * flow. + */ + def inputs: Set[TableIdentifier] = { + (batchInputs ++ streamingInputs).map(_.input.identifier) + } + + /** Names of [[Input]]s read completely by this [[Flow]]. */ + def batchInputs: Set[ResolvedInput] = usedBatchInputs + + /** Names of [[Input]]s read incrementally by this [[Flow]]. */ + def streamingInputs: Set[ResolvedInput] = usedStreamingInputs + + /** Returns errors that occurred when attempting to analyze this [[Flow]]. */ + def failure: Seq[Throwable] = { + dataFrame.failed.toOption.toSeq + } + + /** Whether this [[Flow]] is successfully analyzed. */ + final def resolved: Boolean = failure.isEmpty // don't override this, override failure +} + +/** A [[Flow]] whose output schema and dependencies aren't known. */ +class UnresolvedFlow( + val identifier: TableIdentifier, + val destinationIdentifier: TableIdentifier, + val func: FlowFunction, + val currentCatalog: Option[String], + val currentDatabase: Option[String], + val sqlConf: Map[String, String], + val comment: Option[String] = None, + override val once: Boolean, + override val origin: QueryOrigin +) extends Flow { + def copy( + identifier: TableIdentifier = identifier, + destinationIdentifier: TableIdentifier = destinationIdentifier, + sqlConf: Map[String, String] = sqlConf + ): UnresolvedFlow = { + new UnresolvedFlow( + identifier = identifier, + destinationIdentifier = destinationIdentifier, + func = func, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase, + sqlConf = sqlConf, + comment = comment, + once = once, + origin = origin + ) + } +} + +/** + * A [[Flow]] whose flow function has been invoked, meaning either: + * - Its output schema and dependencies are known. + * - It failed to resolve. + */ +trait ResolutionCompletedFlow extends Flow { + def flow: UnresolvedFlow + def funcResult: FlowFunctionResult + + val identifier: TableIdentifier = flow.identifier + val destinationIdentifier: TableIdentifier = flow.destinationIdentifier + def func: FlowFunction = flow.func + def currentCatalog: Option[String] = flow.currentCatalog + def currentDatabase: Option[String] = flow.currentDatabase + def comment: Option[String] = flow.comment + def sqlConf: Map[String, String] = funcResult.sqlConf + def origin: QueryOrigin = flow.origin +} + +/** A [[Flow]] whose flow function has failed to resolve. */ +class ResolutionFailedFlow(val flow: UnresolvedFlow, val funcResult: FlowFunctionResult) + extends ResolutionCompletedFlow { + assert(!funcResult.resolved) + + def failure: Seq[Throwable] = funcResult.failure +} + +/** A [[Flow]] whose flow function has successfully resolved. */ +trait ResolvedFlow extends ResolutionCompletedFlow with Input { + assert(funcResult.resolved) + + /** The logical plan for this flow's query. */ + def df: DataFrame = funcResult.dataFrame.get + + /** Returns the schema of the output of this [[Flow]]. */ + def schema: StructType = df.schema + override def load(readOptions: InputReadOptions): DataFrame = df + def inputs: Set[TableIdentifier] = funcResult.inputs +} + +/** A [[Flow]] that represents stateful movement of data to some target. */ +class StreamingFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult, + val mustBeAppend: Boolean = false +) extends ResolvedFlow {} + +/** A [[Flow]] that declares exactly what data should be in the target table. */ +class CompleteFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult, + val mustBeAppend: Boolean = false +) extends ResolvedFlow {} + +/** A [[Flow]] that reads source[s] completely and appends data to the target, just once. + */ +class AppendOnceFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult +) extends ResolvedFlow { + + /** + * Whether the flow was declared as once or not in UnresolvedFlow. If false, then it means the + * flow is created from batch query. + */ + val definedAsOnce: Boolean = flow.once + + override val once = true +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala new file mode 100644 index 0000000000000..57bb82fb77fad --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} +import org.apache.spark.sql.pipelines.{AnalysisWarning, Sql} +import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ + ExternalDatasetIdentifier, + InternalDatasetIdentifier +} +import org.apache.spark.sql.pipelines.util.{ + BatchReadOptions, + InputReadOptions, + StreamingReadOptions +} + +object FlowAnalysis { + def createFlowFunctionFromLogicalPlan(plan: LogicalPlan): FlowFunction = { + new FlowFunction { + override def call( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + confs: Map[String, String], + currentCatalog: Option[String], + currentDatabase: Option[String] + ): FlowFunctionResult = { + val ctx = FlowAnalysisContext( + allInputs = allInputs, + availableInputs = availableInputs, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase, + spark = SparkSession.active + ) + val df = try { + confs.foreach { case (k, v) => ctx.setConf(k, v) } + Try(FlowAnalysis.analyze(ctx, plan)) + } finally { + ctx.restoreOriginalConf() + } + FlowFunctionResult( + requestedInputs = ctx.requestedInputs.toSet, + usedBatchInputs = ctx.batchInputs.toSet, + usedStreamingInputs = ctx.streamingInputs.toSet, + usedExternalInputs = ctx.externalInputs.toSet, + dataFrame = df, + sqlConf = confs, + analysisWarnings = ctx.analysisWarnings.toList + ) + } + } + } + + /** + * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific + * TVFs and datasets that cannot be resolved directly by Catalyst. This is used in SQL pipelines + * and Python pipelines (when users call spark.sql). + * + * This function shouldn't call any singleton as it will break concurrent access to graph + * analysis; or any thread local variables as graph analysis and this function will use + * different threads in python repl. + * + * @param plan The [[LogicalPlan]] defining a flow. + * @param forAnalysisApi Whether the query is being analyzed through Analysis API algorithm. + * @return An analyzed [[DataFrame]]. + */ + private def analyze( + context: FlowAnalysisContext, + plan: LogicalPlan, + forAnalysisApi: Boolean = false + ): DataFrame = { + // Users can define CTEs within their CREATE statements. For example, + // + // CREATE STREAMING TABLE a + // WITH b AS ( + // SELECT * FROM STREAM upstream + // ) + // SELECT * FROM b + // + // The relation defined using the WITH keyword is not included in the children of the main + // plan so the specific analysis we do below will not be applied to those relations. + // Instead, we call an analyzer rule to inline all of the CTE relations in the main plan before + // we do analysis. This rule would be called during analysis anyways, but we just call it + // earlier so we only need to apply analysis to a single logical plan. + val planWithInlinedCTEs = CTESubstitution(plan) + + val spark = context.spark + // Traverse the user's query plan and recursively resolve nodes that reference Pipelines + // features that the Spark analyzer is unable to resolve + val resolvedPlan = planWithInlinedCTEs transformWithSubqueries { + // Streaming read on another dataset + // This branch will be hit for the following kinds of queries: + // - SELECT ... FROM STREAM(t1) + // - SELECT ... FROM STREAM t1 + case u: UnresolvedRelation if u.isStreaming => + readStreamInput( + context, + name = IdentifierHelper.toQuotedString(u.multipartIdentifier), + spark.readStream, + streamingReadOptions = StreamingReadOptions( + apiLanguage = Sql() + ) + ).queryExecution.analyzed + + // Batch read on another dataset in the pipeline + case u: UnresolvedRelation => + readBatchInput( + context, + name = IdentifierHelper.toQuotedString(u.multipartIdentifier), + batchReadOptions = BatchReadOptions(apiLanguage = Sql()) + ).queryExecution.analyzed + } + Dataset.ofRows(spark, resolvedPlan) + + } + + /** + * Internal helper to reference the batch dataset (i.e., non-streaming dataset) with the given + * name. + * 1. The dataset can be a table, view, or a named flow. + * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external + * catalog. + * All the public APIs that read from a dataset should call this function to read the dataset. + * + * @param name the name of the Dataset to be read. + * @param batchReadOptions Options for this batch read + * @return batch DataFrame that represents data from the specified Dataset. + */ + final private def readBatchInput( + context: FlowAnalysisContext, + name: String, + batchReadOptions: BatchReadOptions + ): DataFrame = { + GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { + case inputIdentifier: InternalDatasetIdentifier => + readGraphInput(context, inputIdentifier, batchReadOptions) + + case inputIdentifier: ExternalDatasetIdentifier => + readExternalBatchInput( + context, + inputIdentifier = inputIdentifier, + name = name + ) + } + } + + /** + * Internal helper to reference the streaming dataset with the given name. + * 1. The dataset can be a table, view, or a named flow. + * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external + * catalog. + * All the public APIs that read from a dataset should call this function to read the dataset. + * + * @param name the name of the Dataset to be read. + * @param streamReader The [[DataStreamReader]] that may hold read options specified by the user. + * @param streamingReadOptions Options for this streaming read. + * @return streaming DataFrame that represents data from the specified Dataset. + */ + final private def readStreamInput( + context: FlowAnalysisContext, + name: String, + streamReader: DataStreamReader, + streamingReadOptions: StreamingReadOptions + ): DataFrame = { + GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { + case inputIdentifier: InternalDatasetIdentifier => + readGraphInput( + context, + inputIdentifier, + streamingReadOptions + ) + + case inputIdentifier: ExternalDatasetIdentifier => + readExternalStreamInput( + context, + inputIdentifier = inputIdentifier, + streamReader = streamReader, + name = name + ) + } + } + + /** + * Internal helper to reference dataset defined in the same [[DataflowGraph]]. + * + * @param inputIdentifier The identifier of the Dataset to be read. + * @param readOptions Options for this read (may be either streaming or batch options) + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readGraphInput( + ctx: FlowAnalysisContext, + inputIdentifier: InternalDatasetIdentifier, + readOptions: InputReadOptions + ): DataFrame = { + val datasetIdentifier = inputIdentifier.identifier + + ctx.requestedInputs += datasetIdentifier + + val i = if (!ctx.allInputs.contains(datasetIdentifier)) { + // Dataset not defined in the dataflow graph + throw GraphErrors.pipelineLocalDatasetNotDefinedError(datasetIdentifier.unquotedString) + } else if (!ctx.availableInput.contains(datasetIdentifier)) { + // Dataset defined in the dataflow graph but not yet resolved + throw UnresolvedDatasetException(datasetIdentifier) + } else { + // Dataset is resolved, so we can read from it + ctx.availableInput(datasetIdentifier) + } + + val inputDF = i.load(readOptions) + i match { + // If the referenced input is a [[Flow]], because the query plans will be fused + // together, we also need to fuse their confs. + case f: Flow => f.sqlConf.foreach { case (k, v) => ctx.setConf(k, v) } + case _ => + } + + val incompatibleViewReadCheck = + ctx.spark.conf.get("pipelines.incompatibleViewCheck.enabled", "true").toBoolean + + // Wrap the DF in an alias so that columns in the DF can be referenced with + // the following in the query: + // - ... + // - .. + // - . + val aliasIdentifier = AliasIdentifier( + name = datasetIdentifier.table, + qualifier = Seq(datasetIdentifier.catalog, datasetIdentifier.database).flatten + ) + + readOptions match { + case sro: StreamingReadOptions => + if (!inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_BATCH_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + + if (sro.droppedUserOptions.nonEmpty) { + ctx.analysisWarnings += AnalysisWarning.StreamingReaderOptionsDropped( + sourceName = datasetIdentifier.unquotedString, + droppedOptions = sro.droppedUserOptions.keys.toSeq + ) + } + ctx.streamingInputs += ResolvedInput(i, aliasIdentifier) + case _ => + if (inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_STREAMING_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + ctx.batchInputs += ResolvedInput(i, aliasIdentifier) + } + Dataset.ofRows( + ctx.spark, + SubqueryAlias(identifier = aliasIdentifier, child = inputDF.queryExecution.logical) + ) + } + + /** + * Internal helper to reference batch dataset (i.e., non-streaming dataset) defined in an external + * catalog or as a path. + * + * @param inputIdentifier The identifier of the dataset to be read. + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readExternalBatchInput( + context: FlowAnalysisContext, + inputIdentifier: ExternalDatasetIdentifier, + name: String): DataFrame = { + + val spark = context.spark + context.externalInputs += name + spark.read.table(inputIdentifier.identifier.quotedString) + } + + /** + * Internal helper to reference dataset defined in an external catalog or as a path. + * + * @param inputIdentifier The identifier of the dataset to be read. + * @param streamReader The [[DataStreamReader]] that may hold additional read options specified by + * the user. + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readExternalStreamInput( + context: FlowAnalysisContext, + inputIdentifier: ExternalDatasetIdentifier, + streamReader: DataStreamReader, + name: String): DataFrame = { + + context.externalInputs += name + streamReader.table(inputIdentifier.identifier.quotedString) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala new file mode 100644 index 0000000000000..db6d6435701c0 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.AnalysisWarning + +/** + * A context used when evaluating a [[Flow]]'s query into a concrete [[DataFrame]]. + * + * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph. + * @param availableInputs Inputs available to be referenced with `read` or `readStream`. + * @param currentCatalog The current catalog in execution context when the query is defined. + * @param currentDatabase The current schema in execution context when the query is defined. + * @param requestedInputs A mutable buffer populated with names of all inputs that were + * requested. + * @param spark the spark session to be used. + * @param externalInputs The names of external inputs that were used to evaluate + * the flow's query. + */ +private[pipelines] case class FlowAnalysisContext( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + currentCatalog: Option[String], + currentDatabase: Option[String], + batchInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, + streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, + requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty, + shouldLowerCaseNames: Boolean = false, + analysisWarnings: mutable.Buffer[AnalysisWarning] = new ListBuffer[AnalysisWarning], + spark: SparkSession, + externalInputs: mutable.HashSet[String] = mutable.HashSet.empty +) { + + /** Map from [[Input]] name to the actual [[Input]] */ + val availableInput: Map[TableIdentifier, Input] = + availableInputs.map(i => i.identifier -> i).toMap + + /** The confs set in this context that should be undone when exiting this context. */ + private val confsToRestore = mutable.HashMap[String, Option[String]]() + + /** Sets a Spark conf within this context that will be undone by `restoreOriginalConf`. */ + def setConf(key: String, value: String): Unit = { + if (!confsToRestore.contains(key)) { + confsToRestore.put(key, spark.conf.getOption(key)) + } + spark.conf.set(key, value) + } + + /** Restores the Spark conf to its state when this context was creating by undoing confs set. */ + def restoreOriginalConf(): Unit = confsToRestore.foreach { + case (k, Some(v)) => spark.conf.set(k, v) + case (k, None) => spark.conf.unset(k) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala new file mode 100644 index 0000000000000..bb3d290c3d3e1 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.pipelines.common.DatasetType + +object GraphElementTypeUtils { + + /** + * Helper function to obtain the DatasetType. This function should be called with all + * the flows that are writing into a table and all the flows should be resolved. + */ + def getDatasetTypeForMaterializedViewOrStreamingTable( + flowsToTable: Seq[ResolvedFlow]): DatasetType = { + val isStreamingTable = flowsToTable.exists(f => f.df.isStreaming || f.once) + if (isStreamingTable) { + DatasetType.STREAMING_TABLE + } else { + DatasetType.MATERIALIZED_VIEW + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala new file mode 100644 index 0000000000000..1a01a8df3f911 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.types.StructType + +/** Collection of errors that can be thrown during graph resolution / analysis. */ +object GraphErrors { + + def pipelineLocalDatasetNotDefinedError(datasetName: String): SparkException = { + // TODO: this should be an internal error, as we never expect this to happen + new SparkException( + errorClass = "PIPELINE_LOCAL_DATASET_NOT_DEFINED", + messageParameters = Map("datasetName" -> datasetName), + cause = null + ) + } + + /** + * Throws when the catalog or schema name in the "USE CATALOG | SCHEMA" command is invalid + * + * @param command string "USE CATALOG" or "USE SCHEMA" + * @param name the invalid catalog or schema name + * @param reason the reason why the name is invalid + */ + def invalidNameInUseCommandError( + command: String, + name: String, + reason: String + ): SparkException = { + new SparkException( + errorClass = "INVALID_NAME_IN_USE_COMMAND", + messageParameters = Map("command" -> command, "name" -> name, "reason" -> reason), + cause = null + ) + } + + def unresolvedTablePath(identifier: TableIdentifier): SparkException = { + new SparkException( + errorClass = "UNRESOLVED_TABLE_PATH", + messageParameters = Map("identifier" -> identifier.toString), + cause = null + ) + } + + def incompatibleUserSpecifiedAndInferredSchemasError( + tableIdentifier: TableIdentifier, + datasetType: DatasetType, + specifiedSchema: StructType, + inferredSchema: StructType, + cause: Option[Throwable] = None + ): AnalysisException = { + val streamingTableHint = + if (datasetType == DatasetType.STREAMING_TABLE) { + s"""" + |Streaming tables are stateful and remember data that has already been + |processed. If you want to recompute the table from scratch, please full refresh + |the table. + """.stripMargin + } else { + "" + } + + new AnalysisException( + errorClass = "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE", + messageParameters = Map( + "tableName" -> tableIdentifier.unquotedString, + "streamingTableHint" -> streamingTableHint, + "specifiedSchema" -> specifiedSchema.treeString, + "inferredDataSchema" -> inferredSchema.treeString + ), + cause = Option(cause.orNull) + ) + } + + def unableToInferSchemaError( + tableIdentifier: TableIdentifier, + inferredSchema: StructType, + incompatibleSchema: StructType, + cause: Option[Throwable] = None + ): AnalysisException = { + new AnalysisException( + errorClass = "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA", + messageParameters = Map( + "tableName" -> tableIdentifier.unquotedString, + "inferredDataSchema" -> inferredSchema.treeString, + "incompatibleDataSchema" -> incompatibleSchema.treeString + ), + cause = Option(cause.orNull) + ) + } + + def persistedViewReadsFromTemporaryView( + persistedViewIdentifier: TableIdentifier, + temporaryViewIdentifier: TableIdentifier): AnalysisException = { + new AnalysisException( + "PERSISTED_VIEW_READS_FROM_TEMPORARY_VIEW", + Map( + "persistedViewName" -> persistedViewIdentifier.toString, + "temporaryViewName" -> temporaryViewIdentifier.toString + ) + ) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala new file mode 100644 index 0000000000000..49d8dd35b5e2f --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.datasources.DataSource + +/** + * Responsible for properly qualify the identifiers for datasets inside or referenced by the + * dataflow graph. + */ +object GraphIdentifierManager { + + import IdentifierHelper._ + + def parseTableIdentifier(name: String, spark: SparkSession): TableIdentifier = { + toTableIdentifier(spark.sessionState.sqlParser.parseMultipartIdentifier(name)) + } + + /** + * Fully qualify (if needed) the user-specified identifier used to reference datasets, and + * categorizing the dataset we're referencing (i.e. dataset from this pipeline or dataset that is + * external to this pipeline). + * + * Returns whether the input dataset should be read as a dataset and also the qualified + * identifier. + * + * @param rawInputName the user-specified name when referencing datasets. + */ + def parseAndQualifyInputIdentifier( + context: FlowAnalysisContext, + rawInputName: String): DatasetIdentifier = { + resolveDatasetReadInsideQueryDefinition(context = context, rawInputName = rawInputName) + } + + /** + * Resolve dataset reads that happens inside the dataset query definition (i.e., inside + * the @materialized_view() annotation in Python). + */ + private def resolveDatasetReadInsideQueryDefinition( + context: FlowAnalysisContext, + rawInputName: String + ): DatasetIdentifier = { + // After identifier is pre-processed, we first check whether we're referencing a + // single-part-name dataset (e.g., temp view). If so, don't fully qualified the identifier + // and directly read from it, because single-part-name datasets always out-mask other + // fully-qualified-datasets that have the same name. For example, if there's a view named + // "a" and also a table named "catalog.schema.a" defined in the graph. "SELECT * FROM a" + // would always read the view "a". To read table "a", user would need to read it using + // fully/partially qualified name (e.g., "SELECT * FROM catalog.schema.a" or "SELECT * FROM + // schema.a"). + val inputIdentifier = parseTableIdentifier(rawInputName, context.spark) + // TODO: once pipeline spec catalog/schema is propagated, use that catalog via DSv2 API + val catalog = context.spark.sessionState.catalog + + /** Return whether we're referencing a dataset that is part of the pipeline. */ + def isInternalDataset(identifier: TableIdentifier): Boolean = { + context.allInputs.contains(identifier) + } + + if (isSinglePartIdentifier(inputIdentifier) && + isInternalDataset(inputIdentifier)) { + // reading a single-part-name dataset defined in the dataflow graph (e.g., a view) + InternalDatasetIdentifier(identifier = inputIdentifier) + } else if (catalog.isTempView(inputIdentifier)) { + // referencing a temp view or temp table in the current spark session + ExternalDatasetIdentifier(identifier = inputIdentifier) + } else if (isPathIdentifier(context.spark, inputIdentifier)) { + // path-based reference, always read as external dataset + ExternalDatasetIdentifier(identifier = inputIdentifier) + } else { + val fullyQualifiedInputIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = inputIdentifier, + currentCatalog = context.currentCatalog, + currentDatabase = context.currentDatabase + ) + assertIsFullyQualifiedForRead(identifier = fullyQualifiedInputIdentifier) + + if (isInternalDataset(fullyQualifiedInputIdentifier)) { + InternalDatasetIdentifier(identifier = fullyQualifiedInputIdentifier) + } else { + ExternalDatasetIdentifier(fullyQualifiedInputIdentifier) + } + } + } + + /** + * @param rawDatasetIdentifier the dataset identifier specified by the user. + */ + @throws[AnalysisException] + private def parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier: TableIdentifier): InternalDatasetIdentifier = { + InternalDatasetIdentifier(identifier = rawDatasetIdentifier) + } + + /** + * Parses the table identifier from the raw table identifier and fully qualifies it. + * + * @param rawTableIdentifier the raw table identifier + * @return the parsed table identifier + * @throws AnalysisException if the table identifier is not allowed + */ + @throws[AnalysisException] + def parseAndQualifyTableIdentifier( + rawTableIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): InternalDatasetIdentifier = { + val pipelineDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawTableIdentifier + ) + val fullyQualifiedTableIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = pipelineDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedTableIdentifier) + InternalDatasetIdentifier(identifier = fullyQualifiedTableIdentifier) + } + + /** + * Parses and validates the view identifier from the raw view identifier for temporary views. + * + * @param rawViewIdentifier the raw view identifier + * @return the parsed view identifier + * @throws AnalysisException if the view identifier is not allowed + */ + @throws[AnalysisException] + def parseAndValidateTemporaryViewIdentifier( + rawViewIdentifier: TableIdentifier): TableIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawViewIdentifier + ) + // Temporary views are not persisted to the catalog in use, therefore should not be qualified. + if (!isSinglePartIdentifier(internalDatasetIdentifier.identifier)) { + throw new AnalysisException( + "MULTIPART_TEMPORARY_VIEW_NAME_NOT_SUPPORTED", + Map("viewName" -> rawViewIdentifier.unquotedString) + ) + } + internalDatasetIdentifier.identifier + } + + /** + * Parses and validates the view identifier from the raw view identifier for persisted views. + * + * @param rawViewIdentifier the raw view identifier + * @param currentCatalog the catalog + * @param currentDatabase the schema + * @return the parsed view identifier + * @throws AnalysisException if the view identifier is not allowed + */ + def parseAndValidatePersistedViewIdentifier( + rawViewIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String]): TableIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawViewIdentifier + ) + // Persisted views have fully qualified names + val fullyQualifiedViewIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedViewIdentifier) + fullyQualifiedViewIdentifier + } + + /** + * Parses the flow identifier from the raw flow identifier and fully qualify it. + * + * @param rawFlowIdentifier the raw flow identifier + * @return the parsed flow identifier + * @throws AnalysisException if the flow identifier is not allowed + */ + @throws[AnalysisException] + def parseAndQualifyFlowIdentifier( + rawFlowIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): InternalDatasetIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawFlowIdentifier + ) + + val fullyQualifiedFlowIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedFlowIdentifier) + InternalDatasetIdentifier(identifier = fullyQualifiedFlowIdentifier) + } + + /** Represents the identifier for a dataset that is defined or referenced in a pipeline. */ + sealed trait DatasetIdentifier + + /** Represents the identifier for a dataset that is defined by the current pipeline. */ + case class InternalDatasetIdentifier private ( + identifier: TableIdentifier + ) extends DatasetIdentifier + + /** Represents the identifier for a dataset that is external to the current pipeline. */ + case class ExternalDatasetIdentifier(identifier: TableIdentifier) extends DatasetIdentifier +} + +object IdentifierHelper { + + /** + * Returns the quoted string for the name parts. + * + * @param nameParts the dataset name parts. + * @return the quoted string for the name parts. + */ + def toQuotedString(nameParts: Seq[String]): String = { + toTableIdentifier(nameParts).quotedString + } + + /** + * Returns the table identifier constructed from the name parts. + * + * @param nameParts the dataset name parts. + * @return the table identifier constructed from the name parts. + * @throws UnsupportedOperationException if the name parts have more than 3 parts. + */ + @throws[UnsupportedOperationException] + def toTableIdentifier(nameParts: Seq[String]): TableIdentifier = { + nameParts.length match { + case 1 => TableIdentifier(tableName = nameParts.head) + case 2 => TableIdentifier(table = nameParts(1), database = Option(nameParts.head)) + case 3 => + TableIdentifier( + table = nameParts(2), + database = Option(nameParts(1)), + catalog = Option(nameParts.head) + ) + case _ => + throw new UnsupportedOperationException( + s"4+ part table identifier ${nameParts.mkString(".")} is not supported." + ) + } + } + + /** + * Returns the table identifier constructed from the logical plan. + * + * @param table the logical plan. + * @return the table identifier constructed from the logical plan. + * @throws SparkException if the table identifier cannot be resolved. + */ + def toTableIdentifier(table: LogicalPlan): TableIdentifier = { + val parts = table match { + case r: ResolvedIdentifier => r.identifier.namespace.toSeq :+ r.identifier.name + case u: UnresolvedIdentifier => u.nameParts + case u: UnresolvedRelation => u.multipartIdentifier + case _ => + throw new UnsupportedOperationException(s"Unable to resolve name for $table.") + } + toTableIdentifier(parts) + } + + /** Return whether the input identifier is a single-part identifier. */ + def isSinglePartIdentifier(identifier: TableIdentifier): Boolean = { + identifier.database.isEmpty && identifier.catalog.isEmpty + } + + /** + * Return true if the identifier should be resolved as a path-based reference + * (i.e., `datasource`.`path`). + */ + def isPathIdentifier(spark: SparkSession, identifier: TableIdentifier): Boolean = { + if (identifier.nameParts.length != 2) { + return false + } + val Seq(datasource, path) = identifier.nameParts + val sqlConf = spark.sessionState.conf + + def isDatasourceValid = { + try { + DataSource.lookupDataSource(datasource, sqlConf) + true + } catch { + case _: ClassNotFoundException => false + } + } + + // Whether the provided datasource is valid. + isDatasourceValid + } + + /** Fully qualifies provided identifier with provided catalog & schema. */ + def fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): TableIdentifier = { + maybeFullyQualifiedIdentifier.copy( + database = maybeFullyQualifiedIdentifier.database.orElse(currentDatabase), + catalog = maybeFullyQualifiedIdentifier.catalog.orElse(currentCatalog) + ) + } + + /** Assert whether the identifier is properly fully qualified when creating a dataset. */ + def assertIsFullyQualifiedForCreate(identifier: TableIdentifier): Unit = { + // TODO: validate catalog exists + assert( + identifier.catalog.isDefined && identifier.database.isDefined, + s"Dataset identifier $identifier is not properly fully qualified, expect a " + + s"three-part-name .." + ) + } + + /** Assert whether the identifier is properly qualified when reading a dataset in a pipeline. */ + def assertIsFullyQualifiedForRead(identifier: TableIdentifier): Unit = { + assert( + identifier.catalog.isDefined && identifier.database.isDefined, + s"Failed to reference dataset $identifier, expect a " + + s"three-part-name ..
" + ) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala new file mode 100644 index 0000000000000..c0b5a360afea7 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * @param identifier The identifier of the flow. + * @param inputs The identifiers of nodes used as inputs to this flow. + * @param output The identifier of the output that this flow writes to. + */ +case class FlowNode( + identifier: TableIdentifier, + inputs: Set[TableIdentifier], + output: TableIdentifier) + +trait GraphOperations { + this: DataflowGraph => + + /** The set of all outputs in the graph. */ + private lazy val destinationSet: Set[TableIdentifier] = + flows.map(_.destinationIdentifier).toSet + + /** A map from flow identifier to [[FlowNode]], which contains the input/output nodes. */ + lazy val flowNodes: Map[TableIdentifier, FlowNode] = { + flows.map { f => + val identifier = f.identifier + val n = + FlowNode( + identifier = identifier, + inputs = resolvedFlow(f.identifier).inputs, + output = f.destinationIdentifier + ) + identifier -> n + }.toMap + } + + /** Map from dataset identifier to all reachable upstream destinations, including itself. */ + private lazy val upstreamDestinations = + mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefault(key => dfsInternal(startDestination = key, downstream = false)) + + /** Map from dataset identifier to all reachable downstream destinations, including itself. */ + private lazy val downstreamDestinations = + mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefault(key => dfsInternal(startDestination = key, downstream = true)) + + /** + * Performs a DFS starting from `startNode` and returns the set of nodes (datasets) reached. + * @param startDestination The identifier of the node to start from. + * @param downstream if true, traverse output edges (search downstream) + * if false, traverse input edges (search upstream). + * @param stopAtMaterializationPoints If true, stop when we reach a materialization point (table). + * If false, keep going until the end. + */ + protected def dfsInternal( + startDestination: TableIdentifier, + downstream: Boolean, + stopAtMaterializationPoints: Boolean = false): Set[TableIdentifier] = { + assert( + destinationSet.contains(startDestination), + s"$startDestination is not a valid start node" + ) + val visited = new mutable.HashSet[TableIdentifier] + + // Same semantics as a stack. Need to be able to push/pop items. + var nextNodes = List[TableIdentifier]() + nextNodes = startDestination :: nextNodes + + while (nextNodes.nonEmpty) { + val currNode = nextNodes.head + nextNodes = nextNodes.tail + + if (!visited.contains(currNode) + // When we stop at materialization points, skip non-start nodes that are materialized. + && !(stopAtMaterializationPoints && table.contains(currNode) + && currNode != startDestination)) { + visited.add(currNode) + val neighbors = if (downstream) { + flowNodes.values.filter(_.inputs.contains(currNode)).map(_.output) + } else { + flowNodes.values.filter(_.output == currNode).flatMap(_.inputs) + } + nextNodes = neighbors.toList ++ nextNodes + } + } + visited.toSet + } + + /** + * An implementation of DFS that takes in a sequence of start nodes and returns the + * "reachability set" of nodes from the start nodes. + * + * @param downstream Walks the graph via the input edges if true, otherwise via the output + * edges. + * @return A map from visited nodes to its origin[s] in `datasetIdentifiers`, e.g. + * Let graph = a -> b c -> d (partitioned graph) + * + * reachabilitySet(Seq("a", "c"), downstream = true) + * -> ["a" -> ["a"], "b" -> ["a"], "c" -> ["c"], "d" -> ["c"]] + */ + private def reachabilitySet( + datasetIdentifiers: Seq[TableIdentifier], + downstream: Boolean): Map[TableIdentifier, Set[TableIdentifier]] = { + // Seq of the form "node identifier" -> Set("dependency1, "dependency2") + val deps = datasetIdentifiers.map(n => n -> reachabilitySet(n, downstream)) + + // Invert deps so that we get a map from each dependency to node identifier + val finalMap = mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefaultValue(Set.empty[TableIdentifier]) + deps.foreach { + case (start, reachableNodes) => + reachableNodes.foreach(n => finalMap.put(n, finalMap(n) + start)) + } + finalMap.toMap + } + + /** + * Returns all datasets that can be reached from `destinationIdentifier`. + */ + private def reachabilitySet( + destinationIdentifier: TableIdentifier, + downstream: Boolean): Set[TableIdentifier] = { + if (downstream) downstreamDestinations(destinationIdentifier) + else upstreamDestinations(destinationIdentifier) + } + + /** Returns the set of flows reachable from `flowIdentifier` via output (child) edges. */ + def downstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = { + assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow") + val downstreamDatasets = reachabilitySet(flowNodes(flowIdentifier).output, downstream = true) + flowNodes.values.filter(_.inputs.exists(downstreamDatasets.contains)).map(_.identifier).toSet + } + + /** Returns the set of flows reachable from `flowIdentifier` via input (parent) edges. */ + def upstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = { + assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow") + val upstreamDatasets = + flowNodes(flowIdentifier).inputs.flatMap(reachabilitySet(_, downstream = false)) + flowNodes.values.filter(e => upstreamDatasets.contains(e.output)).map(_.identifier).toSet + } + + /** Returns the set of datasets reachable from `datasetIdentifier` via input (parent) edges. */ + def upstreamDatasets(datasetIdentifier: TableIdentifier): Set[TableIdentifier] = + reachabilitySet(datasetIdentifier, downstream = false) - datasetIdentifier + + /** + * Traverses the graph upstream starting from the specified `datasetIdentifiers` to return the + * reachable nodes. The return map's keyset consists of all datasets reachable from + * `datasetIdentifiers`. For each entry in the response map, the value of that element refers + * to which of `datasetIdentifiers` was able to reach the key. If multiple of `datasetIdentifiers` + * could reach that key, one is picked arbitrarily. + */ + def upstreamDatasets( + datasetIdentifiers: Seq[TableIdentifier]): Map[TableIdentifier, Set[TableIdentifier]] = + reachabilitySet(datasetIdentifiers, downstream = false) -- datasetIdentifiers +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala new file mode 100644 index 0000000000000..646ed68465c95 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * @param defaultCatalog The pipeline's default catalog. + * @param defaultDatabase The pipeline's default schema. + */ +class GraphRegistrationContext( + val defaultCatalog: String, + val defaultDatabase: String, + val defaultSqlConf: Map[String, String]) { + import GraphRegistrationContext._ + + protected val tables = new mutable.ListBuffer[Table] + protected val views = new mutable.ListBuffer[View] + protected val flows = new mutable.ListBuffer[UnresolvedFlow] + + def registerTable(tableDef: Table): Unit = { + tables += tableDef + } + + def registerView(viewDef: View): Unit = { + views += viewDef + } + + def registerFlow(flowDef: UnresolvedFlow): Unit = { + flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf) + } + + def toDataflowGraph: DataflowGraph = { + val qualifiedTables = tables.toSeq.map { t => + // TODO: Track session-level catalog changes. + t.copy( + identifier = GraphIdentifierManager + .parseAndQualifyTableIdentifier( + rawTableIdentifier = t.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier + ) + } + + val validatedViews = views.toSeq.collect { + case v: TemporaryView => + v.copy( + identifier = GraphIdentifierManager + .parseAndValidateTemporaryViewIdentifier( + rawViewIdentifier = v.identifier + ) + ) + case v: PersistedView => + v.copy( + identifier = GraphIdentifierManager + .parseAndValidatePersistedViewIdentifier( + rawViewIdentifier = v.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + ) + } + + val qualifiedFlows = flows.toSeq.map { f => + val isImplicitFlow = f.identifier == f.destinationIdentifier + val flowWritesToView = + validatedViews + .filter(_.isInstanceOf[TemporaryView]) + .exists(_.identifier == f.destinationIdentifier) + + // If the flow is created implicitly as part of defining a view, then we do not + // qualify the flow identifier and the flow destination. This is because views are + // not permitted to have multipart + if (isImplicitFlow && flowWritesToView) { + f + } else { + // TODO: Track session-level catalog changes. + f.copy( + identifier = GraphIdentifierManager + .parseAndQualifyFlowIdentifier( + rawFlowIdentifier = f.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier, + destinationIdentifier = GraphIdentifierManager + .parseAndQualifyFlowIdentifier( + rawFlowIdentifier = f.destinationIdentifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier + ) + } + } + + assertNoDuplicates( + qualifiedTables = qualifiedTables, + validatedViews = validatedViews, + qualifiedFlows = qualifiedFlows + ) + + new DataflowGraph( + tables = qualifiedTables, + views = validatedViews, + flows = qualifiedFlows + ) + } + + private def assertNoDuplicates( + qualifiedTables: Seq[Table], + validatedViews: Seq[View], + qualifiedFlows: Seq[UnresolvedFlow]): Unit = { + + (qualifiedTables.map(_.identifier) ++ validatedViews.map(_.identifier)) + .foreach { identifier => + assertDatasetIdentifierIsUnique( + identifier = identifier, + tables = qualifiedTables, + views = validatedViews + ) + } + + qualifiedFlows.foreach { flow => + assertFlowIdentifierIsUnique( + flow = flow, + datasetType = TableType, + flows = qualifiedFlows + ) + } + } + + private def assertDatasetIdentifierIsUnique( + identifier: TableIdentifier, + tables: Seq[Table], + views: Seq[View]): Unit = { + + // We need to check for duplicates in both tables and views, as they can have the same name. + val allDatasets = tables.map(t => t.identifier -> TableType) ++ views.map( + v => v.identifier -> ViewType + ) + + val grouped = allDatasets.groupBy { case (id, _) => id } + + grouped(identifier).toList match { + case (_, firstType) :: (_, secondType) :: _ => + // Sort the types in lexicographic order to ensure consistent error messages. + val sortedTypes = Seq(firstType.toString, secondType.toString).sorted + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET", + messageParameters = Map( + "datasetName" -> identifier.quotedString, + "datasetType1" -> sortedTypes.head, + "datasetType2" -> sortedTypes.last + ) + ) + case _ => // No duplicates found. + } + } + + private def assertFlowIdentifierIsUnique( + flow: UnresolvedFlow, + datasetType: DatasetType, + flows: Seq[UnresolvedFlow]): Unit = { + flows.groupBy(i => i.identifier).get(flow.identifier).filter(_.size > 1).foreach { + duplicateFlows => + val duplicateFlow = duplicateFlows.filter(_ != flow).head + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW", + messageParameters = Map( + "flowName" -> flow.identifier.unquotedString, + "datasetNames" -> Set( + flow.destinationIdentifier.quotedString, + duplicateFlow.destinationIdentifier.quotedString + ).mkString(",") + ) + ) + } + } +} + +object GraphRegistrationContext { + sealed trait DatasetType + + private object TableType extends DatasetType { + override def toString: String = "TABLE" + } + + private object ViewType extends DatasetType { + override def toString: String = "VIEW" + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala new file mode 100644 index 0000000000000..6a3a9e18ddf96 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique +import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils + +/** Validations performed on a [[DataflowGraph]]. */ +trait GraphValidations extends Logging { + this: DataflowGraph => + + /** + * Validate multi query table correctness. Exposed for Python unit testing, which currently cannot + * run anything which invokes the flow function as there's no persistent Python to run it. + * + * @return the multi-query tables by destination + */ + protected[pipelines] def validateMultiQueryTables(): Map[TableIdentifier, Seq[Flow]] = { + val multiQueryTables = flowsTo.filter(_._2.size > 1) + // Non-streaming tables do not support multiflow. + multiQueryTables + .find { + case (dest, flows) => + flows.exists(f => !resolvedFlow(f.identifier).df.isStreaming) && + table.contains(dest) + } + .foreach { + case (dest, flows) => + throw new AnalysisException( + "MATERIALIZED_VIEW_WITH_MULTIPLE_QUERIES", + Map( + "tableName" -> dest.unquotedString, + "queries" -> flows.map(_.identifier).mkString(",") + ) + ) + } + + multiQueryTables + } + + /** Throws an exception if the flows in this graph are not topologically sorted. */ + protected[graph] def validateGraphIsTopologicallySorted(): Unit = { + val visitedNodes = mutable.Set.empty[TableIdentifier] // Set of visited nodes + val visitedEdges = mutable.Set.empty[TableIdentifier] // Set of visited edges + flows.foreach { f => + // Unvisited inputs of the current flow + val unvisitedInputNodes = + resolvedFlow(f.identifier).inputs -- visitedNodes + unvisitedInputNodes.headOption match { + case None => + visitedEdges.add(f.identifier) + if (flowsTo(f.destinationIdentifier).map(_.identifier).forall(visitedEdges.contains)) { + // A node is marked visited if all its inputs are visited + visitedNodes.add(f.destinationIdentifier) + } + case Some(unvisitedInput) => + throw new AnalysisException( + "PIPELINE_GRAPH_NOT_TOPOLOGICALLY_SORTED", + Map( + "flowName" -> f.identifier.unquotedString, + "inputName" -> unvisitedInput.unquotedString + ) + ) + } + } + } + + /** + * Validate that all tables are resettable. This is a best-effort check that will only catch + * upstream tables that are resettable but have a non-resettable downstream dependency. + */ + protected def validateTablesAreResettable(): Seq[GraphValidationWarning] = { + validateTablesAreResettable(tables) + } + + /** Validate that all specified tables are resettable. */ + protected def validateTablesAreResettable(tables: Seq[Table]): Seq[GraphValidationWarning] = { + val tableLookup = mapUnique(tables, "table")(_.identifier) + val nonResettableTables = + tables.filter(t => !PipelinesTableProperties.resetAllowed.fromMap(t.properties)) + val upstreamResettableTables = upstreamDatasets(nonResettableTables.map(_.identifier)) + .collect { + // Filter for upstream datasets that are tables with downstream streaming tables + case (upstreamDataset, nonResettableDownstreams) if table.contains(upstreamDataset) => + nonResettableDownstreams + .filter( + t => flowsTo(t).exists(f => resolvedFlow(f.identifier).df.isStreaming) + ) + .map(id => (tableLookup(upstreamDataset), tableLookup(id).displayName)) + } + .flatten + .toSeq + .filter { + case (t, _) => PipelinesTableProperties.resetAllowed.fromMap(t.properties) + } // Filter for resettable + + upstreamResettableTables + .groupBy(_._2) // Group-by non-resettable downstream tables + .view + .mapValues(_.map(_._1)) + .toSeq + .sortBy(_._2.size) // Output errors from largest to smallest + .reverse + .map { + case (nameForEvent, tables) => + InvalidResettableDependencyException(nameForEvent, tables) + } + } + + /** + * Validate if we have any append only flows writing into a streaming table but was created + * from a batch query. + */ + protected def validateAppendOnceFlows(): Seq[GraphValidationWarning] = { + flows + .filter { + case af: AppendOnceFlow => !af.definedAsOnce + case _ => false + } + .groupBy(_.destinationIdentifier) + .flatMap { + case (destination, flows) => + table + .get(destination) + .map(t => AppendOnceFlowCreatedFromBatchQueryException(t, flows.map(_.identifier))) + } + .toSeq + } + + protected def validateUserSpecifiedSchemas(): Unit = { + flows.flatMap(f => tableInput(f.identifier)).foreach { t: TableInput => + // The output inferred schema of a table is the declared schema merged with the + // schema of all incoming flows. This must be equivalent to the declared schema. + val inferredSchema = SchemaInferenceUtils + .inferSchemaFromFlows( + flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)), + userSpecifiedSchema = t.specifiedSchema + ) + + t.specifiedSchema.foreach { ss => + // Check the inferred schema matches the specified schema. Used to catch errors where the + // inferred user-facing schema has columns that are not in the specified one. + if (inferredSchema != ss) { + val datasetType = GraphElementTypeUtils + .getDatasetTypeForMaterializedViewOrStreamingTable( + flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)) + ) + throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError( + t.identifier, + datasetType, + ss, + inferredSchema + ) + } + } + } + } + + /** + * Validates that all flows are resolved. If there are unresolved flows, + * detects a possible cyclic dependency and throw the appropriate execption. + */ + protected def validateSuccessfulFlowAnalysis(): Unit = { + // all failed flows with their errors + val flowAnalysisFailures = resolutionFailedFlows.flatMap( + f => f.failure.headOption.map(err => (f.identifier, err)) + ) + // only proceed if there are unresolved flows + if (flowAnalysisFailures.nonEmpty) { + val failedFlowIdentifiers = flowAnalysisFailures.map(_._1).toSet + // used to collect the subgraph of only the unresolved flows + // maps every unresolved flow to the set of unresolved flows writing to one if its inputs + val failedFlowsSubgraph = mutable.Map[TableIdentifier, Seq[TableIdentifier]]() + val (downstreamFailures, directFailures) = flowAnalysisFailures.partition { + case (flowIdentifier, _) => + // If a failed flow writes to any of the requested datasets, we mark this flow as a + // downstream failure + val failedFlowsWritingToRequestedDatasets = + resolutionFailedFlow(flowIdentifier).funcResult.requestedInputs + .flatMap(d => flowsTo.getOrElse(d, Seq())) + .map(_.identifier) + .intersect(failedFlowIdentifiers) + .toSeq + failedFlowsSubgraph += (flowIdentifier -> failedFlowsWritingToRequestedDatasets) + failedFlowsWritingToRequestedDatasets.nonEmpty + } + // if there are flow that failed due to unresolved upstream flows, check for a cycle + if (failedFlowsSubgraph.nonEmpty) { + detectCycle(failedFlowsSubgraph.toMap).foreach { + case (upstream, downstream) => + val upstreamDataset = flow(upstream).destinationIdentifier + val downstreamDataset = flow(downstream).destinationIdentifier + throw CircularDependencyException( + downstreamDataset, + upstreamDataset + ) + } + } + // otherwise report what flows failed directly vs. depending on a failed flow + throw UnresolvedPipelineException( + this, + directFailures.map { case (id, value) => (id, value) }.toMap, + downstreamFailures.map { case (id, value) => (id, value) }.toMap + ) + } + } + + /** + * Generic method to detect a cycle in directed graph via DFS traversal. + * The graph is given as a reverse adjacency map, that is, a map from + * each node to its ancestors. + * @return the start and end node of a cycle if found, None otherwise + */ + private def detectCycle(ancestors: Map[TableIdentifier, Seq[TableIdentifier]]) + : Option[(TableIdentifier, TableIdentifier)] = { + var cycle: Option[(TableIdentifier, TableIdentifier)] = None + val visited = mutable.Set[TableIdentifier]() + def visit(f: TableIdentifier, currentPath: List[TableIdentifier]): Unit = { + if (cycle.isEmpty && !visited.contains(f)) { + if (currentPath.contains(f)) { + cycle = Option((currentPath.head, f)) + } else { + ancestors(f).foreach(visit(_, f :: currentPath)) + visited += f + } + } + } + ancestors.keys.foreach(visit(_, Nil)) + cycle + } + + /** Validates that persisted views don't read from invalid sources */ + protected[graph] def validatePersistedViewSources(): Unit = { + val viewToFlowMap = ViewHelpers.persistedViewIdentifierToFlow(graph = this) + + persistedViews + .foreach { persistedView => + val flow = viewToFlowMap(persistedView.identifier) + val funcResult = resolvedFlow(flow.identifier).funcResult + val inputIdentifiers = (funcResult.batchInputs ++ funcResult.streamingInputs) + .map(_.input.identifier) + + inputIdentifiers + .flatMap(view.get) + .foreach { + case tempView: TemporaryView => + throw GraphErrors.persistedViewReadsFromTemporaryView( + persistedViewIdentifier = persistedView.identifier, + temporaryViewIdentifier = tempView.identifier + ) + case _ => + } + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala new file mode 100644 index 0000000000000..9a88ad70a7601 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.annotation.unused + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * Exception raised when a flow tries to read from a dataset that exists but is unresolved + * + * @param identifier The identifier of the dataset + */ +case class UnresolvedDatasetException(identifier: TableIdentifier) + extends AnalysisException( + s"Failed to read dataset '${identifier.unquotedString}'. Dataset is defined in the " + + s"pipeline but could not be resolved." + ) + +/** + * Exception raised when a flow fails to read from a table defined within the pipeline + * + * @param name The name of the table + * @param cause The cause of the failure + */ +case class LoadTableException(name: String, override val cause: Option[Throwable]) + extends AnalysisException(s"Failed to load table '$name'", cause = cause) + +/** + * Exception raised when a pipeline has one or more flows that cannot be resolved + * + * @param directFailures Mapping between the name of flows that failed to resolve (due to an + * error in that flow) and the error that occurred when attempting to + * resolve them + * @param downstreamFailures Mapping between the name of flows that failed to resolve (because they + * failed to read from other unresolved flows) and the error that occurred + * when attempting to resolve them + */ +case class UnresolvedPipelineException( + graph: DataflowGraph, + directFailures: Map[TableIdentifier, Throwable], + downstreamFailures: Map[TableIdentifier, Throwable], + additionalHint: Option[String] = None) + extends AnalysisException( + s""" + |Failed to resolve flows in the pipeline. + | + |A flow can fail to resolve because the flow itself contains errors or because it reads + |from an upstream flow which failed to resolve. + |${additionalHint.getOrElse("")} + |Flows with errors: ${directFailures.keys.map(_.unquotedString).toSeq.sorted.mkString(", ")} + |Flows that failed due to upstream errors: ${downstreamFailures.keys + .map(_.unquotedString) + .toSeq + .sorted + .mkString(", ")} + | + |To view the exceptions that were raised while resolving these flows, look for FlowProgress + |logs with status FAILED that precede this log.""".stripMargin + ) + +/** A validation error that can either be thrown as an exception or logged as a warning. */ +trait GraphValidationWarning extends Logging { + + /** The exception to throw when this validation fails. */ + protected def exception: AnalysisException + + /** The details of the event to construct when this validation fails. */ + // protected def warningEventDetails: EventDetails => EventDetails + + // /** Log the exception message and construct a warning event log. */ + // def logAndConstructWarningEventLog(origin: Origin): PipelineEvent = { + // logWarning(exception.getMessage) + // ConstructPipelineEvent( + // origin = origin, + // level = EventLevel.WARN, + // message = exception.getMessage, + // details = warningEventDetails + // ) + // } + + /** Log the exception message and throw the exception. */ + @unused + def logAndThrow(): Unit = { + logError(exception.getMessage) + throw exception + } +} + +/** + * Raised when there's a circular dependency in the current pipeline. That is, a downstream + * table is referenced while creating a upstream table. + */ +case class CircularDependencyException( + downstreamTable: TableIdentifier, + upstreamDataset: TableIdentifier) + extends AnalysisException( + s"The downstream table '${downstreamTable.unquotedString}' is referenced when " + + s"creating the upstream table or view '${upstreamDataset.unquotedString}'. " + + s"Circular dependencies are not supported in a pipeline. Please remove the dependency " + + s"between '${upstreamDataset.unquotedString}' and '${downstreamTable.unquotedString}'." + ) + +/** + * Raised when some tables in the current pipeline are not resettable due to some non-resettable + * downstream dependencies. + */ +case class InvalidResettableDependencyException(originName: String, tables: Seq[Table]) + extends GraphValidationWarning { + override def exception: AnalysisException = new AnalysisException( + "INVALID_RESETTABLE_DEPENDENCY", + Map( + "downstreamTable" -> originName, + "upstreamResettableTables" -> tables + .map(_.displayName) + .sorted + .map(t => s"'$t'") + .mkString(", "), + "resetAllowedKey" -> PipelinesTableProperties.resetAllowed.key + ) + ) +} + +/** + * Warn if the append once flows was declared from batch query if there was a run before. + * Throw an exception if not. + * @param table the streaming destination that contains Append Once flows declared with batch query. + * @param flows the append once flows that are declared with batch query. + */ +case class AppendOnceFlowCreatedFromBatchQueryException(table: Table, flows: Seq[TableIdentifier]) + extends GraphValidationWarning { + override def exception: AnalysisException = new AnalysisException( + "APPEND_ONCE_FROM_BATCH_QUERY", + Map("table" -> table.displayName) + ) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala new file mode 100644 index 0000000000000..c8627a9a0a2e4 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.Locale + +import scala.collection.mutable +import scala.util.control.NonFatal + +/** + * Interface for validating and accessing Pipeline-specific table properties. + */ +object PipelinesTableProperties { + + /** Prefix used for all table properties. */ + val pipelinesPrefix = "pipelines." + + /** Map of keys to property entries. */ + private val entries: mutable.HashMap[String, PipelineTableProperty[_]] = mutable.HashMap.empty + + /** Whether the table should be reset when a Reset is triggered. */ + val resetAllowed: PipelineTableProperty[Boolean] = + buildProp("pipelines.reset.allowed", default = true) + + /** + * Validates that all known pipeline properties are valid and can be parsed. + * Also warn the user if they try to set an unknown pipelines property, or any + * property that looks suspiciously similar to a known pipeline property. + * + * @param rawProps Raw table properties to validate and canonicalize + * @param warnFunction Function to warn users of potential issues. Fatal errors are still thrown. + * @return a map of table properties, with canonical-case keys for valid pipelines properties, + * and excluding invalid pipelines properties. + */ + def validateAndCanonicalize( + rawProps: Map[String, String], + warnFunction: String => Unit): Map[String, String] = { + rawProps.flatMap { + case (originalCaseKey, value) => + val k = originalCaseKey.toLowerCase(Locale.ROOT) + + // Make sure all pipelines properties are valid + if (k.startsWith(pipelinesPrefix)) { + entries.get(k) match { + case Some(prop) => + prop.fromMap(rawProps) // Make sure the value can be retrieved + Option(prop.key -> value) // Canonicalize the case + case None => + warnFunction(s"Unknown pipelines table property '$originalCaseKey'.") + // exclude this property - the pipeline won't recognize it anyway, + // and setting it would allow users to use `pipelines.xyz` for their custom + // purposes - which we don't want to encourage. + None + } + } else { + // Make sure they weren't trying to set a pipelines property + val similarProp = entries.get(s"${pipelinesPrefix}$k") + similarProp.foreach { c => + warnFunction( + s"You are trying to set table property '$originalCaseKey', which has a similar name" + + s" to Pipelines table property '${c.key}'. If you are trying to set the Pipelines " + + s"table property, please include the correct prefix." + ) + } + Option(originalCaseKey -> value) + } + } + } + + /** Registers a pipelines table property with the specified key and default value. */ + private def buildProp[T]( + key: String, + default: String, + fromString: String => T): PipelineTableProperty[T] = { + val prop = PipelineTableProperty(key, default, fromString) + entries.put(key.toLowerCase(Locale.ROOT), prop) + prop + } + + private def buildProp(key: String, default: Boolean): PipelineTableProperty[Boolean] = + buildProp(key, default.toString, _.toBoolean) +} + +case class PipelineTableProperty[T](key: String, default: String, fromString: String => T) { + def fromMap(rawProps: Map[String, String]): T = parseFromString(rawProps.getOrElse(key, default)) + + private def parseFromString(value: String): T = { + try { + fromString(value) + } catch { + case NonFatal(_) => + throw new IllegalArgumentException( + s"Could not parse value '$value' for table property '$key'" + ) + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala new file mode 100644 index 0000000000000..d85c59287c683 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.control.{NonFatal, NoStackTrace} + +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.pipelines.Language +import org.apache.spark.sql.pipelines.logging.SourceCodeLocation + +/** + * Records information used to track the provenance of a given query to user code. + * + * @param language The language used by the user to define the query. + * @param fileName The file name of the user code that defines the query. + * @param cellNumber The cell number of the user code that defines the query. + * Cell numbers are 1-indexed. + * @param sqlText The SQL text of the query. + * @param line The line number of the query in the user code. + * Line numbers are 1-indexed. + * @param startPosition The start position of the query in the user code. + * @param objectType The type of the object that the query is associated with. (Table, View, etc) + * @param objectName The name of the object that the query is associated with. + */ +case class QueryOrigin( + language: Option[Language] = None, + fileName: Option[String] = None, + cellNumber: Option[Int] = None, + sqlText: Option[String] = None, + line: Option[Int] = None, + startPosition: Option[Int] = None, + objectType: Option[String] = None, + objectName: Option[String] = None +) { + /** + * Merges this origin with another one. + * + * The result has fields set to the value in the other origin if it is defined, or if not, then + * the value in this origin. + */ + def merge(other: QueryOrigin): QueryOrigin = { + QueryOrigin( + language = other.language.orElse(language), + fileName = other.fileName.orElse(fileName), + sqlText = other.sqlText.orElse(sqlText), + line = other.line.orElse(line), + cellNumber = other.cellNumber.orElse(cellNumber), + startPosition = other.startPosition.orElse(startPosition), + objectType = other.objectType.orElse(objectType), + objectName = other.objectName.orElse(objectName) + ) + } + + /** + * Merge values from the catalyst origin. + * + * The result has fields set to the value in the other origin if it is defined, or if not, then + * the value in this origin. + */ + def merge(other: Origin): QueryOrigin = { + merge( + QueryOrigin( + sqlText = other.sqlText, + line = other.line, + startPosition = other.startPosition + ) + ) + } + + /** Generates a [[SourceCodeLocation]] using the details present in the query origin. */ + def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation( + path = fileName, + // QueryOrigin tracks line numbers using a 1-indexed numbering scheme whereas SourceCodeLocation + // tracks them using a 0-indexed numbering scheme. + lineNumber = line.map(_ - 1), + columnNumber = startPosition, + endingLineNumber = None, + endingColumnNumber = None + ) +} + +object QueryOrigin { + + /** An empty QueryOrigin without any provenance information. */ + val empty: QueryOrigin = QueryOrigin() + + /** + * An exception that wraps [[QueryOrigin]] and lets us store it in errors as suppressed + * exceptions. + */ + private case class QueryOriginWrapper(origin: QueryOrigin) + extends Exception + with NoStackTrace + + implicit class ExceptionHelpers(t: Throwable) { + + /** + * Stores `origin` inside the given throwable using suppressed exceptions. + * + * We rely on suppressed exceptions since that lets us preserve the original exception class + * and type. + */ + def addOrigin(origin: QueryOrigin): Throwable = { + // Only try to add the query context if one has not already been added. + try { + // Do not add the origin again if one is already present. + // This also handles the case where the throwable is `null`. + if (getOrigin(t).isEmpty) { + t.addSuppressed(QueryOriginWrapper(origin)) + } + } catch { + case NonFatal(e) => + // logger.error("Failed to add pipeline context", e) + } + t + } + } + + /** Returns the [[QueryOrigin]] stored as a suppressed exception in the given throwable. + * + * @return Some(origin) if the origin is recorded as part of the given throwable, `None` + * otherwise. + */ + def getOrigin(t: Throwable): Option[QueryOrigin] = { + try { + // Wrap in an `Option(_)` first to handle `null` throwables. + Option(t).flatMap { ex => + ex.getSuppressed.collectFirst { + case QueryOriginWrapper(context) => context + } + } + } catch { + case NonFatal(e) => + // logger.error("Failed to get pipeline context", e) + None + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala new file mode 100644 index 0000000000000..9f05219c383dd --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier + +object ViewHelpers { + + /** Map of view identifier to corresponding unresolved flow */ + def persistedViewIdentifierToFlow(graph: DataflowGraph): Map[TableIdentifier, Flow] = { + graph.persistedViews.map { v => + require( + graph.flowsTo.get(v.identifier).isDefined, + s"No flows to view ${v.identifier} were found" + ) + val flowsToView = graph.flowsTo(v.identifier) + require( + flowsToView.size == 1, + s"Expected a single flow to the view, found ${flowsToView.size} flows to ${v.identifier}" + ) + (v.identifier, flowsToView.head) + }.toMap + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala new file mode 100644 index 0000000000000..348c6ecd344a6 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.pipelines.util.{ + BatchReadOptions, + InputReadOptions, + SchemaInferenceUtils, + StreamingReadOptions +} +import org.apache.spark.sql.types.StructType + +/** An element in a [[DataflowGraph]]. */ +trait GraphElement { + + /** + * Contains provenance to tie back this GraphElement to the user code that defined it. + * + * This must be set when a [[GraphElement]] is directly created by some user code. + * Subsequently, this initial origin must be propagated as is without modification. + * If this [[GraphElement]] is copied or converted to a different type, then this origin must be + * copied as is. + */ + def origin: QueryOrigin + + protected def spark: SparkSession = SparkSession.getActiveSession.get + + /** Returns the unique identifier for this [[GraphElement]]. */ + def identifier: TableIdentifier + + /** + * Returns a user-visible name for the element. + */ + def displayName: String = identifier.unquotedString +} + +/** + * Specifies an input that can be referenced by another Dataset's query. + */ +trait Input extends GraphElement { + + /** + * Returns a [[DataFrame]] that is a result of loading data from this [[Input]]. + * @param readOptions Type of input. Used to determine streaming/batch + * @return Streaming or batch [[DataFrame]] of this Input's data. + */ + def load(readOptions: InputReadOptions): DataFrame +} + +/** Represents a node in a [[DataflowGraph]] that can be written to by a [[Flow]]. */ +sealed trait Output { + + /** + * Normalized storage location used for storing materializations for this [[Output]]. + * If [[None]], it means this [[Output]] has not been normalized yet. + */ + def normalizedPath: Option[String] + + /** Return whether the storage location for this [[Output]] has been normalized. */ + final def normalized: Boolean = normalizedPath.isDefined + + /** + * Return the normalized storage location for this [[Output]] and throw if the + * storage location has not been normalized. + */ + @throws[SparkException] + def path: String +} + +/** A type of [[Input]] where data is loaded from a table. */ +sealed trait TableInput extends Input { + + /** The user-specified schema for this table. */ + def specifiedSchema: Option[StructType] +} + +/** + * A table representing a materialized dataset in a [[DataflowGraph]]. + * + * @param identifier The identifier of this table within the graph. + * @param specifiedSchema The user-specified schema for this table. + * @param partitionCols What columns the table should be partitioned by when materialized. + * @param normalizedPath Normalized storage location for the table based on the user-specified table + * path (if not defined, we will normalize a managed storage path for it). + * @param properties Table Properties to set in table metadata. + * @param sqlText For SQL-defined pipelines, the original string of the SELECT query. + * @param comment User-specified comment that can be placed on the table. + * @param isStreamingTableOpt if the table is a streaming table, will be None until we have resolved + * flows into table + */ +case class Table( + identifier: TableIdentifier, + specifiedSchema: Option[StructType], + partitionCols: Option[Seq[String]], + normalizedPath: Option[String], + properties: Map[String, String] = Map.empty, + sqlText: Option[String], + comment: Option[String], + baseOrigin: QueryOrigin, + isStreamingTableOpt: Option[Boolean], + format: Option[String] +) extends GraphElement + with Output + with TableInput { + + override val origin: QueryOrigin = baseOrigin.copy( + objectType = Some("table"), + objectName = Some(identifier.unquotedString) + ) + + // Load this table's data from underlying storage. + override def load(readOptions: InputReadOptions): DataFrame = { + try { + lazy val tableName = identifier.quotedString + + val df = readOptions match { + case sro: StreamingReadOptions => + spark.readStream.options(sro.userOptions).table(tableName) + case _: BatchReadOptions => + spark.read.table(tableName) + case _ => + throw new IllegalArgumentException("Unhandled `InputReadOptions` type when loading table") + } + + df + } catch { + case NonFatal(e) => throw LoadTableException(displayName, Option(e)) + } + } + + /** Returns the normalized storage location to this [[Table]]. */ + override def path: String = { + if (!normalized) { + throw GraphErrors.unresolvedTablePath(identifier) + } + normalizedPath.get + } + + /** + * Tell if a table is a streaming table or not. This property is not set until we have resolved + * the flows into the table. The exception reminds engineers that they cant call at random time. + */ + def isStreamingTable: Boolean = isStreamingTableOpt.getOrElse { + throw new IllegalStateException( + "Cannot identify whether the table is streaming table or not. You may need to resolve the " + + "flows into table." + ) + } + + /** + * Get the DatasetType of the table + */ + def datasetType: DatasetType = { + if (isStreamingTable) { + DatasetType.STREAMING_TABLE + } else { + DatasetType.MATERIALIZED_VIEW + } + } +} + +/** + * A type of [[TableInput]] that returns data from a specified schema or from the inferred + * [[Flow]]s that write to the table. + */ +case class VirtualTableInput( + identifier: TableIdentifier, + specifiedSchema: Option[StructType], + incomingFlowIdentifiers: Set[TableIdentifier], + availableFlows: Seq[ResolvedFlow] = Nil +) extends TableInput + with Logging { + override def origin: QueryOrigin = QueryOrigin() + + assert(availableFlows.forall(_.destinationIdentifier == identifier)) + override def load(readOptions: InputReadOptions): DataFrame = { + // Infer the schema for this virtual table + def getFinalSchema: StructType = { + specifiedSchema match { + // This is not a backing table, and we have a user-specified schema, so use it directly. + case Some(ss) => ss + // Otherwise infer the schema from a combination of the incoming flows and the + // user-specified schema, if provided. + case _ => + SchemaInferenceUtils.inferSchemaFromFlows(availableFlows, specifiedSchema) + } + } + + // create empty streaming/batch df based on input type. + def createEmptyDF(schema: StructType): DataFrame = readOptions match { + case _: StreamingReadOptions => + MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark.sqlContext) + .toDF() + case _ => spark.createDataFrame(new util.ArrayList[Row](), schema) + } + + val df = createEmptyDF(getFinalSchema) + df + } +} + +/** + * Representing a view in the [[DataflowGraph]]. + */ +trait View extends GraphElement { + + /** Returns the unique identifier for this [[View]]. */ + val identifier: TableIdentifier + + /** Properties of this view */ + val properties: Map[String, String] + + /** (SQL-specific) The raw query that defines the [[View]]. */ + val sqlText: Option[String] + + /** User-specified comment that can be placed on the [[View]]. */ + val comment: Option[String] +} + +/** + * Representing a temporary [[View]] in a [[DataflowGraph]]. + * + * @param identifier The identifier of this view within the graph. + * @param properties Properties of the view + * @param sqlText Raw SQL query that defines the view. + * @param comment when defining a view + */ +case class TemporaryView( + identifier: TableIdentifier, + properties: Map[String, String], + sqlText: Option[String], + comment: Option[String], + origin: QueryOrigin +) extends View {} + +/** + * Representing a persisted [[View]] in a [[DataflowGraph]]. + * + * @param identifier The identifier of this view within the graph. + * @param properties Properties of the view + * @param sqlText Raw SQL query that defines the view. + * @param comment when defining a view + */ +case class PersistedView( + identifier: TableIdentifier, + properties: Map[String, String], + sqlText: Option[String], + comment: Option[String], + origin: QueryOrigin +) extends View {} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala new file mode 100644 index 0000000000000..6e3ef2c6a7da8 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.pipelines.Language +import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions + +/** + * Generic options for a read of an input. + */ +sealed trait InputReadOptions { + // The language of the public API that called this function. + def apiLanguage: Language +} + +/** + * Options for a batch read of an input. + * + * @param apiLanguage The language of the public API that called this function. + */ +final case class BatchReadOptions(apiLanguage: Language) extends InputReadOptions + +/** + * Options for a streaming read of an input. + * + * @param apiLanguage The language of the public API that called this function. + * @param userOptions Holds the user defined read options. + * @param droppedUserOptions Holds the options that were specified by the user but + * not actually used. This is a bug but we are preserving this behavior + * for now to avoid making a backwards incompatible change. + */ +final case class StreamingReadOptions( + apiLanguage: Language, + userOptions: CaseInsensitiveMap[String] = EmptyUserOptions, + droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions +) extends InputReadOptions + +object StreamingReadOptions { + val EmptyUserOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map()) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala new file mode 100644 index 0000000000000..d5bd4887ce248 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.pipelines.graph.{GraphElementTypeUtils, GraphErrors, ResolvedFlow} +import org.apache.spark.sql.types.{StructField, StructType} + + +object SchemaInferenceUtils { + + /** + * Given a set of flows that write to the same destination and possibly a user-specified schema, + * we infer the resulting schema. + * + * Returns an error if encountered during schema inference or merging the inferred schema with + * the user-specified one. + */ + def inferSchemaFromFlows( + flows: Seq[ResolvedFlow], + userSpecifiedSchema: Option[StructType]): StructType = { + if (flows.isEmpty) { + return userSpecifiedSchema.getOrElse(new StructType()) + } + + require( + flows.forall(_.destinationIdentifier == flows.head.destinationIdentifier), + "Expected all flows to have the same destination" + ) + + val inferredSchema = flows.map(_.schema).fold(new StructType()) { (schemaSoFar, schema) => + try { + SchemaMergingUtils.mergeSchemas(schemaSoFar, schema) + } catch { + case NonFatal(e) => + throw GraphErrors.unableToInferSchemaError( + flows.head.destinationIdentifier, + schemaSoFar, + schema, + cause = Option(e) + ) + } + } + + val identifier = flows.head.destinationIdentifier + val datasetType = GraphElementTypeUtils.getDatasetTypeForMaterializedViewOrStreamingTable(flows) + // We merge the inferred schema with the user-specified schema to pick up any schema metadata + // that is provided by the user, e.g., comments or column masks. + mergeInferredAndUserSchemasIfNeeded( + identifier, + datasetType, + inferredSchema, + userSpecifiedSchema + ) + } + + private def mergeInferredAndUserSchemasIfNeeded( + tableIdentifier: TableIdentifier, + datasetType: DatasetType, + inferredSchema: StructType, + userSpecifiedSchema: Option[StructType]): StructType = { + userSpecifiedSchema match { + case Some(userSpecifiedSchema) => + try { + // Merge the inferred schema with the user-provided schema hint + SchemaMergingUtils.mergeSchemas(userSpecifiedSchema, inferredSchema) + } catch { + case NonFatal(e) => + throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError( + tableIdentifier, + datasetType, + userSpecifiedSchema, + inferredSchema, + cause = Option(e) + ) + } + case None => inferredSchema + } + } + + /** + * Determines the column changes needed to transform the current schema into the target schema. + * + * This function compares the current schema with the target schema and produces a sequence of + * TableChange objects representing: + * 1. New columns that need to be added + * 2. Existing columns that need type updates + * + * @param currentSchema The current schema of the table + * @param targetSchema The target schema that we want the table to have + * @return A sequence of TableChange objects representing the necessary changes + */ + def diffSchemas(currentSchema: StructType, targetSchema: StructType): Seq[TableChange] = { + val changes = scala.collection.mutable.ArrayBuffer.empty[TableChange] + + // Helper function to get a map of field name to field + def getFieldMap(schema: StructType): Map[String, StructField] = { + schema.fields.map(field => field.name -> field).toMap + } + + val currentFields = getFieldMap(currentSchema) + val targetFields = getFieldMap(targetSchema) + + // Find columns to add (in target but not in current) + val columnsToAdd = targetFields.keySet.diff(currentFields.keySet) + columnsToAdd.foreach { columnName => + val field = targetFields(columnName) + changes += TableChange.addColumn( + Array(columnName), + field.dataType, + field.nullable, + field.getComment().orNull + ) + } + + // Find columns to delete (in current but not in target) + val columnsToDelete = currentFields.keySet.diff(targetFields.keySet) + columnsToDelete.foreach { columnName => + changes += TableChange.deleteColumn(Array(columnName), false) + } + + // Find columns with type changes (in both but with different types) + val commonColumns = currentFields.keySet.intersect(targetFields.keySet) + commonColumns.foreach { columnName => + val currentField = currentFields(columnName) + val targetField = targetFields(columnName) + + // If data types are different, add a type update change + if (currentField.dataType != targetField.dataType) { + changes += TableChange.updateColumnType(Array(columnName), targetField.dataType) + } + + // If nullability is different, add a nullability update change + if (currentField.nullable != targetField.nullable) { + changes += TableChange.updateColumnNullability(Array(columnName), targetField.nullable) + } + + // If comments are different, add a comment update change + val currentComment = currentField.getComment().orNull + val targetComment = targetField.getComment().orNull + if (currentComment != targetComment) { + changes += TableChange.updateColumnComment(Array(columnName), targetComment) + } + } + + changes.toSeq + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala new file mode 100644 index 0000000000000..d15e7ac6425cc --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.sql.types.StructType + +object SchemaMergingUtils { + def mergeSchemas(tableSchema: StructType, dataSchema: StructType): StructType = { + StructType.merge(tableSchema, dataSchema).asInstanceOf[StructType] + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala new file mode 100644 index 0000000000000..091b77424b3b4 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types.{IntegerType, StructType} + +/** + * Test suite for converting one or more [[Pipeline]]s into a connected [[DataflowGraph]]. These + * examples are all semantically correct but contain logical errors which should be found + * when connect is called and thrown when validate() is called. + */ +class ConnectInvalidPipelineSuite extends PipelineTest { + + import originalSpark.implicits._ + test("Missing source") { + class P extends TestGraphRegistrationContext(spark) { + registerView("b", query = readFlowFunc("a")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved, "Pipeline should not have resolved properly") + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) + assertAnalysisException( + ex.directFailures(fullyQualifiedIdentifier("b", isView = true)), + "TABLE_OR_VIEW_NOT_FOUND" + ) + } + + test("Correctly differentiate between upstream and downstream errors") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(spark.range(5).toDF())) + registerView("b", query = readFlowFunc("nonExistentFlow")) + registerView("c", query = readFlowFunc("b")) + registerView("d", query = dfFlowFunc(spark.range(5).toDF())) + registerView("e", query = sqlFlowFunc(spark, "SELECT nonExistentColumn FROM RANGE(5)")) + registerView("f", query = readFlowFunc("e")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved, "Pipeline should not have resolved properly") + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) + assert( + ex.getMessage.contains( + s"Flows with errors: " + + s"${fullyQualifiedIdentifier("b", isView = true).unquotedString}," + + s" ${fullyQualifiedIdentifier("e", isView = true).unquotedString}" + ) + ) + assert( + ex.getMessage.contains( + s"Flows that failed due to upstream errors: " + + s"${fullyQualifiedIdentifier("c", isView = true).unquotedString}, " + + s"${fullyQualifiedIdentifier("f", isView = true).unquotedString}" + ) + ) + assert( + ex.directFailures.keySet == Set( + fullyQualifiedIdentifier("b", isView = true), + fullyQualifiedIdentifier("e", isView = true) + ) + ) + assert( + ex.downstreamFailures.keySet == Set( + fullyQualifiedIdentifier("c", isView = true), + fullyQualifiedIdentifier("f", isView = true) + ) + ) + assertAnalysisException( + ex.directFailures(fullyQualifiedIdentifier("b", isView = true)), + "TABLE_OR_VIEW_NOT_FOUND" + ) + assert( + ex.directFailures(fullyQualifiedIdentifier("e", isView = true)) + .isInstanceOf[AnalysisException] + ) + assert( + ex.directFailures(fullyQualifiedIdentifier("e", isView = true)) + .getMessage + .contains("nonExistentColumn") + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true)) + .isInstanceOf[UnresolvedDatasetException] + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains( + s"Failed to read dataset " + + s"'${fullyQualifiedIdentifier("b", isView = true).unquotedString}'. " + + s"Dataset is defined in the pipeline but could not be resolved" + ) + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true)) + .isInstanceOf[UnresolvedDatasetException] + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true)) + .getMessage + .contains( + s"Failed to read dataset " + + s"'${fullyQualifiedIdentifier("e", isView = true).unquotedString}'. " + + s"Dataset is defined in the pipeline but could not be resolved" + ) + ) + } + + test("correctly identify direct and downstream errors for multi-flow pipelines") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a") + registerFlow("a", "a", dfFlowFunc(spark.range(5).toDF())) + registerFlow("a", "a_2", sqlFlowFunc(spark, "SELECT non_existent_col FROM RANGE(5)")) + registerTable("b", query = Option(readFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { new P().resolveToDataflowGraph().validate() } + assert(ex.directFailures.keySet == Set(fullyQualifiedIdentifier("a_2"))) + assert(ex.downstreamFailures.keySet == Set(fullyQualifiedIdentifier("b"))) + + } + + test("Missing attribute in the schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("z"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT x FROM a")) + } + + val dfg = new P().resolveToDataflowGraph() + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + }.directFailures(fullyQualifiedIdentifier("b", isView = true)).getMessage + verifyUnresolveColumnError(ex, "x", Seq("z")) + } + + test("Joining on a column with different names") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq("a", "b", "c").toDF("y"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)")) + } + + val dfg = new P().resolveToDataflowGraph() + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert( + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("USING column `x` cannot be resolved on the right side") + ) + } + + test("Writing to one table by unioning flows with different schemas") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(true, false).toDF("x"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT x FROM a UNION SELECT x FROM b")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved) + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + // This error message has changed between DBR versions. + assert( + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("compatible column types") || + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("Failed to merge incompatible data types") + ) + } + + test("Self reference") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = readFlowFunc("a")) + } + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset == fullyQualifiedIdentifier("a", isView = true)) + assert(e.downstreamTable == fullyQualifiedIdentifier("a", isView = true)) + } + + test("Cyclic graph - simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = readFlowFunc("b")) + registerView("b", query = readFlowFunc("a")) + } + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + val cycle = Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ) + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + registerView("c", query = readFlowFunc("b")) + registerView("d", query = readFlowFunc("c")) + } + val cycle = + Set( + fullyQualifiedIdentifier("b", isView = true), + fullyQualifiedIdentifier("c", isView = true), + fullyQualifiedIdentifier("d", isView = true) + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph with materialized nodes") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + ) + registerTable("c", query = Option(readFlowFunc("b"))) + registerTable("d", query = Option(readFlowFunc("c"))) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph - second query makes it cyclic") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable("b") + registerFlow("b", "b", readFlowFunc("a")) + registerFlow("b", "b2", readFlowFunc("d")) + registerTable("c", query = Option(readFlowFunc("b"))) + registerTable("d", query = Option(readFlowFunc("c"))) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph - all named queries") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable("b") + registerFlow("b", "`b-name`", sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + registerTable("c") + registerFlow("c", "`c-name`", readFlowFunc("b")) + registerTable("d") + registerFlow("d", "`d-name`", readFlowFunc("c")) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("view-table conf conflict") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val")) + registerTable("b", query = Option(readFlowFunc("a")), sqlConf = Map("x" -> "b-val")) + } + val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() } + assert( + ex.getMessage.contains( + s"Found duplicate sql conf for dataset " + + s"'${fullyQualifiedIdentifier("b").unquotedString}':" + ) + ) + assert( + ex.getMessage.contains( + s"'x' is defined by both " + + s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " + + s"and '${fullyQualifiedIdentifier("b").unquotedString}'" + ) + ) + } + + test("view-view conf conflict") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val")) + registerView("b", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "b-val")) + registerTable( + "c", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")), + sqlConf = Map("y" -> "c-val") + ) + } + val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() } + assert( + ex.getMessage.contains( + s"Found duplicate sql conf for dataset " + + s"'${fullyQualifiedIdentifier("c").unquotedString}':" + ) + ) + assert( + ex.getMessage.contains( + s"'x' is defined by both " + + s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " + + s"and '${fullyQualifiedIdentifier("b", isView = true).unquotedString}'" + ) + ) + } + + test("reading a complete view incrementally") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF())) + registerTable("b", query = Option(readStreamFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() } + assert( + ex.directFailures(fullyQualifiedIdentifier("b")) + .getMessage + .contains( + s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString}" + + s" is not a streaming view and must be referenced using read." + ) + ) + } + + test("reading an incremental view completely") { + val p = new TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + mem.addData(1) + registerView("a", query = dfFlowFunc(mem.toDF())) + registerTable("b", query = Option(readFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() } + assert( + ex.directFailures(fullyQualifiedIdentifier("b")) + .getMessage + .contains( + s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString} " + + s"is a streaming view and must be referenced using readStream" + ) + ) + } + + test("Inferred schema that isn't a subset of user-specified schema") { + val graph1 = new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq(1, 2).toDF("incorrect-col-name"))), + specifiedSchema = Option(new StructType().add("x", IntegerType)) + ) + }.resolveToDataflowGraph() + val ex1 = intercept[AnalysisException] { graph1.validate() } + assert( + ex1.getMessage.contains( + s"'${fullyQualifiedIdentifier("a").unquotedString}' " + + s"has a user-specified schema that is incompatible" + ) + ) + assert(ex1.getMessage.contains("incorrect-col-name")) + + val graph2 = new TestGraphRegistrationContext(spark) { + registerTable("a", specifiedSchema = Option(new StructType().add("x", IntegerType))) + registerFlow("a", "a", query = dfFlowFunc(Seq(true, false).toDF("x")), once = true) + }.resolveToDataflowGraph() + val ex2 = intercept[AnalysisException] { graph2.validate() } + assert( + ex2.getMessage.contains( + s"'${fullyQualifiedIdentifier("a").unquotedString}' " + + s"has a user-specified schema that is incompatible" + ) + ) + assert(ex2.getMessage.contains("boolean") && ex2.getMessage.contains("integer")) + + val streamingTableHint = "please full refresh" + assert(!ex1.getMessage.contains(streamingTableHint)) + assert(ex2.getMessage.contains(streamingTableHint)) + } + +// test("Sink has wrong type of data coming in") { +// class P extends TestPipelineDefinition { +// val mem = MemoryStream[Int] +// mem.addData(1, 2, 3) +// createView("a", query =(mem.toDF()) +// createView("b") , query =(readStream("a")) +// .sink( +// func = df => df.writeStream, +// userValidation = df => +// if (df.schema != new StructType().add("value", StringType, false)) { +// throw new IllegalStateException() +// } +// ) +// } +// val p = DataflowGraph(new P).connect +// intercept[IllegalStateException] { p.validate() } +// } + +// test("Flow to sink cannot be resolved") { +// class P extends TestPipelineDefinition { +// createView("a", query =(() => sys.error("not available")) +// createView("b", query =(readStream("a")).sink(df => df.writeStream) +// } +// val p = DataflowGraph(new P()).connect +// assert(!p.resolved) +// val ex = intercept[UnresolvedPipelineException] { p.validate() } +// assert(ex.directFailures.keySet == Set("a")) +// assert(ex.downstreamFailures.keySet == Set("b")) +// assert(ex.directFailures("a").getMessage.contains("not available")) +// } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala new file mode 100644 index 0000000000000..5f53ab84b8b7a --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Test suite for converting a [[PipelineDefinition]]s into a connected [[DataflowGraph]]. These + * examples are all semantically correct and logically correct and connect should not result in any + * errors. + */ +class ConnectValidPipelineSuite extends PipelineTest { + + import originalSpark.implicits._ + + test("Extra simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("y"))) + } + val p = new P().resolveToDataflowGraph() + val outSchema = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema) + } + + test("Simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a")) + } + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("a", isView = true), + new StructType().add("x", IntegerType, false) + ) + val outSchema = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Dependencies") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT y as z FROM b")) + registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a")) + } + val p = new P().resolveToDataflowGraph() + val schemaAB = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaAB) + val schemaBC = new StructType().add("z", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaBC) + } + + test("Multi-hop schema merging") { + class P extends TestGraphRegistrationContext(spark) { + +// import org.apache.spark.sql.functions._ fixme + + registerView( + "b", + query = sqlFlowFunc(spark, """SELECT * FROM VALUES ((1)) OUTER JOIN d ON false""") + ) + registerView("e", query = readFlowFunc("b")) + registerView("d", query = dfFlowFunc(Seq(1).toDF("y"))) + } + val p = new P().resolveToDataflowGraph() + val schemaE = new StructType().add("col1", IntegerType, false).add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaE) + verifyFlowSchema(p, fullyQualifiedIdentifier("e", isView = true), schemaE) + } + + test("Cross product join merges schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a CROSS JOIN b")) + } + val p = new P().resolveToDataflowGraph() + val schemaC = new StructType().add("x", IntegerType, false).add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Real join merges schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq((1, "a"), (2, "b"), (3, "c")).toDF("x", "y"))) + registerView("b", query = dfFlowFunc(Seq((2, "m"), (3, "n"), (4, "o")).toDF("x", "z"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)")) + } + val p = new P().resolveToDataflowGraph() + val schemaC = new StructType() + .add("x", IntegerType, false) + .add("y", StringType) + .add("z", StringType) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Union of streaming and batch Dataframes") { + class P extends TestGraphRegistrationContext(spark) { + val ints = MemoryStream[Int] + ints.addData(1, 2, 3, 4) + registerView("a", query = dfFlowFunc(ints.toDF())) + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF())) + registerView( + "c", + query = FlowAnalysis.createFlowFunctionFromLogicalPlan( + Union( + Seq( + UnresolvedRelation( + TableIdentifier("a"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ), + UnresolvedRelation(TableIdentifier("b")) + ) + ) + ) + ) + } + + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("c", isView = true), + new StructType().add("value", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Union of two streaming Dataframes") { + class P extends TestGraphRegistrationContext(spark) { + val ints1 = MemoryStream[Int] + ints1.addData(1, 2, 3, 4) + val ints2 = MemoryStream[Int] + ints2.addData(1, 2, 3, 4) + registerView("a", query = dfFlowFunc(ints1.toDF())) + registerView("b", query = dfFlowFunc(ints2.toDF())) + registerView( + "c", + query = FlowAnalysis.createFlowFunctionFromLogicalPlan( + Union( + Seq( + UnresolvedRelation( + TableIdentifier("a"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ), + UnresolvedRelation( + TableIdentifier("b"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + ) + ) + ) + } + + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("c", isView = true), + new StructType().add("value", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("MultipleInputs") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y"))) + registerView( + "c", + query = sqlFlowFunc(spark, "SELECT x AS z FROM a UNION SELECT y AS z FROM b") + ) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("z", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schema) + } + +// test("connect from actual table") { +// val rootDir = Files.createTempDirectory("evaluation").toString +// +// def tablePath(tableName: String): String = s"$rootDir/$tableName" +// +// Seq(3, 4).toDF().write.format("delta").save(tablePath("b")) +// val p = legacyNormalizeTablePaths( +// DataflowGraph(new TestPipelineDefinition { +// createView("a", query = () => Seq(3, 4).toDF("incorrect-col-name")) +// createTable("b", query = read("a"), path = Option(tablePath("b"))) +// createView("c", query = read("b")) +// createView("d", query = () => spark.readStream.table("b")) +// }), +// dbfsRoot = rootDir +// ).connect.validate() +// assert(p.flowByName("c").df.schema == new StructType().add("value", IntegerType)) +// assert(p.flowByName("d").df.schema == new StructType().add("value", IntegerType)) +// } +// +// test("connect from source") { +// val rootDir = Files.createTempDirectory("evaluation").toString +// val inputPath = rootDir + "/input" +// val g = legacyNormalizeTablePaths( +// DataflowGraph(new TestPipelineDefinition { +// createView("a", query = () => spark.spark.table.format("delta").load(inputPath)) +// createTable("b", query = read("a")).expectOrFail("blah", $"value" > 0) +// }), +// rootDir +// ) +// Seq(1).toDF().write.format("delta").save(inputPath) +// Seq(3).toDF().write.format("delta").save(g.tableByName("b").path) +// val graph1 = g.virtualize.connect.validate() +// assert(graph1.resolved) +// assert(graph1.flowByName("b").df.schema == new StructType().add("value", IntegerType)) +// +// val graph2 = g.connect.validate() +// assert(graph2.resolved) +// assert(graph2.flowByName("b").df.schema == new StructType().add("value", IntegerType)) +// } + + test("Connect retains and fuses confs") { + // a -> b \ + // d + // c / + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val")) + registerView("b", query = readFlowFunc("a"), Map("b" -> "b-val")) + registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), Map("c" -> "c-val")) + registerTable( + "d", + query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")), + Map("d" -> "d-val") + ) + } + val graph = p.resolveToDataflowGraph() + assert( + graph + .flow(fullyQualifiedIdentifier("d", isView = false)) + .sqlConf == Map("a" -> "a-val", "b" -> "b-val", "c" -> "c-val", "d" -> "d-val") + ) + } + + test("Confs aren't fused past materialization points") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val")) + registerTable("b", query = Option(readFlowFunc("a")), Map("b" -> "b-val")) + registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), sqlConf = Map("c" -> "c-val")) + registerTable( + "d", + query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")), + Map("d" -> "d-val") + ) + } + val graph = p.resolveToDataflowGraph() + assert(graph.flow(fullyQualifiedIdentifier("a", isView = true)).sqlConf == Map("a" -> "a-val")) + assert( + graph + .flow(fullyQualifiedIdentifier("b")) + .sqlConf == Map("a" -> "a-val", "b" -> "b-val") + ) + assert(graph.flow(fullyQualifiedIdentifier("c", isView = true)).sqlConf == Map("c" -> "c-val")) + assert( + graph + .flow(fullyQualifiedIdentifier("d")) + .sqlConf == Map("c" -> "c-val", "d" -> "d-val") + ) + } + + test("Setting the same conf with the same value is totally cool") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val")) + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val")) + registerTable( + "c", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")), + Map("key" -> "val") + ) + } + val graph = p.resolveToDataflowGraph() + assert(graph.flow(fullyQualifiedIdentifier("c")).sqlConf == Map("key" -> "val")) + } + + test("Named query only") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerTable("b") + registerFlow("b", "`b-query`", readFlowFunc("a")) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("x", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema) + verifyFlowSchema(p, fullyQualifiedIdentifier("b-query"), schema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b-query")).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Default query and named query") { + class P extends TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + registerView("a", query = dfFlowFunc(mem.toDF())) + registerTable("b") + registerFlow("b", "b", dfFlowFunc(mem.toDF().select($"value" as "y"))) + registerFlow("b", "b2", readStreamFlowFunc("a")) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("value", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema) + verifyFlowSchema(p, fullyQualifiedIdentifier("b2"), schema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b2")).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + verifyFlowSchema( + p, + fullyQualifiedIdentifier("b"), + new StructType().add("y", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b")).inputs == Set.empty, + "Flow did not have the expected inputs" + ) + } + + test("Multi-query table with 2 complete queries") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a") + registerFlow("a", "a", query = dfFlowFunc(spark.range(5).toDF())) + registerFlow("a", "a2", query = dfFlowFunc(spark.range(6).toDF())) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("id", LongType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a"), schema) + } + + test("Correct types of flows after connection") { + val graph = new TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + mem.addData(1, 2) + registerView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) + registerView("incremental-view", query = dfFlowFunc(mem.toDF())) + registerTable("`complete-table`", query = Option(readFlowFunc("complete-view"))) + registerTable("`incremental-table`") + registerFlow( + "`incremental-table`", + "`incremental-table`", + FlowAnalysis.createFlowFunctionFromLogicalPlan( + UnresolvedRelation( + TableIdentifier("incremental-view"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + ) + registerFlow( + "`incremental-table`", + "`append-once`", + dfFlowFunc(Seq(1, 2).toDF("x")), + once = true + ) + }.resolveToDataflowGraph() + + assert( + graph + .flow(fullyQualifiedIdentifier("complete-view", isView = true)) + .isInstanceOf[CompleteFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("incremental-view", isView = true)) + .isInstanceOf[StreamingFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("complete-table")) + .isInstanceOf[CompleteFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("incremental-table")) + .isInstanceOf[StreamingFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("append-once")) + .isInstanceOf[AppendOnceFlow] + ) + } + + test("Pipeline level default spark confs are applied with correct precedence") { + val P = new TestGraphRegistrationContext( + spark, + Map("default.conf" -> "value") + ) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))), + sqlConf = Map("other.conf" -> "value") + ) + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT x as y FROM a")), + sqlConf = Map("default.conf" -> "other-value") + ) + } + val p = P.resolveToDataflowGraph() + + assert( + p.flow(fullyQualifiedIdentifier("a")).sqlConf == Map( + "default.conf" -> "value", + "other.conf" -> "value" + ) + ) + + assert( + p.flow(fullyQualifiedIdentifier("b")).sqlConf == Map( + "default.conf" -> "other-value" + ) + ) + } + +// test("Dataset names are case sensitive and can be spark.table correctly") { +// class P extends TestPipelineDefinition { +// createView("A", query = dfFlowFunc(spark.range(5)) +// createView("a", query = readFlowFunc("A").filter("id = 0").toDF("value")) +// createView("b", query = readFlowFunc("a")) +// } +// +// val p = new P().resolveToDataflowGraph() +// verifyFlowSchema(p, "b", new StructType().add("value", LongType, nullable = false)) +// } + +// test("Dataset names are case sensitive and will not spark.table in a case insensitive way") { +// class P extends TestPipelineDefinition { +// createView("A", query = () => spark.range(5).toDF()) +// createView("b", query = sql(spark, "SELECT x as y FROM a"))) +// } +// +// val dfg = DataflowGraph(new P).connect +// assert(!dfg.resolved, "Pipeline should not have resolved properly") +// val ex = intercept[UnresolvedPipelineException] { +// dfg.validate() +// } +// assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) +// assertSparkException( +// ex.directFailures("b"), +// "DATASET_NOT_DEFINED", +// Map("datasetName" -> "a") +// ) +// } + +// test("Verify duplicate table names due to lower-casing is caught") { +// class P extends TestPipelineDefinition { +// createTable("A", query = Option(() => spark.range(5).toDF())) +// createTable("a", query = read("A").filter("id = 0").toDF("value")) +// } +// val ex = intercept[AnalysisException] { +// (new P).resolveToDataflowGraph(shouldLowerCaseNames = true) +// } +// assert(ex.getMessage.contains("Found duplicate table")) +// } +// +// test("Verify duplicate view names due to lower-casing is caught") { +// class P extends TestPipelineDefinition { +// createTable("A", query = () => spark.range(5)) +// createView("a", query = read("A").filter("id = 0").toDF("value")) +// } +// val ex = intercept[AnalysisException] { +// (new P).resolveToDataflowGraph(shouldLowerCaseNames = true) +// } +// assert(ex.getMessage.contains("Found duplicate dataset")) +// } + + /** Verifies the [[DataflowGraph]] has the specified [[Flow]] with the specified schema. */ + private def verifyFlowSchema( + pipeline: DataflowGraph, + identifier: TableIdentifier, + expected: StructType): Unit = { + assert( + pipeline.flow.contains(identifier), + s"Flow ${identifier.unquotedString} not found," + + s" all flow names: ${pipeline.flow.keys.map(_.unquotedString)}" + ) + assert( + pipeline.resolvedFlow.contains(identifier), + s"Flow ${identifier.unquotedString} has not been resolved" + ) + assert( + pipeline.resolvedFlow(identifier).schema == expected, + s"Flow ${identifier.unquotedString} has the wrong schema" + ) + } + +// test("external sink") { +// SparkSessionUtils.withSQLConf( +// spark, +// ("pipelines.externalSink.enabled", true.toString) +// ) { +// class P extends TestPipelineDefinition { +// val mem = MemoryStream[Int] +// mem.addData(1, 2) +// createView("a", query = mem.toDF().select($"value" as "x")) +// createExternalSink("sink_a", format = "memory") +// createSinkForTopLevelFlow("sink_a", query = "sink_flow", readStream("a")) +// } +// val g = new P().resolveToDataflowGraph() +// g.validate() +// assert(g.resolved) +// assert(g.sinkByName("sink_a").isInstanceOf[ExternalSink]) +// val externalSink = g.sinkByName("sink_a").asInstanceOf[ExternalSink] +// assert(externalSink.format == "memory") +// assert(g.flowByName("sink_flow").isInstanceOf[ExternalSinkFlow]) +// } +// } +// +// test("Writing from view to sink") { +// class P extends TestPipelineDefinition { +// val mem = MemoryStream[Int] +// mem.addData(1, 2, 3) +// createView("a", query = mem.toDF()) +// createView("b") +// .query(readStream("a")) +// .sink( +// func = df => df.writeStream, +// userValidation = +// df => assert(df.schema == new StructType().add("value", IntegerType, false)) +// ) +// } +// val p = DataflowGraph(new P).connect +// assert(p.resolved) +// } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala new file mode 100644 index 0000000000000..d13fd1902621e --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -0,0 +1,643 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import java.io.{BufferedReader, File, FileNotFoundException, InputStreamReader} +import java.nio.file.{Files, Paths} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Try} +import scala.util.control.NonFatal + +import org.scalactic.source +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag} +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, QueryTest, Row, TypedColumn} +import org.apache.spark.sql.SparkSession.{clearActiveSession, setActiveSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession, SQLContext} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, createTempDir} + +abstract class PipelineTest + extends SparkFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach + with Matchers +// with SQLImplicits + with SparkErrorTestMixin + with TargetCatalogAndSchemaMixin + with Logging { + + final protected val storageRoot = createTempDir() + + var spark: SparkSession = createAndInitializeSpark() + val originalSpark: SparkSession = spark.cloneSession() + + implicit def sqlContext: SQLContext = spark.sqlContext + def sql(text: String): DataFrame = spark.sql(text) + + /** + * Spark confs for [[originalSpark]]. Spark confs set here will be the default spark confs for + * all spark sessions created in tests. + */ + protected def sparkConf: SparkConf = { + var conf = new SparkConf() + .set("spark.sql.shuffle.partitions", "2") + .set("spark.sql.session.timeZone", "UTC") + + if (schemaInPipelineSpec.isDefined) { + conf = conf.set("pipelines.schema", schemaInPipelineSpec.get) + } + + if (Option(System.getenv("ENABLE_SPARK_UI")).exists(s => java.lang.Boolean.valueOf(s))) { + conf = conf.set("spark.ui.enabled", "true") + } + conf + } + + /** Returns the dataset name in the event log. */ + protected def eventLogName( + name: String, + catalog: Option[String] = catalogInPipelineSpec, + schema: Option[String] = schemaInPipelineSpec, + isView: Boolean = false + ): String = { + fullyQualifiedIdentifier(name, catalog, schema, isView).unquotedString + } + + /** Returns the fully qualified identifier. */ + protected def fullyQualifiedIdentifier( + name: String, + catalog: Option[String] = catalogInPipelineSpec, + schema: Option[String] = schemaInPipelineSpec, + isView: Boolean = false + ): TableIdentifier = { + if (isView) { + TableIdentifier(name) + } else { + TableIdentifier( + catalog = catalog, + database = schema, + table = name + ) + } + } + +// /** Returns the [[PipelineApiConf]] constructed from the current spark session */ +// def pipelineApiConf: PipelineApiConf = PipelineApiConf.instance + + /** + * Runs the given function with the given spark conf, and resets the conf after the function + * completes. + */ + def withSparkConfs[T](confs: Map[String, String])(f: => T): T = { + val originalConfs = confs.keys.map(k => k -> spark.conf.getOption(k)).toMap + confs.foreach { case (k, v) => spark.conf.set(k, v) } + try f + finally originalConfs.foreach { + case (k, v) => + v match { + case Some(v) => spark.conf.set(k, v) + case None => spark.conf.unset(k) + } + } + } + + /** + * This exists temporarily for compatibility with tests that become invalid when multiple + * executors are available. + */ + protected def master = "local[*]" + + /** Creates and returns a initialized spark session. */ + def createAndInitializeSpark(): SparkSession = { + val newSparkSession = SparkSession + .builder() + .config(sparkConf) + .master(master) + .getOrCreate() + newSparkSession + } + + /** Set up the spark session before each test. */ + protected def initializeSparkBeforeEachTest(): Unit = { + clearActiveSession() + spark = originalSpark.newSession() + setActiveSession(spark) + } + + override def beforeEach(): Unit = { + super.beforeEach() + initializeSparkBeforeEachTest() + cleanupMetastore(spark) + (catalogInPipelineSpec, schemaInPipelineSpec) match { + case (Some(catalog), Some(schema)) => + sql(s"CREATE SCHEMA IF NOT EXISTS `$catalog`.`$schema`") + case _ => + schemaInPipelineSpec.foreach(s => sql(s"CREATE SCHEMA IF NOT EXISTS `$s`")) + } + } + + override def afterEach(): Unit = { + cleanupMetastore(spark) + super.afterEach() + } + + override def afterAll(): Unit = { + spark.stop() + } + + protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + namedGridTest(testNamePrefix, testTags: _*)(params.map(a => a.toString -> a).toMap)(testFun) + } + + override def test(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */ )( + implicit pos: source.Position): Unit = super.test(testName, testTags: _*) { + runWithInstrumentation(testFun) + } + + /** + * Adds custom instrumentation for tests. + * + * This instrumentation runs after `beforeEach` and + * before `afterEach` which lets us instrument the state of a test and its environment + * after any setup and before any clean-up done for a test. + */ + private def runWithInstrumentation(testFunc: => Any): Any = { + try { + testFunc + } catch { + case e: TestFailedDueToTimeoutException => +// val stackTraces = StackTraceReporter.dumpAllStackTracesToString() +// logInfo( +// s""" +// |Triggering thread dump since test failed with a timeout exception: +// |$stackTraces +// |""".stripMargin +// ) + throw e + } + } + + /** + * Creates individual tests for all items in [[params]]. + * + * The full test name will be " ( = )" where is one + * item in [[params]]. + * + * @param testNamePrefix The test name prefix. + * @param paramName A descriptive name for the parameter. + * @param testTags Extra tags for the test. + * @param params The list of parameters for which to generate tests. + * @param testFun The actual test function. This function will be called with one argument of + * type [[A]]. + * @tparam A The type of the params. + */ + protected def gridTest[A](testNamePrefix: String, paramName: String, testTags: Tag*)( + params: Seq[A])(testFun: A => Unit): Unit = + namedGridTest(testNamePrefix, testTags: _*)( + params.map(a => s"$paramName = $a" -> a).toMap + )(testFun) + + /** + * Specialized version of gridTest where the params are two boolean values - [[true]] and + * [[false]]. + */ + protected def booleanGridTest(testNamePrefix: String, paramName: String, testTags: Tag*)( + testFun: Boolean => Unit): Unit = { + gridTest(testNamePrefix, paramName, testTags: _*)(Seq(true, false))(testFun) + } + + protected def namedGridTest[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2)) + } + } + + protected def namedGridIgnore[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])( + testFun: A => Unit): Unit = { + for (param <- params) { + ignore(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2)) + } + } + + /** + * Returns a [[Seq]] of JARs generated by compiling this test. + * + * Includes a "delta-pipelines-repo" to ensure the export_test, which compiles differently, + * still succeeds. + */ + protected def getTestJars: Seq[String] = + getUniqueAbsoluteTestJarPaths.map(_.getName) :+ "delta-pipelines-repo" + + /** + * Returns a [[Seq]] of absolute paths of all JAR files found in the + * current directory. See [[getUniqueAbsoluteTestJarPaths]]. + */ + protected def getTestJarPaths: Seq[String] = + getUniqueAbsoluteTestJarPaths.map(_.getAbsolutePath) + + /** + * Returns a sequence of JARs found in the current directory. In a bazel test, + * the current directory includes all jars that are required to run the test + * (its run files). This allows us to include these jars in the class path + * for the graph loading class loader. + * + * Because dependent jars can be included multiple times in this list, we deduplicate + * by file name (ignoring the path). + */ + private def getUniqueAbsoluteTestJarPaths: Seq[File] = + Files + .walk(Paths.get(".")) + .iterator() + .asScala + .map(_.toFile) + .filter( + f => + f.isFile && + // This filters JARs to match 2 main cases: + // - JARs built by Bazel that are usually suffixed with deploy.jar; + // - classpath.jar that Scala test template can also create if the classpath is too long. + f.getName.matches("classpath.jar|.*deploy.jar") + ) + .toSeq + .groupBy(_.getName) + .flatMap(_._2.headOption) + .toSeq + +// /** +// * Returns a [[DataFrame]] given the path to json encoded data stored in the project's +// * test resources. Schema is parsed from first line. +// */ +// protected def jsonData(path: String): DataFrame = { +// val contents = loadResource(path) +// val data = contents.tail +// val schema = contents.head +// jsonData(schema, data) +// } +// +// /** Returns a [[DataFrame]] given the string representation of it schema and data. */ +// protected def jsonData(schemaString: String, data: Seq[String]): DataFrame = { +// val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +// spark.read.schema(schema).json(data.toDS()) +// } + + /** Loads a package resources as a Seq of lines. */ + protected def loadResource(path: String): Seq[String] = { + val stream = Thread.currentThread.getContextClassLoader.getResourceAsStream(path) + if (stream == null) { + throw new FileNotFoundException(path) + } + val reader = new BufferedReader(new InputStreamReader(stream)) + val data = new ArrayBuffer[String] + var line = reader.readLine() + while (line != null) { + data.append(line) + line = reader.readLine() + } + data.toSeq + } + + private def checkAnswerAndPlan( + df: => DataFrame, + expectedAnswer: Seq[Row], + checkPlan: Option[SparkPlan => Unit]): Unit = { + QueryTest.checkAnswer(df, expectedAnswer) + + // To help with test development, you can dump the plan to the log by passing + // `--test_env=DUMP_PLAN=true` to `bazel test`. + if (Option(System.getenv("DUMP_PLAN")).exists(s => java.lang.Boolean.valueOf(s))) { + log.info(s"Spark plan:\n${df.queryExecution.executedPlan}") + } + checkPlan.foreach(_.apply(df.queryExecution.executedPlan)) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + checkAnswerAndPlan(df, expectedAnswer, None) + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + case class ValidationArgs( + ignoreFieldOrder: Boolean = false, + ignoreFieldCase: Boolean = false + ) + +// /** +// * Runs the query and makes sure the answer matches the expected result. +// * +// * @param validateSchema Whether or not the exact schema fields are validated. This validates +// * the schema types and field names, but does not validate field +// * nullability. +// */ +// protected def checkAnswer( +// df: => DataFrame, +// expectedAnswer: DataFrame, +// validateSchema: Boolean = false, +// validationArgs: ValidationArgs = ValidationArgs(), +// checkPlan: Option[SparkPlan => Unit] = None +// ): Unit = { +// // Evaluate `df` so we get a constant DF. +// val dfByVal = df +// val actualSchema = dfByVal.schema +// val expectedSchema = expectedAnswer.schema +// +// def transformSchema(original: StructType): StructType = { +// var result = original +// if (validationArgs.ignoreFieldOrder) { +// result = StructType(result.fields.sortBy(_.name)) +// } +// if (validationArgs.ignoreFieldCase) { +// result = StructType(result.fields.map { field => +// field.copy(name = field.name.toLowerCase(Locale.ROOT)) +// }) +// } +// result +// } +// +// def transformDataFrame(original: DataFrame): DataFrame = { +// var result = original +// if (validationArgs.ignoreFieldOrder) { +// result = result.select( +// result.columns.sorted.map { columnName => +// result.col(UnresolvedAttribute.quoted(columnName).name) +// }: _* +// ) +// } +// result +// } +// +// if (validateSchema) { +// assert( +// transformSchema(actualSchema.asNullable) == transformSchema(expectedSchema.asNullable), +// s"Expected and actual schemas are different:\n" + +// s"Expected: $expectedSchema\n" + +// s"Actual: $actualSchema" +// ) +// } +// checkAnswerAndPlan( +// transformDataFrame(dfByVal), +// transformDataFrame(expectedAnswer).collect().toIndexedSeq, +// checkPlan +// ) +// } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + */ + protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!QueryTest.compare(result.toSeq, expectedAnswer)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T: Ordering](result: Array[T], expectedAnswer: T*): Unit = { + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + protected def checkDatasetUnorderly[T: Ordering](ds: => Dataset[T], expectedAnswer: T*): Unit = { + val result = getResult(ds) + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + ds + + try ds.collect() + catch { + case NonFatal(e) => + fail( + s""" + |Exception collecting dataset as objects + |${ds.queryExecution} + """.stripMargin, + e + ) + } + } + + /** Holds a parsed version along with the original json of a test. */ + private case class TestSequence(json: Seq[String], rows: Seq[Row]) { + require(json.size == rows.size) + } + + /** + * Helper method to verify unresolved column error message. We expect three elements to be present + * in the message: error class, unresolved column name, list of suggested columns. There are three + * significant differences between different versions of DBR: + * - Error class changed in DBR 11.3 from `MISSING_COLUMN` to `UNRESOLVED_COLUMN.WITH_SUGGESTION` + * - Name parts in suggested columns are escaped with backticks starting from DBR 11.3, + * e.g. table.column => `table`.`column` + * - Starting from DBR 13.1 suggested columns qualification matches unresolved column, i.e. if + * unresolved column is a single-part identifier then suggested column will be as well. E.g. + * for unresolved column `x` suggested columns will omit catalog/schema or `LIVE` qualifier. For + * this reason we verify only last part of suggested column name. + */ + protected def verifyUnresolveColumnError( + errorMessage: String, + unresolved: String, + suggested: Seq[String]): Unit = { + assert(errorMessage.contains(unresolved)) + assert( + errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]") || + errorMessage.contains("[MISSING_COLUMN]") + ) + suggested.foreach { x => + if (errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]")) { + assert(errorMessage.contains(s"`$x`")) + } else { + assert(errorMessage.contains(x)) + } + } + } + + /** Evaluates the given column and returns the result. */ + def eval(column: Column): Any = { + spark.range(1).select(column).collect().head.get(0) + } + + /** Evaluates a column as part of a query and returns the result */ + def eval[T](col: TypedColumn[Any, T]): T = { + spark.range(1).select(col).head() + } + +// /** +// * Helper class to create a SQLPipeline that is eligible to resolve flows parallely. +// */ +// class SQLPipelineWithParallelResolve( +// queries: Seq[String], +// notebookPath: Option[String] = None, +// catalog: Option[String] = catalogInPipelineSpec, +// schema: Option[String] = schemaInPipelineSpec +// ) extends SQLPipeline(queries, notebookPath, catalog, schema) { +// override def eligibleForResolvingFlowsParallely = true +// } + +// /** +// * Helper method to create a [[SQLPipelineWithParallelResolve]] with catalog and schema set to +// * the test's catalog and schema. +// */ +// protected def createSqlParallelPipeline( +// queries: Seq[String], +// notebookPath: Option[String] = None, +// catalog: Option[String] = catalogInPipelineSpec, +// schema: Option[String] = schemaInPipelineSpec +// ): SQLPipelineWithParallelResolve = { +// new SQLPipelineWithParallelResolve( +// queries = queries, +// notebookPath = notebookPath, +// catalog = catalog, +// schema = schema +// ) +// } +} + +/** + * A trait that provides a way to specify the target catalog and schema for a test. + */ +trait TargetCatalogAndSchemaMixin { + + protected def catalogInPipelineSpec: Option[String] = Option( + TestGraphRegistrationContext.DEFAULT_CATALOG + ) + + protected def schemaInPipelineSpec: Option[String] = Option( + TestGraphRegistrationContext.DEFAULT_DATABASE + ) +} + +object PipelineTest extends Logging { + /** System schemas per-catalog that's can't be directly deleted. */ + protected val systemSchemas: Set[String] = Set("default", "information_schema") + + /** System catalogs that are read-only and cannot be modified/dropped. */ + private val systemCatalogs: Set[String] = Set("samples") + + /** Catalogs that cannot be dropped but schemas or tables under it can be cleaned up. */ + private val undroppableCatalogs: Set[String] = Set( + "hive_metastore", + "spark_catalog", + "system", + "main" + ) + + /** Creates a temporary directory. */ + protected def createTempDir(): String = { + Files.createTempDirectory(getClass.getSimpleName).normalize.toString + } + + /** + * Try to drop the schema in the catalog and return whether it is successfully dropped. + */ + private def dropSchemaIfPossible( + spark: SparkSession, + catalogName: String, + schemaName: String): Boolean = { + try { + spark.sql(s"DROP SCHEMA IF EXISTS `$catalogName`.`$schemaName` CASCADE") + true + } catch { + case NonFatal(e) => + logInfo( + s"Failed to drop schema $schemaName in catalog $catalogName, ex:${e.getMessage}" + ) + false + } + } + + /** Cleanup resources created in the metastore by tests. */ + def cleanupMetastore(spark: SparkSession): Unit = synchronized { + // some tests stop the spark session and managed the cleanup by themself, so no need to + // cleanup if no active spark session found + if (spark.sparkContext.isStopped) { + return + } + val catalogs = + spark.sql(s"SHOW CATALOGS").collect().map(_.getString(0)).filterNot(systemCatalogs.contains) + catalogs.foreach { catalog => + if (undroppableCatalogs.contains(catalog)) { + val schemas = + spark.sql(s"SHOW SCHEMAS IN `$catalog`").collect().map(_.getString(0)) + schemas.foreach { schema => + if (systemSchemas.contains(schema) || !dropSchemaIfPossible(spark, catalog, schema)) { + spark + .sql(s"SHOW tables in `$catalog`.`$schema`") + .collect() + .map(_.getString(0)) + .foreach { table => + Try(spark.sql(s"DROP table IF EXISTS `$catalog`.`$schema`.`$table`")) match { + case Failure(e) => + logInfo( + s"Failed to drop table $table in schema $schema in catalog $catalog, " + + s"ex:${e.getMessage}" + ) + case _ => + } + } + } + } + } else { + Try(spark.sql(s"DROP CATALOG IF EXISTS `$catalog` CASCADE")) match { + case Failure(e) => + logInfo(s"Failed to drop catalog $catalog, ex:${e.getMessage}") + case _ => + } + } + } + spark.sessionState.catalog.reset() + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala new file mode 100644 index 0000000000000..d891d6529de96 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.AnalysisException + +/** + * Collection of helper methods to simplify working with exceptions in tests. + */ +trait SparkErrorTestMixin { + this: SparkFunSuite => + + /** + * Asserts that the given exception is a [[AnalysisException]] with the specific error class. + */ + def assertAnalysisException(ex: Throwable, errorClass: String): Unit = { + ex match { + case t: AnalysisException => + assert( + t.getCondition == errorClass, + s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}" + ) + case _ => fail(s"Expected analysis exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[AnalysisException]]h the specific error class + * and metadata. + */ + def assertAnalysisException( + ex: Throwable, + errorClass: String, + metadata: Map[String, String] + ): Unit = { + ex match { + case t: AnalysisException => + assert( + t.getCondition == errorClass, + s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}" + ) + assert( + t.getMessageParameters.asScala.toMap == metadata, + s"Expected analysis exception with metadata $metadata, but got " + + s"${t.getMessageParameters}" + ) + case _ => fail(s"Expected analysis exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[SparkException]] with the specific error class. + */ + def assertSparkException(ex: Throwable, errorClass: String): Unit = { + ex match { + case t: SparkException => + assert( + t.getCondition == errorClass, + s"Expected spark exception with error class $errorClass, but got ${t.getCondition}" + ) + case _ => fail(s"Expected spark exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[SparkException]] with the specific error class + * and metadata. + */ + def assertSparkException( + ex: Throwable, + errorClass: String, + metadata: Map[String, String] + ): Unit = { + ex match { + case t: SparkException => + assert( + t.getCondition == errorClass, + s"Expected spark exception with error class $errorClass, but got ${t.getCondition}" + ) + assert( + t.getMessageParameters.asScala.toMap == metadata, + s"Expected spark exception with metadata $metadata, but got ${t.getMessageParameters}" + ) + case _ => fail(s"Expected spark exception but got ${ex.getClass}") + } + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala new file mode 100644 index 0000000000000..c74352722b44d --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{LocalTempView, UnresolvedRelation, ViewType} +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.pipelines.graph.{ + DataflowGraph, + FlowAnalysis, + FlowFunction, + GraphIdentifierManager, + GraphRegistrationContext, + PersistedView, + QueryOrigin, + Table, + TemporaryView, + UnresolvedFlow +} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A test class to simplify the creation of pipelines and datasets for unit testing. + */ +class TestGraphRegistrationContext( + val spark: SparkSession, + val sqlConf: Map[String, String] = Map.empty) + extends GraphRegistrationContext( + defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG, + defaultDatabase = TestGraphRegistrationContext.DEFAULT_DATABASE, + defaultSqlConf = sqlConf + ) { + + // scalastyle:off + // Disable scalastyle to ignore argument count. + def registerTable( + name: String, + query: Option[FlowFunction] = None, + sqlConf: Map[String, String] = Map.empty, + comment: Option[String] = None, + sqlText: Option[String] = None, + specifiedSchema: Option[StructType] = None, + partitionCols: Option[Seq[String]] = None, + properties: Map[String, String] = Map.empty, + baseOrigin: QueryOrigin = QueryOrigin.empty, + format: Option[String] = None, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + // scalastyle:on + val tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + registerTable( + Table( + identifier = GraphIdentifierManager.parseTableIdentifier(name, spark), + comment = comment, + sqlText = sqlText, + specifiedSchema = specifiedSchema, + partitionCols = partitionCols, + properties = properties, + baseOrigin = baseOrigin, + format = format.orElse(Some("parquet")), + normalizedPath = None, + isStreamingTableOpt = None + ) + ) + + if (query.isDefined) { + registerFlow( + new UnresolvedFlow( + identifier = tableIdentifier, + destinationIdentifier = tableIdentifier, + func = query.get, + sqlConf = sqlConf, + once = false, + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)), + comment = comment, + origin = baseOrigin + ) + ) + } + } + + def registerView( + name: String, + query: FlowFunction, + sqlConf: Map[String, String] = Map.empty, + comment: Option[String] = None, + sqlText: Option[String] = None, + origin: QueryOrigin = QueryOrigin.empty, + viewType: ViewType = LocalTempView, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + + val viewIdentifier = GraphIdentifierManager + .parseAndValidateTemporaryViewIdentifier(rawViewIdentifier = TableIdentifier(name)) + + registerView( + viewType match { + case LocalTempView => + TemporaryView( + identifier = viewIdentifier, + comment = comment, + sqlText = sqlText, + origin = origin, + properties = Map.empty + ) + case _ => + PersistedView( + identifier = viewIdentifier, + comment = comment, + sqlText = sqlText, + origin = origin, + properties = Map.empty + ) + } + ) + + registerFlow( + new UnresolvedFlow( + identifier = viewIdentifier, + destinationIdentifier = viewIdentifier, + func = query, + sqlConf = sqlConf, + once = false, + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)), + comment = comment, + origin = origin + ) + ) + } + + def registerFlow( + destinationName: String, + name: String, + query: FlowFunction, + once: Boolean = false, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + val flowDestinationIdentifier = + GraphIdentifierManager.parseTableIdentifier(destinationName, spark) + + registerFlow( + new UnresolvedFlow( + identifier = flowIdentifier, + destinationIdentifier = flowDestinationIdentifier, + func = query, + sqlConf = Map.empty, + once = once, + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)), + comment = None, + origin = QueryOrigin() + ) + ) + } + + /** + * Creates a flow function from a logical plan that reads from a table with the given name. + */ + def readFlowFunc(name: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(UnresolvedRelation(TableIdentifier(name))) + } + + /** + * Creates a flow function from a logical plan that reads a stream from a table with the given + * name. + */ + def readStreamFlowFunc(name: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan( + UnresolvedRelation( + TableIdentifier(name), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + } + + /** + * Creates a flow function from a logical plan parsed from the given SQL text. + */ + def sqlFlowFunc(spark: SparkSession, sql: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(spark.sessionState.sqlParser.parsePlan(sql)) + } + + /** + * Creates a flow function from a logical plan from the given DataFrame. This is meant for + * DataFrames that don't read from tables within the pipeline. + */ + def dfFlowFunc(df: DataFrame): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(df.logicalPlan) + } + + /** + * Generates a dataflow graph from this pipeline definition and resolves it. + * @return + */ + def resolveToDataflowGraph(): DataflowGraph = toDataflowGraph.resolve() +} + +object TestGraphRegistrationContext { + // TODO: Add support for custom catalogs in tests. + val DEFAULT_CATALOG = "spark_catalog" + val DEFAULT_DATABASE = "test_db" +} From 573c0bf719b026231239ec13ce25ec3896dbbfa8 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 13:13:23 -0700 Subject: [PATCH 02/32] 2 --- .../spark/sql/pipelines/graph/GraphRegistrationContext.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala index 646ed68465c95..2d63c3a94d82b 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -49,7 +49,6 @@ class GraphRegistrationContext( def toDataflowGraph: DataflowGraph = { val qualifiedTables = tables.toSeq.map { t => - // TODO: Track session-level catalog changes. t.copy( identifier = GraphIdentifierManager .parseAndQualifyTableIdentifier( @@ -93,7 +92,6 @@ class GraphRegistrationContext( if (isImplicitFlow && flowWritesToView) { f } else { - // TODO: Track session-level catalog changes. f.copy( identifier = GraphIdentifierManager .parseAndQualifyFlowIdentifier( From 1aa2253972863ea47f5be0663f151d0e618de01f Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 13:36:28 -0700 Subject: [PATCH 03/32] 3 --- project/SparkBuild.scala | 90 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cdf87dcc142e4..eb41554c00974 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -418,7 +418,7 @@ object SparkBuild extends PomBuild { enable(HiveThriftServer.settings)(hiveThriftServer) - enable(SparkConnect.settings)(pipelines) + enable(SparkDeclarativePipelines.settings)(pipelines) enable(SparkConnectCommon.settings)(connectCommon) enable(SparkConnect.settings)(connect) @@ -886,6 +886,94 @@ object SparkConnectClient { ) } +object SparkDeclarativePipelines { + import BuildCommons.protoVersion + + lazy val settings = Seq( + // For some reason the resolution from the imported Maven build does not work for some + // of these dependendencies that we need to shade later on. + libraryDependencies ++= { + val guavaVersion = + SbtPomKeys.effectivePom.value.getProperties.get( + "connect.guava.version").asInstanceOf[String] + val guavaFailureaccessVersion = + SbtPomKeys.effectivePom.value.getProperties.get( + "guava.failureaccess.version").asInstanceOf[String] + Seq( + "com.google.guava" % "guava" % guavaVersion, + "com.google.guava" % "failureaccess" % guavaFailureaccessVersion, + "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" + ) + }, + + dependencyOverrides ++= { + val guavaVersion = + SbtPomKeys.effectivePom.value.getProperties.get( + "connect.guava.version").asInstanceOf[String] + val guavaFailureaccessVersion = + SbtPomKeys.effectivePom.value.getProperties.get( + "guava.failureaccess.version").asInstanceOf[String] + Seq( + "com.google.guava" % "guava" % guavaVersion, + "com.google.guava" % "failureaccess" % guavaFailureaccessVersion, + "com.google.protobuf" % "protobuf-java" % protoVersion + ) + }, + + (assembly / test) := { }, + + (assembly / logLevel) := Level.Info, + + // Exclude `scala-library` from assembly. + (assembly / assemblyPackageScala / assembleArtifact) := false, + + // SPARK-46733: Include `spark-connect-*.jar`, `unused-*.jar`,`guava-*.jar`, + // `failureaccess-*.jar`, `annotations-*.jar`, `grpc-*.jar`, `protobuf-*.jar`, + // `gson-*.jar`, `error_prone_annotations-*.jar`, `j2objc-annotations-*.jar`, + // `animal-sniffer-annotations-*.jar`, `perfmark-api-*.jar`, + // `proto-google-common-protos-*.jar` in assembly. + // This needs to be consistent with the content of `maven-shade-plugin`. + (assembly / assemblyExcludedJars) := { + val cp = (assembly / fullClasspath).value + val validPrefixes = Set("spark-connect", "unused-", "guava-", "failureaccess-", + "annotations-", "grpc-", "protobuf-", "gson", "error_prone_annotations", + "j2objc-annotations", "animal-sniffer-annotations", "perfmark-api", + "proto-google-common-protos") + cp filterNot { v => + validPrefixes.exists(v.data.getName.startsWith) + } + }, + + (assembly / assemblyShadeRules) := Seq( + ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll, + ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll, + ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll, + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll, + ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.android_annotation.@1").inAll, + ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.io_perfmark.@1").inAll, + ShadeRule.rename("org.codehaus.mojo.animal_sniffer.**" -> "org.sparkproject.connect.animal_sniffer.@1").inAll, + ShadeRule.rename("com.google.j2objc.annotations.**" -> "org.sparkproject.connect.j2objc_annotations.@1").inAll, + ShadeRule.rename("com.google.errorprone.annotations.**" -> "org.sparkproject.connect.errorprone_annotations.@1").inAll, + ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.checkerframework.@1").inAll, + ShadeRule.rename("com.google.gson.**" -> "org.sparkproject.connect.gson.@1").inAll, + ShadeRule.rename("com.google.api.**" -> "org.sparkproject.connect.google_protos.api.@1").inAll, + ShadeRule.rename("com.google.cloud.**" -> "org.sparkproject.connect.google_protos.cloud.@1").inAll, + ShadeRule.rename("com.google.geo.**" -> "org.sparkproject.connect.google_protos.geo.@1").inAll, + ShadeRule.rename("com.google.logging.**" -> "org.sparkproject.connect.google_protos.logging.@1").inAll, + ShadeRule.rename("com.google.longrunning.**" -> "org.sparkproject.connect.google_protos.longrunning.@1").inAll, + ShadeRule.rename("com.google.rpc.**" -> "org.sparkproject.connect.google_protos.rpc.@1").inAll, + ShadeRule.rename("com.google.type.**" -> "org.sparkproject.connect.google_protos.type.@1").inAll + ), + + (assembly / assemblyMergeStrategy) := { + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard + // Drop all proto files that are not needed as artifacts of the build. + case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard + case _ => MergeStrategy.first + } + ) +} + object SparkProtobuf { import BuildCommons.protoVersion From 249fc3ab944a7b71a53db2f0366b10d4bbea0b10 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 13:46:58 -0700 Subject: [PATCH 04/32] 4 --- .../graph/ConnectInvalidPipelineSuite.scala | 32 ----- .../graph/ConnectValidPipelineSuite.scala | 134 ------------------ 2 files changed, 166 deletions(-) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index 091b77424b3b4..a384143fed0a2 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -186,7 +186,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest { val ex = intercept[UnresolvedPipelineException] { dfg.validate() } - // This error message has changed between DBR versions. assert( ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) .getMessage @@ -432,35 +431,4 @@ class ConnectInvalidPipelineSuite extends PipelineTest { assert(!ex1.getMessage.contains(streamingTableHint)) assert(ex2.getMessage.contains(streamingTableHint)) } - -// test("Sink has wrong type of data coming in") { -// class P extends TestPipelineDefinition { -// val mem = MemoryStream[Int] -// mem.addData(1, 2, 3) -// createView("a", query =(mem.toDF()) -// createView("b") , query =(readStream("a")) -// .sink( -// func = df => df.writeStream, -// userValidation = df => -// if (df.schema != new StructType().add("value", StringType, false)) { -// throw new IllegalStateException() -// } -// ) -// } -// val p = DataflowGraph(new P).connect -// intercept[IllegalStateException] { p.validate() } -// } - -// test("Flow to sink cannot be resolved") { -// class P extends TestPipelineDefinition { -// createView("a", query =(() => sys.error("not available")) -// createView("b", query =(readStream("a")).sink(df => df.writeStream) -// } -// val p = DataflowGraph(new P()).connect -// assert(!p.resolved) -// val ex = intercept[UnresolvedPipelineException] { p.validate() } -// assert(ex.directFailures.keySet == Set("a")) -// assert(ex.downstreamFailures.keySet == Set("b")) -// assert(ex.directFailures("a").getMessage.contains("not available")) -// } } diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index 5f53ab84b8b7a..7009786fc1b18 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -79,9 +79,6 @@ class ConnectValidPipelineSuite extends PipelineTest { test("Multi-hop schema merging") { class P extends TestGraphRegistrationContext(spark) { - -// import org.apache.spark.sql.functions._ fixme - registerView( "b", query = sqlFlowFunc(spark, """SELECT * FROM VALUES ((1)) OUTER JOIN d ON false""") @@ -230,46 +227,6 @@ class ConnectValidPipelineSuite extends PipelineTest { verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schema) } -// test("connect from actual table") { -// val rootDir = Files.createTempDirectory("evaluation").toString -// -// def tablePath(tableName: String): String = s"$rootDir/$tableName" -// -// Seq(3, 4).toDF().write.format("delta").save(tablePath("b")) -// val p = legacyNormalizeTablePaths( -// DataflowGraph(new TestPipelineDefinition { -// createView("a", query = () => Seq(3, 4).toDF("incorrect-col-name")) -// createTable("b", query = read("a"), path = Option(tablePath("b"))) -// createView("c", query = read("b")) -// createView("d", query = () => spark.readStream.table("b")) -// }), -// dbfsRoot = rootDir -// ).connect.validate() -// assert(p.flowByName("c").df.schema == new StructType().add("value", IntegerType)) -// assert(p.flowByName("d").df.schema == new StructType().add("value", IntegerType)) -// } -// -// test("connect from source") { -// val rootDir = Files.createTempDirectory("evaluation").toString -// val inputPath = rootDir + "/input" -// val g = legacyNormalizeTablePaths( -// DataflowGraph(new TestPipelineDefinition { -// createView("a", query = () => spark.spark.table.format("delta").load(inputPath)) -// createTable("b", query = read("a")).expectOrFail("blah", $"value" > 0) -// }), -// rootDir -// ) -// Seq(1).toDF().write.format("delta").save(inputPath) -// Seq(3).toDF().write.format("delta").save(g.tableByName("b").path) -// val graph1 = g.virtualize.connect.validate() -// assert(graph1.resolved) -// assert(graph1.flowByName("b").df.schema == new StructType().add("value", IntegerType)) -// -// val graph2 = g.connect.validate() -// assert(graph2.resolved) -// assert(graph2.flowByName("b").df.schema == new StructType().add("value", IntegerType)) -// } - test("Connect retains and fuses confs") { // a -> b \ // d @@ -476,58 +433,6 @@ class ConnectValidPipelineSuite extends PipelineTest { ) } -// test("Dataset names are case sensitive and can be spark.table correctly") { -// class P extends TestPipelineDefinition { -// createView("A", query = dfFlowFunc(spark.range(5)) -// createView("a", query = readFlowFunc("A").filter("id = 0").toDF("value")) -// createView("b", query = readFlowFunc("a")) -// } -// -// val p = new P().resolveToDataflowGraph() -// verifyFlowSchema(p, "b", new StructType().add("value", LongType, nullable = false)) -// } - -// test("Dataset names are case sensitive and will not spark.table in a case insensitive way") { -// class P extends TestPipelineDefinition { -// createView("A", query = () => spark.range(5).toDF()) -// createView("b", query = sql(spark, "SELECT x as y FROM a"))) -// } -// -// val dfg = DataflowGraph(new P).connect -// assert(!dfg.resolved, "Pipeline should not have resolved properly") -// val ex = intercept[UnresolvedPipelineException] { -// dfg.validate() -// } -// assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) -// assertSparkException( -// ex.directFailures("b"), -// "DATASET_NOT_DEFINED", -// Map("datasetName" -> "a") -// ) -// } - -// test("Verify duplicate table names due to lower-casing is caught") { -// class P extends TestPipelineDefinition { -// createTable("A", query = Option(() => spark.range(5).toDF())) -// createTable("a", query = read("A").filter("id = 0").toDF("value")) -// } -// val ex = intercept[AnalysisException] { -// (new P).resolveToDataflowGraph(shouldLowerCaseNames = true) -// } -// assert(ex.getMessage.contains("Found duplicate table")) -// } -// -// test("Verify duplicate view names due to lower-casing is caught") { -// class P extends TestPipelineDefinition { -// createTable("A", query = () => spark.range(5)) -// createView("a", query = read("A").filter("id = 0").toDF("value")) -// } -// val ex = intercept[AnalysisException] { -// (new P).resolveToDataflowGraph(shouldLowerCaseNames = true) -// } -// assert(ex.getMessage.contains("Found duplicate dataset")) -// } - /** Verifies the [[DataflowGraph]] has the specified [[Flow]] with the specified schema. */ private def verifyFlowSchema( pipeline: DataflowGraph, @@ -547,43 +452,4 @@ class ConnectValidPipelineSuite extends PipelineTest { s"Flow ${identifier.unquotedString} has the wrong schema" ) } - -// test("external sink") { -// SparkSessionUtils.withSQLConf( -// spark, -// ("pipelines.externalSink.enabled", true.toString) -// ) { -// class P extends TestPipelineDefinition { -// val mem = MemoryStream[Int] -// mem.addData(1, 2) -// createView("a", query = mem.toDF().select($"value" as "x")) -// createExternalSink("sink_a", format = "memory") -// createSinkForTopLevelFlow("sink_a", query = "sink_flow", readStream("a")) -// } -// val g = new P().resolveToDataflowGraph() -// g.validate() -// assert(g.resolved) -// assert(g.sinkByName("sink_a").isInstanceOf[ExternalSink]) -// val externalSink = g.sinkByName("sink_a").asInstanceOf[ExternalSink] -// assert(externalSink.format == "memory") -// assert(g.flowByName("sink_flow").isInstanceOf[ExternalSinkFlow]) -// } -// } -// -// test("Writing from view to sink") { -// class P extends TestPipelineDefinition { -// val mem = MemoryStream[Int] -// mem.addData(1, 2, 3) -// createView("a", query = mem.toDF()) -// createView("b") -// .query(readStream("a")) -// .sink( -// func = df => df.writeStream, -// userValidation = -// df => assert(df.schema == new StructType().add("value", IntegerType, false)) -// ) -// } -// val p = DataflowGraph(new P).connect -// assert(p.resolved) -// } } From a2c3b39f6636b81a182e96ec72071662e35de3cb Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 15:55:21 -0700 Subject: [PATCH 05/32] 5 --- project/SparkBuild.scala | 35 ---------- sql/pipelines/pom.xml | 66 ++----------------- .../spark/sql/pipelines/graph/Flow.scala | 4 ++ .../sql/pipelines/graph/FlowAnalysis.scala | 7 +- .../graph/GraphIdentifierManager.scala | 6 -- .../graph/GraphRegistrationContext.scala | 2 + .../sql/pipelines/graph/PipelinesErrors.scala | 23 ------- .../sql/pipelines/graph/QueryOrigin.scala | 13 ++-- .../utils/TestGraphRegistrationContext.scala | 1 - 9 files changed, 18 insertions(+), 139 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index eb41554c00974..1802c6f5ab282 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -906,20 +906,6 @@ object SparkDeclarativePipelines { ) }, - dependencyOverrides ++= { - val guavaVersion = - SbtPomKeys.effectivePom.value.getProperties.get( - "connect.guava.version").asInstanceOf[String] - val guavaFailureaccessVersion = - SbtPomKeys.effectivePom.value.getProperties.get( - "guava.failureaccess.version").asInstanceOf[String] - Seq( - "com.google.guava" % "guava" % guavaVersion, - "com.google.guava" % "failureaccess" % guavaFailureaccessVersion, - "com.google.protobuf" % "protobuf-java" % protoVersion - ) - }, - (assembly / test) := { }, (assembly / logLevel) := Level.Info, @@ -944,27 +930,6 @@ object SparkDeclarativePipelines { } }, - (assembly / assemblyShadeRules) := Seq( - ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll, - ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll, - ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll, - ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll, - ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.android_annotation.@1").inAll, - ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.io_perfmark.@1").inAll, - ShadeRule.rename("org.codehaus.mojo.animal_sniffer.**" -> "org.sparkproject.connect.animal_sniffer.@1").inAll, - ShadeRule.rename("com.google.j2objc.annotations.**" -> "org.sparkproject.connect.j2objc_annotations.@1").inAll, - ShadeRule.rename("com.google.errorprone.annotations.**" -> "org.sparkproject.connect.errorprone_annotations.@1").inAll, - ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.checkerframework.@1").inAll, - ShadeRule.rename("com.google.gson.**" -> "org.sparkproject.connect.gson.@1").inAll, - ShadeRule.rename("com.google.api.**" -> "org.sparkproject.connect.google_protos.api.@1").inAll, - ShadeRule.rename("com.google.cloud.**" -> "org.sparkproject.connect.google_protos.cloud.@1").inAll, - ShadeRule.rename("com.google.geo.**" -> "org.sparkproject.connect.google_protos.geo.@1").inAll, - ShadeRule.rename("com.google.logging.**" -> "org.sparkproject.connect.google_protos.logging.@1").inAll, - ShadeRule.rename("com.google.longrunning.**" -> "org.sparkproject.connect.google_protos.longrunning.@1").inAll, - ShadeRule.rename("com.google.rpc.**" -> "org.sparkproject.connect.google_protos.rpc.@1").inAll, - ShadeRule.rename("com.google.type.**" -> "org.sparkproject.connect.google_protos.type.@1").inAll - ), - (assembly / assemblyMergeStrategy) := { case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard // Drop all proto files that are not needed as artifacts of the build. diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 7d791e57bd000..b18ee7adc8fb9 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -43,64 +43,6 @@ ${project.version} test - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - - com.google.protobuf - protobuf-java - ${protobuf.version} - - - com.google.protobuf - protobuf-java-util - compile - - - - com.google.guava - guava - ${connect.guava.version} - compile - - - com.google.guava - failureaccess - ${guava.failureaccess.version} - compile - - - - - io.grpc - grpc-stub - ${io.grpc.version} - - - io.grpc - grpc-netty - ${io.grpc.version} - - - io.grpc - grpc-protobuf - ${io.grpc.version} - - - io.grpc - grpc-services - ${io.grpc.version} - - - - org.scala-lang.modules - scala-parser-combinators_${scala.binary.version} - org.apache.spark spark-core_${scala.binary.version} @@ -108,10 +50,8 @@ org.apache.spark - spark-core_${scala.binary.version} + spark-sql_${scala.binary.version} ${project.version} - test-jar - test org.apache.spark @@ -121,8 +61,10 @@ org.apache.spark - spark-sql_${scala.binary.version} + spark-core_${scala.binary.version} ${project.version} + test-jar + test org.apache.spark diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index f0a80962f0994..cc45f8897a776 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -26,6 +26,10 @@ import org.apache.spark.sql.pipelines.AnalysisWarning import org.apache.spark.sql.pipelines.util.InputReadOptions import org.apache.spark.sql.types.StructType +/** + * A [[Flow]] is a node of data transformation in a dataflow graph. It describes the movement + * of data into a particular dataset. + */ trait Flow extends GraphElement with Logging { /** The [[FlowFunction]] containing the user's query. */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 57bb82fb77fad..941ceec25916a 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -73,21 +73,18 @@ object FlowAnalysis { /** * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific - * TVFs and datasets that cannot be resolved directly by Catalyst. This is used in SQL pipelines - * and Python pipelines (when users call spark.sql). + * TVFs and datasets that cannot be resolved directly by Catalyst. * * This function shouldn't call any singleton as it will break concurrent access to graph * analysis; or any thread local variables as graph analysis and this function will use * different threads in python repl. * * @param plan The [[LogicalPlan]] defining a flow. - * @param forAnalysisApi Whether the query is being analyzed through Analysis API algorithm. * @return An analyzed [[DataFrame]]. */ private def analyze( context: FlowAnalysisContext, - plan: LogicalPlan, - forAnalysisApi: Boolean = false + plan: LogicalPlan ): DataFrame = { // Users can define CTEs within their CREATE statements. For example, // diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala index 49d8dd35b5e2f..f2bc76d634546 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala @@ -69,8 +69,6 @@ object GraphIdentifierManager { // fully/partially qualified name (e.g., "SELECT * FROM catalog.schema.a" or "SELECT * FROM // schema.a"). val inputIdentifier = parseTableIdentifier(rawInputName, context.spark) - // TODO: once pipeline spec catalog/schema is propagated, use that catalog via DSv2 API - val catalog = context.spark.sessionState.catalog /** Return whether we're referencing a dataset that is part of the pipeline. */ def isInternalDataset(identifier: TableIdentifier): Boolean = { @@ -81,9 +79,6 @@ object GraphIdentifierManager { isInternalDataset(inputIdentifier)) { // reading a single-part-name dataset defined in the dataflow graph (e.g., a view) InternalDatasetIdentifier(identifier = inputIdentifier) - } else if (catalog.isTempView(inputIdentifier)) { - // referencing a temp view or temp table in the current spark session - ExternalDatasetIdentifier(identifier = inputIdentifier) } else if (isPathIdentifier(context.spark, inputIdentifier)) { // path-based reference, always read as external dataset ExternalDatasetIdentifier(identifier = inputIdentifier) @@ -326,7 +321,6 @@ object IdentifierHelper { /** Assert whether the identifier is properly fully qualified when creating a dataset. */ def assertIsFullyQualifiedForCreate(identifier: TableIdentifier): Unit = { - // TODO: validate catalog exists assert( identifier.catalog.isDefined && identifier.database.isDefined, s"Dataset identifier $identifier is not properly fully qualified, expect a " + diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala index 2d63c3a94d82b..0e2ba42b15e59 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier /** + * A mutable context for registering tables, views, and flows in a dataflow graph. + * * @param defaultCatalog The pipeline's default catalog. * @param defaultDatabase The pipeline's default schema. */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 9a88ad70a7601..1b083a0c8fdf4 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import scala.annotation.unused - import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -81,27 +79,6 @@ trait GraphValidationWarning extends Logging { /** The exception to throw when this validation fails. */ protected def exception: AnalysisException - - /** The details of the event to construct when this validation fails. */ - // protected def warningEventDetails: EventDetails => EventDetails - - // /** Log the exception message and construct a warning event log. */ - // def logAndConstructWarningEventLog(origin: Origin): PipelineEvent = { - // logWarning(exception.getMessage) - // ConstructPipelineEvent( - // origin = origin, - // level = EventLevel.WARN, - // message = exception.getMessage, - // details = warningEventDetails - // ) - // } - - /** Log the exception message and throw the exception. */ - @unused - def logAndThrow(): Unit = { - logError(exception.getMessage) - throw exception - } } /** diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala index d85c59287c683..3467f3d88d630 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.pipelines.graph import scala.util.control.{NonFatal, NoStackTrace} +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.pipelines.Language import org.apache.spark.sql.pipelines.logging.SourceCodeLocation @@ -47,6 +48,7 @@ case class QueryOrigin( objectType: Option[String] = None, objectName: Option[String] = None ) { + /** * Merges this origin with another one. * @@ -94,7 +96,7 @@ case class QueryOrigin( ) } -object QueryOrigin { +object QueryOrigin extends Logging { /** An empty QueryOrigin without any provenance information. */ val empty: QueryOrigin = QueryOrigin() @@ -103,9 +105,7 @@ object QueryOrigin { * An exception that wraps [[QueryOrigin]] and lets us store it in errors as suppressed * exceptions. */ - private case class QueryOriginWrapper(origin: QueryOrigin) - extends Exception - with NoStackTrace + private case class QueryOriginWrapper(origin: QueryOrigin) extends Exception with NoStackTrace implicit class ExceptionHelpers(t: Throwable) { @@ -124,8 +124,7 @@ object QueryOrigin { t.addSuppressed(QueryOriginWrapper(origin)) } } catch { - case NonFatal(e) => - // logger.error("Failed to add pipeline context", e) + case NonFatal(e) => logError("Failed to add pipeline context", e) } t } @@ -146,7 +145,7 @@ object QueryOrigin { } } catch { case NonFatal(e) => - // logger.error("Failed to get pipeline context", e) + logError("Failed to get pipeline context", e) None } } diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index c74352722b44d..3257a39b2a95f 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -219,7 +219,6 @@ class TestGraphRegistrationContext( } object TestGraphRegistrationContext { - // TODO: Add support for custom catalogs in tests. val DEFAULT_CATALOG = "spark_catalog" val DEFAULT_DATABASE = "test_db" } From 2c22e191dbff1ab5ad405ed4b3a0d72a68a1d162 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 23 May 2025 16:11:21 -0700 Subject: [PATCH 06/32] 6 --- project/SparkBuild.scala | 2 -- .../apache/spark/sql/pipelines/graph/DataflowGraph.scala | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1802c6f5ab282..c1bb8c93b1116 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -906,8 +906,6 @@ object SparkDeclarativePipelines { ) }, - (assembly / test) := { }, - (assembly / logLevel) := Level.Info, // Exclude `scala-library` from assembly. diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 958dab2b647a8..37b3032c78109 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -132,9 +132,9 @@ class DataflowGraph(val flows: Seq[Flow], val tables: Seq[Table], val views: Seq /** Returns a copy of this [[DataflowGraph]] with optionally replaced components. */ def copy( - flows: Seq[Flow] = flows, - tables: Seq[Table] = tables, - views: Seq[View] = views): DataflowGraph = { + flows: Seq[Flow] = flows, + tables: Seq[Table] = tables, + views: Seq[View] = views): DataflowGraph = { new DataflowGraph(flows, tables, views) } From 61e10004fc3d3f30e560d7d9579a8f15fffd48c4 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Tue, 27 May 2025 22:43:28 -0700 Subject: [PATCH 07/32] comments p1 --- .../pipelines/graph/CoreDataflowNodeProcessor.scala | 11 +++++------ .../pipelines/graph/ConnectInvalidPipelineSuite.scala | 2 +- .../pipelines/graph/ConnectValidPipelineSuite.scala | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index e1400331f4148..74715d40b0278 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ } /** - * Core processor that is responsible for analyzing each flow and sort the nodes in + * Processor that is responsible for analyzing each flow and sort the nodes in * topological order */ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { @@ -58,10 +58,11 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { /** * Processes the node of the graph, re-arranging them if they are not topologically sorted. - * Takes care of resolving the flows and virtualization if needed for the nodes. + * Takes care of resolving the flows and virtualizing tables (i.e. removing tables to + * ensure resolution is internally consistent) if needed for the nodes. * @param node The node to process * @param upstreamNodes Upstream nodes for the node - * @return + * @return The resolved nodes generated by processing this element. */ def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = { node match { @@ -137,9 +138,7 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { val result = flowFunctionResult match { case f if f.dataFrame.isSuccess => - // Merge the flow's inputs' confs into confs for this flow, throwing if any conflict - // But do not merge confs from inputs that are tables; we don't want to propagate confs - // past materialization points. + // Merge confs from any upstream views into confs for this flow. val allFConfs = (flowToResolve +: f.inputs.toSeq diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index a384143fed0a2..552aa18d00f73 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistration import org.apache.spark.sql.types.{IntegerType, StructType} /** - * Test suite for converting one or more [[Pipeline]]s into a connected [[DataflowGraph]]. These + * Test suite for resolving the flows in a [[DataflowGraph]]. These * examples are all semantically correct but contain logical errors which should be found * when connect is called and thrown when validate() is called. */ diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index 7009786fc1b18..f8b5133ff167d 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * Test suite for converting a [[PipelineDefinition]]s into a connected [[DataflowGraph]]. These + * Test suite for resolving the flows in a [[DataflowGraph]]. These * examples are all semantically correct and logically correct and connect should not result in any * errors. */ From d2effa237475128560a18146788f6b262c58c6fb Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Tue, 27 May 2025 22:52:13 -0700 Subject: [PATCH 08/32] 1 --- .../resources/error/error-conditions.json | 2 +- project/SparkBuild.scala | 4 +- .../apache/spark/sql/pipelines/Language.scala | 6 ++- .../spark/sql/pipelines/graph/Flow.scala | 50 +++++-------------- .../sql/pipelines/graph/FlowAnalysis.scala | 21 +++----- 5 files changed, 26 insertions(+), 57 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 4b38ac8954f67..35621b509489a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2027,7 +2027,7 @@ }, "INCOMPATIBLE_BATCH_VIEW_READ": { "message": [ - "View is not a streaming view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + "View is not a batch view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." ], "sqlState": "42000" }, diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c1bb8c93b1116..e3db94eb39b3d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -896,12 +896,12 @@ object SparkDeclarativePipelines { val guavaVersion = SbtPomKeys.effectivePom.value.getProperties.get( "connect.guava.version").asInstanceOf[String] - val guavaFailureaccessVersion = + val guavaFailureAccessVersion = SbtPomKeys.effectivePom.value.getProperties.get( "guava.failureaccess.version").asInstanceOf[String] Seq( "com.google.guava" % "guava" % guavaVersion, - "com.google.guava" % "failureaccess" % guavaFailureaccessVersion, + "com.google.guava" % "failureaccess" % guavaFailureAccessVersion, "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" ) }, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala index 77ca6f84b7fbc..c627850b667be 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.pipelines sealed trait Language {} -case class Python() extends Language {} +object Language { + case class Python() extends Language {} + case class Sql() extends Language {} +} -case class Sql() extends Language {} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index cc45f8897a776..68a587c79a63f 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -90,15 +90,15 @@ trait FlowFunction extends Logging { /** * Holds the [[DataFrame]] returned by a [[FlowFunction]] along with the inputs used to * construct it. - * @param usedBatchInputs the identifiers of the complete inputs read by the flow - * @param usedStreamingInputs the identifiers of the incremental inputs read by the flow + * @param batchInputs the complete inputs read by the flow + * @param sreamingInputs the incremental inputs read by the flow * @param usedExternalInputs the identifiers of the external inputs read by the flow * @param dataFrame the [[DataFrame]] expression executed by the flow if the flow can be resolved */ case class FlowFunctionResult( requestedInputs: Set[TableIdentifier], - usedBatchInputs: Set[ResolvedInput], - usedStreamingInputs: Set[ResolvedInput], + batchInputs: Set[ResolvedInput], + streamingInputs: Set[ResolvedInput], usedExternalInputs: Set[String], dataFrame: Try[DataFrame], sqlConf: Map[String, String], @@ -113,12 +113,6 @@ case class FlowFunctionResult( (batchInputs ++ streamingInputs).map(_.input.identifier) } - /** Names of [[Input]]s read completely by this [[Flow]]. */ - def batchInputs: Set[ResolvedInput] = usedBatchInputs - - /** Names of [[Input]]s read incrementally by this [[Flow]]. */ - def streamingInputs: Set[ResolvedInput] = usedStreamingInputs - /** Returns errors that occurred when attempting to analyze this [[Flow]]. */ def failure: Seq[Throwable] = { dataFrame.failed.toOption.toSeq @@ -129,35 +123,17 @@ case class FlowFunctionResult( } /** A [[Flow]] whose output schema and dependencies aren't known. */ -class UnresolvedFlow( - val identifier: TableIdentifier, - val destinationIdentifier: TableIdentifier, - val func: FlowFunction, - val currentCatalog: Option[String], - val currentDatabase: Option[String], - val sqlConf: Map[String, String], - val comment: Option[String] = None, +case class UnresolvedFlow( + identifier: TableIdentifier, + destinationIdentifier: TableIdentifier, + func: FlowFunction, + currentCatalog: Option[String], + currentDatabase: Option[String], + sqlConf: Map[String, String], + comment: Option[String] = None, override val once: Boolean, override val origin: QueryOrigin -) extends Flow { - def copy( - identifier: TableIdentifier = identifier, - destinationIdentifier: TableIdentifier = destinationIdentifier, - sqlConf: Map[String, String] = sqlConf - ): UnresolvedFlow = { - new UnresolvedFlow( - identifier = identifier, - destinationIdentifier = destinationIdentifier, - func = func, - currentCatalog = currentCatalog, - currentDatabase = currentDatabase, - sqlConf = sqlConf, - comment = comment, - once = once, - origin = origin - ) - } -} +) extends Flow /** * A [[Flow]] whose flow function has been invoked, meaning either: diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 941ceec25916a..98ad5e149b7ba 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -17,23 +17,14 @@ package org.apache.spark.sql.pipelines.graph -import scala.util.Try - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} -import org.apache.spark.sql.pipelines.{AnalysisWarning, Sql} -import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ - ExternalDatasetIdentifier, - InternalDatasetIdentifier -} -import org.apache.spark.sql.pipelines.util.{ - BatchReadOptions, - InputReadOptions, - StreamingReadOptions -} +import org.apache.spark.sql.classic.{DataFrame, DataStreamReader, Dataset, SparkSession} +import org.apache.spark.sql.pipelines.{AnalysisWarning, Language} +import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} +import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} object FlowAnalysis { def createFlowFunctionFromLogicalPlan(plan: LogicalPlan): FlowFunction = { @@ -115,7 +106,7 @@ object FlowAnalysis { name = IdentifierHelper.toQuotedString(u.multipartIdentifier), spark.readStream, streamingReadOptions = StreamingReadOptions( - apiLanguage = Sql() + apiLanguage = Language.Sql() ) ).queryExecution.analyzed @@ -124,7 +115,7 @@ object FlowAnalysis { readBatchInput( context, name = IdentifierHelper.toQuotedString(u.multipartIdentifier), - batchReadOptions = BatchReadOptions(apiLanguage = Sql()) + batchReadOptions = BatchReadOptions(apiLanguage = Language.Sql()) ).queryExecution.analyzed } Dataset.ofRows(spark, resolvedPlan) From 676ee9df967b345d05d63d3b5588629047763f85 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Tue, 27 May 2025 23:38:09 -0700 Subject: [PATCH 09/32] queryContext addition --- .../graph/CoreDataflowNodeProcessor.scala | 6 ++--- .../spark/sql/pipelines/graph/Flow.scala | 26 +++++++++---------- .../sql/pipelines/graph/FlowAnalysis.scala | 14 +++++----- .../pipelines/graph/FlowAnalysisContext.scala | 6 ++--- .../graph/GraphIdentifierManager.scala | 4 +-- .../graph/ConnectInvalidPipelineSuite.scala | 2 +- .../utils/TestGraphRegistrationContext.scala | 19 +++++++++----- 7 files changed, 40 insertions(+), 37 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 74715d40b0278..9866ad3f61bd9 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -132,8 +132,7 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { allInputs = allInputs, availableInputs = availableResolvedInputs.values.toList, configuration = flowToResolve.sqlConf, - currentCatalog = flowToResolve.currentCatalog, - currentDatabase = flowToResolve.currentDatabase + queryContext = flowToResolve.queryContext ) val result = flowFunctionResult match { @@ -179,8 +178,7 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { allInputs = allInputs, availableInputs = availableResolvedInputs.values.toList, configuration = newSqlConf, - currentCatalog = flowToResolve.currentCatalog, - currentDatabase = flowToResolve.currentDatabase + queryContext = flowToResolve.queryContext ) } else { f diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index 68a587c79a63f..ff87a1e5cc70f 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -26,6 +26,13 @@ import org.apache.spark.sql.pipelines.AnalysisWarning import org.apache.spark.sql.pipelines.util.InputReadOptions import org.apache.spark.sql.types.StructType +/** + * Contains the catalog and database context information for query execution. + */ +case class QueryContext( + currentCatalog: Option[String], + currentDatabase: Option[String]) + /** * A [[Flow]] is a node of data transformation in a dataflow graph. It describes the movement * of data into a particular dataset. @@ -49,11 +56,8 @@ trait Flow extends GraphElement with Logging { */ def once: Boolean = false - /** The current catalog in the execution context when the query is defined. */ - def currentCatalog: Option[String] - - /** The current database in the execution context when the query is defined. */ - def currentDatabase: Option[String] + /** The current query context (catalog and database) when the query is defined. */ + def queryContext: QueryContext /** The comment associated with this flow */ def comment: Option[String] @@ -74,16 +78,14 @@ trait FlowFunction extends Logging { * [[DataflowGraph]]. * @param availableInputs the list of all [[Input]]s available to this flow * @param configuration the spark configurations that apply to this flow. - * @param currentCatalog The current catalog in execution context when the query is defined. - * @param currentDatabase The current database in execution context when the query is defined. + * @param queryContext The context of the query being evaluated. * @return the inputs actually used, and the [[DataFrame]] expression for the flow */ def call( allInputs: Set[TableIdentifier], availableInputs: Seq[Input], configuration: Map[String, String], - currentCatalog: Option[String], - currentDatabase: Option[String] + queryContext: QueryContext ): FlowFunctionResult } @@ -127,8 +129,7 @@ case class UnresolvedFlow( identifier: TableIdentifier, destinationIdentifier: TableIdentifier, func: FlowFunction, - currentCatalog: Option[String], - currentDatabase: Option[String], + queryContext: QueryContext, sqlConf: Map[String, String], comment: Option[String] = None, override val once: Boolean, @@ -147,8 +148,7 @@ trait ResolutionCompletedFlow extends Flow { val identifier: TableIdentifier = flow.identifier val destinationIdentifier: TableIdentifier = flow.destinationIdentifier def func: FlowFunction = flow.func - def currentCatalog: Option[String] = flow.currentCatalog - def currentDatabase: Option[String] = flow.currentDatabase + def queryContext: QueryContext = flow.queryContext def comment: Option[String] = flow.comment def sqlConf: Map[String, String] = funcResult.sqlConf def origin: QueryOrigin = flow.origin diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 98ad5e149b7ba..bdc685741624f 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.pipelines.graph +import scala.util.Try + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.classic.{DataFrame, DataStreamReader, Dataset, SparkSession} +import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} import org.apache.spark.sql.pipelines.{AnalysisWarning, Language} import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} @@ -33,14 +35,12 @@ object FlowAnalysis { allInputs: Set[TableIdentifier], availableInputs: Seq[Input], confs: Map[String, String], - currentCatalog: Option[String], - currentDatabase: Option[String] + queryContext: QueryContext ): FlowFunctionResult = { val ctx = FlowAnalysisContext( allInputs = allInputs, availableInputs = availableInputs, - currentCatalog = currentCatalog, - currentDatabase = currentDatabase, + queryContext = queryContext, spark = SparkSession.active ) val df = try { @@ -51,8 +51,8 @@ object FlowAnalysis { } FlowFunctionResult( requestedInputs = ctx.requestedInputs.toSet, - usedBatchInputs = ctx.batchInputs.toSet, - usedStreamingInputs = ctx.streamingInputs.toSet, + batchInputs = ctx.batchInputs.toSet, + streamingInputs = ctx.streamingInputs.toSet, usedExternalInputs = ctx.externalInputs.toSet, dataFrame = df, sqlConf = confs, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala index db6d6435701c0..68385c6cedee2 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.pipelines.AnalysisWarning * * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph. * @param availableInputs Inputs available to be referenced with `read` or `readStream`. - * @param currentCatalog The current catalog in execution context when the query is defined. - * @param currentDatabase The current schema in execution context when the query is defined. + * @param queryContext The context of the query being evaluated. * @param requestedInputs A mutable buffer populated with names of all inputs that were * requested. * @param spark the spark session to be used. @@ -40,8 +39,7 @@ import org.apache.spark.sql.pipelines.AnalysisWarning private[pipelines] case class FlowAnalysisContext( allInputs: Set[TableIdentifier], availableInputs: Seq[Input], - currentCatalog: Option[String], - currentDatabase: Option[String], + queryContext: QueryContext, batchInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala index f2bc76d634546..2561198aa2b5a 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala @@ -85,8 +85,8 @@ object GraphIdentifierManager { } else { val fullyQualifiedInputIdentifier = fullyQualifyIdentifier( maybeFullyQualifiedIdentifier = inputIdentifier, - currentCatalog = context.currentCatalog, - currentDatabase = context.currentDatabase + currentCatalog = context.queryContext.currentCatalog, + currentDatabase = context.queryContext.currentDatabase ) assertIsFullyQualifiedForRead(identifier = fullyQualifiedInputIdentifier) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index 552aa18d00f73..16f0d07a96998 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -374,7 +374,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest { .getMessage .contains( s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString}" + - s" is not a streaming view and must be referenced using read." + s" is not a batch view and must be referenced using read." ) ) } diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index 3257a39b2a95f..d5eafc411ade2 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.pipelines.graph.{ GraphIdentifierManager, GraphRegistrationContext, PersistedView, + QueryContext, QueryOrigin, Table, TemporaryView, @@ -86,10 +87,12 @@ class TestGraphRegistrationContext( identifier = tableIdentifier, destinationIdentifier = tableIdentifier, func = query.get, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), sqlConf = sqlConf, once = false, - currentCatalog = catalog.orElse(Some(defaultCatalog)), - currentDatabase = database.orElse(Some(defaultDatabase)), comment = comment, origin = baseOrigin ) @@ -138,10 +141,12 @@ class TestGraphRegistrationContext( identifier = viewIdentifier, destinationIdentifier = viewIdentifier, func = query, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), sqlConf = sqlConf, once = false, - currentCatalog = catalog.orElse(Some(defaultCatalog)), - currentDatabase = database.orElse(Some(defaultDatabase)), comment = comment, origin = origin ) @@ -165,10 +170,12 @@ class TestGraphRegistrationContext( identifier = flowIdentifier, destinationIdentifier = flowDestinationIdentifier, func = query, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), sqlConf = Map.empty, once = once, - currentCatalog = catalog.orElse(Some(defaultCatalog)), - currentDatabase = database.orElse(Some(defaultDatabase)), comment = None, origin = QueryOrigin() ) From af7842dbafb60157381ece810487c195bb8390bc Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Wed, 28 May 2025 00:01:29 -0700 Subject: [PATCH 10/32] 1 --- .../scala/org/apache/spark/sql/pipelines/graph/Flow.scala | 2 +- .../org/apache/spark/sql/pipelines/graph/elements.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index ff87a1e5cc70f..354730ed0c068 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -93,7 +93,7 @@ trait FlowFunction extends Logging { * Holds the [[DataFrame]] returned by a [[FlowFunction]] along with the inputs used to * construct it. * @param batchInputs the complete inputs read by the flow - * @param sreamingInputs the incremental inputs read by the flow + * @param streamingInputs the incremental inputs read by the flow * @param usedExternalInputs the identifiers of the external inputs read by the flow * @param dataFrame the [[DataFrame]] expression executed by the flow if the flow can be resolved */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 348c6ecd344a6..73607d8350e21 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -126,9 +126,8 @@ case class Table( baseOrigin: QueryOrigin, isStreamingTableOpt: Option[Boolean], format: Option[String] -) extends GraphElement - with Output - with TableInput { +) extends TableInput + with Output { override val origin: QueryOrigin = baseOrigin.copy( objectType = Some("table"), From e0e4d132df786b501c2831349bd0f0baa3c177da Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Wed, 28 May 2025 13:37:36 -0700 Subject: [PATCH 11/32] error formatting --- .../resources/error/error-conditions.json | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 35621b509489a..3aaf77ee194f8 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1322,12 +1322,6 @@ ], "sqlState" : "42713" }, - "DUPLICATE_FLOW_SQL_CONF": { - "message": [ - "Found duplicate sql conf for dataset '': '' is defined by both '' and ''" - ], - "sqlState": "42710" - }, "DUPLICATED_MAP_KEY" : { "message" : [ "Duplicate map key was found, please check the input data.", @@ -1378,6 +1372,12 @@ }, "sqlState" : "42734" }, + "DUPLICATE_FLOW_SQL_CONF" : { + "message" : [ + "Found duplicate sql conf for dataset '': '' is defined by both '' and ''" + ], + "sqlState" : "42710" + }, "DUPLICATE_KEY" : { "message" : [ "Found duplicate keys ." @@ -1949,6 +1949,12 @@ ], "sqlState" : "42818" }, + "INCOMPATIBLE_BATCH_VIEW_READ" : { + "message" : [ + "View is not a batch view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + ], + "sqlState" : "42000" + }, "INCOMPATIBLE_COLUMN_TYPE" : { "message" : [ " can only be performed on tables with compatible column types. The column of the table is type which is not compatible with at the same column of the first table.." @@ -2025,17 +2031,11 @@ ], "sqlState" : "42613" }, - "INCOMPATIBLE_BATCH_VIEW_READ": { - "message": [ - "View is not a batch view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." - ], - "sqlState": "42000" - }, - "INCOMPATIBLE_STREAMING_VIEW_READ": { - "message": [ + "INCOMPATIBLE_STREAMING_VIEW_READ" : { + "message" : [ "View is a streaming view and must be referenced using readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." ], - "sqlState": "42000" + "sqlState" : "42000" }, "INCOMPATIBLE_VIEW_SCHEMA_CHANGE" : { "message" : [ @@ -6589,10 +6589,10 @@ ], "sqlState" : "P0001" }, - "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE": { - "message": [ + "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE" : { + "message" : [ "Table '' has a user-specified schema that is incompatible with the schema", - " inferred from its query.", + "inferred from its query.", "", "", "Declared schema:", @@ -6601,7 +6601,7 @@ "Inferred schema:", "" ], - "sqlState": "42000" + "sqlState" : "42000" }, "VARIABLE_ALREADY_EXISTS" : { "message" : [ From 22cbef2c5754159face5f2da36cd7502297dbf9b Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Thu, 29 May 2025 09:55:53 -0700 Subject: [PATCH 12/32] comments --- .../src/main/resources/error/error-conditions.json | 4 ++-- .../spark/sql/pipelines/graph/DataflowGraph.scala | 10 +--------- .../sql/pipelines/graph/DataflowGraphTransformer.scala | 2 +- .../org/apache/spark/sql/pipelines/graph/Flow.scala | 2 +- .../spark/sql/pipelines/graph/FlowAnalysis.scala | 4 ++-- .../sql/pipelines/graph/FlowAnalysisContext.scala | 2 +- .../apache/spark/sql/pipelines/graph/elements.scala | 5 ++++- .../pipelines/graph/ConnectInvalidPipelineSuite.scala | 4 ++-- 8 files changed, 14 insertions(+), 19 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3aaf77ee194f8..6da4b1be6a1d8 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1951,7 +1951,7 @@ }, "INCOMPATIBLE_BATCH_VIEW_READ" : { "message" : [ - "View is not a batch view and must be referenced using read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + "View is a batch view and must be referenced using SparkSession#read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." ], "sqlState" : "42000" }, @@ -2033,7 +2033,7 @@ }, "INCOMPATIBLE_STREAMING_VIEW_READ" : { "message" : [ - "View is a streaming view and must be referenced using readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + "View is a streaming view and must be referenced using SparkSession#readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." ], "sqlState" : "42000" }, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 37b3032c78109..49ba752bb3195 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.StructType * It manages the relationships between logical flows, tables, and views, providing * operations for graph traversal, validation, and transformation. */ -class DataflowGraph(val flows: Seq[Flow], val tables: Seq[Table], val views: Seq[View]) +case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) extends GraphOperations with GraphValidations { @@ -130,14 +130,6 @@ class DataflowGraph(val flows: Seq[Flow], val tables: Seq[Table], val views: Seq }.toMap } - /** Returns a copy of this [[DataflowGraph]] with optionally replaced components. */ - def copy( - flows: Seq[Flow] = flows, - tables: Seq[Table] = tables, - views: Seq[View] = views): DataflowGraph = { - new DataflowGraph(flows, tables, views) - } - /** * Used to reanalyze the flow's DF for a given table. This is done by finding all upstream * flows (until a table is reached) for the specified source and reanalyzing all upstream diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index 5ae3b4e9a389c..863f995ffd5d2 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -97,7 +97,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { flows.groupBy(_.destinationIdentifier) } - def transformTables(transformer: Table => Table): DataflowGraphTransformer = { + def transformTables(transformer: Table => Table): DataflowGraphTransformer = synchronized { tables = tables.map(transformer) tableMap = computeTableMap() this diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index 354730ed0c068..aa64ae61b59eb 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -101,7 +101,7 @@ case class FlowFunctionResult( requestedInputs: Set[TableIdentifier], batchInputs: Set[ResolvedInput], streamingInputs: Set[ResolvedInput], - usedExternalInputs: Set[String], + usedExternalInputs: Set[TableIdentifier], dataFrame: Try[DataFrame], sqlConf: Map[String, String], analysisWarnings: Seq[AnalysisWarning] = Nil) { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index bdc685741624f..df4771fae3360 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -280,7 +280,7 @@ object FlowAnalysis { name: String): DataFrame = { val spark = context.spark - context.externalInputs += name + context.externalInputs += inputIdentifier.identifier spark.read.table(inputIdentifier.identifier.quotedString) } @@ -298,7 +298,7 @@ object FlowAnalysis { streamReader: DataStreamReader, name: String): DataFrame = { - context.externalInputs += name + context.externalInputs += inputIdentifier.identifier streamReader.table(inputIdentifier.identifier.quotedString) } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala index 68385c6cedee2..0e04e17f7f7b0 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -46,7 +46,7 @@ private[pipelines] case class FlowAnalysisContext( shouldLowerCaseNames: Boolean = false, analysisWarnings: mutable.Buffer[AnalysisWarning] = new ListBuffer[AnalysisWarning], spark: SparkSession, - externalInputs: mutable.HashSet[String] = mutable.HashSet.empty + externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty ) { /** Map from [[Input]] name to the actual [[Input]] */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 73607d8350e21..2dfcaf5f7856a 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -74,7 +74,10 @@ trait Input extends GraphElement { def load(readOptions: InputReadOptions): DataFrame } -/** Represents a node in a [[DataflowGraph]] that can be written to by a [[Flow]]. */ +/** + * Represents a node in a [[DataflowGraph]] that can be written to by a [[Flow]]. + * Must be backed by a file source. + */ sealed trait Output { /** diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index 16f0d07a96998..01b2a91bb9329 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -374,7 +374,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest { .getMessage .contains( s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString}" + - s" is not a batch view and must be referenced using read." + s" is a batch view and must be referenced using SparkSession#read." ) ) } @@ -392,7 +392,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest { .getMessage .contains( s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString} " + - s"is a streaming view and must be referenced using readStream" + s"is a streaming view and must be referenced using SparkSession#readStream" ) ) } From 71c69aff03b8e7769ccddad8e7d025ef9b16dfff Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 30 May 2025 09:42:21 -0700 Subject: [PATCH 13/32] 1 --- .../sql/pipelines/graph/CoreDataflowNodeProcessor.scala | 6 ++---- .../apache/spark/sql/pipelines/graph/DataflowGraph.scala | 5 ----- .../apache/spark/sql/pipelines/graph/GraphValidations.scala | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 9866ad3f61bd9..a9a90ca509d09 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -87,10 +87,8 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { isStreamingTableOpt = Option(resolvedFlowsToTable.exists(f => f.df.isStreaming)) ) - // Table will be virtual in either of the following scenarios: - // 1. If table is present in context.fullRefreshTables - // 2. If table has any virtual inputs (flows or tables) - // 3. If the table pre-existing metadata is different from current metadata + // We mark all tables as virtual to ensure resolution uses incoming flows + // rather than previously materialized tables. val virtualTableInput = VirtualTableInput( identifier = table.identifier, specifiedSchema = table.specifiedSchema, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 49ba752bb3195..7099ca9c9dbac 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -37,11 +37,6 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) /** Returns a [[Output]] given its identifier */ lazy val output: Map[TableIdentifier, Output] = mapUnique(tables, "output")(_.identifier) - /** - * Returns a [[TableInput]], if one is available, that can be read from by downstream flows. - */ - def tableInput(identifier: TableIdentifier): Option[TableInput] = table.get(identifier) - /** * Returns [[Flow]]s in this graph that need to get planned and potentially executed when * executing the graph. Flows that write to logical views are excluded. diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 6a3a9e18ddf96..7ab7e0522ff1e 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -148,7 +148,7 @@ trait GraphValidations extends Logging { } protected def validateUserSpecifiedSchemas(): Unit = { - flows.flatMap(f => tableInput(f.identifier)).foreach { t: TableInput => + flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput => // The output inferred schema of a table is the declared schema merged with the // schema of all incoming flows. This must be equivalent to the declared schema. val inferredSchema = SchemaInferenceUtils From 41a060025605eeba483f9440cd31e1801864d051 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 30 May 2025 15:00:13 -0700 Subject: [PATCH 14/32] fix --- .../resources/error/error-conditions.json | 45 ++++ .../sql/pipelines/graph/DataflowGraph.scala | 7 +- .../graph/DataflowGraphTransformer.scala | 12 +- .../spark/sql/pipelines/graph/Flow.scala | 14 +- .../sql/pipelines/graph/FlowAnalysis.scala | 17 +- .../sql/pipelines/graph/GraphErrors.scala | 35 ++- .../pipelines/graph/GraphValidations.scala | 5 +- .../sql/pipelines/graph/PipelinesErrors.scala | 13 +- .../sql/pipelines/graph/QueryOrigin.scala | 3 - .../sql/pipelines/util/InputReadInfo.scala | 9 +- .../pipelines/util/SchemaInferenceUtils.scala | 9 +- .../sql/pipelines/utils/PipelineTest.scala | 210 ++---------------- 12 files changed, 143 insertions(+), 236 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6da4b1be6a1d8..cc461221290c4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -82,6 +82,14 @@ ], "sqlState" : "XX000" }, + "APPEND_ONCE_FROM_BATCH_QUERY" : { + "message" : [ + "Creating a streaming table from a batch query prevents incremental loading of new data from source. Offending table: '
'.", + "Please use the stream() operator. Example usage:", + "CREATE STREAMING TABLE ... AS SELECT ... FROM stream() ..." + ], + "sqlState" : "42000" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." @@ -3137,6 +3145,12 @@ }, "sqlState" : "KD002" }, + "INVALID_NAME_IN_USE_COMMAND" : { + "message" : [ + "Invalid name '' in command. Reason: " + ], + "sqlState" : "42000" + }, "INVALID_NON_DETERMINISTIC_EXPRESSIONS" : { "message" : [ "The operator expects a deterministic expression, but the actual expression is ." @@ -3402,6 +3416,12 @@ ], "sqlState" : "22023" }, + "INVALID_RESETTABLE_DEPENDENCY" : { + "message" : [ + "Tables are resettable but have a non-resettable downstream dependency ''. `reset` will fail as Spark Streaming does not support deleted source data. You can either remove the =false property from '' or add it to its upstream dependencies." + ], + "sqlState" : "42000" + }, "INVALID_RESET_COMMAND_FORMAT" : { "message" : [ "Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`." @@ -4587,6 +4607,12 @@ ], "sqlState" : "42K03" }, + "PERSISTED_VIEW_READS_FROM_TEMPORARY_VIEW" : { + "message" : [ + "Persisted view cannot reference temporary view that will not be available outside the pipeline scope. Either make the persisted view temporary or persist the temporary view." + ], + "sqlState" : "42K0F" + }, "PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : { "message" : [ "Non-grouping expression is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again." @@ -5443,6 +5469,19 @@ ], "sqlState" : "42KD9" }, + "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA" : { + "message" : [ + "Failed to infer the schema for table from its upstream flows.", + "Please modify the flows that write to this table to make their schemas compatible.", + "", + "Inferred schema so far:", + "", + "", + "Incompatible schema:", + "" + ], + "sqlState" : "42KD9" + }, "UNBOUND_SQL_PARAMETER" : { "message" : [ "Found the unbound parameter: . Please, fix `args` and provide a mapping of the parameter to either a SQL literal or collection constructor functions such as `map()`, `array()`, `struct()`." @@ -5608,6 +5647,12 @@ ], "sqlState" : "42883" }, + "UNRESOLVED_TABLE_PATH" : { + "message" : [ + "Storage path for table cannot be resolved." + ], + "sqlState" : "22KD1" + }, "UNRESOLVED_USING_COLUMN_FOR_JOIN" : { "message" : [ "USING column cannot be resolved on the side of the join. The -side columns: []." diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 7099ca9c9dbac..49c9611a483ff 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -192,12 +192,15 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) validateEveryDatasetHasFlow() validateTablesAreResettable() validateAppendOnceFlows() + // Ensures that all flows are resolved and have a valid schema. inferredSchema }.failed - /** Enforce every dataset has at least once input flow. For example its possible to define + /** + * Enforce every dataset has at least one input flow. For example its possible to define * streaming tables without a query; such tables should still have at least one flow - * writing to it. */ + * writing to it. + */ def validateEveryDatasetHasFlow(): Unit = { (tables.map(_.identifier) ++ views.map(_.identifier)).foreach { identifier => if (!flows.exists(_.destinationIdentifier == identifier)) { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index 863f995ffd5d2..8448ed5f10d21 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.ThreadUtils * Assumptions: * 1. Each output will have at-least 1 flow to it. * 2. Each flow may or may not have a destination table. If a flow does not have a destination - * table, the destination is a view. + * table, the destination is a temporary view. * * The way graph is structured is that flows, tables and sinks all are graph elements or nodes. * While we expose transformation functions for each of these entities, we also expose a way to @@ -66,8 +66,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { // Failed flows are flows that are failed to resolve or its inputs are not available or its // destination failed to resolve. private var failedFlows: Seq[ResolutionCompletedFlow] = Seq.empty - // We define a dataset is failed to resolve if: - // 1. It is a destination of a flow that is unresolved. + // We define a dataset is failed to resolve if it is a destination of a flow that is unresolved. private var failedTables: Seq[Table] = Seq.empty private val parallelism = 10 @@ -341,7 +340,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { object DataflowGraphTransformer { /** - * Exception thrown when a node in the graph fails to be transformed because at least one of its + * Exception thrown when transforming a node in the graph fails because at least one of its * dependencies weren't yet transformed. * * @param datasetIdentifier The identifier for an untransformed dependency table identifier in the @@ -353,6 +352,11 @@ object DataflowGraphTransformer { extends Exception with NoStackTrace + /** + * Exception thrown when transforming a node in the graph fails with a non-retryable error. + * + * @param failedNode The failed node that could not be transformed. + */ case class TransformNodeFailedException(failedNode: ResolutionFailedFlow) extends Exception with NoStackTrace diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index aa64ae61b59eb..978dc09812d28 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -29,9 +29,7 @@ import org.apache.spark.sql.types.StructType /** * Contains the catalog and database context information for query execution. */ -case class QueryContext( - currentCatalog: Option[String], - currentDatabase: Option[String]) +case class QueryContext(currentCatalog: Option[String], currentDatabase: Option[String]) /** * A [[Flow]] is a node of data transformation in a dataflow graph. It describes the movement @@ -45,9 +43,7 @@ trait Flow extends GraphElement with Logging { val identifier: TableIdentifier /** - * The dataset that this Flow represents a write to. Since the DataflowGraph doesn't have a first- - * class concept of views, writing to a destination that isn't a Table or a Sink represents a - * view. + * The dataset that this Flow represents a write to. */ val destinationIdentifier: TableIdentifier @@ -65,7 +61,7 @@ trait Flow extends GraphElement with Logging { def sqlConf: Map[String, String] } -/** A wrapper for a resolved internal input that includes the identifier used in SubqueryAlias */ +/** A wrapper for a resolved internal input that includes the alias provided by the user. */ case class ResolvedInput(input: Input, aliasIdentifier: AliasIdentifier) /** A wrapper for the lambda function that defines a [[Flow]]. */ @@ -90,12 +86,12 @@ trait FlowFunction extends Logging { } /** - * Holds the [[DataFrame]] returned by a [[FlowFunction]] along with the inputs used to + * Holds the DataFrame returned by a [[FlowFunction]] along with the inputs used to * construct it. * @param batchInputs the complete inputs read by the flow * @param streamingInputs the incremental inputs read by the flow * @param usedExternalInputs the identifiers of the external inputs read by the flow - * @param dataFrame the [[DataFrame]] expression executed by the flow if the flow can be resolved + * @param dataFrame the DataFrame expression executed by the flow if the flow can be resolved */ case class FlowFunctionResult( requestedInputs: Set[TableIdentifier], diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index df4771fae3360..05cba8d8d415d 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -28,7 +28,18 @@ import org.apache.spark.sql.pipelines.{AnalysisWarning, Language} import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} + object FlowAnalysis { + /** + * Creates a [[FlowFunction]] that attempts to analyze the provided LogicalPlan + * using the existing resolved inputs. + * - If all upstream inputs have been resolved, then analysis succeeds and the + * function returns a [[FlowFunctionResult]] containing the dataframe. + * - If any upstream inputs are unresolved, then the function throws an exception. + * + * @param plan The user-supplied LogicalPlan defining a flow. + * @return A FlowFunction that attempts to analyze the provided LogicalPlan. + */ def createFlowFunctionFromLogicalPlan(plan: LogicalPlan): FlowFunction = { new FlowFunction { override def call( @@ -105,9 +116,7 @@ object FlowAnalysis { context, name = IdentifierHelper.toQuotedString(u.multipartIdentifier), spark.readStream, - streamingReadOptions = StreamingReadOptions( - apiLanguage = Language.Sql() - ) + streamingReadOptions = StreamingReadOptions() ).queryExecution.analyzed // Batch read on another dataset in the pipeline @@ -115,7 +124,7 @@ object FlowAnalysis { readBatchInput( context, name = IdentifierHelper.toQuotedString(u.multipartIdentifier), - batchReadOptions = BatchReadOptions(apiLanguage = Language.Sql()) + batchReadOptions = BatchReadOptions() ).queryExecution.analyzed } Dataset.ofRows(spark, resolvedPlan) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala index 1a01a8df3f911..16f38b2897a8c 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala @@ -26,12 +26,15 @@ import org.apache.spark.sql.types.StructType /** Collection of errors that can be thrown during graph resolution / analysis. */ object GraphErrors { + /** + * Throws when a dataset is marked as internal but is not defined in the graph. + * + * @param datasetName the name of the dataset that is not defined + */ def pipelineLocalDatasetNotDefinedError(datasetName: String): SparkException = { - // TODO: this should be an internal error, as we never expect this to happen - new SparkException( - errorClass = "PIPELINE_LOCAL_DATASET_NOT_DEFINED", - messageParameters = Map("datasetName" -> datasetName), - cause = null + SparkException.internalError( + s"Failed to read dataset '$datasetName'. This dataset was expected to be " + + s"defined and created by the pipeline." ) } @@ -54,6 +57,11 @@ object GraphErrors { ) } + /** + * Throws when a table path is unresolved, i.e. the table identifier does not exist in the catalog. + * + * @param identifier the unresolved table identifier + */ def unresolvedTablePath(identifier: TableIdentifier): SparkException = { new SparkException( errorClass = "UNRESOLVED_TABLE_PATH", @@ -62,6 +70,11 @@ object GraphErrors { ) } + /** + * Throws an error if the user-specified schema and the inferred schema are not compatible. + * + * @param tableIdentifier the identifier of the table that was not found + */ def incompatibleUserSpecifiedAndInferredSchemasError( tableIdentifier: TableIdentifier, datasetType: DatasetType, @@ -92,6 +105,12 @@ object GraphErrors { ) } + /** + * Throws if the latest inferred schema for a pipeline table is not compatible with + * the table's existing schema. + * + * @param tableIdentifier the identifier of the table that was not found + */ def unableToInferSchemaError( tableIdentifier: TableIdentifier, inferredSchema: StructType, @@ -109,6 +128,12 @@ object GraphErrors { ) } + /** + * Throws an error when a persisted view is trying to read from a temporary view. + * + * @param persistedViewIdentifier the identifier of the persisted view + * @param temporaryViewIdentifier the identifier of the temporary view + */ def persistedViewReadsFromTemporaryView( persistedViewIdentifier: TableIdentifier, temporaryViewIdentifier: TableIdentifier): AnalysisException = { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 7ab7e0522ff1e..56ca4d0a8a840 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -30,10 +30,7 @@ trait GraphValidations extends Logging { this: DataflowGraph => /** - * Validate multi query table correctness. Exposed for Python unit testing, which currently cannot - * run anything which invokes the flow function as there's no persistent Python to run it. - * - * @return the multi-query tables by destination + * Validate multi query table correctness. */ protected[pipelines] def validateMultiQueryTables(): Map[TableIdentifier, Seq[Flow]] = { val multiQueryTables = flowsTo.filter(_._2.size > 1) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 1b083a0c8fdf4..aaeece7007aa8 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.pipelines.graph +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -38,8 +39,12 @@ case class UnresolvedDatasetException(identifier: TableIdentifier) * @param name The name of the table * @param cause The cause of the failure */ -case class LoadTableException(name: String, override val cause: Option[Throwable]) - extends AnalysisException(s"Failed to load table '$name'", cause = cause) +case class LoadTableException(name: String, cause: Option[Throwable]) + extends SparkException( + errorClass = "INTERNAL_ERROR", + messageParameters = Map("message" -> s"Failed to load table '$name'"), + cause = cause.orNull + ) /** * Exception raised when a pipeline has one or more flows that cannot be resolved @@ -70,8 +75,8 @@ case class UnresolvedPipelineException( .sorted .mkString(", ")} | - |To view the exceptions that were raised while resolving these flows, look for FlowProgress - |logs with status FAILED that precede this log.""".stripMargin + |To view the exceptions that were raised while resolving these flows, look for flow + |failures that precede this log.""".stripMargin ) /** A validation error that can either be thrown as an exception or logged as a warning. */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala index 3467f3d88d630..9daa760678daa 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -29,8 +29,6 @@ import org.apache.spark.sql.pipelines.logging.SourceCodeLocation * * @param language The language used by the user to define the query. * @param fileName The file name of the user code that defines the query. - * @param cellNumber The cell number of the user code that defines the query. - * Cell numbers are 1-indexed. * @param sqlText The SQL text of the query. * @param line The line number of the query in the user code. * Line numbers are 1-indexed. @@ -41,7 +39,6 @@ import org.apache.spark.sql.pipelines.logging.SourceCodeLocation case class QueryOrigin( language: Option[Language] = None, fileName: Option[String] = None, - cellNumber: Option[Int] = None, sqlText: Option[String] = None, line: Option[Int] = None, startPosition: Option[Int] = None, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala index 6e3ef2c6a7da8..d7f652792041b 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala @@ -24,29 +24,24 @@ import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions /** * Generic options for a read of an input. */ -sealed trait InputReadOptions { - // The language of the public API that called this function. - def apiLanguage: Language -} +sealed trait InputReadOptions /** * Options for a batch read of an input. * * @param apiLanguage The language of the public API that called this function. */ -final case class BatchReadOptions(apiLanguage: Language) extends InputReadOptions +final case class BatchReadOptions() extends InputReadOptions /** * Options for a streaming read of an input. * - * @param apiLanguage The language of the public API that called this function. * @param userOptions Holds the user defined read options. * @param droppedUserOptions Holds the options that were specified by the user but * not actually used. This is a bug but we are preserving this behavior * for now to avoid making a backwards incompatible change. */ final case class StreamingReadOptions( - apiLanguage: Language, userOptions: CaseInsensitiveMap[String] = EmptyUserOptions, droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions ) extends InputReadOptions diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala index d5bd4887ce248..4777772342d7d 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala @@ -30,8 +30,13 @@ object SchemaInferenceUtils { /** * Given a set of flows that write to the same destination and possibly a user-specified schema, - * we infer the resulting schema. - * + * we infer the schema of the destination dataset. The logic is as follows: + * 1. If there are no incoming flows, return the user-specified schema (if provided) + * or an empty schema. + * 2. If there are incoming flows, we merge the schemas of all flows that write to + * the same destination. + * 3. If a user-specified schema is provided, we merge it with the inferred schema. + * The user-specified schema will take precedence over the inferred schema. * Returns an error if encountered during schema inference or merging the inferred schema with * the user-specified one. */ diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala index d13fd1902621e..4cf8039b8aad7 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -44,9 +44,8 @@ abstract class PipelineTest with BeforeAndAfterAll with BeforeAndAfterEach with Matchers -// with SQLImplicits with SparkErrorTestMixin - with TargetCatalogAndSchemaMixin + with TargetCatalogAndDatabaseMixin with Logging { final protected val storageRoot = createTempDir() @@ -65,32 +64,23 @@ abstract class PipelineTest var conf = new SparkConf() .set("spark.sql.shuffle.partitions", "2") .set("spark.sql.session.timeZone", "UTC") - - if (schemaInPipelineSpec.isDefined) { - conf = conf.set("pipelines.schema", schemaInPipelineSpec.get) - } - - if (Option(System.getenv("ENABLE_SPARK_UI")).exists(s => java.lang.Boolean.valueOf(s))) { - conf = conf.set("spark.ui.enabled", "true") - } - conf } /** Returns the dataset name in the event log. */ protected def eventLogName( name: String, catalog: Option[String] = catalogInPipelineSpec, - schema: Option[String] = schemaInPipelineSpec, + database: Option[String] = databaseInPipelineSpec, isView: Boolean = false ): String = { - fullyQualifiedIdentifier(name, catalog, schema, isView).unquotedString + fullyQualifiedIdentifier(name, catalog, database, isView).unquotedString } /** Returns the fully qualified identifier. */ protected def fullyQualifiedIdentifier( name: String, catalog: Option[String] = catalogInPipelineSpec, - schema: Option[String] = schemaInPipelineSpec, + database: Option[String] = databaseInPipelineSpec, isView: Boolean = false ): TableIdentifier = { if (isView) { @@ -98,15 +88,12 @@ abstract class PipelineTest } else { TableIdentifier( catalog = catalog, - database = schema, + database = database, table = name ) } } -// /** Returns the [[PipelineApiConf]] constructed from the current spark session */ -// def pipelineApiConf: PipelineApiConf = PipelineApiConf.instance - /** * Runs the given function with the given spark conf, and resets the conf after the function * completes. @@ -151,11 +138,11 @@ abstract class PipelineTest super.beforeEach() initializeSparkBeforeEachTest() cleanupMetastore(spark) - (catalogInPipelineSpec, schemaInPipelineSpec) match { + (catalogInPipelineSpec, databaseInPipelineSpec) match { case (Some(catalog), Some(schema)) => - sql(s"CREATE SCHEMA IF NOT EXISTS `$catalog`.`$schema`") + sql(s"CREATE DATABASE IF NOT EXISTS `$catalog`.`$schema`") case _ => - schemaInPipelineSpec.foreach(s => sql(s"CREATE SCHEMA IF NOT EXISTS `$s`")) + databaseInPipelineSpec.foreach(s => sql(s"CREATE DATABASE IF NOT EXISTS `$s`")) } } @@ -186,19 +173,7 @@ abstract class PipelineTest * after any setup and before any clean-up done for a test. */ private def runWithInstrumentation(testFunc: => Any): Any = { - try { - testFunc - } catch { - case e: TestFailedDueToTimeoutException => -// val stackTraces = StackTraceReporter.dumpAllStackTracesToString() -// logInfo( -// s""" -// |Triggering thread dump since test failed with a timeout exception: -// |$stackTraces -// |""".stripMargin -// ) - throw e - } + testFunc } /** @@ -244,67 +219,6 @@ abstract class PipelineTest } } - /** - * Returns a [[Seq]] of JARs generated by compiling this test. - * - * Includes a "delta-pipelines-repo" to ensure the export_test, which compiles differently, - * still succeeds. - */ - protected def getTestJars: Seq[String] = - getUniqueAbsoluteTestJarPaths.map(_.getName) :+ "delta-pipelines-repo" - - /** - * Returns a [[Seq]] of absolute paths of all JAR files found in the - * current directory. See [[getUniqueAbsoluteTestJarPaths]]. - */ - protected def getTestJarPaths: Seq[String] = - getUniqueAbsoluteTestJarPaths.map(_.getAbsolutePath) - - /** - * Returns a sequence of JARs found in the current directory. In a bazel test, - * the current directory includes all jars that are required to run the test - * (its run files). This allows us to include these jars in the class path - * for the graph loading class loader. - * - * Because dependent jars can be included multiple times in this list, we deduplicate - * by file name (ignoring the path). - */ - private def getUniqueAbsoluteTestJarPaths: Seq[File] = - Files - .walk(Paths.get(".")) - .iterator() - .asScala - .map(_.toFile) - .filter( - f => - f.isFile && - // This filters JARs to match 2 main cases: - // - JARs built by Bazel that are usually suffixed with deploy.jar; - // - classpath.jar that Scala test template can also create if the classpath is too long. - f.getName.matches("classpath.jar|.*deploy.jar") - ) - .toSeq - .groupBy(_.getName) - .flatMap(_._2.headOption) - .toSeq - -// /** -// * Returns a [[DataFrame]] given the path to json encoded data stored in the project's -// * test resources. Schema is parsed from first line. -// */ -// protected def jsonData(path: String): DataFrame = { -// val contents = loadResource(path) -// val data = contents.tail -// val schema = contents.head -// jsonData(schema, data) -// } -// -// /** Returns a [[DataFrame]] given the string representation of it schema and data. */ -// protected def jsonData(schemaString: String, data: Seq[String]): DataFrame = { -// val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] -// spark.read.schema(schema).json(data.toDS()) -// } - /** Loads a package resources as a Seq of lines. */ protected def loadResource(path: String): Seq[String] = { val stream = Thread.currentThread.getContextClassLoader.getResourceAsStream(path) @@ -354,65 +268,6 @@ abstract class PipelineTest ignoreFieldCase: Boolean = false ) -// /** -// * Runs the query and makes sure the answer matches the expected result. -// * -// * @param validateSchema Whether or not the exact schema fields are validated. This validates -// * the schema types and field names, but does not validate field -// * nullability. -// */ -// protected def checkAnswer( -// df: => DataFrame, -// expectedAnswer: DataFrame, -// validateSchema: Boolean = false, -// validationArgs: ValidationArgs = ValidationArgs(), -// checkPlan: Option[SparkPlan => Unit] = None -// ): Unit = { -// // Evaluate `df` so we get a constant DF. -// val dfByVal = df -// val actualSchema = dfByVal.schema -// val expectedSchema = expectedAnswer.schema -// -// def transformSchema(original: StructType): StructType = { -// var result = original -// if (validationArgs.ignoreFieldOrder) { -// result = StructType(result.fields.sortBy(_.name)) -// } -// if (validationArgs.ignoreFieldCase) { -// result = StructType(result.fields.map { field => -// field.copy(name = field.name.toLowerCase(Locale.ROOT)) -// }) -// } -// result -// } -// -// def transformDataFrame(original: DataFrame): DataFrame = { -// var result = original -// if (validationArgs.ignoreFieldOrder) { -// result = result.select( -// result.columns.sorted.map { columnName => -// result.col(UnresolvedAttribute.quoted(columnName).name) -// }: _* -// ) -// } -// result -// } -// -// if (validateSchema) { -// assert( -// transformSchema(actualSchema.asNullable) == transformSchema(expectedSchema.asNullable), -// s"Expected and actual schemas are different:\n" + -// s"Expected: $expectedSchema\n" + -// s"Actual: $actualSchema" -// ) -// } -// checkAnswerAndPlan( -// transformDataFrame(dfByVal), -// transformDataFrame(expectedAnswer).collect().toIndexedSeq, -// checkPlan -// ) -// } - /** * Evaluates a dataset to make sure that the result of calling collect matches the given * expected answer. @@ -514,55 +369,26 @@ abstract class PipelineTest def eval[T](col: TypedColumn[Any, T]): T = { spark.range(1).select(col).head() } - -// /** -// * Helper class to create a SQLPipeline that is eligible to resolve flows parallely. -// */ -// class SQLPipelineWithParallelResolve( -// queries: Seq[String], -// notebookPath: Option[String] = None, -// catalog: Option[String] = catalogInPipelineSpec, -// schema: Option[String] = schemaInPipelineSpec -// ) extends SQLPipeline(queries, notebookPath, catalog, schema) { -// override def eligibleForResolvingFlowsParallely = true -// } - -// /** -// * Helper method to create a [[SQLPipelineWithParallelResolve]] with catalog and schema set to -// * the test's catalog and schema. -// */ -// protected def createSqlParallelPipeline( -// queries: Seq[String], -// notebookPath: Option[String] = None, -// catalog: Option[String] = catalogInPipelineSpec, -// schema: Option[String] = schemaInPipelineSpec -// ): SQLPipelineWithParallelResolve = { -// new SQLPipelineWithParallelResolve( -// queries = queries, -// notebookPath = notebookPath, -// catalog = catalog, -// schema = schema -// ) -// } } /** * A trait that provides a way to specify the target catalog and schema for a test. */ -trait TargetCatalogAndSchemaMixin { +trait TargetCatalogAndDatabaseMixin { protected def catalogInPipelineSpec: Option[String] = Option( TestGraphRegistrationContext.DEFAULT_CATALOG ) - protected def schemaInPipelineSpec: Option[String] = Option( + protected def databaseInPipelineSpec: Option[String] = Option( TestGraphRegistrationContext.DEFAULT_DATABASE ) } object PipelineTest extends Logging { + /** System schemas per-catalog that's can't be directly deleted. */ - protected val systemSchemas: Set[String] = Set("default", "information_schema") + protected val systemDatabases: Set[String] = Set("default", "information_schema") /** System catalogs that are read-only and cannot be modified/dropped. */ private val systemCatalogs: Set[String] = Set("samples") @@ -583,17 +409,17 @@ object PipelineTest extends Logging { /** * Try to drop the schema in the catalog and return whether it is successfully dropped. */ - private def dropSchemaIfPossible( + private def dropDatabaseIfPossible( spark: SparkSession, catalogName: String, - schemaName: String): Boolean = { + databaseName: String): Boolean = { try { - spark.sql(s"DROP SCHEMA IF EXISTS `$catalogName`.`$schemaName` CASCADE") + spark.sql(s"DROP DATABASE IF EXISTS `$catalogName`.`$databaseName` CASCADE") true } catch { case NonFatal(e) => logInfo( - s"Failed to drop schema $schemaName in catalog $catalogName, ex:${e.getMessage}" + s"Failed to drop database $databaseName in catalog $catalogName, ex:${e.getMessage}" ) false } @@ -613,7 +439,7 @@ object PipelineTest extends Logging { val schemas = spark.sql(s"SHOW SCHEMAS IN `$catalog`").collect().map(_.getString(0)) schemas.foreach { schema => - if (systemSchemas.contains(schema) || !dropSchemaIfPossible(spark, catalog, schema)) { + if (systemDatabases.contains(schema) || !dropDatabaseIfPossible(spark, catalog, schema)) { spark .sql(s"SHOW tables in `$catalog`.`$schema`") .collect() From d350dbf98adb83d311760cc455e6c92309296ac9 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Fri, 30 May 2025 15:04:59 -0700 Subject: [PATCH 15/32] fix 2 --- .../apache/spark/sql/pipelines/graph/FlowAnalysis.scala | 2 +- .../apache/spark/sql/pipelines/graph/GraphErrors.scala | 3 ++- .../apache/spark/sql/pipelines/graph/QueryOrigin.scala | 1 - .../apache/spark/sql/pipelines/util/InputReadInfo.scala | 1 - .../apache/spark/sql/pipelines/utils/PipelineTest.scala | 8 +++----- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 05cba8d8d415d..7e2e97f2b5d74 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} -import org.apache.spark.sql.pipelines.{AnalysisWarning, Language} +import org.apache.spark.sql.pipelines.AnalysisWarning import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala index 16f38b2897a8c..b360a0807bb0d 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala @@ -58,7 +58,8 @@ object GraphErrors { } /** - * Throws when a table path is unresolved, i.e. the table identifier does not exist in the catalog. + * Throws when a table path is unresolved, i.e. the table identifier + * does not exist in the catalog. * * @param identifier the unresolved table identifier */ diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala index 9daa760678daa..c9af65a0972c7 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -58,7 +58,6 @@ case class QueryOrigin( fileName = other.fileName.orElse(fileName), sqlText = other.sqlText.orElse(sqlText), line = other.line.orElse(line), - cellNumber = other.cellNumber.orElse(cellNumber), startPosition = other.startPosition.orElse(startPosition), objectType = other.objectType.orElse(objectType), objectName = other.objectName.orElse(objectName) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala index d7f652792041b..82ede5c1bd228 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.pipelines.util import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.pipelines.Language import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions /** diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala index 4cf8039b8aad7..3be6510053d44 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql.pipelines.utils -import java.io.{BufferedReader, File, FileNotFoundException, InputStreamReader} -import java.nio.file.{Files, Paths} +import java.io.{BufferedReader, FileNotFoundException, InputStreamReader} +import java.nio.file.Files import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ import scala.util.{Failure, Try} import scala.util.control.NonFatal import org.scalactic.source import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag} -import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.matchers.should.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} @@ -61,7 +59,7 @@ abstract class PipelineTest * all spark sessions created in tests. */ protected def sparkConf: SparkConf = { - var conf = new SparkConf() + new SparkConf() .set("spark.sql.shuffle.partitions", "2") .set("spark.sql.session.timeZone", "UTC") } From 777189feeae14fe08513865c09d0b691765489ae Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 1 Jun 2025 07:38:20 -0700 Subject: [PATCH 16/32] remove shim exclusion in spark core test dep --- sql/pipelines/pom.xml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index b18ee7adc8fb9..105fcf7e70375 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -65,12 +65,6 @@ ${project.version} test-jar test - - - org.apache.spark - spark-connect-shims_${scala.binary.version} - - org.apache.spark From 4ac428c532f8b340b4d18b73b1f53ddd4c711712 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Mon, 2 Jun 2025 14:59:51 -0700 Subject: [PATCH 17/32] 1 --- sql/pipelines/pom.xml | 14 +++++++------- .../graph/CoreDataflowNodeProcessor.scala | 5 +++-- .../spark/sql/pipelines/graph/elements.scala | 9 --------- .../sql/pipelines/utils/PipelineTest.scala | 17 ----------------- .../utils/TestGraphRegistrationContext.scala | 5 ----- 5 files changed, 10 insertions(+), 40 deletions(-) diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 105fcf7e70375..bd7d0b54746ee 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} @@ -59,13 +66,6 @@ - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index a9a90ca509d09..c79b80c21bbf8 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -183,8 +183,9 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { } convertResolvedToTypedFlow(flowToResolve, maybeNewFuncResult) - // If flow failed due to unresolved dataset, throw a retryable exception, otherwise just - // return the failed flow. + // If the flow failed due to an UnresolvedDatasetException, it means that one of the + // flow's inputs wasn't available. After other flows are resolved, these inputs + // may become available, so throw a retryable exception in this case. case f => f.dataFrame.failed.toOption.collectFirst { case e: UnresolvedDatasetException => e diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 2dfcaf5f7856a..79550634084d8 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -113,7 +113,6 @@ sealed trait TableInput extends Input { * @param normalizedPath Normalized storage location for the table based on the user-specified table * path (if not defined, we will normalize a managed storage path for it). * @param properties Table Properties to set in table metadata. - * @param sqlText For SQL-defined pipelines, the original string of the SELECT query. * @param comment User-specified comment that can be placed on the table. * @param isStreamingTableOpt if the table is a streaming table, will be None until we have resolved * flows into table @@ -124,7 +123,6 @@ case class Table( partitionCols: Option[Seq[String]], normalizedPath: Option[String], properties: Map[String, String] = Map.empty, - sqlText: Option[String], comment: Option[String], baseOrigin: QueryOrigin, isStreamingTableOpt: Option[Boolean], @@ -239,9 +237,6 @@ trait View extends GraphElement { /** Properties of this view */ val properties: Map[String, String] - /** (SQL-specific) The raw query that defines the [[View]]. */ - val sqlText: Option[String] - /** User-specified comment that can be placed on the [[View]]. */ val comment: Option[String] } @@ -251,13 +246,11 @@ trait View extends GraphElement { * * @param identifier The identifier of this view within the graph. * @param properties Properties of the view - * @param sqlText Raw SQL query that defines the view. * @param comment when defining a view */ case class TemporaryView( identifier: TableIdentifier, properties: Map[String, String], - sqlText: Option[String], comment: Option[String], origin: QueryOrigin ) extends View {} @@ -267,13 +260,11 @@ case class TemporaryView( * * @param identifier The identifier of this view within the graph. * @param properties Properties of the view - * @param sqlText Raw SQL query that defines the view. * @param comment when defining a view */ case class PersistedView( identifier: TableIdentifier, properties: Map[String, String], - sqlText: Option[String], comment: Option[String], origin: QueryOrigin ) extends View {} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala index 3be6510053d44..981fa3cdcae85 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -92,23 +92,6 @@ abstract class PipelineTest } } - /** - * Runs the given function with the given spark conf, and resets the conf after the function - * completes. - */ - def withSparkConfs[T](confs: Map[String, String])(f: => T): T = { - val originalConfs = confs.keys.map(k => k -> spark.conf.getOption(k)).toMap - confs.foreach { case (k, v) => spark.conf.set(k, v) } - try f - finally originalConfs.foreach { - case (k, v) => - v match { - case Some(v) => spark.conf.set(k, v) - case None => spark.conf.unset(k) - } - } - } - /** * This exists temporarily for compatibility with tests that become invalid when multiple * executors are available. diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index d5eafc411ade2..3449c5155c754 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -55,7 +55,6 @@ class TestGraphRegistrationContext( query: Option[FlowFunction] = None, sqlConf: Map[String, String] = Map.empty, comment: Option[String] = None, - sqlText: Option[String] = None, specifiedSchema: Option[StructType] = None, partitionCols: Option[Seq[String]] = None, properties: Map[String, String] = Map.empty, @@ -70,7 +69,6 @@ class TestGraphRegistrationContext( Table( identifier = GraphIdentifierManager.parseTableIdentifier(name, spark), comment = comment, - sqlText = sqlText, specifiedSchema = specifiedSchema, partitionCols = partitionCols, properties = properties, @@ -105,7 +103,6 @@ class TestGraphRegistrationContext( query: FlowFunction, sqlConf: Map[String, String] = Map.empty, comment: Option[String] = None, - sqlText: Option[String] = None, origin: QueryOrigin = QueryOrigin.empty, viewType: ViewType = LocalTempView, catalog: Option[String] = None, @@ -121,7 +118,6 @@ class TestGraphRegistrationContext( TemporaryView( identifier = viewIdentifier, comment = comment, - sqlText = sqlText, origin = origin, properties = Map.empty ) @@ -129,7 +125,6 @@ class TestGraphRegistrationContext( PersistedView( identifier = viewIdentifier, comment = comment, - sqlText = sqlText, origin = origin, properties = Map.empty ) From d887aea11e24fcff339428e258dad9aaab097d15 Mon Sep 17 00:00:00 2001 From: Aakash Japi Date: Mon, 2 Jun 2025 16:00:02 -0700 Subject: [PATCH 18/32] remove param --- .../org/apache/spark/sql/pipelines/util/InputReadInfo.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala index 82ede5c1bd228..070927aea295f 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala @@ -27,8 +27,6 @@ sealed trait InputReadOptions /** * Options for a batch read of an input. - * - * @param apiLanguage The language of the public API that called this function. */ final case class BatchReadOptions() extends InputReadOptions From b4cbf083076b2c3922b11d341399bdc35561e709 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 21 May 2025 17:23:45 -0700 Subject: [PATCH 19/32] Remove validateAppendOnceFlows --- .../sql/pipelines/graph/DataflowGraph.scala | 2 -- .../spark/sql/pipelines/graph/Flow.scala | 6 ------ .../pipelines/graph/GraphValidations.scala | 20 ------------------- 3 files changed, 28 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 49c9611a483ff..adf4f8c061653 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -191,8 +191,6 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) validatePersistedViewSources() validateEveryDatasetHasFlow() validateTablesAreResettable() - validateAppendOnceFlows() - // Ensures that all flows are resolved and have a valid schema. inferredSchema }.failed diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index 978dc09812d28..67ec95dba0f8e 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -192,11 +192,5 @@ class AppendOnceFlow( val funcResult: FlowFunctionResult ) extends ResolvedFlow { - /** - * Whether the flow was declared as once or not in UnresolvedFlow. If false, then it means the - * flow is created from batch query. - */ - val definedAsOnce: Boolean = flow.once - override val once = true } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 56ca4d0a8a840..8a5b003a18246 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -124,26 +124,6 @@ trait GraphValidations extends Logging { } } - /** - * Validate if we have any append only flows writing into a streaming table but was created - * from a batch query. - */ - protected def validateAppendOnceFlows(): Seq[GraphValidationWarning] = { - flows - .filter { - case af: AppendOnceFlow => !af.definedAsOnce - case _ => false - } - .groupBy(_.destinationIdentifier) - .flatMap { - case (destination, flows) => - table - .get(destination) - .map(t => AppendOnceFlowCreatedFromBatchQueryException(t, flows.map(_.identifier))) - } - .toSeq - } - protected def validateUserSpecifiedSchemas(): Unit = { flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput => // The output inferred schema of a table is the declared schema merged with the From 6b20fa0f6575b213de8a1a4d7f87fa80e78b319d Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 28 May 2025 13:58:35 -0700 Subject: [PATCH 20/32] Clean up PipelinesErrors --- .../pipelines/graph/GraphValidations.scala | 17 ++++-- .../sql/pipelines/graph/PipelinesErrors.scala | 53 ++----------------- 2 files changed, 18 insertions(+), 52 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 8a5b003a18246..98bb8afb94c16 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -86,12 +86,12 @@ trait GraphValidations extends Logging { * Validate that all tables are resettable. This is a best-effort check that will only catch * upstream tables that are resettable but have a non-resettable downstream dependency. */ - protected def validateTablesAreResettable(): Seq[GraphValidationWarning] = { + protected def validateTablesAreResettable(): Unit = { validateTablesAreResettable(tables) } /** Validate that all specified tables are resettable. */ - protected def validateTablesAreResettable(tables: Seq[Table]): Seq[GraphValidationWarning] = { + protected def validateTablesAreResettable(tables: Seq[Table]): Unit = { val tableLookup = mapUnique(tables, "table")(_.identifier) val nonResettableTables = tables.filter(t => !PipelinesTableProperties.resetAllowed.fromMap(t.properties)) @@ -120,7 +120,18 @@ trait GraphValidations extends Logging { .reverse .map { case (nameForEvent, tables) => - InvalidResettableDependencyException(nameForEvent, tables) + throw new AnalysisException( + "INVALID_RESETTABLE_DEPENDENCY", + Map( + "downstreamTable" -> nameForEvent, + "upstreamResettableTables" -> tables + .map(_.displayName) + .sorted + .map(t => s"'$t'") + .mkString(", "), + "resetAllowedKey" -> PipelinesTableProperties.resetAllowed.key + ) + ) } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index aaeece7007aa8..1a4cec1eb18e7 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -39,12 +37,10 @@ case class UnresolvedDatasetException(identifier: TableIdentifier) * @param name The name of the table * @param cause The cause of the failure */ -case class LoadTableException(name: String, cause: Option[Throwable]) - extends SparkException( - errorClass = "INTERNAL_ERROR", - messageParameters = Map("message" -> s"Failed to load table '$name'"), - cause = cause.orNull - ) +case class LoadTableException(name: String, override val cause: Option[Throwable]) + extends AnalysisException(s"Failed to load table '$name'", cause = cause) + + /** * Exception raised when a pipeline has one or more flows that cannot be resolved @@ -79,13 +75,6 @@ case class UnresolvedPipelineException( |failures that precede this log.""".stripMargin ) -/** A validation error that can either be thrown as an exception or logged as a warning. */ -trait GraphValidationWarning extends Logging { - - /** The exception to throw when this validation fails. */ - protected def exception: AnalysisException -} - /** * Raised when there's a circular dependency in the current pipeline. That is, a downstream * table is referenced while creating a upstream table. @@ -99,37 +88,3 @@ case class CircularDependencyException( s"Circular dependencies are not supported in a pipeline. Please remove the dependency " + s"between '${upstreamDataset.unquotedString}' and '${downstreamTable.unquotedString}'." ) - -/** - * Raised when some tables in the current pipeline are not resettable due to some non-resettable - * downstream dependencies. - */ -case class InvalidResettableDependencyException(originName: String, tables: Seq[Table]) - extends GraphValidationWarning { - override def exception: AnalysisException = new AnalysisException( - "INVALID_RESETTABLE_DEPENDENCY", - Map( - "downstreamTable" -> originName, - "upstreamResettableTables" -> tables - .map(_.displayName) - .sorted - .map(t => s"'$t'") - .mkString(", "), - "resetAllowedKey" -> PipelinesTableProperties.resetAllowed.key - ) - ) -} - -/** - * Warn if the append once flows was declared from batch query if there was a run before. - * Throw an exception if not. - * @param table the streaming destination that contains Append Once flows declared with batch query. - * @param flows the append once flows that are declared with batch query. - */ -case class AppendOnceFlowCreatedFromBatchQueryException(table: Table, flows: Seq[TableIdentifier]) - extends GraphValidationWarning { - override def exception: AnalysisException = new AnalysisException( - "APPEND_ONCE_FROM_BATCH_QUERY", - Map("table" -> table.displayName) - ) -} From 73a1df392604cf7b4c5ce69c5ad0b8048817a94c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 2 Jun 2025 15:45:45 -0700 Subject: [PATCH 21/32] Remove weird Returns comments for lazy vals in DataflowGraph --- .../sql/pipelines/graph/DataflowGraph.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index adf4f8c061653..585ba6295f239 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -34,11 +34,11 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) extends GraphOperations with GraphValidations { - /** Returns a [[Output]] given its identifier */ + /** Map of [[Output]]s by their identifiers */ lazy val output: Map[TableIdentifier, Output] = mapUnique(tables, "output")(_.identifier) /** - * Returns [[Flow]]s in this graph that need to get planned and potentially executed when + * [[Flow]]s in this graph that need to get planned and potentially executed when * executing the graph. Flows that write to logical views are excluded. */ lazy val materializedFlows: Seq[ResolvedFlow] = { @@ -47,14 +47,14 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) ) } - /** Returns the identifiers of [[materializedFlows]]. */ + /** The identifiers of [[materializedFlows]]. */ val materializedFlowIdentifiers: Set[TableIdentifier] = materializedFlows.map(_.identifier).toSet - /** Returns a [[Table]] given its identifier */ + /** Map of [[Table]]s by their identifiers */ lazy val table: Map[TableIdentifier, Table] = mapUnique(tables, "table")(_.identifier) - /** Returns a [[Flow]] given its identifier */ + /** Map of [[Flow]]s by their identifier */ lazy val flow: Map[TableIdentifier, Flow] = { // Better error message than using mapUnique. val flowsByIdentifier = flows.groupBy(_.identifier) @@ -89,20 +89,20 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) flowsByIdentifier.view.mapValues(_.head).toMap } - /** Returns a [[View]] given its identifier */ + /** Map of [[View]]s by their identifiers */ lazy val view: Map[TableIdentifier, View] = mapUnique(views, "view")(_.identifier) - /** Returns the [[PersistedView]]s of the graph */ + /** The [[PersistedView]]s of the graph */ lazy val persistedViews: Seq[PersistedView] = views.collect { case v: PersistedView => v } - /** Returns all the [[Input]]s in the current DataflowGraph. */ + /** All the [[Input]]s in the current DataflowGraph. */ lazy val inputIdentifiers: Set[TableIdentifier] = { (flows ++ tables).map(_.identifier).toSet } - /** Returns the [[Flow]]s that write to a given destination. */ + /** The [[Flow]]s that write to a given destination. */ lazy val flowsTo: Map[TableIdentifier, Seq[Flow]] = flows.groupBy(_.destinationIdentifier) lazy val resolvedFlows: Seq[ResolvedFlow] = { @@ -155,7 +155,7 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) } /** - * Returns a map of the inferred schema of each table, computed by merging the analyzed schemas + * A map of the inferred schema of each table, computed by merging the analyzed schemas * of all flows writing to that table. */ lazy val inferredSchema: Map[TableIdentifier, StructType] = { From c09a644fee361838a7b27b248bb80cd09ab4117c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 2 Jun 2025 15:51:12 -0700 Subject: [PATCH 22/32] Fix LoadTableException again --- .../spark/sql/pipelines/graph/PipelinesErrors.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 1a4cec1eb18e7..4bed25f2aa1c7 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.pipelines.graph +import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -37,10 +38,12 @@ case class UnresolvedDatasetException(identifier: TableIdentifier) * @param name The name of the table * @param cause The cause of the failure */ -case class LoadTableException(name: String, override val cause: Option[Throwable]) - extends AnalysisException(s"Failed to load table '$name'", cause = cause) - - +case class LoadTableException(name: String, cause: Option[Throwable]) + extends SparkException( + errorClass = "INTERNAL_ERROR", + messageParameters = Map("message" -> s"Failed to load table '$name'"), + cause = cause.orNull + ) /** * Exception raised when a pipeline has one or more flows that cannot be resolved From a714d3577cfd6930c018a40029928cae21243bda Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 2 Jun 2025 17:24:29 -0700 Subject: [PATCH 23/32] Add SchemaInferenceUtilsSuite --- .../util/SchemaInferenceUtilsSuite.scala | 273 ++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala new file mode 100644 index 0000000000000..41d5bbe14a6b1 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.types._ + +class SchemaInferenceUtilsSuite extends SparkFunSuite { + + test("determineColumnChanges - adding new columns") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("age", IntegerType) + .add("email", StringType, nullable = true, "Email address") + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 2 changes - adding 'age' and 'email' columns + assert(changes.length === 2) + + // Verify the changes are of the correct type and have the right properties + val ageChange = changes + .find { + case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("age")) + case _ => false + } + .get + .asInstanceOf[TableChange.AddColumn] + + val emailChange = changes + .find { + case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("email")) + case _ => false + } + .get + .asInstanceOf[TableChange.AddColumn] + + // Verify age column properties + assert(ageChange.dataType() === IntegerType) + assert(ageChange.isNullable() === true) // Default nullable is true + assert(ageChange.comment() === null) + + // Verify email column properties + assert(emailChange.dataType() === StringType) + assert(emailChange.isNullable() === true) + assert(emailChange.comment() === "Email address") + } + + test("determineColumnChanges - updating column types") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("amount", DoubleType) + .add("timestamp", TimestampType) + + val targetSchema = new StructType() + .add("id", LongType, nullable = false) // Changed type from Int to Long + .add("amount", DecimalType(10, 2)) // Changed type from Double to Decimal + .add("timestamp", TimestampType) // No change + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 2 changes - updating 'id' and 'amount' column types + assert(changes.length === 2) + + // Verify the changes are of the correct type + val idChange = changes + .find { + case update: TableChange.UpdateColumnType => update.fieldNames().sameElements(Array("id")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnType] + + val amountChange = changes + .find { + case update: TableChange.UpdateColumnType => + update.fieldNames().sameElements(Array("amount")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnType] + + // Verify the new data types + assert(idChange.newDataType() === LongType) + assert(amountChange.newDataType() === DecimalType(10, 2)) + } + + test("determineColumnChanges - updating nullability and comments") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType, nullable = true) + .add("description", StringType, nullable = true, "Item description") + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = true) // Changed nullability + .add("name", StringType, nullable = false) // Changed nullability + .add("description", StringType, nullable = true, "Product description") // Changed comment + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 3 changes - updating nullability for 'id' and 'name', and comment for + // 'description' + assert(changes.length === 3) + + // Verify the nullability changes + val idNullabilityChange = changes + .find { + case update: TableChange.UpdateColumnNullability => + update.fieldNames().sameElements(Array("id")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnNullability] + + val nameNullabilityChange = changes + .find { + case update: TableChange.UpdateColumnNullability => + update.fieldNames().sameElements(Array("name")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnNullability] + + // Verify the comment change + val descriptionCommentChange = changes + .find { + case update: TableChange.UpdateColumnComment => + update.fieldNames().sameElements(Array("description")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnComment] + + // Verify the new nullability values + assert(idNullabilityChange.nullable() === true) + assert(nameNullabilityChange.nullable() === false) + + // Verify the new comment + assert(descriptionCommentChange.newComment() === "Product description") + } + + test("determineColumnChanges - complex changes") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("old_field", BooleanType) + + val targetSchema = new StructType() + .add("id", LongType, nullable = true) // Changed type and nullability + // Added comment and changed nullability + .add("name", StringType, nullable = false, "Full name") + .add("new_field", StringType) // New field + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have these changes: + // 1. Update id type + // 2. Update id nullability + // 3. Update name nullability + // 4. Update name comment + // 5. Add new_field + // 6. Remove old_field + assert(changes.length === 6) + + // Count the types of changes + val typeChanges = changes.collect { case _: TableChange.UpdateColumnType => 1 }.size + val nullabilityChanges = changes.collect { + case _: TableChange.UpdateColumnNullability => 1 + }.size + val commentChanges = changes.collect { case _: TableChange.UpdateColumnComment => 1 }.size + val addColumnChanges = changes.collect { case _: TableChange.AddColumn => 1 }.size + + assert(typeChanges === 1) + assert(nullabilityChanges === 2) + assert(commentChanges === 1) + assert(addColumnChanges === 1) + } + + test("determineColumnChanges - no changes") { + val schema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("timestamp", TimestampType) + + // Same schema, no changes expected + val changes = SchemaInferenceUtils.diffSchemas(schema, schema) + assert(changes.isEmpty) + } + + test("determineColumnChanges - deleting columns") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("age", IntegerType) + .add("email", StringType) + .add("phone", StringType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + // age, email, and phone columns are removed + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 3 changes - deleting 'age', 'email', and 'phone' columns + assert(changes.length === 3) + + // Verify all changes are DeleteColumn operations + val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc } + assert(deleteChanges.length === 3) + + // Verify the specific columns being deleted + val columnNames = deleteChanges.map(_.fieldNames()(0)).toSet + assert(columnNames === Set("age", "email", "phone")) + } + + test("determineColumnChanges - mixed additions and deletions") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("first_name", StringType) + .add("last_name", StringType) + .add("age", IntegerType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("full_name", StringType) // New column + .add("email", StringType) // New column + .add("age", IntegerType) // Unchanged + // first_name and last_name are removed + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 4 changes: + // - 2 additions (full_name, email) + // - 2 deletions (first_name, last_name) + assert(changes.length === 4) + + // Count the types of changes + val addChanges = changes.collect { case ac: TableChange.AddColumn => ac } + val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc } + + assert(addChanges.length === 2) + assert(deleteChanges.length === 2) + + // Verify the specific columns being added and deleted + val addedColumnNames = addChanges.map(_.fieldNames()(0)).toSet + val deletedColumnNames = deleteChanges.map(_.fieldNames()(0)).toSet + + assert(addedColumnNames === Set("full_name", "email")) + assert(deletedColumnNames === Set("first_name", "last_name")) + } +} From 7e40fa8636dceec65e889a36126946f0c859f069 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 2 Jun 2025 20:42:05 -0700 Subject: [PATCH 24/32] wenchen feedback and try to fix build errors --- .../main/resources/error/error-conditions.json | 6 ------ sql/pipelines/pom.xml | 6 ------ .../spark/sql/pipelines/graph/Flow.scala | 4 ++-- .../sql/pipelines/graph/GraphErrors.scala | 18 ------------------ .../sql/pipelines/graph/GraphValidations.scala | 10 +++++++--- .../spark/sql/pipelines/graph/elements.scala | 4 ++-- 6 files changed, 11 insertions(+), 37 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index cc461221290c4..6ce5d2240529d 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4607,12 +4607,6 @@ ], "sqlState" : "42K03" }, - "PERSISTED_VIEW_READS_FROM_TEMPORARY_VIEW" : { - "message" : [ - "Persisted view cannot reference temporary view that will not be available outside the pipeline scope. Either make the persisted view temporary or persist the temporary view." - ], - "sqlState" : "42K0F" - }, "PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : { "message" : [ "Non-grouping expression is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again." diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index bd7d0b54746ee..254fd1c16cfb0 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -59,12 +59,6 @@ org.apache.spark spark-sql_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-connect-shims_${scala.binary.version} - - org.apache.spark diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index 67ec95dba0f8e..2378b6f8d96a6 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -68,14 +68,14 @@ case class ResolvedInput(input: Input, aliasIdentifier: AliasIdentifier) trait FlowFunction extends Logging { /** - * This function defines the transformations performed by a flow, expressed as a [[DataFrame]]. + * This function defines the transformations performed by a flow, expressed as a DataFrame. * * @param allInputs the set of identifiers for all the [[Input]]s defined in the * [[DataflowGraph]]. * @param availableInputs the list of all [[Input]]s available to this flow * @param configuration the spark configurations that apply to this flow. * @param queryContext The context of the query being evaluated. - * @return the inputs actually used, and the [[DataFrame]] expression for the flow + * @return the inputs actually used, and the DataFrame expression for the flow */ def call( allInputs: Set[TableIdentifier], diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala index b360a0807bb0d..53db669e687d2 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala @@ -128,22 +128,4 @@ object GraphErrors { cause = Option(cause.orNull) ) } - - /** - * Throws an error when a persisted view is trying to read from a temporary view. - * - * @param persistedViewIdentifier the identifier of the persisted view - * @param temporaryViewIdentifier the identifier of the temporary view - */ - def persistedViewReadsFromTemporaryView( - persistedViewIdentifier: TableIdentifier, - temporaryViewIdentifier: TableIdentifier): AnalysisException = { - new AnalysisException( - "PERSISTED_VIEW_READS_FROM_TEMPORARY_VIEW", - Map( - "persistedViewName" -> persistedViewIdentifier.toString, - "temporaryViewName" -> temporaryViewIdentifier.toString - ) - ) - } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 98bb8afb94c16..99142432f9cec 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -252,9 +252,13 @@ trait GraphValidations extends Logging { .flatMap(view.get) .foreach { case tempView: TemporaryView => - throw GraphErrors.persistedViewReadsFromTemporaryView( - persistedViewIdentifier = persistedView.identifier, - temporaryViewIdentifier = tempView.identifier + throw new AnalysisException( + errorClass = "INVALID_TEMP_OBJ_REFERENCE", + messageParameters = Map( + "persistedViewName" -> persistedView.identifier.toString, + "temporaryViewName" -> tempView.identifier.toString + ), + cause = null ) case _ => } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 79550634084d8..2f03aa665b012 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -67,9 +67,9 @@ trait GraphElement { trait Input extends GraphElement { /** - * Returns a [[DataFrame]] that is a result of loading data from this [[Input]]. + * Returns a DataFrame that is a result of loading data from this [[Input]]. * @param readOptions Type of input. Used to determine streaming/batch - * @return Streaming or batch [[DataFrame]] of this Input's data. + * @return Streaming or batch DataFrame of this Input's data. */ def load(readOptions: InputReadOptions): DataFrame } From ab007fe40ad47ee6343afe84a0713931a69b466e Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 3 Jun 2025 07:03:46 -0700 Subject: [PATCH 25/32] another docs fix --- .../scala/org/apache/spark/sql/pipelines/graph/elements.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 2f03aa665b012..770776b29cf08 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -82,7 +82,7 @@ sealed trait Output { /** * Normalized storage location used for storing materializations for this [[Output]]. - * If [[None]], it means this [[Output]] has not been normalized yet. + * If None, it means this [[Output]] has not been normalized yet. */ def normalizedPath: Option[String] From 551516764185635e2ec2b976fc7afeb37ffe5c88 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 3 Jun 2025 07:16:47 -0700 Subject: [PATCH 26/32] remove direct spark-core test jar dependency --- sql/pipelines/pom.xml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 254fd1c16cfb0..179fe0d3ed75b 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -48,13 +48,6 @@ spark-core_${scala.binary.version} ${project.version} - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.spark spark-sql_${scala.binary.version} From 54690c7d9c8fab3011faab503159035ddc7497ae Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 3 Jun 2025 09:19:57 -0700 Subject: [PATCH 27/32] fix SparkThrowableSuite --- .../src/main/resources/error/error-conditions.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6ce5d2240529d..993ffd888e0e7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5457,12 +5457,6 @@ ], "sqlState" : "58030" }, - "UNABLE_TO_INFER_SCHEMA" : { - "message" : [ - "Unable to infer schema for . It must be specified manually." - ], - "sqlState" : "42KD9" - }, "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA" : { "message" : [ "Failed to infer the schema for table from its upstream flows.", @@ -5476,6 +5470,12 @@ ], "sqlState" : "42KD9" }, + "UNABLE_TO_INFER_SCHEMA" : { + "message" : [ + "Unable to infer schema for . It must be specified manually." + ], + "sqlState" : "42KD9" + }, "UNBOUND_SQL_PARAMETER" : { "message" : [ "Found the unbound parameter: . Please, fix `args` and provide a mapping of the parameter to either a SQL literal or collection constructor functions such as `map()`, `array()`, `struct()`." From f4e3ecd8a123312f865e9162b27be020a27f158e Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 3 Jun 2025 11:19:01 -0700 Subject: [PATCH 28/32] try more to fix the build --- sql/pipelines/pom.xml | 5 ----- .../spark/sql/pipelines/graph/FlowAnalysisContext.scala | 2 +- .../org/apache/spark/sql/pipelines/graph/QueryOrigin.scala | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 179fe0d3ed75b..4ab7db079e455 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -43,11 +43,6 @@ ${project.version} test - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala index 0e04e17f7f7b0..fb96c6cb5bb1d 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.pipelines.AnalysisWarning /** - * A context used when evaluating a [[Flow]]'s query into a concrete [[DataFrame]]. + * A context used when evaluating a [[Flow]]'s query into a concrete DataFrame. * * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph. * @param availableInputs Inputs available to be referenced with `read` or `readStream`. diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala index c9af65a0972c7..042b4d9626fd6 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -80,7 +80,7 @@ case class QueryOrigin( ) } - /** Generates a [[SourceCodeLocation]] using the details present in the query origin. */ + /** Generates a SourceCodeLocation using the details present in the query origin. */ def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation( path = fileName, // QueryOrigin tracks line numbers using a 1-indexed numbering scheme whereas SourceCodeLocation From 5927325ffc417f0b7759c1c465f73b6cb5bbc80d Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 3 Jun 2025 13:29:42 -0700 Subject: [PATCH 29/32] more docs fix --- .../spark/sql/pipelines/graph/GraphIdentifierManager.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala index 2561198aa2b5a..414d9d0effea4 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala @@ -112,7 +112,6 @@ object GraphIdentifierManager { * * @param rawTableIdentifier the raw table identifier * @return the parsed table identifier - * @throws AnalysisException if the table identifier is not allowed */ @throws[AnalysisException] def parseAndQualifyTableIdentifier( @@ -138,7 +137,6 @@ object GraphIdentifierManager { * * @param rawViewIdentifier the raw view identifier * @return the parsed view identifier - * @throws AnalysisException if the view identifier is not allowed */ @throws[AnalysisException] def parseAndValidateTemporaryViewIdentifier( @@ -163,7 +161,6 @@ object GraphIdentifierManager { * @param currentCatalog the catalog * @param currentDatabase the schema * @return the parsed view identifier - * @throws AnalysisException if the view identifier is not allowed */ def parseAndValidatePersistedViewIdentifier( rawViewIdentifier: TableIdentifier, @@ -188,7 +185,6 @@ object GraphIdentifierManager { * * @param rawFlowIdentifier the raw flow identifier * @return the parsed flow identifier - * @throws AnalysisException if the flow identifier is not allowed */ @throws[AnalysisException] def parseAndQualifyFlowIdentifier( @@ -240,7 +236,6 @@ object IdentifierHelper { * * @param nameParts the dataset name parts. * @return the table identifier constructed from the name parts. - * @throws UnsupportedOperationException if the name parts have more than 3 parts. */ @throws[UnsupportedOperationException] def toTableIdentifier(nameParts: Seq[String]): TableIdentifier = { @@ -265,7 +260,6 @@ object IdentifierHelper { * * @param table the logical plan. * @return the table identifier constructed from the logical plan. - * @throws SparkException if the table identifier cannot be resolved. */ def toTableIdentifier(table: LogicalPlan): TableIdentifier = { val parts = table match { From f310b4f3ad2052f830d7641c1f951da7e9008d4b Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 4 Jun 2025 07:54:58 -0700 Subject: [PATCH 30/32] fix structured logging checks --- .../spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index c79b80c21bbf8..2df434fc73d43 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -219,9 +219,9 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { case _: UnresolvedFlow => new CompleteFlow(flow, funcResult) } if (!funcResult.resolved) { - logError(s"Failed to resolve ${flow.displayName}: ${funcResult.failure.mkString("\n\n\n")}") + logError(log"Failed to resolve ${flow.displayName}: ${funcResult.failure.mkString("\n\n\n")}") } else { - logInfo(s"Successfully resolved ${flow.displayName}") + logInfo(log"Successfully resolved ${flow.displayName}") } typedFlow } From 55ad0abd21242cca4560019d6fb459817b32876e Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 4 Jun 2025 08:42:04 -0700 Subject: [PATCH 31/32] more --- .../sql/pipelines/graph/CoreDataflowNodeProcessor.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 2df434fc73d43..d33924c2e1c37 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -21,7 +21,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.jdk.CollectionConverters._ -import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ @@ -116,7 +115,7 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { } } -private class FlowResolver(rawGraph: DataflowGraph) extends Logging { +private class FlowResolver(rawGraph: DataflowGraph) { /** Helper used to track which confs were set by which flows. */ private case class FlowConf(key: String, value: String, flowIdentifier: TableIdentifier) @@ -218,11 +217,6 @@ private class FlowResolver(rawGraph: DataflowGraph) extends Logging { new StreamingFlow(flow, funcResult, mustBeAppend = mustBeAppend) case _: UnresolvedFlow => new CompleteFlow(flow, funcResult) } - if (!funcResult.resolved) { - logError(log"Failed to resolve ${flow.displayName}: ${funcResult.failure.mkString("\n\n\n")}") - } else { - logInfo(log"Successfully resolved ${flow.displayName}") - } typedFlow } } From 6bf48db91ae2a0042c1fc95e574ba9aebf8519e0 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 4 Jun 2025 09:43:44 -0700 Subject: [PATCH 32/32] simplify build files for pipelines module --- project/SparkBuild.scala | 53 ---------------------------------------- sql/pipelines/pom.xml | 46 ---------------------------------- 2 files changed, 99 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e3db94eb39b3d..77001e6bdf227 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -418,8 +418,6 @@ object SparkBuild extends PomBuild { enable(HiveThriftServer.settings)(hiveThriftServer) - enable(SparkDeclarativePipelines.settings)(pipelines) - enable(SparkConnectCommon.settings)(connectCommon) enable(SparkConnect.settings)(connect) enable(SparkConnectClient.settings)(connectClient) @@ -886,57 +884,6 @@ object SparkConnectClient { ) } -object SparkDeclarativePipelines { - import BuildCommons.protoVersion - - lazy val settings = Seq( - // For some reason the resolution from the imported Maven build does not work for some - // of these dependendencies that we need to shade later on. - libraryDependencies ++= { - val guavaVersion = - SbtPomKeys.effectivePom.value.getProperties.get( - "connect.guava.version").asInstanceOf[String] - val guavaFailureAccessVersion = - SbtPomKeys.effectivePom.value.getProperties.get( - "guava.failureaccess.version").asInstanceOf[String] - Seq( - "com.google.guava" % "guava" % guavaVersion, - "com.google.guava" % "failureaccess" % guavaFailureAccessVersion, - "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" - ) - }, - - (assembly / logLevel) := Level.Info, - - // Exclude `scala-library` from assembly. - (assembly / assemblyPackageScala / assembleArtifact) := false, - - // SPARK-46733: Include `spark-connect-*.jar`, `unused-*.jar`,`guava-*.jar`, - // `failureaccess-*.jar`, `annotations-*.jar`, `grpc-*.jar`, `protobuf-*.jar`, - // `gson-*.jar`, `error_prone_annotations-*.jar`, `j2objc-annotations-*.jar`, - // `animal-sniffer-annotations-*.jar`, `perfmark-api-*.jar`, - // `proto-google-common-protos-*.jar` in assembly. - // This needs to be consistent with the content of `maven-shade-plugin`. - (assembly / assemblyExcludedJars) := { - val cp = (assembly / fullClasspath).value - val validPrefixes = Set("spark-connect", "unused-", "guava-", "failureaccess-", - "annotations-", "grpc-", "protobuf-", "gson", "error_prone_annotations", - "j2objc-annotations", "animal-sniffer-annotations", "perfmark-api", - "proto-google-common-protos") - cp filterNot { v => - validPrefixes.exists(v.data.getName.startsWith) - } - }, - - (assembly / assemblyMergeStrategy) := { - case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard - // Drop all proto files that are not needed as artifacts of the build. - case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard - case _ => MergeStrategy.first - } - ) -} - object SparkProtobuf { import BuildCommons.protoVersion diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 4ab7db079e455..a04993299ce7c 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -119,51 +119,5 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - kr.motd.maven - os-maven-plugin - 1.7.0 - true - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.6.1 - - com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - - src/main/protobuf - grpc-java - io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier} - - - - - generate-sources - - compile - compile-custom - test-compile - - - - - - - net.alchim31.maven - scala-maven-plugin - 4.9.2 - - - - compile - testCompile - - - - - -