diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 8dddca5077a67..993ffd888e0e7 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -82,6 +82,14 @@
],
"sqlState" : "XX000"
},
+ "APPEND_ONCE_FROM_BATCH_QUERY" : {
+ "message" : [
+ "Creating a streaming table from a batch query prevents incremental loading of new data from source. Offending table: '
'.",
+ "Please use the stream() operator. Example usage:",
+ "CREATE STREAMING TABLE ... AS SELECT ... FROM stream() ..."
+ ],
+ "sqlState" : "42000"
+ },
"ARITHMETIC_OVERFLOW" : {
"message" : [
". If necessary set to \"false\" to bypass this error."
@@ -1372,6 +1380,12 @@
},
"sqlState" : "42734"
},
+ "DUPLICATE_FLOW_SQL_CONF" : {
+ "message" : [
+ "Found duplicate sql conf for dataset '': '' is defined by both '' and ''"
+ ],
+ "sqlState" : "42710"
+ },
"DUPLICATE_KEY" : {
"message" : [
"Found duplicate keys ."
@@ -1943,6 +1957,12 @@
],
"sqlState" : "42818"
},
+ "INCOMPATIBLE_BATCH_VIEW_READ" : {
+ "message" : [
+ "View is a batch view and must be referenced using SparkSession#read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false."
+ ],
+ "sqlState" : "42000"
+ },
"INCOMPATIBLE_COLUMN_TYPE" : {
"message" : [
" can only be performed on tables with compatible column types. The column of the table is type which is not compatible with at the same column of the first table.."
@@ -2019,6 +2039,12 @@
],
"sqlState" : "42613"
},
+ "INCOMPATIBLE_STREAMING_VIEW_READ" : {
+ "message" : [
+ "View is a streaming view and must be referenced using SparkSession#readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false."
+ ],
+ "sqlState" : "42000"
+ },
"INCOMPATIBLE_VIEW_SCHEMA_CHANGE" : {
"message" : [
"The SQL query of view has an incompatible schema change and column cannot be resolved. Expected columns named but got .",
@@ -3119,6 +3145,12 @@
},
"sqlState" : "KD002"
},
+ "INVALID_NAME_IN_USE_COMMAND" : {
+ "message" : [
+ "Invalid name '' in command. Reason: "
+ ],
+ "sqlState" : "42000"
+ },
"INVALID_NON_DETERMINISTIC_EXPRESSIONS" : {
"message" : [
"The operator expects a deterministic expression, but the actual expression is ."
@@ -3384,6 +3416,12 @@
],
"sqlState" : "22023"
},
+ "INVALID_RESETTABLE_DEPENDENCY" : {
+ "message" : [
+ "Tables are resettable but have a non-resettable downstream dependency ''. `reset` will fail as Spark Streaming does not support deleted source data. You can either remove the =false property from '' or add it to its upstream dependencies."
+ ],
+ "sqlState" : "42000"
+ },
"INVALID_RESET_COMMAND_FORMAT" : {
"message" : [
"Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`."
@@ -5419,6 +5457,19 @@
],
"sqlState" : "58030"
},
+ "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA" : {
+ "message" : [
+ "Failed to infer the schema for table from its upstream flows.",
+ "Please modify the flows that write to this table to make their schemas compatible.",
+ "",
+ "Inferred schema so far:",
+ "",
+ "",
+ "Incompatible schema:",
+ ""
+ ],
+ "sqlState" : "42KD9"
+ },
"UNABLE_TO_INFER_SCHEMA" : {
"message" : [
"Unable to infer schema for . It must be specified manually."
@@ -5590,6 +5641,12 @@
],
"sqlState" : "42883"
},
+ "UNRESOLVED_TABLE_PATH" : {
+ "message" : [
+ "Storage path for table cannot be resolved."
+ ],
+ "sqlState" : "22KD1"
+ },
"UNRESOLVED_USING_COLUMN_FOR_JOIN" : {
"message" : [
"USING column cannot be resolved on the side of the join. The -side columns: []."
@@ -6571,6 +6628,20 @@
],
"sqlState" : "P0001"
},
+ "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE" : {
+ "message" : [
+ "Table '' has a user-specified schema that is incompatible with the schema",
+ "inferred from its query.",
+ "",
+ "",
+ "Declared schema:",
+ "",
+ "",
+ "Inferred schema:",
+ ""
+ ],
+ "sqlState" : "42000"
+ },
"VARIABLE_ALREADY_EXISTS" : {
"message" : [
"Cannot create the variable because it already exists.",
diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml
index 7d796a83af69d..a04993299ce7c 100644
--- a/sql/pipelines/pom.xml
+++ b/sql/pipelines/pom.xml
@@ -16,7 +16,8 @@
~ limitations under the License.
-->
-
+4.0.0org.apache.spark
@@ -44,11 +45,76 @@
org.apache.spark
- spark-core_${scala.binary.version}
+ spark-sql_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}${project.version}test-jartest
+
+ org.apache.spark
+ spark-sql-api_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-connect-shims_${scala.binary.version}
+
+
+
+
+ org.scala-lang.modules
+ scala-parallel-collections_${scala.binary.version}
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+ net.bytebuddy
+ byte-buddy
+ test
+
+
+ net.bytebuddy
+ byte-buddy-agent
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+ test-jar
+ test
+
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala
new file mode 100644
index 0000000000000..35b8185c255e1
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines
+
+/** Represents a warning generated as part of graph analysis. */
+sealed trait AnalysisWarning
+
+object AnalysisWarning {
+
+ /**
+ * Warning that some streaming reader options are being dropped
+ *
+ * @param sourceName Source for which reader options are being dropped.
+ * @param droppedOptions Set of reader options that are being dropped for a specific source.
+ */
+ case class StreamingReaderOptionsDropped(sourceName: String, droppedOptions: Seq[String])
+ extends AnalysisWarning
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala
new file mode 100644
index 0000000000000..c627850b667be
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines
+
+sealed trait Language {}
+
+object Language {
+ case class Python() extends Language {}
+ case class Sql() extends Language {}
+}
+
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
new file mode 100644
index 0000000000000..d33924c2e1c37
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{
+ TransformNodeFailedException,
+ TransformNodeRetryableException
+}
+
+/**
+ * Processor that is responsible for analyzing each flow and sort the nodes in
+ * topological order
+ */
+class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) {
+
+ private val flowResolver = new FlowResolver(rawGraph)
+
+ // Map of input identifier to resolved [[Input]].
+ private val resolvedInputs = new ConcurrentHashMap[TableIdentifier, Input]()
+ // Map & queue of resolved flows identifiers
+ // queue is there to track the topological order while map is used to store the id -> flow
+ // mapping
+ private val resolvedFlowNodesMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]()
+ private val resolvedFlowNodesQueue = new ConcurrentLinkedQueue[ResolvedFlow]()
+
+ private def processUnresolvedFlow(flow: UnresolvedFlow): ResolvedFlow = {
+ val resolvedFlow = flowResolver.attemptResolveFlow(
+ flow,
+ rawGraph.inputIdentifiers,
+ resolvedInputs.asScala.toMap
+ )
+ resolvedFlowNodesQueue.add(resolvedFlow)
+ resolvedFlowNodesMap.put(flow.identifier, resolvedFlow)
+ resolvedFlow
+ }
+
+ /**
+ * Processes the node of the graph, re-arranging them if they are not topologically sorted.
+ * Takes care of resolving the flows and virtualizing tables (i.e. removing tables to
+ * ensure resolution is internally consistent) if needed for the nodes.
+ * @param node The node to process
+ * @param upstreamNodes Upstream nodes for the node
+ * @return The resolved nodes generated by processing this element.
+ */
+ def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = {
+ node match {
+ case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow))
+ case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow))
+ case table: Table =>
+ // Ensure all upstreamNodes for a table are flows
+ val flowsToTable = upstreamNodes.map {
+ case flow: Flow => flow
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Unsupported upstream node type for table ${table.displayName}: " +
+ s"${upstreamNodes.getClass}"
+ )
+ }
+ val resolvedFlowsToTable = flowsToTable.map { flow =>
+ resolvedFlowNodesMap.get(flow.identifier)
+ }
+
+ // Assign isStreamingTable (MV or ST) to the table based on the resolvedFlowsToTable
+ val tableWithType = table.copy(
+ isStreamingTableOpt = Option(resolvedFlowsToTable.exists(f => f.df.isStreaming))
+ )
+
+ // We mark all tables as virtual to ensure resolution uses incoming flows
+ // rather than previously materialized tables.
+ val virtualTableInput = VirtualTableInput(
+ identifier = table.identifier,
+ specifiedSchema = table.specifiedSchema,
+ incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet,
+ availableFlows = resolvedFlowsToTable
+ )
+ resolvedInputs.put(table.identifier, virtualTableInput)
+ Seq(tableWithType)
+ case view: View =>
+ // For view, add the flow to resolvedInputs and return empty.
+ require(upstreamNodes.size == 1, "Found multiple flows to view")
+ upstreamNodes.head match {
+ case f: Flow =>
+ resolvedInputs.put(view.identifier, resolvedFlowNodesMap.get(f.destinationIdentifier))
+ Seq(view)
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Unsupported upstream node type for view ${view.displayName}: " +
+ s"${upstreamNodes.getClass}"
+ )
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported node type: ${node.getClass}")
+ }
+ }
+}
+
+private class FlowResolver(rawGraph: DataflowGraph) {
+
+ /** Helper used to track which confs were set by which flows. */
+ private case class FlowConf(key: String, value: String, flowIdentifier: TableIdentifier)
+
+ /** Attempts resolving a single flow using the map of resolved inputs. */
+ def attemptResolveFlow(
+ flowToResolve: UnresolvedFlow,
+ allInputs: Set[TableIdentifier],
+ availableResolvedInputs: Map[TableIdentifier, Input]): ResolvedFlow = {
+ val flowFunctionResult = flowToResolve.func.call(
+ allInputs = allInputs,
+ availableInputs = availableResolvedInputs.values.toList,
+ configuration = flowToResolve.sqlConf,
+ queryContext = flowToResolve.queryContext
+ )
+ val result =
+ flowFunctionResult match {
+ case f if f.dataFrame.isSuccess =>
+ // Merge confs from any upstream views into confs for this flow.
+ val allFConfs =
+ (flowToResolve +:
+ f.inputs.toSeq
+ .map(availableResolvedInputs(_))
+ .filter {
+ // Input is a flow implies that the upstream table is a View.
+ case _: Flow => true
+ // We stop in all other cases.
+ case _ => false
+ }).collect {
+ case g: Flow =>
+ g.sqlConf.toSeq.map { case (k, v) => FlowConf(k, v, g.identifier) }
+ }.flatten
+
+ allFConfs
+ .groupBy(_.key) // Key name -> Seq[FlowConf]
+ .filter(_._2.length > 1) // Entries where key was set more than once
+ .find(_._2.map(_.value).toSet.size > 1) // Entry where key was set with diff vals
+ .foreach {
+ case (key, confs) =>
+ val sortedByVal = confs.sortBy(_.value)
+ throw new AnalysisException(
+ "DUPLICATE_FLOW_SQL_CONF",
+ Map(
+ "key" -> key,
+ "datasetName" -> flowToResolve.displayName,
+ "flowName1" -> sortedByVal.head.flowIdentifier.unquotedString,
+ "flowName2" -> sortedByVal.last.flowIdentifier.unquotedString
+ )
+ )
+ }
+
+ val newSqlConf = allFConfs.map(fc => fc.key -> fc.value).toMap
+ // if the new sql confs are different from the original sql confs the flow was resolved
+ // with, resolve again.
+ val maybeNewFuncResult = if (newSqlConf != flowToResolve.sqlConf) {
+ flowToResolve.func.call(
+ allInputs = allInputs,
+ availableInputs = availableResolvedInputs.values.toList,
+ configuration = newSqlConf,
+ queryContext = flowToResolve.queryContext
+ )
+ } else {
+ f
+ }
+ convertResolvedToTypedFlow(flowToResolve, maybeNewFuncResult)
+
+ // If the flow failed due to an UnresolvedDatasetException, it means that one of the
+ // flow's inputs wasn't available. After other flows are resolved, these inputs
+ // may become available, so throw a retryable exception in this case.
+ case f =>
+ f.dataFrame.failed.toOption.collectFirst {
+ case e: UnresolvedDatasetException => e
+ case _ => None
+ } match {
+ case Some(e: UnresolvedDatasetException) =>
+ throw TransformNodeRetryableException(
+ e.identifier,
+ new ResolutionFailedFlow(flowToResolve, flowFunctionResult)
+ )
+ case _ =>
+ throw TransformNodeFailedException(
+ new ResolutionFailedFlow(flowToResolve, flowFunctionResult)
+ )
+ }
+ }
+ result
+ }
+
+ private def convertResolvedToTypedFlow(
+ flow: UnresolvedFlow,
+ funcResult: FlowFunctionResult): ResolvedFlow = {
+ val typedFlow = flow match {
+ case f: UnresolvedFlow if f.once => new AppendOnceFlow(flow, funcResult)
+ case f: UnresolvedFlow if funcResult.dataFrame.get.isStreaming =>
+ // If there's more than 1 flow to this flow's destination, we should not allow it
+ // to be planned with an output mode other than Append, as the other flows will
+ // then get their results overwritten.
+ val mustBeAppend = rawGraph.flowsTo(f.destinationIdentifier).size > 1
+ new StreamingFlow(flow, funcResult, mustBeAppend = mustBeAppend)
+ case _: UnresolvedFlow => new CompleteFlow(flow, funcResult)
+ }
+ typedFlow
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
new file mode 100644
index 0000000000000..585ba6295f239
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.Try
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique
+import org.apache.spark.sql.pipelines.util.SchemaMergingUtils
+import org.apache.spark.sql.types.StructType
+
+/**
+ * DataflowGraph represents the core graph structure for Spark declarative pipelines.
+ * It manages the relationships between logical flows, tables, and views, providing
+ * operations for graph traversal, validation, and transformation.
+ */
+case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View])
+ extends GraphOperations
+ with GraphValidations {
+
+ /** Map of [[Output]]s by their identifiers */
+ lazy val output: Map[TableIdentifier, Output] = mapUnique(tables, "output")(_.identifier)
+
+ /**
+ * [[Flow]]s in this graph that need to get planned and potentially executed when
+ * executing the graph. Flows that write to logical views are excluded.
+ */
+ lazy val materializedFlows: Seq[ResolvedFlow] = {
+ resolvedFlows.filter(
+ f => output.contains(f.destinationIdentifier)
+ )
+ }
+
+ /** The identifiers of [[materializedFlows]]. */
+ val materializedFlowIdentifiers: Set[TableIdentifier] = materializedFlows.map(_.identifier).toSet
+
+ /** Map of [[Table]]s by their identifiers */
+ lazy val table: Map[TableIdentifier, Table] =
+ mapUnique(tables, "table")(_.identifier)
+
+ /** Map of [[Flow]]s by their identifier */
+ lazy val flow: Map[TableIdentifier, Flow] = {
+ // Better error message than using mapUnique.
+ val flowsByIdentifier = flows.groupBy(_.identifier)
+ flowsByIdentifier
+ .find(_._2.size > 1)
+ .foreach {
+ case (flowIdentifier, flows) =>
+ // We don't expect this to ever actually be hit, graph registration should validate for
+ // unique flow names.
+ throw new AnalysisException(
+ errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW",
+ messageParameters = Map(
+ "flowName" -> flowIdentifier.unquotedString,
+ "datasetNames" -> flows.map(_.destinationIdentifier).mkString(",")
+ )
+ )
+ }
+ // Flows with non-default names shouldn't conflict with table names
+ flows
+ .filterNot(f => f.identifier == f.destinationIdentifier)
+ .filter(f => table.contains(f.identifier))
+ .foreach { f =>
+ throw new AnalysisException(
+ "FLOW_NAME_CONFLICTS_WITH_TABLE",
+ Map(
+ "flowName" -> f.identifier.toString(),
+ "target" -> f.destinationIdentifier.toString(),
+ "tableName" -> f.identifier.toString()
+ )
+ )
+ }
+ flowsByIdentifier.view.mapValues(_.head).toMap
+ }
+
+ /** Map of [[View]]s by their identifiers */
+ lazy val view: Map[TableIdentifier, View] = mapUnique(views, "view")(_.identifier)
+
+ /** The [[PersistedView]]s of the graph */
+ lazy val persistedViews: Seq[PersistedView] = views.collect {
+ case v: PersistedView => v
+ }
+
+ /** All the [[Input]]s in the current DataflowGraph. */
+ lazy val inputIdentifiers: Set[TableIdentifier] = {
+ (flows ++ tables).map(_.identifier).toSet
+ }
+
+ /** The [[Flow]]s that write to a given destination. */
+ lazy val flowsTo: Map[TableIdentifier, Seq[Flow]] = flows.groupBy(_.destinationIdentifier)
+
+ lazy val resolvedFlows: Seq[ResolvedFlow] = {
+ flows.collect { case f: ResolvedFlow => f }
+ }
+
+ lazy val resolvedFlow: Map[TableIdentifier, ResolvedFlow] = {
+ resolvedFlows.map { f =>
+ f.identifier -> f
+ }.toMap
+ }
+
+ lazy val resolutionFailedFlows: Seq[ResolutionFailedFlow] = {
+ flows.collect { case f: ResolutionFailedFlow => f }
+ }
+
+ lazy val resolutionFailedFlow: Map[TableIdentifier, ResolutionFailedFlow] = {
+ resolutionFailedFlows.map { f =>
+ f.identifier -> f
+ }.toMap
+ }
+
+ /**
+ * Used to reanalyze the flow's DF for a given table. This is done by finding all upstream
+ * flows (until a table is reached) for the specified source and reanalyzing all upstream
+ * flows.
+ *
+ * @param srcFlow The flow that writes into the table that we will start from when finding
+ * upstream flows
+ * @return The reanalyzed flow
+ */
+ protected[graph] def reanalyzeFlow(srcFlow: Flow): ResolvedFlow = {
+ val upstreamDatasetIdentifiers = dfsInternal(
+ flowNodes(srcFlow.identifier).output,
+ downstream = false,
+ stopAtMaterializationPoints = true
+ )
+ val upstreamFlows =
+ resolvedFlows
+ .filter(f => upstreamDatasetIdentifiers.contains(f.destinationIdentifier))
+ .map(_.flow)
+ val upstreamViews = upstreamDatasetIdentifiers.flatMap(identifier => view.get(identifier)).toSeq
+
+ val subgraph = new DataflowGraph(
+ flows = upstreamFlows,
+ views = upstreamViews,
+ tables = Seq(table(srcFlow.destinationIdentifier))
+ )
+ subgraph.resolve().resolvedFlow(srcFlow.identifier)
+ }
+
+ /**
+ * A map of the inferred schema of each table, computed by merging the analyzed schemas
+ * of all flows writing to that table.
+ */
+ lazy val inferredSchema: Map[TableIdentifier, StructType] = {
+ flowsTo.view.mapValues { flows =>
+ flows
+ .map { flow =>
+ resolvedFlow(flow.identifier).schema
+ }
+ .reduce(SchemaMergingUtils.mergeSchemas)
+ }.toMap
+ }
+
+ /** Ensure that the [[DataflowGraph]] is valid and throws errors if not. */
+ def validate(): DataflowGraph = {
+ validationFailure.toOption match {
+ case Some(exception) => throw exception
+ case None => this
+ }
+ }
+
+ /**
+ * Validate the current [[DataflowGraph]] and cache the validation failure.
+ *
+ * To add more validations, add them in a helper function that throws an exception if the
+ * validation fails, and invoke the helper function here.
+ */
+ private lazy val validationFailure: Try[Throwable] = Try {
+ validateSuccessfulFlowAnalysis()
+ validateUserSpecifiedSchemas()
+ // Connecting the graph sorts it topologically
+ validateGraphIsTopologicallySorted()
+ validateMultiQueryTables()
+ validatePersistedViewSources()
+ validateEveryDatasetHasFlow()
+ validateTablesAreResettable()
+ inferredSchema
+ }.failed
+
+ /**
+ * Enforce every dataset has at least one input flow. For example its possible to define
+ * streaming tables without a query; such tables should still have at least one flow
+ * writing to it.
+ */
+ def validateEveryDatasetHasFlow(): Unit = {
+ (tables.map(_.identifier) ++ views.map(_.identifier)).foreach { identifier =>
+ if (!flows.exists(_.destinationIdentifier == identifier)) {
+ throw new AnalysisException(
+ "PIPELINE_DATASET_WITHOUT_FLOW",
+ Map("identifier" -> identifier.quotedString)
+ )
+ }
+ }
+ }
+
+ /** Returns true iff all [[Flow]]s are successfully analyzed. */
+ def resolved: Boolean =
+ flows.forall(f => resolvedFlow.contains(f.identifier))
+
+ def resolve(): DataflowGraph =
+ DataflowGraphTransformer.withDataflowGraphTransformer(this) { transformer =>
+ val coreDataflowNodeProcessor =
+ new CoreDataflowNodeProcessor(rawGraph = this)
+ transformer
+ .transformDownNodes(coreDataflowNodeProcessor.processNode)
+ .getDataflowGraph
+ }
+}
+
+object DataflowGraph {
+ protected[graph] def mapUnique[K, A](input: Seq[A], tpe: String)(f: A => K): Map[K, A] = {
+ val grouped = input.groupBy(f)
+ grouped.filter(_._2.length > 1).foreach {
+ case (name, _) =>
+ throw new AnalysisException(
+ errorClass = "DUPLICATE_GRAPH_ELEMENT",
+ messageParameters = Map("graphElementType" -> tpe, "graphElementName" -> name.toString)
+ )
+ }
+ grouped.view.mapValues(_.head).toMap
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
new file mode 100644
index 0000000000000..8448ed5f10d21
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import java.util.concurrent.{
+ ConcurrentHashMap,
+ ConcurrentLinkedDeque,
+ ConcurrentLinkedQueue,
+ ExecutionException,
+ Future
+}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+import scala.util.control.NoStackTrace
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Resolves the [[DataflowGraph]] by processing each node in the graph. This class exposes visitor
+ * functionality to resolve/analyze graph nodes.
+ * We only expose simple visitor abilities to transform different entities of the
+ * graph.
+ * For advanced transformations we also expose a mechanism to walk the graph over entity by entity.
+ *
+ * Assumptions:
+ * 1. Each output will have at-least 1 flow to it.
+ * 2. Each flow may or may not have a destination table. If a flow does not have a destination
+ * table, the destination is a temporary view.
+ *
+ * The way graph is structured is that flows, tables and sinks all are graph elements or nodes.
+ * While we expose transformation functions for each of these entities, we also expose a way to
+ * process to walk over the graph.
+ *
+ * Constructor is private as all usages should be via
+ * DataflowGraphTransformer.withDataflowGraphTransformer.
+ * @param graph: Any Dataflow Graph
+ */
+class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable {
+ import DataflowGraphTransformer._
+
+ private var tables: Seq[Table] = graph.tables
+ private var tableMap: Map[TableIdentifier, Table] = computeTableMap()
+ private var flows: Seq[Flow] = graph.flows
+ private var flowsTo: Map[TableIdentifier, Seq[Flow]] = computeFlowsTo()
+ private var views: Seq[View] = graph.views
+ private var viewMap: Map[TableIdentifier, View] = computeViewMap()
+
+ // Fail analysis nodes
+ // Failed flows are flows that are failed to resolve or its inputs are not available or its
+ // destination failed to resolve.
+ private var failedFlows: Seq[ResolutionCompletedFlow] = Seq.empty
+ // We define a dataset is failed to resolve if it is a destination of a flow that is unresolved.
+ private var failedTables: Seq[Table] = Seq.empty
+
+ private val parallelism = 10
+
+ // Executor used to resolve nodes in parallel. It is lazily initialized to avoid creating it
+ // for scenarios its not required. To track if the lazy val was evaluated or not we use a
+ // separate variable so we know if we need to shutdown the executor or not.
+ private var fixedPoolExecutorInitialized = false
+ lazy private val fixedPoolExecutor = {
+ fixedPoolExecutorInitialized = true
+ ThreadUtils.newDaemonFixedThreadPool(
+ parallelism,
+ prefix = "data-flow-graph-transformer-"
+ )
+ }
+ private val selfExecutor = ThreadUtils.sameThreadExecutorService()
+
+ private def computeTableMap(): Map[TableIdentifier, Table] = synchronized {
+ tables.map(table => table.identifier -> table).toMap
+ }
+
+ private def computeViewMap(): Map[TableIdentifier, View] = synchronized {
+ views.map(view => view.identifier -> view).toMap
+ }
+
+ private def computeFlowsTo(): Map[TableIdentifier, Seq[Flow]] = synchronized {
+ flows.groupBy(_.destinationIdentifier)
+ }
+
+ def transformTables(transformer: Table => Table): DataflowGraphTransformer = synchronized {
+ tables = tables.map(transformer)
+ tableMap = computeTableMap()
+ this
+ }
+
+ private def defaultOnFailedDependentTables(
+ failedTableDependencies: Map[TableIdentifier, Seq[Table]]): Unit = {
+ require(
+ failedTableDependencies.isEmpty,
+ "Dependency failure happened and some tables were not resolved"
+ )
+ }
+
+ /**
+ * Example graph: [Flow1, Flow 2] -> ST -> Flow3 -> MV
+ * Order of processing: Flow1, Flow2, ST, Flow3, MV.
+ * @param transformer function that transforms any graph entity.
+ * transformer(
+ * nodeToTransform: GraphElement, upstreamNodes: Seq[GraphElement]
+ * ) => transformedNodes: Seq[GraphElement]
+ * @return this
+ */
+ def transformDownNodes(
+ transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement],
+ disableParallelism: Boolean = false): DataflowGraphTransformer = {
+ val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor
+ val batchSize = if (disableParallelism) 1 else parallelism
+ // List of resolved tables, sinks and flows
+ val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]()
+ val resolvedTables = new ConcurrentLinkedQueue[Table]()
+ val resolvedViews = new ConcurrentLinkedQueue[View]()
+ // Flow identifier to a list of transformed flows mapping to track resolved flows
+ val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]()
+ val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]()
+ val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]()
+ val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]()
+
+ var futures = ArrayBuffer[Future[Unit]]()
+ val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]()
+ toBeResolvedFlows.addAll(flows.asJava)
+
+ while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) {
+ val (done, notDone) = futures.partition(_.isDone)
+ // Explicitly call future.get() to propagate exceptions one by one if any
+ try {
+ done.foreach(_.get())
+ } catch {
+ case exn: ExecutionException =>
+ // Computation threw the exception that is the cause of exn
+ throw exn.getCause
+ }
+ futures = notDone
+ val flowOpt = {
+ // We only schedule [[batchSize]] number of flows in parallel.
+ if (futures.size < batchSize) {
+ Option(toBeResolvedFlows.pollFirst())
+ } else {
+ None
+ }
+ }
+ if (flowOpt.isDefined) {
+ val flow = flowOpt.get
+ futures.append(
+ executor.submit(
+ () =>
+ try {
+ try {
+ // Note: Flow don't need their inputs passed, so for now we send empty Seq.
+ val result = transformer(flow, Seq.empty)
+ require(
+ result.forall(_.isInstanceOf[ResolvedFlow]),
+ "transformer must return a Seq[Flow]"
+ )
+
+ val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow])
+ resolvedFlowsMap.put(flow.identifier, transformedFlows)
+ resolvedFlows.addAll(transformedFlows.asJava)
+ } catch {
+ case e: TransformNodeRetryableException =>
+ val datasetIdentifier = e.datasetIdentifier
+ failedDependentFlows.compute(
+ datasetIdentifier,
+ (_, flows) => {
+ // Don't add the input flow back but the failed flow object
+ // back which has relevant failure information.
+ val failedFlow = e.failedNode
+ if (flows == null) {
+ Seq(failedFlow)
+ } else {
+ flows :+ failedFlow
+ }
+ }
+ )
+ // Between the time the flow started and finished resolving, perhaps the
+ // dependent dataset was resolved
+ resolvedFlowDestinationsMap.computeIfPresent(
+ datasetIdentifier,
+ (_, resolved) => {
+ if (resolved) {
+ // Check if the dataset that the flow is dependent on has been resolved
+ // and if so, remove all dependent flows from the failedDependentFlows and
+ // add them to the toBeResolvedFlows queue for retry.
+ failedDependentFlows.computeIfPresent(
+ datasetIdentifier,
+ (_, toRetryFlows) => {
+ toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
+ null
+ }
+ )
+ }
+ resolved
+ }
+ )
+ case other: Throwable => throw other
+ }
+ // If all flows to this particular destination are resolved, move to the destination
+ // node transformer
+ if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination =>
+ resolvedFlowsMap.containsKey(flowToDestination.identifier)
+ })) {
+ // If multiple flows completed in parallel, ensure we resolve the destination only
+ // once by electing a leader via computeIfAbsent
+ var isCurrentThreadLeader = false
+ resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => {
+ isCurrentThreadLeader = true
+ // Set initial value as false as flow destination is not resolved yet.
+ false
+ })
+ if (isCurrentThreadLeader) {
+ if (tableMap.contains(flow.destinationIdentifier)) {
+ val transformed =
+ transformer(
+ tableMap(flow.destinationIdentifier),
+ flowsTo(flow.destinationIdentifier)
+ )
+ resolvedTables.addAll(
+ transformed.collect { case t: Table => t }.asJava
+ )
+ resolvedFlows.addAll(
+ transformed.collect { case f: ResolvedFlow => f }.asJava
+ )
+ } else {
+ if (viewMap.contains(flow.destinationIdentifier)) {
+ resolvedViews.addAll {
+ val transformed =
+ transformer(
+ viewMap(flow.destinationIdentifier),
+ flowsTo(flow.destinationIdentifier)
+ )
+ transformed.map(_.asInstanceOf[View]).asJava
+ }
+ } else {
+ throw new IllegalArgumentException(
+ s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" +
+ s" in flow: ${flow.displayName} at transformDownNodes"
+ )
+ }
+ }
+ // Set flow destination as resolved now.
+ resolvedFlowDestinationsMap.computeIfPresent(
+ flow.destinationIdentifier,
+ (_, _) => {
+ // If there are any other node failures dependent on this destination, retry
+ // them
+ failedDependentFlows.computeIfPresent(
+ flow.destinationIdentifier,
+ (_, toRetryFlows) => {
+ toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
+ null
+ }
+ )
+ true
+ }
+ )
+ }
+ }
+ } catch {
+ case ex: TransformNodeFailedException => failedFlowsQueue.add(ex.failedNode)
+ }
+ )
+ )
+ }
+ }
+
+ // Mutate the fail analysis entities
+ // A table is failed to analyze if:
+ // - It does not exist in the resolvedFlowDestinationsMap
+ failedTables = tables.filterNot { table =>
+ resolvedFlowDestinationsMap.getOrDefault(table.identifier, false)
+ }
+
+ // We maintain the topological sort order of successful flows always
+ val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) =
+ resolvedFlows.asScala.toSeq.partition(flow => {
+ resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false)
+ })
+
+ // A flow is failed to analyze if:
+ // - It is non-retryable
+ // - It is retryable but could not be retried, i.e. the dependent dataset is still unresolved
+ // - It might be resolvable but it writes into a destination that is failed to analyze
+ // To note: because we are transform down nodes, all downstream nodes of any pruned nodes
+ // will also be pruned
+ failedFlows =
+ // All transformed flows that write to a destination that is failed to analyze.
+ resolvedFlowsWithFailedDest ++
+ // All failed flows thrown by TransformNodeFailedException
+ failedFlowsQueue.asScala.toSeq ++
+ // All flows that have not been transformed and resolved yet
+ failedDependentFlows.values().asScala.flatten.toSeq
+
+ // Mutate the resolved entities
+ flows = resolvedFlowsWithResolvedDest
+ flowsTo = computeFlowsTo()
+ tables = resolvedTables.asScala.toSeq
+ views = resolvedViews.asScala.toSeq
+ tableMap = computeTableMap()
+ viewMap = computeViewMap()
+ this
+ }
+
+ def getDataflowGraph: DataflowGraph = {
+ graph.copy(
+ // Returns all flows (resolved and failed) in topological order.
+ // The relative order between flows and failed flows doesn't matter here.
+ // For failed flows that were resolved but were marked failed due to destination failure,
+ // they will be front of the list in failedFlows and thus by definition topologically sorted
+ // in the combined sequence too.
+ flows = flows ++ failedFlows,
+ tables = tables ++ failedTables
+ )
+ }
+
+ override def close(): Unit = {
+ if (fixedPoolExecutorInitialized) {
+ fixedPoolExecutor.shutdown()
+ }
+ }
+}
+
+object DataflowGraphTransformer {
+
+ /**
+ * Exception thrown when transforming a node in the graph fails because at least one of its
+ * dependencies weren't yet transformed.
+ *
+ * @param datasetIdentifier The identifier for an untransformed dependency table identifier in the
+ * dataflow graph.
+ */
+ case class TransformNodeRetryableException(
+ datasetIdentifier: TableIdentifier,
+ failedNode: ResolutionFailedFlow)
+ extends Exception
+ with NoStackTrace
+
+ /**
+ * Exception thrown when transforming a node in the graph fails with a non-retryable error.
+ *
+ * @param failedNode The failed node that could not be transformed.
+ */
+ case class TransformNodeFailedException(failedNode: ResolutionFailedFlow)
+ extends Exception
+ with NoStackTrace
+
+ /**
+ * Autocloseable wrapper around DataflowGraphTransformer to ensure that the transformer is closed
+ * without clients needing to remember to close it. It takes in the same arguments as
+ * [[DataflowGraphTransformer]] constructor. It exposes the DataflowGraphTransformer instance
+ * within the callable scope.
+ */
+ def withDataflowGraphTransformer[T](graph: DataflowGraph)(f: DataflowGraphTransformer => T): T = {
+ val dataflowGraphTransformer = new DataflowGraphTransformer(graph)
+ try {
+ f(dataflowGraphTransformer)
+ } finally {
+ dataflowGraphTransformer.close()
+ }
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
new file mode 100644
index 0000000000000..2378b6f8d96a6
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.Try
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
+import org.apache.spark.sql.classic.DataFrame
+import org.apache.spark.sql.pipelines.AnalysisWarning
+import org.apache.spark.sql.pipelines.util.InputReadOptions
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Contains the catalog and database context information for query execution.
+ */
+case class QueryContext(currentCatalog: Option[String], currentDatabase: Option[String])
+
+/**
+ * A [[Flow]] is a node of data transformation in a dataflow graph. It describes the movement
+ * of data into a particular dataset.
+ */
+trait Flow extends GraphElement with Logging {
+
+ /** The [[FlowFunction]] containing the user's query. */
+ def func: FlowFunction
+
+ val identifier: TableIdentifier
+
+ /**
+ * The dataset that this Flow represents a write to.
+ */
+ val destinationIdentifier: TableIdentifier
+
+ /**
+ * Whether this is a ONCE flow. ONCE flows should run only once per full refresh.
+ */
+ def once: Boolean = false
+
+ /** The current query context (catalog and database) when the query is defined. */
+ def queryContext: QueryContext
+
+ /** The comment associated with this flow */
+ def comment: Option[String]
+
+ def sqlConf: Map[String, String]
+}
+
+/** A wrapper for a resolved internal input that includes the alias provided by the user. */
+case class ResolvedInput(input: Input, aliasIdentifier: AliasIdentifier)
+
+/** A wrapper for the lambda function that defines a [[Flow]]. */
+trait FlowFunction extends Logging {
+
+ /**
+ * This function defines the transformations performed by a flow, expressed as a DataFrame.
+ *
+ * @param allInputs the set of identifiers for all the [[Input]]s defined in the
+ * [[DataflowGraph]].
+ * @param availableInputs the list of all [[Input]]s available to this flow
+ * @param configuration the spark configurations that apply to this flow.
+ * @param queryContext The context of the query being evaluated.
+ * @return the inputs actually used, and the DataFrame expression for the flow
+ */
+ def call(
+ allInputs: Set[TableIdentifier],
+ availableInputs: Seq[Input],
+ configuration: Map[String, String],
+ queryContext: QueryContext
+ ): FlowFunctionResult
+}
+
+/**
+ * Holds the DataFrame returned by a [[FlowFunction]] along with the inputs used to
+ * construct it.
+ * @param batchInputs the complete inputs read by the flow
+ * @param streamingInputs the incremental inputs read by the flow
+ * @param usedExternalInputs the identifiers of the external inputs read by the flow
+ * @param dataFrame the DataFrame expression executed by the flow if the flow can be resolved
+ */
+case class FlowFunctionResult(
+ requestedInputs: Set[TableIdentifier],
+ batchInputs: Set[ResolvedInput],
+ streamingInputs: Set[ResolvedInput],
+ usedExternalInputs: Set[TableIdentifier],
+ dataFrame: Try[DataFrame],
+ sqlConf: Map[String, String],
+ analysisWarnings: Seq[AnalysisWarning] = Nil) {
+
+ /**
+ * Returns the names of all of the [[Input]]s used when resolving this [[Flow]]. If the
+ * flow failed to resolve, we return the all the datasets that were requested when evaluating the
+ * flow.
+ */
+ def inputs: Set[TableIdentifier] = {
+ (batchInputs ++ streamingInputs).map(_.input.identifier)
+ }
+
+ /** Returns errors that occurred when attempting to analyze this [[Flow]]. */
+ def failure: Seq[Throwable] = {
+ dataFrame.failed.toOption.toSeq
+ }
+
+ /** Whether this [[Flow]] is successfully analyzed. */
+ final def resolved: Boolean = failure.isEmpty // don't override this, override failure
+}
+
+/** A [[Flow]] whose output schema and dependencies aren't known. */
+case class UnresolvedFlow(
+ identifier: TableIdentifier,
+ destinationIdentifier: TableIdentifier,
+ func: FlowFunction,
+ queryContext: QueryContext,
+ sqlConf: Map[String, String],
+ comment: Option[String] = None,
+ override val once: Boolean,
+ override val origin: QueryOrigin
+) extends Flow
+
+/**
+ * A [[Flow]] whose flow function has been invoked, meaning either:
+ * - Its output schema and dependencies are known.
+ * - It failed to resolve.
+ */
+trait ResolutionCompletedFlow extends Flow {
+ def flow: UnresolvedFlow
+ def funcResult: FlowFunctionResult
+
+ val identifier: TableIdentifier = flow.identifier
+ val destinationIdentifier: TableIdentifier = flow.destinationIdentifier
+ def func: FlowFunction = flow.func
+ def queryContext: QueryContext = flow.queryContext
+ def comment: Option[String] = flow.comment
+ def sqlConf: Map[String, String] = funcResult.sqlConf
+ def origin: QueryOrigin = flow.origin
+}
+
+/** A [[Flow]] whose flow function has failed to resolve. */
+class ResolutionFailedFlow(val flow: UnresolvedFlow, val funcResult: FlowFunctionResult)
+ extends ResolutionCompletedFlow {
+ assert(!funcResult.resolved)
+
+ def failure: Seq[Throwable] = funcResult.failure
+}
+
+/** A [[Flow]] whose flow function has successfully resolved. */
+trait ResolvedFlow extends ResolutionCompletedFlow with Input {
+ assert(funcResult.resolved)
+
+ /** The logical plan for this flow's query. */
+ def df: DataFrame = funcResult.dataFrame.get
+
+ /** Returns the schema of the output of this [[Flow]]. */
+ def schema: StructType = df.schema
+ override def load(readOptions: InputReadOptions): DataFrame = df
+ def inputs: Set[TableIdentifier] = funcResult.inputs
+}
+
+/** A [[Flow]] that represents stateful movement of data to some target. */
+class StreamingFlow(
+ val flow: UnresolvedFlow,
+ val funcResult: FlowFunctionResult,
+ val mustBeAppend: Boolean = false
+) extends ResolvedFlow {}
+
+/** A [[Flow]] that declares exactly what data should be in the target table. */
+class CompleteFlow(
+ val flow: UnresolvedFlow,
+ val funcResult: FlowFunctionResult,
+ val mustBeAppend: Boolean = false
+) extends ResolvedFlow {}
+
+/** A [[Flow]] that reads source[s] completely and appends data to the target, just once.
+ */
+class AppendOnceFlow(
+ val flow: UnresolvedFlow,
+ val funcResult: FlowFunctionResult
+) extends ResolvedFlow {
+
+ override val once = true
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
new file mode 100644
index 0000000000000..7e2e97f2b5d74
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.Try
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
+import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession}
+import org.apache.spark.sql.pipelines.AnalysisWarning
+import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier}
+import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions}
+
+
+object FlowAnalysis {
+ /**
+ * Creates a [[FlowFunction]] that attempts to analyze the provided LogicalPlan
+ * using the existing resolved inputs.
+ * - If all upstream inputs have been resolved, then analysis succeeds and the
+ * function returns a [[FlowFunctionResult]] containing the dataframe.
+ * - If any upstream inputs are unresolved, then the function throws an exception.
+ *
+ * @param plan The user-supplied LogicalPlan defining a flow.
+ * @return A FlowFunction that attempts to analyze the provided LogicalPlan.
+ */
+ def createFlowFunctionFromLogicalPlan(plan: LogicalPlan): FlowFunction = {
+ new FlowFunction {
+ override def call(
+ allInputs: Set[TableIdentifier],
+ availableInputs: Seq[Input],
+ confs: Map[String, String],
+ queryContext: QueryContext
+ ): FlowFunctionResult = {
+ val ctx = FlowAnalysisContext(
+ allInputs = allInputs,
+ availableInputs = availableInputs,
+ queryContext = queryContext,
+ spark = SparkSession.active
+ )
+ val df = try {
+ confs.foreach { case (k, v) => ctx.setConf(k, v) }
+ Try(FlowAnalysis.analyze(ctx, plan))
+ } finally {
+ ctx.restoreOriginalConf()
+ }
+ FlowFunctionResult(
+ requestedInputs = ctx.requestedInputs.toSet,
+ batchInputs = ctx.batchInputs.toSet,
+ streamingInputs = ctx.streamingInputs.toSet,
+ usedExternalInputs = ctx.externalInputs.toSet,
+ dataFrame = df,
+ sqlConf = confs,
+ analysisWarnings = ctx.analysisWarnings.toList
+ )
+ }
+ }
+ }
+
+ /**
+ * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific
+ * TVFs and datasets that cannot be resolved directly by Catalyst.
+ *
+ * This function shouldn't call any singleton as it will break concurrent access to graph
+ * analysis; or any thread local variables as graph analysis and this function will use
+ * different threads in python repl.
+ *
+ * @param plan The [[LogicalPlan]] defining a flow.
+ * @return An analyzed [[DataFrame]].
+ */
+ private def analyze(
+ context: FlowAnalysisContext,
+ plan: LogicalPlan
+ ): DataFrame = {
+ // Users can define CTEs within their CREATE statements. For example,
+ //
+ // CREATE STREAMING TABLE a
+ // WITH b AS (
+ // SELECT * FROM STREAM upstream
+ // )
+ // SELECT * FROM b
+ //
+ // The relation defined using the WITH keyword is not included in the children of the main
+ // plan so the specific analysis we do below will not be applied to those relations.
+ // Instead, we call an analyzer rule to inline all of the CTE relations in the main plan before
+ // we do analysis. This rule would be called during analysis anyways, but we just call it
+ // earlier so we only need to apply analysis to a single logical plan.
+ val planWithInlinedCTEs = CTESubstitution(plan)
+
+ val spark = context.spark
+ // Traverse the user's query plan and recursively resolve nodes that reference Pipelines
+ // features that the Spark analyzer is unable to resolve
+ val resolvedPlan = planWithInlinedCTEs transformWithSubqueries {
+ // Streaming read on another dataset
+ // This branch will be hit for the following kinds of queries:
+ // - SELECT ... FROM STREAM(t1)
+ // - SELECT ... FROM STREAM t1
+ case u: UnresolvedRelation if u.isStreaming =>
+ readStreamInput(
+ context,
+ name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
+ spark.readStream,
+ streamingReadOptions = StreamingReadOptions()
+ ).queryExecution.analyzed
+
+ // Batch read on another dataset in the pipeline
+ case u: UnresolvedRelation =>
+ readBatchInput(
+ context,
+ name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
+ batchReadOptions = BatchReadOptions()
+ ).queryExecution.analyzed
+ }
+ Dataset.ofRows(spark, resolvedPlan)
+
+ }
+
+ /**
+ * Internal helper to reference the batch dataset (i.e., non-streaming dataset) with the given
+ * name.
+ * 1. The dataset can be a table, view, or a named flow.
+ * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external
+ * catalog.
+ * All the public APIs that read from a dataset should call this function to read the dataset.
+ *
+ * @param name the name of the Dataset to be read.
+ * @param batchReadOptions Options for this batch read
+ * @return batch DataFrame that represents data from the specified Dataset.
+ */
+ final private def readBatchInput(
+ context: FlowAnalysisContext,
+ name: String,
+ batchReadOptions: BatchReadOptions
+ ): DataFrame = {
+ GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match {
+ case inputIdentifier: InternalDatasetIdentifier =>
+ readGraphInput(context, inputIdentifier, batchReadOptions)
+
+ case inputIdentifier: ExternalDatasetIdentifier =>
+ readExternalBatchInput(
+ context,
+ inputIdentifier = inputIdentifier,
+ name = name
+ )
+ }
+ }
+
+ /**
+ * Internal helper to reference the streaming dataset with the given name.
+ * 1. The dataset can be a table, view, or a named flow.
+ * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external
+ * catalog.
+ * All the public APIs that read from a dataset should call this function to read the dataset.
+ *
+ * @param name the name of the Dataset to be read.
+ * @param streamReader The [[DataStreamReader]] that may hold read options specified by the user.
+ * @param streamingReadOptions Options for this streaming read.
+ * @return streaming DataFrame that represents data from the specified Dataset.
+ */
+ final private def readStreamInput(
+ context: FlowAnalysisContext,
+ name: String,
+ streamReader: DataStreamReader,
+ streamingReadOptions: StreamingReadOptions
+ ): DataFrame = {
+ GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match {
+ case inputIdentifier: InternalDatasetIdentifier =>
+ readGraphInput(
+ context,
+ inputIdentifier,
+ streamingReadOptions
+ )
+
+ case inputIdentifier: ExternalDatasetIdentifier =>
+ readExternalStreamInput(
+ context,
+ inputIdentifier = inputIdentifier,
+ streamReader = streamReader,
+ name = name
+ )
+ }
+ }
+
+ /**
+ * Internal helper to reference dataset defined in the same [[DataflowGraph]].
+ *
+ * @param inputIdentifier The identifier of the Dataset to be read.
+ * @param readOptions Options for this read (may be either streaming or batch options)
+ * @return streaming or batch DataFrame that represents data from the specified Dataset.
+ */
+ final private def readGraphInput(
+ ctx: FlowAnalysisContext,
+ inputIdentifier: InternalDatasetIdentifier,
+ readOptions: InputReadOptions
+ ): DataFrame = {
+ val datasetIdentifier = inputIdentifier.identifier
+
+ ctx.requestedInputs += datasetIdentifier
+
+ val i = if (!ctx.allInputs.contains(datasetIdentifier)) {
+ // Dataset not defined in the dataflow graph
+ throw GraphErrors.pipelineLocalDatasetNotDefinedError(datasetIdentifier.unquotedString)
+ } else if (!ctx.availableInput.contains(datasetIdentifier)) {
+ // Dataset defined in the dataflow graph but not yet resolved
+ throw UnresolvedDatasetException(datasetIdentifier)
+ } else {
+ // Dataset is resolved, so we can read from it
+ ctx.availableInput(datasetIdentifier)
+ }
+
+ val inputDF = i.load(readOptions)
+ i match {
+ // If the referenced input is a [[Flow]], because the query plans will be fused
+ // together, we also need to fuse their confs.
+ case f: Flow => f.sqlConf.foreach { case (k, v) => ctx.setConf(k, v) }
+ case _ =>
+ }
+
+ val incompatibleViewReadCheck =
+ ctx.spark.conf.get("pipelines.incompatibleViewCheck.enabled", "true").toBoolean
+
+ // Wrap the DF in an alias so that columns in the DF can be referenced with
+ // the following in the query:
+ // - ...
+ // - ..
+ // - .
+ val aliasIdentifier = AliasIdentifier(
+ name = datasetIdentifier.table,
+ qualifier = Seq(datasetIdentifier.catalog, datasetIdentifier.database).flatten
+ )
+
+ readOptions match {
+ case sro: StreamingReadOptions =>
+ if (!inputDF.isStreaming && incompatibleViewReadCheck) {
+ throw new AnalysisException(
+ "INCOMPATIBLE_BATCH_VIEW_READ",
+ Map("datasetIdentifier" -> datasetIdentifier.toString)
+ )
+ }
+
+ if (sro.droppedUserOptions.nonEmpty) {
+ ctx.analysisWarnings += AnalysisWarning.StreamingReaderOptionsDropped(
+ sourceName = datasetIdentifier.unquotedString,
+ droppedOptions = sro.droppedUserOptions.keys.toSeq
+ )
+ }
+ ctx.streamingInputs += ResolvedInput(i, aliasIdentifier)
+ case _ =>
+ if (inputDF.isStreaming && incompatibleViewReadCheck) {
+ throw new AnalysisException(
+ "INCOMPATIBLE_STREAMING_VIEW_READ",
+ Map("datasetIdentifier" -> datasetIdentifier.toString)
+ )
+ }
+ ctx.batchInputs += ResolvedInput(i, aliasIdentifier)
+ }
+ Dataset.ofRows(
+ ctx.spark,
+ SubqueryAlias(identifier = aliasIdentifier, child = inputDF.queryExecution.logical)
+ )
+ }
+
+ /**
+ * Internal helper to reference batch dataset (i.e., non-streaming dataset) defined in an external
+ * catalog or as a path.
+ *
+ * @param inputIdentifier The identifier of the dataset to be read.
+ * @return streaming or batch DataFrame that represents data from the specified Dataset.
+ */
+ final private def readExternalBatchInput(
+ context: FlowAnalysisContext,
+ inputIdentifier: ExternalDatasetIdentifier,
+ name: String): DataFrame = {
+
+ val spark = context.spark
+ context.externalInputs += inputIdentifier.identifier
+ spark.read.table(inputIdentifier.identifier.quotedString)
+ }
+
+ /**
+ * Internal helper to reference dataset defined in an external catalog or as a path.
+ *
+ * @param inputIdentifier The identifier of the dataset to be read.
+ * @param streamReader The [[DataStreamReader]] that may hold additional read options specified by
+ * the user.
+ * @return streaming or batch DataFrame that represents data from the specified Dataset.
+ */
+ final private def readExternalStreamInput(
+ context: FlowAnalysisContext,
+ inputIdentifier: ExternalDatasetIdentifier,
+ streamReader: DataStreamReader,
+ name: String): DataFrame = {
+
+ context.externalInputs += inputIdentifier.identifier
+ streamReader.table(inputIdentifier.identifier.quotedString)
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
new file mode 100644
index 0000000000000..fb96c6cb5bb1d
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.pipelines.AnalysisWarning
+
+/**
+ * A context used when evaluating a [[Flow]]'s query into a concrete DataFrame.
+ *
+ * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph.
+ * @param availableInputs Inputs available to be referenced with `read` or `readStream`.
+ * @param queryContext The context of the query being evaluated.
+ * @param requestedInputs A mutable buffer populated with names of all inputs that were
+ * requested.
+ * @param spark the spark session to be used.
+ * @param externalInputs The names of external inputs that were used to evaluate
+ * the flow's query.
+ */
+private[pipelines] case class FlowAnalysisContext(
+ allInputs: Set[TableIdentifier],
+ availableInputs: Seq[Input],
+ queryContext: QueryContext,
+ batchInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty,
+ streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty,
+ requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty,
+ shouldLowerCaseNames: Boolean = false,
+ analysisWarnings: mutable.Buffer[AnalysisWarning] = new ListBuffer[AnalysisWarning],
+ spark: SparkSession,
+ externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty
+) {
+
+ /** Map from [[Input]] name to the actual [[Input]] */
+ val availableInput: Map[TableIdentifier, Input] =
+ availableInputs.map(i => i.identifier -> i).toMap
+
+ /** The confs set in this context that should be undone when exiting this context. */
+ private val confsToRestore = mutable.HashMap[String, Option[String]]()
+
+ /** Sets a Spark conf within this context that will be undone by `restoreOriginalConf`. */
+ def setConf(key: String, value: String): Unit = {
+ if (!confsToRestore.contains(key)) {
+ confsToRestore.put(key, spark.conf.getOption(key))
+ }
+ spark.conf.set(key, value)
+ }
+
+ /** Restores the Spark conf to its state when this context was creating by undoing confs set. */
+ def restoreOriginalConf(): Unit = confsToRestore.foreach {
+ case (k, Some(v)) => spark.conf.set(k, v)
+ case (k, None) => spark.conf.unset(k)
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala
new file mode 100644
index 0000000000000..bb3d290c3d3e1
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.pipelines.common.DatasetType
+
+object GraphElementTypeUtils {
+
+ /**
+ * Helper function to obtain the DatasetType. This function should be called with all
+ * the flows that are writing into a table and all the flows should be resolved.
+ */
+ def getDatasetTypeForMaterializedViewOrStreamingTable(
+ flowsToTable: Seq[ResolvedFlow]): DatasetType = {
+ val isStreamingTable = flowsToTable.exists(f => f.df.isStreaming || f.once)
+ if (isStreamingTable) {
+ DatasetType.STREAMING_TABLE
+ } else {
+ DatasetType.MATERIALIZED_VIEW
+ }
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala
new file mode 100644
index 0000000000000..53db669e687d2
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.pipelines.common.DatasetType
+import org.apache.spark.sql.types.StructType
+
+/** Collection of errors that can be thrown during graph resolution / analysis. */
+object GraphErrors {
+
+ /**
+ * Throws when a dataset is marked as internal but is not defined in the graph.
+ *
+ * @param datasetName the name of the dataset that is not defined
+ */
+ def pipelineLocalDatasetNotDefinedError(datasetName: String): SparkException = {
+ SparkException.internalError(
+ s"Failed to read dataset '$datasetName'. This dataset was expected to be " +
+ s"defined and created by the pipeline."
+ )
+ }
+
+ /**
+ * Throws when the catalog or schema name in the "USE CATALOG | SCHEMA" command is invalid
+ *
+ * @param command string "USE CATALOG" or "USE SCHEMA"
+ * @param name the invalid catalog or schema name
+ * @param reason the reason why the name is invalid
+ */
+ def invalidNameInUseCommandError(
+ command: String,
+ name: String,
+ reason: String
+ ): SparkException = {
+ new SparkException(
+ errorClass = "INVALID_NAME_IN_USE_COMMAND",
+ messageParameters = Map("command" -> command, "name" -> name, "reason" -> reason),
+ cause = null
+ )
+ }
+
+ /**
+ * Throws when a table path is unresolved, i.e. the table identifier
+ * does not exist in the catalog.
+ *
+ * @param identifier the unresolved table identifier
+ */
+ def unresolvedTablePath(identifier: TableIdentifier): SparkException = {
+ new SparkException(
+ errorClass = "UNRESOLVED_TABLE_PATH",
+ messageParameters = Map("identifier" -> identifier.toString),
+ cause = null
+ )
+ }
+
+ /**
+ * Throws an error if the user-specified schema and the inferred schema are not compatible.
+ *
+ * @param tableIdentifier the identifier of the table that was not found
+ */
+ def incompatibleUserSpecifiedAndInferredSchemasError(
+ tableIdentifier: TableIdentifier,
+ datasetType: DatasetType,
+ specifiedSchema: StructType,
+ inferredSchema: StructType,
+ cause: Option[Throwable] = None
+ ): AnalysisException = {
+ val streamingTableHint =
+ if (datasetType == DatasetType.STREAMING_TABLE) {
+ s""""
+ |Streaming tables are stateful and remember data that has already been
+ |processed. If you want to recompute the table from scratch, please full refresh
+ |the table.
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ new AnalysisException(
+ errorClass = "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE",
+ messageParameters = Map(
+ "tableName" -> tableIdentifier.unquotedString,
+ "streamingTableHint" -> streamingTableHint,
+ "specifiedSchema" -> specifiedSchema.treeString,
+ "inferredDataSchema" -> inferredSchema.treeString
+ ),
+ cause = Option(cause.orNull)
+ )
+ }
+
+ /**
+ * Throws if the latest inferred schema for a pipeline table is not compatible with
+ * the table's existing schema.
+ *
+ * @param tableIdentifier the identifier of the table that was not found
+ */
+ def unableToInferSchemaError(
+ tableIdentifier: TableIdentifier,
+ inferredSchema: StructType,
+ incompatibleSchema: StructType,
+ cause: Option[Throwable] = None
+ ): AnalysisException = {
+ new AnalysisException(
+ errorClass = "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA",
+ messageParameters = Map(
+ "tableName" -> tableIdentifier.unquotedString,
+ "inferredDataSchema" -> inferredSchema.treeString,
+ "incompatibleDataSchema" -> incompatibleSchema.treeString
+ ),
+ cause = Option(cause.orNull)
+ )
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
new file mode 100644
index 0000000000000..414d9d0effea4
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.execution.datasources.DataSource
+
+/**
+ * Responsible for properly qualify the identifiers for datasets inside or referenced by the
+ * dataflow graph.
+ */
+object GraphIdentifierManager {
+
+ import IdentifierHelper._
+
+ def parseTableIdentifier(name: String, spark: SparkSession): TableIdentifier = {
+ toTableIdentifier(spark.sessionState.sqlParser.parseMultipartIdentifier(name))
+ }
+
+ /**
+ * Fully qualify (if needed) the user-specified identifier used to reference datasets, and
+ * categorizing the dataset we're referencing (i.e. dataset from this pipeline or dataset that is
+ * external to this pipeline).
+ *
+ * Returns whether the input dataset should be read as a dataset and also the qualified
+ * identifier.
+ *
+ * @param rawInputName the user-specified name when referencing datasets.
+ */
+ def parseAndQualifyInputIdentifier(
+ context: FlowAnalysisContext,
+ rawInputName: String): DatasetIdentifier = {
+ resolveDatasetReadInsideQueryDefinition(context = context, rawInputName = rawInputName)
+ }
+
+ /**
+ * Resolve dataset reads that happens inside the dataset query definition (i.e., inside
+ * the @materialized_view() annotation in Python).
+ */
+ private def resolveDatasetReadInsideQueryDefinition(
+ context: FlowAnalysisContext,
+ rawInputName: String
+ ): DatasetIdentifier = {
+ // After identifier is pre-processed, we first check whether we're referencing a
+ // single-part-name dataset (e.g., temp view). If so, don't fully qualified the identifier
+ // and directly read from it, because single-part-name datasets always out-mask other
+ // fully-qualified-datasets that have the same name. For example, if there's a view named
+ // "a" and also a table named "catalog.schema.a" defined in the graph. "SELECT * FROM a"
+ // would always read the view "a". To read table "a", user would need to read it using
+ // fully/partially qualified name (e.g., "SELECT * FROM catalog.schema.a" or "SELECT * FROM
+ // schema.a").
+ val inputIdentifier = parseTableIdentifier(rawInputName, context.spark)
+
+ /** Return whether we're referencing a dataset that is part of the pipeline. */
+ def isInternalDataset(identifier: TableIdentifier): Boolean = {
+ context.allInputs.contains(identifier)
+ }
+
+ if (isSinglePartIdentifier(inputIdentifier) &&
+ isInternalDataset(inputIdentifier)) {
+ // reading a single-part-name dataset defined in the dataflow graph (e.g., a view)
+ InternalDatasetIdentifier(identifier = inputIdentifier)
+ } else if (isPathIdentifier(context.spark, inputIdentifier)) {
+ // path-based reference, always read as external dataset
+ ExternalDatasetIdentifier(identifier = inputIdentifier)
+ } else {
+ val fullyQualifiedInputIdentifier = fullyQualifyIdentifier(
+ maybeFullyQualifiedIdentifier = inputIdentifier,
+ currentCatalog = context.queryContext.currentCatalog,
+ currentDatabase = context.queryContext.currentDatabase
+ )
+ assertIsFullyQualifiedForRead(identifier = fullyQualifiedInputIdentifier)
+
+ if (isInternalDataset(fullyQualifiedInputIdentifier)) {
+ InternalDatasetIdentifier(identifier = fullyQualifiedInputIdentifier)
+ } else {
+ ExternalDatasetIdentifier(fullyQualifiedInputIdentifier)
+ }
+ }
+ }
+
+ /**
+ * @param rawDatasetIdentifier the dataset identifier specified by the user.
+ */
+ @throws[AnalysisException]
+ private def parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier: TableIdentifier): InternalDatasetIdentifier = {
+ InternalDatasetIdentifier(identifier = rawDatasetIdentifier)
+ }
+
+ /**
+ * Parses the table identifier from the raw table identifier and fully qualifies it.
+ *
+ * @param rawTableIdentifier the raw table identifier
+ * @return the parsed table identifier
+ */
+ @throws[AnalysisException]
+ def parseAndQualifyTableIdentifier(
+ rawTableIdentifier: TableIdentifier,
+ currentCatalog: Option[String],
+ currentDatabase: Option[String]
+ ): InternalDatasetIdentifier = {
+ val pipelineDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier = rawTableIdentifier
+ )
+ val fullyQualifiedTableIdentifier = fullyQualifyIdentifier(
+ maybeFullyQualifiedIdentifier = pipelineDatasetIdentifier.identifier,
+ currentCatalog = currentCatalog,
+ currentDatabase = currentDatabase
+ )
+ // assert the identifier is properly fully qualified
+ assertIsFullyQualifiedForCreate(fullyQualifiedTableIdentifier)
+ InternalDatasetIdentifier(identifier = fullyQualifiedTableIdentifier)
+ }
+
+ /**
+ * Parses and validates the view identifier from the raw view identifier for temporary views.
+ *
+ * @param rawViewIdentifier the raw view identifier
+ * @return the parsed view identifier
+ */
+ @throws[AnalysisException]
+ def parseAndValidateTemporaryViewIdentifier(
+ rawViewIdentifier: TableIdentifier): TableIdentifier = {
+ val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier = rawViewIdentifier
+ )
+ // Temporary views are not persisted to the catalog in use, therefore should not be qualified.
+ if (!isSinglePartIdentifier(internalDatasetIdentifier.identifier)) {
+ throw new AnalysisException(
+ "MULTIPART_TEMPORARY_VIEW_NAME_NOT_SUPPORTED",
+ Map("viewName" -> rawViewIdentifier.unquotedString)
+ )
+ }
+ internalDatasetIdentifier.identifier
+ }
+
+ /**
+ * Parses and validates the view identifier from the raw view identifier for persisted views.
+ *
+ * @param rawViewIdentifier the raw view identifier
+ * @param currentCatalog the catalog
+ * @param currentDatabase the schema
+ * @return the parsed view identifier
+ */
+ def parseAndValidatePersistedViewIdentifier(
+ rawViewIdentifier: TableIdentifier,
+ currentCatalog: Option[String],
+ currentDatabase: Option[String]): TableIdentifier = {
+ val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier = rawViewIdentifier
+ )
+ // Persisted views have fully qualified names
+ val fullyQualifiedViewIdentifier = fullyQualifyIdentifier(
+ maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier,
+ currentCatalog = currentCatalog,
+ currentDatabase = currentDatabase
+ )
+ // assert the identifier is properly fully qualified
+ assertIsFullyQualifiedForCreate(fullyQualifiedViewIdentifier)
+ fullyQualifiedViewIdentifier
+ }
+
+ /**
+ * Parses the flow identifier from the raw flow identifier and fully qualify it.
+ *
+ * @param rawFlowIdentifier the raw flow identifier
+ * @return the parsed flow identifier
+ */
+ @throws[AnalysisException]
+ def parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier: TableIdentifier,
+ currentCatalog: Option[String],
+ currentDatabase: Option[String]
+ ): InternalDatasetIdentifier = {
+ val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier = rawFlowIdentifier
+ )
+
+ val fullyQualifiedFlowIdentifier = fullyQualifyIdentifier(
+ maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier,
+ currentCatalog = currentCatalog,
+ currentDatabase = currentDatabase
+ )
+
+ // assert the identifier is properly fully qualified
+ assertIsFullyQualifiedForCreate(fullyQualifiedFlowIdentifier)
+ InternalDatasetIdentifier(identifier = fullyQualifiedFlowIdentifier)
+ }
+
+ /** Represents the identifier for a dataset that is defined or referenced in a pipeline. */
+ sealed trait DatasetIdentifier
+
+ /** Represents the identifier for a dataset that is defined by the current pipeline. */
+ case class InternalDatasetIdentifier private (
+ identifier: TableIdentifier
+ ) extends DatasetIdentifier
+
+ /** Represents the identifier for a dataset that is external to the current pipeline. */
+ case class ExternalDatasetIdentifier(identifier: TableIdentifier) extends DatasetIdentifier
+}
+
+object IdentifierHelper {
+
+ /**
+ * Returns the quoted string for the name parts.
+ *
+ * @param nameParts the dataset name parts.
+ * @return the quoted string for the name parts.
+ */
+ def toQuotedString(nameParts: Seq[String]): String = {
+ toTableIdentifier(nameParts).quotedString
+ }
+
+ /**
+ * Returns the table identifier constructed from the name parts.
+ *
+ * @param nameParts the dataset name parts.
+ * @return the table identifier constructed from the name parts.
+ */
+ @throws[UnsupportedOperationException]
+ def toTableIdentifier(nameParts: Seq[String]): TableIdentifier = {
+ nameParts.length match {
+ case 1 => TableIdentifier(tableName = nameParts.head)
+ case 2 => TableIdentifier(table = nameParts(1), database = Option(nameParts.head))
+ case 3 =>
+ TableIdentifier(
+ table = nameParts(2),
+ database = Option(nameParts(1)),
+ catalog = Option(nameParts.head)
+ )
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"4+ part table identifier ${nameParts.mkString(".")} is not supported."
+ )
+ }
+ }
+
+ /**
+ * Returns the table identifier constructed from the logical plan.
+ *
+ * @param table the logical plan.
+ * @return the table identifier constructed from the logical plan.
+ */
+ def toTableIdentifier(table: LogicalPlan): TableIdentifier = {
+ val parts = table match {
+ case r: ResolvedIdentifier => r.identifier.namespace.toSeq :+ r.identifier.name
+ case u: UnresolvedIdentifier => u.nameParts
+ case u: UnresolvedRelation => u.multipartIdentifier
+ case _ =>
+ throw new UnsupportedOperationException(s"Unable to resolve name for $table.")
+ }
+ toTableIdentifier(parts)
+ }
+
+ /** Return whether the input identifier is a single-part identifier. */
+ def isSinglePartIdentifier(identifier: TableIdentifier): Boolean = {
+ identifier.database.isEmpty && identifier.catalog.isEmpty
+ }
+
+ /**
+ * Return true if the identifier should be resolved as a path-based reference
+ * (i.e., `datasource`.`path`).
+ */
+ def isPathIdentifier(spark: SparkSession, identifier: TableIdentifier): Boolean = {
+ if (identifier.nameParts.length != 2) {
+ return false
+ }
+ val Seq(datasource, path) = identifier.nameParts
+ val sqlConf = spark.sessionState.conf
+
+ def isDatasourceValid = {
+ try {
+ DataSource.lookupDataSource(datasource, sqlConf)
+ true
+ } catch {
+ case _: ClassNotFoundException => false
+ }
+ }
+
+ // Whether the provided datasource is valid.
+ isDatasourceValid
+ }
+
+ /** Fully qualifies provided identifier with provided catalog & schema. */
+ def fullyQualifyIdentifier(
+ maybeFullyQualifiedIdentifier: TableIdentifier,
+ currentCatalog: Option[String],
+ currentDatabase: Option[String]
+ ): TableIdentifier = {
+ maybeFullyQualifiedIdentifier.copy(
+ database = maybeFullyQualifiedIdentifier.database.orElse(currentDatabase),
+ catalog = maybeFullyQualifiedIdentifier.catalog.orElse(currentCatalog)
+ )
+ }
+
+ /** Assert whether the identifier is properly fully qualified when creating a dataset. */
+ def assertIsFullyQualifiedForCreate(identifier: TableIdentifier): Unit = {
+ assert(
+ identifier.catalog.isDefined && identifier.database.isDefined,
+ s"Dataset identifier $identifier is not properly fully qualified, expect a " +
+ s"three-part-name ..
"
+ )
+ }
+
+ /** Assert whether the identifier is properly qualified when reading a dataset in a pipeline. */
+ def assertIsFullyQualifiedForRead(identifier: TableIdentifier): Unit = {
+ assert(
+ identifier.catalog.isDefined && identifier.database.isDefined,
+ s"Failed to reference dataset $identifier, expect a " +
+ s"three-part-name ..
"
+ )
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala
new file mode 100644
index 0000000000000..c0b5a360afea7
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+/**
+ * @param identifier The identifier of the flow.
+ * @param inputs The identifiers of nodes used as inputs to this flow.
+ * @param output The identifier of the output that this flow writes to.
+ */
+case class FlowNode(
+ identifier: TableIdentifier,
+ inputs: Set[TableIdentifier],
+ output: TableIdentifier)
+
+trait GraphOperations {
+ this: DataflowGraph =>
+
+ /** The set of all outputs in the graph. */
+ private lazy val destinationSet: Set[TableIdentifier] =
+ flows.map(_.destinationIdentifier).toSet
+
+ /** A map from flow identifier to [[FlowNode]], which contains the input/output nodes. */
+ lazy val flowNodes: Map[TableIdentifier, FlowNode] = {
+ flows.map { f =>
+ val identifier = f.identifier
+ val n =
+ FlowNode(
+ identifier = identifier,
+ inputs = resolvedFlow(f.identifier).inputs,
+ output = f.destinationIdentifier
+ )
+ identifier -> n
+ }.toMap
+ }
+
+ /** Map from dataset identifier to all reachable upstream destinations, including itself. */
+ private lazy val upstreamDestinations =
+ mutable.HashMap
+ .empty[TableIdentifier, Set[TableIdentifier]]
+ .withDefault(key => dfsInternal(startDestination = key, downstream = false))
+
+ /** Map from dataset identifier to all reachable downstream destinations, including itself. */
+ private lazy val downstreamDestinations =
+ mutable.HashMap
+ .empty[TableIdentifier, Set[TableIdentifier]]
+ .withDefault(key => dfsInternal(startDestination = key, downstream = true))
+
+ /**
+ * Performs a DFS starting from `startNode` and returns the set of nodes (datasets) reached.
+ * @param startDestination The identifier of the node to start from.
+ * @param downstream if true, traverse output edges (search downstream)
+ * if false, traverse input edges (search upstream).
+ * @param stopAtMaterializationPoints If true, stop when we reach a materialization point (table).
+ * If false, keep going until the end.
+ */
+ protected def dfsInternal(
+ startDestination: TableIdentifier,
+ downstream: Boolean,
+ stopAtMaterializationPoints: Boolean = false): Set[TableIdentifier] = {
+ assert(
+ destinationSet.contains(startDestination),
+ s"$startDestination is not a valid start node"
+ )
+ val visited = new mutable.HashSet[TableIdentifier]
+
+ // Same semantics as a stack. Need to be able to push/pop items.
+ var nextNodes = List[TableIdentifier]()
+ nextNodes = startDestination :: nextNodes
+
+ while (nextNodes.nonEmpty) {
+ val currNode = nextNodes.head
+ nextNodes = nextNodes.tail
+
+ if (!visited.contains(currNode)
+ // When we stop at materialization points, skip non-start nodes that are materialized.
+ && !(stopAtMaterializationPoints && table.contains(currNode)
+ && currNode != startDestination)) {
+ visited.add(currNode)
+ val neighbors = if (downstream) {
+ flowNodes.values.filter(_.inputs.contains(currNode)).map(_.output)
+ } else {
+ flowNodes.values.filter(_.output == currNode).flatMap(_.inputs)
+ }
+ nextNodes = neighbors.toList ++ nextNodes
+ }
+ }
+ visited.toSet
+ }
+
+ /**
+ * An implementation of DFS that takes in a sequence of start nodes and returns the
+ * "reachability set" of nodes from the start nodes.
+ *
+ * @param downstream Walks the graph via the input edges if true, otherwise via the output
+ * edges.
+ * @return A map from visited nodes to its origin[s] in `datasetIdentifiers`, e.g.
+ * Let graph = a -> b c -> d (partitioned graph)
+ *
+ * reachabilitySet(Seq("a", "c"), downstream = true)
+ * -> ["a" -> ["a"], "b" -> ["a"], "c" -> ["c"], "d" -> ["c"]]
+ */
+ private def reachabilitySet(
+ datasetIdentifiers: Seq[TableIdentifier],
+ downstream: Boolean): Map[TableIdentifier, Set[TableIdentifier]] = {
+ // Seq of the form "node identifier" -> Set("dependency1, "dependency2")
+ val deps = datasetIdentifiers.map(n => n -> reachabilitySet(n, downstream))
+
+ // Invert deps so that we get a map from each dependency to node identifier
+ val finalMap = mutable.HashMap
+ .empty[TableIdentifier, Set[TableIdentifier]]
+ .withDefaultValue(Set.empty[TableIdentifier])
+ deps.foreach {
+ case (start, reachableNodes) =>
+ reachableNodes.foreach(n => finalMap.put(n, finalMap(n) + start))
+ }
+ finalMap.toMap
+ }
+
+ /**
+ * Returns all datasets that can be reached from `destinationIdentifier`.
+ */
+ private def reachabilitySet(
+ destinationIdentifier: TableIdentifier,
+ downstream: Boolean): Set[TableIdentifier] = {
+ if (downstream) downstreamDestinations(destinationIdentifier)
+ else upstreamDestinations(destinationIdentifier)
+ }
+
+ /** Returns the set of flows reachable from `flowIdentifier` via output (child) edges. */
+ def downstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = {
+ assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow")
+ val downstreamDatasets = reachabilitySet(flowNodes(flowIdentifier).output, downstream = true)
+ flowNodes.values.filter(_.inputs.exists(downstreamDatasets.contains)).map(_.identifier).toSet
+ }
+
+ /** Returns the set of flows reachable from `flowIdentifier` via input (parent) edges. */
+ def upstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = {
+ assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow")
+ val upstreamDatasets =
+ flowNodes(flowIdentifier).inputs.flatMap(reachabilitySet(_, downstream = false))
+ flowNodes.values.filter(e => upstreamDatasets.contains(e.output)).map(_.identifier).toSet
+ }
+
+ /** Returns the set of datasets reachable from `datasetIdentifier` via input (parent) edges. */
+ def upstreamDatasets(datasetIdentifier: TableIdentifier): Set[TableIdentifier] =
+ reachabilitySet(datasetIdentifier, downstream = false) - datasetIdentifier
+
+ /**
+ * Traverses the graph upstream starting from the specified `datasetIdentifiers` to return the
+ * reachable nodes. The return map's keyset consists of all datasets reachable from
+ * `datasetIdentifiers`. For each entry in the response map, the value of that element refers
+ * to which of `datasetIdentifiers` was able to reach the key. If multiple of `datasetIdentifiers`
+ * could reach that key, one is picked arbitrarily.
+ */
+ def upstreamDatasets(
+ datasetIdentifiers: Seq[TableIdentifier]): Map[TableIdentifier, Set[TableIdentifier]] =
+ reachabilitySet(datasetIdentifiers, downstream = false) -- datasetIdentifiers
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
new file mode 100644
index 0000000000000..0e2ba42b15e59
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+/**
+ * A mutable context for registering tables, views, and flows in a dataflow graph.
+ *
+ * @param defaultCatalog The pipeline's default catalog.
+ * @param defaultDatabase The pipeline's default schema.
+ */
+class GraphRegistrationContext(
+ val defaultCatalog: String,
+ val defaultDatabase: String,
+ val defaultSqlConf: Map[String, String]) {
+ import GraphRegistrationContext._
+
+ protected val tables = new mutable.ListBuffer[Table]
+ protected val views = new mutable.ListBuffer[View]
+ protected val flows = new mutable.ListBuffer[UnresolvedFlow]
+
+ def registerTable(tableDef: Table): Unit = {
+ tables += tableDef
+ }
+
+ def registerView(viewDef: View): Unit = {
+ views += viewDef
+ }
+
+ def registerFlow(flowDef: UnresolvedFlow): Unit = {
+ flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf)
+ }
+
+ def toDataflowGraph: DataflowGraph = {
+ val qualifiedTables = tables.toSeq.map { t =>
+ t.copy(
+ identifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = t.identifier,
+ currentCatalog = Some(defaultCatalog),
+ currentDatabase = Some(defaultDatabase)
+ )
+ .identifier
+ )
+ }
+
+ val validatedViews = views.toSeq.collect {
+ case v: TemporaryView =>
+ v.copy(
+ identifier = GraphIdentifierManager
+ .parseAndValidateTemporaryViewIdentifier(
+ rawViewIdentifier = v.identifier
+ )
+ )
+ case v: PersistedView =>
+ v.copy(
+ identifier = GraphIdentifierManager
+ .parseAndValidatePersistedViewIdentifier(
+ rawViewIdentifier = v.identifier,
+ currentCatalog = Some(defaultCatalog),
+ currentDatabase = Some(defaultDatabase)
+ )
+ )
+ }
+
+ val qualifiedFlows = flows.toSeq.map { f =>
+ val isImplicitFlow = f.identifier == f.destinationIdentifier
+ val flowWritesToView =
+ validatedViews
+ .filter(_.isInstanceOf[TemporaryView])
+ .exists(_.identifier == f.destinationIdentifier)
+
+ // If the flow is created implicitly as part of defining a view, then we do not
+ // qualify the flow identifier and the flow destination. This is because views are
+ // not permitted to have multipart
+ if (isImplicitFlow && flowWritesToView) {
+ f
+ } else {
+ f.copy(
+ identifier = GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = f.identifier,
+ currentCatalog = Some(defaultCatalog),
+ currentDatabase = Some(defaultDatabase)
+ )
+ .identifier,
+ destinationIdentifier = GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = f.destinationIdentifier,
+ currentCatalog = Some(defaultCatalog),
+ currentDatabase = Some(defaultDatabase)
+ )
+ .identifier
+ )
+ }
+ }
+
+ assertNoDuplicates(
+ qualifiedTables = qualifiedTables,
+ validatedViews = validatedViews,
+ qualifiedFlows = qualifiedFlows
+ )
+
+ new DataflowGraph(
+ tables = qualifiedTables,
+ views = validatedViews,
+ flows = qualifiedFlows
+ )
+ }
+
+ private def assertNoDuplicates(
+ qualifiedTables: Seq[Table],
+ validatedViews: Seq[View],
+ qualifiedFlows: Seq[UnresolvedFlow]): Unit = {
+
+ (qualifiedTables.map(_.identifier) ++ validatedViews.map(_.identifier))
+ .foreach { identifier =>
+ assertDatasetIdentifierIsUnique(
+ identifier = identifier,
+ tables = qualifiedTables,
+ views = validatedViews
+ )
+ }
+
+ qualifiedFlows.foreach { flow =>
+ assertFlowIdentifierIsUnique(
+ flow = flow,
+ datasetType = TableType,
+ flows = qualifiedFlows
+ )
+ }
+ }
+
+ private def assertDatasetIdentifierIsUnique(
+ identifier: TableIdentifier,
+ tables: Seq[Table],
+ views: Seq[View]): Unit = {
+
+ // We need to check for duplicates in both tables and views, as they can have the same name.
+ val allDatasets = tables.map(t => t.identifier -> TableType) ++ views.map(
+ v => v.identifier -> ViewType
+ )
+
+ val grouped = allDatasets.groupBy { case (id, _) => id }
+
+ grouped(identifier).toList match {
+ case (_, firstType) :: (_, secondType) :: _ =>
+ // Sort the types in lexicographic order to ensure consistent error messages.
+ val sortedTypes = Seq(firstType.toString, secondType.toString).sorted
+ throw new AnalysisException(
+ errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET",
+ messageParameters = Map(
+ "datasetName" -> identifier.quotedString,
+ "datasetType1" -> sortedTypes.head,
+ "datasetType2" -> sortedTypes.last
+ )
+ )
+ case _ => // No duplicates found.
+ }
+ }
+
+ private def assertFlowIdentifierIsUnique(
+ flow: UnresolvedFlow,
+ datasetType: DatasetType,
+ flows: Seq[UnresolvedFlow]): Unit = {
+ flows.groupBy(i => i.identifier).get(flow.identifier).filter(_.size > 1).foreach {
+ duplicateFlows =>
+ val duplicateFlow = duplicateFlows.filter(_ != flow).head
+ throw new AnalysisException(
+ errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW",
+ messageParameters = Map(
+ "flowName" -> flow.identifier.unquotedString,
+ "datasetNames" -> Set(
+ flow.destinationIdentifier.quotedString,
+ duplicateFlow.destinationIdentifier.quotedString
+ ).mkString(",")
+ )
+ )
+ }
+ }
+}
+
+object GraphRegistrationContext {
+ sealed trait DatasetType
+
+ private object TableType extends DatasetType {
+ override def toString: String = "TABLE"
+ }
+
+ private object ViewType extends DatasetType {
+ override def toString: String = "VIEW"
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
new file mode 100644
index 0000000000000..99142432f9cec
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique
+import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils
+
+/** Validations performed on a [[DataflowGraph]]. */
+trait GraphValidations extends Logging {
+ this: DataflowGraph =>
+
+ /**
+ * Validate multi query table correctness.
+ */
+ protected[pipelines] def validateMultiQueryTables(): Map[TableIdentifier, Seq[Flow]] = {
+ val multiQueryTables = flowsTo.filter(_._2.size > 1)
+ // Non-streaming tables do not support multiflow.
+ multiQueryTables
+ .find {
+ case (dest, flows) =>
+ flows.exists(f => !resolvedFlow(f.identifier).df.isStreaming) &&
+ table.contains(dest)
+ }
+ .foreach {
+ case (dest, flows) =>
+ throw new AnalysisException(
+ "MATERIALIZED_VIEW_WITH_MULTIPLE_QUERIES",
+ Map(
+ "tableName" -> dest.unquotedString,
+ "queries" -> flows.map(_.identifier).mkString(",")
+ )
+ )
+ }
+
+ multiQueryTables
+ }
+
+ /** Throws an exception if the flows in this graph are not topologically sorted. */
+ protected[graph] def validateGraphIsTopologicallySorted(): Unit = {
+ val visitedNodes = mutable.Set.empty[TableIdentifier] // Set of visited nodes
+ val visitedEdges = mutable.Set.empty[TableIdentifier] // Set of visited edges
+ flows.foreach { f =>
+ // Unvisited inputs of the current flow
+ val unvisitedInputNodes =
+ resolvedFlow(f.identifier).inputs -- visitedNodes
+ unvisitedInputNodes.headOption match {
+ case None =>
+ visitedEdges.add(f.identifier)
+ if (flowsTo(f.destinationIdentifier).map(_.identifier).forall(visitedEdges.contains)) {
+ // A node is marked visited if all its inputs are visited
+ visitedNodes.add(f.destinationIdentifier)
+ }
+ case Some(unvisitedInput) =>
+ throw new AnalysisException(
+ "PIPELINE_GRAPH_NOT_TOPOLOGICALLY_SORTED",
+ Map(
+ "flowName" -> f.identifier.unquotedString,
+ "inputName" -> unvisitedInput.unquotedString
+ )
+ )
+ }
+ }
+ }
+
+ /**
+ * Validate that all tables are resettable. This is a best-effort check that will only catch
+ * upstream tables that are resettable but have a non-resettable downstream dependency.
+ */
+ protected def validateTablesAreResettable(): Unit = {
+ validateTablesAreResettable(tables)
+ }
+
+ /** Validate that all specified tables are resettable. */
+ protected def validateTablesAreResettable(tables: Seq[Table]): Unit = {
+ val tableLookup = mapUnique(tables, "table")(_.identifier)
+ val nonResettableTables =
+ tables.filter(t => !PipelinesTableProperties.resetAllowed.fromMap(t.properties))
+ val upstreamResettableTables = upstreamDatasets(nonResettableTables.map(_.identifier))
+ .collect {
+ // Filter for upstream datasets that are tables with downstream streaming tables
+ case (upstreamDataset, nonResettableDownstreams) if table.contains(upstreamDataset) =>
+ nonResettableDownstreams
+ .filter(
+ t => flowsTo(t).exists(f => resolvedFlow(f.identifier).df.isStreaming)
+ )
+ .map(id => (tableLookup(upstreamDataset), tableLookup(id).displayName))
+ }
+ .flatten
+ .toSeq
+ .filter {
+ case (t, _) => PipelinesTableProperties.resetAllowed.fromMap(t.properties)
+ } // Filter for resettable
+
+ upstreamResettableTables
+ .groupBy(_._2) // Group-by non-resettable downstream tables
+ .view
+ .mapValues(_.map(_._1))
+ .toSeq
+ .sortBy(_._2.size) // Output errors from largest to smallest
+ .reverse
+ .map {
+ case (nameForEvent, tables) =>
+ throw new AnalysisException(
+ "INVALID_RESETTABLE_DEPENDENCY",
+ Map(
+ "downstreamTable" -> nameForEvent,
+ "upstreamResettableTables" -> tables
+ .map(_.displayName)
+ .sorted
+ .map(t => s"'$t'")
+ .mkString(", "),
+ "resetAllowedKey" -> PipelinesTableProperties.resetAllowed.key
+ )
+ )
+ }
+ }
+
+ protected def validateUserSpecifiedSchemas(): Unit = {
+ flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput =>
+ // The output inferred schema of a table is the declared schema merged with the
+ // schema of all incoming flows. This must be equivalent to the declared schema.
+ val inferredSchema = SchemaInferenceUtils
+ .inferSchemaFromFlows(
+ flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)),
+ userSpecifiedSchema = t.specifiedSchema
+ )
+
+ t.specifiedSchema.foreach { ss =>
+ // Check the inferred schema matches the specified schema. Used to catch errors where the
+ // inferred user-facing schema has columns that are not in the specified one.
+ if (inferredSchema != ss) {
+ val datasetType = GraphElementTypeUtils
+ .getDatasetTypeForMaterializedViewOrStreamingTable(
+ flowsTo(t.identifier).map(f => resolvedFlow(f.identifier))
+ )
+ throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError(
+ t.identifier,
+ datasetType,
+ ss,
+ inferredSchema
+ )
+ }
+ }
+ }
+ }
+
+ /**
+ * Validates that all flows are resolved. If there are unresolved flows,
+ * detects a possible cyclic dependency and throw the appropriate execption.
+ */
+ protected def validateSuccessfulFlowAnalysis(): Unit = {
+ // all failed flows with their errors
+ val flowAnalysisFailures = resolutionFailedFlows.flatMap(
+ f => f.failure.headOption.map(err => (f.identifier, err))
+ )
+ // only proceed if there are unresolved flows
+ if (flowAnalysisFailures.nonEmpty) {
+ val failedFlowIdentifiers = flowAnalysisFailures.map(_._1).toSet
+ // used to collect the subgraph of only the unresolved flows
+ // maps every unresolved flow to the set of unresolved flows writing to one if its inputs
+ val failedFlowsSubgraph = mutable.Map[TableIdentifier, Seq[TableIdentifier]]()
+ val (downstreamFailures, directFailures) = flowAnalysisFailures.partition {
+ case (flowIdentifier, _) =>
+ // If a failed flow writes to any of the requested datasets, we mark this flow as a
+ // downstream failure
+ val failedFlowsWritingToRequestedDatasets =
+ resolutionFailedFlow(flowIdentifier).funcResult.requestedInputs
+ .flatMap(d => flowsTo.getOrElse(d, Seq()))
+ .map(_.identifier)
+ .intersect(failedFlowIdentifiers)
+ .toSeq
+ failedFlowsSubgraph += (flowIdentifier -> failedFlowsWritingToRequestedDatasets)
+ failedFlowsWritingToRequestedDatasets.nonEmpty
+ }
+ // if there are flow that failed due to unresolved upstream flows, check for a cycle
+ if (failedFlowsSubgraph.nonEmpty) {
+ detectCycle(failedFlowsSubgraph.toMap).foreach {
+ case (upstream, downstream) =>
+ val upstreamDataset = flow(upstream).destinationIdentifier
+ val downstreamDataset = flow(downstream).destinationIdentifier
+ throw CircularDependencyException(
+ downstreamDataset,
+ upstreamDataset
+ )
+ }
+ }
+ // otherwise report what flows failed directly vs. depending on a failed flow
+ throw UnresolvedPipelineException(
+ this,
+ directFailures.map { case (id, value) => (id, value) }.toMap,
+ downstreamFailures.map { case (id, value) => (id, value) }.toMap
+ )
+ }
+ }
+
+ /**
+ * Generic method to detect a cycle in directed graph via DFS traversal.
+ * The graph is given as a reverse adjacency map, that is, a map from
+ * each node to its ancestors.
+ * @return the start and end node of a cycle if found, None otherwise
+ */
+ private def detectCycle(ancestors: Map[TableIdentifier, Seq[TableIdentifier]])
+ : Option[(TableIdentifier, TableIdentifier)] = {
+ var cycle: Option[(TableIdentifier, TableIdentifier)] = None
+ val visited = mutable.Set[TableIdentifier]()
+ def visit(f: TableIdentifier, currentPath: List[TableIdentifier]): Unit = {
+ if (cycle.isEmpty && !visited.contains(f)) {
+ if (currentPath.contains(f)) {
+ cycle = Option((currentPath.head, f))
+ } else {
+ ancestors(f).foreach(visit(_, f :: currentPath))
+ visited += f
+ }
+ }
+ }
+ ancestors.keys.foreach(visit(_, Nil))
+ cycle
+ }
+
+ /** Validates that persisted views don't read from invalid sources */
+ protected[graph] def validatePersistedViewSources(): Unit = {
+ val viewToFlowMap = ViewHelpers.persistedViewIdentifierToFlow(graph = this)
+
+ persistedViews
+ .foreach { persistedView =>
+ val flow = viewToFlowMap(persistedView.identifier)
+ val funcResult = resolvedFlow(flow.identifier).funcResult
+ val inputIdentifiers = (funcResult.batchInputs ++ funcResult.streamingInputs)
+ .map(_.input.identifier)
+
+ inputIdentifiers
+ .flatMap(view.get)
+ .foreach {
+ case tempView: TemporaryView =>
+ throw new AnalysisException(
+ errorClass = "INVALID_TEMP_OBJ_REFERENCE",
+ messageParameters = Map(
+ "persistedViewName" -> persistedView.identifier.toString,
+ "temporaryViewName" -> tempView.identifier.toString
+ ),
+ cause = null
+ )
+ case _ =>
+ }
+ }
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
new file mode 100644
index 0000000000000..4bed25f2aa1c7
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+/**
+ * Exception raised when a flow tries to read from a dataset that exists but is unresolved
+ *
+ * @param identifier The identifier of the dataset
+ */
+case class UnresolvedDatasetException(identifier: TableIdentifier)
+ extends AnalysisException(
+ s"Failed to read dataset '${identifier.unquotedString}'. Dataset is defined in the " +
+ s"pipeline but could not be resolved."
+ )
+
+/**
+ * Exception raised when a flow fails to read from a table defined within the pipeline
+ *
+ * @param name The name of the table
+ * @param cause The cause of the failure
+ */
+case class LoadTableException(name: String, cause: Option[Throwable])
+ extends SparkException(
+ errorClass = "INTERNAL_ERROR",
+ messageParameters = Map("message" -> s"Failed to load table '$name'"),
+ cause = cause.orNull
+ )
+
+/**
+ * Exception raised when a pipeline has one or more flows that cannot be resolved
+ *
+ * @param directFailures Mapping between the name of flows that failed to resolve (due to an
+ * error in that flow) and the error that occurred when attempting to
+ * resolve them
+ * @param downstreamFailures Mapping between the name of flows that failed to resolve (because they
+ * failed to read from other unresolved flows) and the error that occurred
+ * when attempting to resolve them
+ */
+case class UnresolvedPipelineException(
+ graph: DataflowGraph,
+ directFailures: Map[TableIdentifier, Throwable],
+ downstreamFailures: Map[TableIdentifier, Throwable],
+ additionalHint: Option[String] = None)
+ extends AnalysisException(
+ s"""
+ |Failed to resolve flows in the pipeline.
+ |
+ |A flow can fail to resolve because the flow itself contains errors or because it reads
+ |from an upstream flow which failed to resolve.
+ |${additionalHint.getOrElse("")}
+ |Flows with errors: ${directFailures.keys.map(_.unquotedString).toSeq.sorted.mkString(", ")}
+ |Flows that failed due to upstream errors: ${downstreamFailures.keys
+ .map(_.unquotedString)
+ .toSeq
+ .sorted
+ .mkString(", ")}
+ |
+ |To view the exceptions that were raised while resolving these flows, look for flow
+ |failures that precede this log.""".stripMargin
+ )
+
+/**
+ * Raised when there's a circular dependency in the current pipeline. That is, a downstream
+ * table is referenced while creating a upstream table.
+ */
+case class CircularDependencyException(
+ downstreamTable: TableIdentifier,
+ upstreamDataset: TableIdentifier)
+ extends AnalysisException(
+ s"The downstream table '${downstreamTable.unquotedString}' is referenced when " +
+ s"creating the upstream table or view '${upstreamDataset.unquotedString}'. " +
+ s"Circular dependencies are not supported in a pipeline. Please remove the dependency " +
+ s"between '${upstreamDataset.unquotedString}' and '${downstreamTable.unquotedString}'."
+ )
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala
new file mode 100644
index 0000000000000..c8627a9a0a2e4
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import java.util.Locale
+
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+/**
+ * Interface for validating and accessing Pipeline-specific table properties.
+ */
+object PipelinesTableProperties {
+
+ /** Prefix used for all table properties. */
+ val pipelinesPrefix = "pipelines."
+
+ /** Map of keys to property entries. */
+ private val entries: mutable.HashMap[String, PipelineTableProperty[_]] = mutable.HashMap.empty
+
+ /** Whether the table should be reset when a Reset is triggered. */
+ val resetAllowed: PipelineTableProperty[Boolean] =
+ buildProp("pipelines.reset.allowed", default = true)
+
+ /**
+ * Validates that all known pipeline properties are valid and can be parsed.
+ * Also warn the user if they try to set an unknown pipelines property, or any
+ * property that looks suspiciously similar to a known pipeline property.
+ *
+ * @param rawProps Raw table properties to validate and canonicalize
+ * @param warnFunction Function to warn users of potential issues. Fatal errors are still thrown.
+ * @return a map of table properties, with canonical-case keys for valid pipelines properties,
+ * and excluding invalid pipelines properties.
+ */
+ def validateAndCanonicalize(
+ rawProps: Map[String, String],
+ warnFunction: String => Unit): Map[String, String] = {
+ rawProps.flatMap {
+ case (originalCaseKey, value) =>
+ val k = originalCaseKey.toLowerCase(Locale.ROOT)
+
+ // Make sure all pipelines properties are valid
+ if (k.startsWith(pipelinesPrefix)) {
+ entries.get(k) match {
+ case Some(prop) =>
+ prop.fromMap(rawProps) // Make sure the value can be retrieved
+ Option(prop.key -> value) // Canonicalize the case
+ case None =>
+ warnFunction(s"Unknown pipelines table property '$originalCaseKey'.")
+ // exclude this property - the pipeline won't recognize it anyway,
+ // and setting it would allow users to use `pipelines.xyz` for their custom
+ // purposes - which we don't want to encourage.
+ None
+ }
+ } else {
+ // Make sure they weren't trying to set a pipelines property
+ val similarProp = entries.get(s"${pipelinesPrefix}$k")
+ similarProp.foreach { c =>
+ warnFunction(
+ s"You are trying to set table property '$originalCaseKey', which has a similar name" +
+ s" to Pipelines table property '${c.key}'. If you are trying to set the Pipelines " +
+ s"table property, please include the correct prefix."
+ )
+ }
+ Option(originalCaseKey -> value)
+ }
+ }
+ }
+
+ /** Registers a pipelines table property with the specified key and default value. */
+ private def buildProp[T](
+ key: String,
+ default: String,
+ fromString: String => T): PipelineTableProperty[T] = {
+ val prop = PipelineTableProperty(key, default, fromString)
+ entries.put(key.toLowerCase(Locale.ROOT), prop)
+ prop
+ }
+
+ private def buildProp(key: String, default: Boolean): PipelineTableProperty[Boolean] =
+ buildProp(key, default.toString, _.toBoolean)
+}
+
+case class PipelineTableProperty[T](key: String, default: String, fromString: String => T) {
+ def fromMap(rawProps: Map[String, String]): T = parseFromString(rawProps.getOrElse(key, default))
+
+ private def parseFromString(value: String): T = {
+ try {
+ fromString(value)
+ } catch {
+ case NonFatal(_) =>
+ throw new IllegalArgumentException(
+ s"Could not parse value '$value' for table property '$key'"
+ )
+ }
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
new file mode 100644
index 0000000000000..042b4d9626fd6
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.control.{NonFatal, NoStackTrace}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.pipelines.Language
+import org.apache.spark.sql.pipelines.logging.SourceCodeLocation
+
+/**
+ * Records information used to track the provenance of a given query to user code.
+ *
+ * @param language The language used by the user to define the query.
+ * @param fileName The file name of the user code that defines the query.
+ * @param sqlText The SQL text of the query.
+ * @param line The line number of the query in the user code.
+ * Line numbers are 1-indexed.
+ * @param startPosition The start position of the query in the user code.
+ * @param objectType The type of the object that the query is associated with. (Table, View, etc)
+ * @param objectName The name of the object that the query is associated with.
+ */
+case class QueryOrigin(
+ language: Option[Language] = None,
+ fileName: Option[String] = None,
+ sqlText: Option[String] = None,
+ line: Option[Int] = None,
+ startPosition: Option[Int] = None,
+ objectType: Option[String] = None,
+ objectName: Option[String] = None
+) {
+
+ /**
+ * Merges this origin with another one.
+ *
+ * The result has fields set to the value in the other origin if it is defined, or if not, then
+ * the value in this origin.
+ */
+ def merge(other: QueryOrigin): QueryOrigin = {
+ QueryOrigin(
+ language = other.language.orElse(language),
+ fileName = other.fileName.orElse(fileName),
+ sqlText = other.sqlText.orElse(sqlText),
+ line = other.line.orElse(line),
+ startPosition = other.startPosition.orElse(startPosition),
+ objectType = other.objectType.orElse(objectType),
+ objectName = other.objectName.orElse(objectName)
+ )
+ }
+
+ /**
+ * Merge values from the catalyst origin.
+ *
+ * The result has fields set to the value in the other origin if it is defined, or if not, then
+ * the value in this origin.
+ */
+ def merge(other: Origin): QueryOrigin = {
+ merge(
+ QueryOrigin(
+ sqlText = other.sqlText,
+ line = other.line,
+ startPosition = other.startPosition
+ )
+ )
+ }
+
+ /** Generates a SourceCodeLocation using the details present in the query origin. */
+ def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation(
+ path = fileName,
+ // QueryOrigin tracks line numbers using a 1-indexed numbering scheme whereas SourceCodeLocation
+ // tracks them using a 0-indexed numbering scheme.
+ lineNumber = line.map(_ - 1),
+ columnNumber = startPosition,
+ endingLineNumber = None,
+ endingColumnNumber = None
+ )
+}
+
+object QueryOrigin extends Logging {
+
+ /** An empty QueryOrigin without any provenance information. */
+ val empty: QueryOrigin = QueryOrigin()
+
+ /**
+ * An exception that wraps [[QueryOrigin]] and lets us store it in errors as suppressed
+ * exceptions.
+ */
+ private case class QueryOriginWrapper(origin: QueryOrigin) extends Exception with NoStackTrace
+
+ implicit class ExceptionHelpers(t: Throwable) {
+
+ /**
+ * Stores `origin` inside the given throwable using suppressed exceptions.
+ *
+ * We rely on suppressed exceptions since that lets us preserve the original exception class
+ * and type.
+ */
+ def addOrigin(origin: QueryOrigin): Throwable = {
+ // Only try to add the query context if one has not already been added.
+ try {
+ // Do not add the origin again if one is already present.
+ // This also handles the case where the throwable is `null`.
+ if (getOrigin(t).isEmpty) {
+ t.addSuppressed(QueryOriginWrapper(origin))
+ }
+ } catch {
+ case NonFatal(e) => logError("Failed to add pipeline context", e)
+ }
+ t
+ }
+ }
+
+ /** Returns the [[QueryOrigin]] stored as a suppressed exception in the given throwable.
+ *
+ * @return Some(origin) if the origin is recorded as part of the given throwable, `None`
+ * otherwise.
+ */
+ def getOrigin(t: Throwable): Option[QueryOrigin] = {
+ try {
+ // Wrap in an `Option(_)` first to handle `null` throwables.
+ Option(t).flatMap { ex =>
+ ex.getSuppressed.collectFirst {
+ case QueryOriginWrapper(context) => context
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ logError("Failed to get pipeline context", e)
+ None
+ }
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala
new file mode 100644
index 0000000000000..9f05219c383dd
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+object ViewHelpers {
+
+ /** Map of view identifier to corresponding unresolved flow */
+ def persistedViewIdentifierToFlow(graph: DataflowGraph): Map[TableIdentifier, Flow] = {
+ graph.persistedViews.map { v =>
+ require(
+ graph.flowsTo.get(v.identifier).isDefined,
+ s"No flows to view ${v.identifier} were found"
+ )
+ val flowsToView = graph.flowsTo(v.identifier)
+ require(
+ flowsToView.size == 1,
+ s"Expected a single flow to the view, found ${flowsToView.size} flows to ${v.identifier}"
+ )
+ (v.identifier, flowsToView.head)
+ }.toMap
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
new file mode 100644
index 0000000000000..770776b29cf08
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import java.util
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.classic.{DataFrame, SparkSession}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.pipelines.common.DatasetType
+import org.apache.spark.sql.pipelines.util.{
+ BatchReadOptions,
+ InputReadOptions,
+ SchemaInferenceUtils,
+ StreamingReadOptions
+}
+import org.apache.spark.sql.types.StructType
+
+/** An element in a [[DataflowGraph]]. */
+trait GraphElement {
+
+ /**
+ * Contains provenance to tie back this GraphElement to the user code that defined it.
+ *
+ * This must be set when a [[GraphElement]] is directly created by some user code.
+ * Subsequently, this initial origin must be propagated as is without modification.
+ * If this [[GraphElement]] is copied or converted to a different type, then this origin must be
+ * copied as is.
+ */
+ def origin: QueryOrigin
+
+ protected def spark: SparkSession = SparkSession.getActiveSession.get
+
+ /** Returns the unique identifier for this [[GraphElement]]. */
+ def identifier: TableIdentifier
+
+ /**
+ * Returns a user-visible name for the element.
+ */
+ def displayName: String = identifier.unquotedString
+}
+
+/**
+ * Specifies an input that can be referenced by another Dataset's query.
+ */
+trait Input extends GraphElement {
+
+ /**
+ * Returns a DataFrame that is a result of loading data from this [[Input]].
+ * @param readOptions Type of input. Used to determine streaming/batch
+ * @return Streaming or batch DataFrame of this Input's data.
+ */
+ def load(readOptions: InputReadOptions): DataFrame
+}
+
+/**
+ * Represents a node in a [[DataflowGraph]] that can be written to by a [[Flow]].
+ * Must be backed by a file source.
+ */
+sealed trait Output {
+
+ /**
+ * Normalized storage location used for storing materializations for this [[Output]].
+ * If None, it means this [[Output]] has not been normalized yet.
+ */
+ def normalizedPath: Option[String]
+
+ /** Return whether the storage location for this [[Output]] has been normalized. */
+ final def normalized: Boolean = normalizedPath.isDefined
+
+ /**
+ * Return the normalized storage location for this [[Output]] and throw if the
+ * storage location has not been normalized.
+ */
+ @throws[SparkException]
+ def path: String
+}
+
+/** A type of [[Input]] where data is loaded from a table. */
+sealed trait TableInput extends Input {
+
+ /** The user-specified schema for this table. */
+ def specifiedSchema: Option[StructType]
+}
+
+/**
+ * A table representing a materialized dataset in a [[DataflowGraph]].
+ *
+ * @param identifier The identifier of this table within the graph.
+ * @param specifiedSchema The user-specified schema for this table.
+ * @param partitionCols What columns the table should be partitioned by when materialized.
+ * @param normalizedPath Normalized storage location for the table based on the user-specified table
+ * path (if not defined, we will normalize a managed storage path for it).
+ * @param properties Table Properties to set in table metadata.
+ * @param comment User-specified comment that can be placed on the table.
+ * @param isStreamingTableOpt if the table is a streaming table, will be None until we have resolved
+ * flows into table
+ */
+case class Table(
+ identifier: TableIdentifier,
+ specifiedSchema: Option[StructType],
+ partitionCols: Option[Seq[String]],
+ normalizedPath: Option[String],
+ properties: Map[String, String] = Map.empty,
+ comment: Option[String],
+ baseOrigin: QueryOrigin,
+ isStreamingTableOpt: Option[Boolean],
+ format: Option[String]
+) extends TableInput
+ with Output {
+
+ override val origin: QueryOrigin = baseOrigin.copy(
+ objectType = Some("table"),
+ objectName = Some(identifier.unquotedString)
+ )
+
+ // Load this table's data from underlying storage.
+ override def load(readOptions: InputReadOptions): DataFrame = {
+ try {
+ lazy val tableName = identifier.quotedString
+
+ val df = readOptions match {
+ case sro: StreamingReadOptions =>
+ spark.readStream.options(sro.userOptions).table(tableName)
+ case _: BatchReadOptions =>
+ spark.read.table(tableName)
+ case _ =>
+ throw new IllegalArgumentException("Unhandled `InputReadOptions` type when loading table")
+ }
+
+ df
+ } catch {
+ case NonFatal(e) => throw LoadTableException(displayName, Option(e))
+ }
+ }
+
+ /** Returns the normalized storage location to this [[Table]]. */
+ override def path: String = {
+ if (!normalized) {
+ throw GraphErrors.unresolvedTablePath(identifier)
+ }
+ normalizedPath.get
+ }
+
+ /**
+ * Tell if a table is a streaming table or not. This property is not set until we have resolved
+ * the flows into the table. The exception reminds engineers that they cant call at random time.
+ */
+ def isStreamingTable: Boolean = isStreamingTableOpt.getOrElse {
+ throw new IllegalStateException(
+ "Cannot identify whether the table is streaming table or not. You may need to resolve the " +
+ "flows into table."
+ )
+ }
+
+ /**
+ * Get the DatasetType of the table
+ */
+ def datasetType: DatasetType = {
+ if (isStreamingTable) {
+ DatasetType.STREAMING_TABLE
+ } else {
+ DatasetType.MATERIALIZED_VIEW
+ }
+ }
+}
+
+/**
+ * A type of [[TableInput]] that returns data from a specified schema or from the inferred
+ * [[Flow]]s that write to the table.
+ */
+case class VirtualTableInput(
+ identifier: TableIdentifier,
+ specifiedSchema: Option[StructType],
+ incomingFlowIdentifiers: Set[TableIdentifier],
+ availableFlows: Seq[ResolvedFlow] = Nil
+) extends TableInput
+ with Logging {
+ override def origin: QueryOrigin = QueryOrigin()
+
+ assert(availableFlows.forall(_.destinationIdentifier == identifier))
+ override def load(readOptions: InputReadOptions): DataFrame = {
+ // Infer the schema for this virtual table
+ def getFinalSchema: StructType = {
+ specifiedSchema match {
+ // This is not a backing table, and we have a user-specified schema, so use it directly.
+ case Some(ss) => ss
+ // Otherwise infer the schema from a combination of the incoming flows and the
+ // user-specified schema, if provided.
+ case _ =>
+ SchemaInferenceUtils.inferSchemaFromFlows(availableFlows, specifiedSchema)
+ }
+ }
+
+ // create empty streaming/batch df based on input type.
+ def createEmptyDF(schema: StructType): DataFrame = readOptions match {
+ case _: StreamingReadOptions =>
+ MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark.sqlContext)
+ .toDF()
+ case _ => spark.createDataFrame(new util.ArrayList[Row](), schema)
+ }
+
+ val df = createEmptyDF(getFinalSchema)
+ df
+ }
+}
+
+/**
+ * Representing a view in the [[DataflowGraph]].
+ */
+trait View extends GraphElement {
+
+ /** Returns the unique identifier for this [[View]]. */
+ val identifier: TableIdentifier
+
+ /** Properties of this view */
+ val properties: Map[String, String]
+
+ /** User-specified comment that can be placed on the [[View]]. */
+ val comment: Option[String]
+}
+
+/**
+ * Representing a temporary [[View]] in a [[DataflowGraph]].
+ *
+ * @param identifier The identifier of this view within the graph.
+ * @param properties Properties of the view
+ * @param comment when defining a view
+ */
+case class TemporaryView(
+ identifier: TableIdentifier,
+ properties: Map[String, String],
+ comment: Option[String],
+ origin: QueryOrigin
+) extends View {}
+
+/**
+ * Representing a persisted [[View]] in a [[DataflowGraph]].
+ *
+ * @param identifier The identifier of this view within the graph.
+ * @param properties Properties of the view
+ * @param comment when defining a view
+ */
+case class PersistedView(
+ identifier: TableIdentifier,
+ properties: Map[String, String],
+ comment: Option[String],
+ origin: QueryOrigin
+) extends View {}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala
new file mode 100644
index 0000000000000..070927aea295f
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.util
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions
+
+/**
+ * Generic options for a read of an input.
+ */
+sealed trait InputReadOptions
+
+/**
+ * Options for a batch read of an input.
+ */
+final case class BatchReadOptions() extends InputReadOptions
+
+/**
+ * Options for a streaming read of an input.
+ *
+ * @param userOptions Holds the user defined read options.
+ * @param droppedUserOptions Holds the options that were specified by the user but
+ * not actually used. This is a bug but we are preserving this behavior
+ * for now to avoid making a backwards incompatible change.
+ */
+final case class StreamingReadOptions(
+ userOptions: CaseInsensitiveMap[String] = EmptyUserOptions,
+ droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions
+) extends InputReadOptions
+
+object StreamingReadOptions {
+ val EmptyUserOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map())
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala
new file mode 100644
index 0000000000000..4777772342d7d
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.util
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.connector.catalog.TableChange
+import org.apache.spark.sql.pipelines.common.DatasetType
+import org.apache.spark.sql.pipelines.graph.{GraphElementTypeUtils, GraphErrors, ResolvedFlow}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+object SchemaInferenceUtils {
+
+ /**
+ * Given a set of flows that write to the same destination and possibly a user-specified schema,
+ * we infer the schema of the destination dataset. The logic is as follows:
+ * 1. If there are no incoming flows, return the user-specified schema (if provided)
+ * or an empty schema.
+ * 2. If there are incoming flows, we merge the schemas of all flows that write to
+ * the same destination.
+ * 3. If a user-specified schema is provided, we merge it with the inferred schema.
+ * The user-specified schema will take precedence over the inferred schema.
+ * Returns an error if encountered during schema inference or merging the inferred schema with
+ * the user-specified one.
+ */
+ def inferSchemaFromFlows(
+ flows: Seq[ResolvedFlow],
+ userSpecifiedSchema: Option[StructType]): StructType = {
+ if (flows.isEmpty) {
+ return userSpecifiedSchema.getOrElse(new StructType())
+ }
+
+ require(
+ flows.forall(_.destinationIdentifier == flows.head.destinationIdentifier),
+ "Expected all flows to have the same destination"
+ )
+
+ val inferredSchema = flows.map(_.schema).fold(new StructType()) { (schemaSoFar, schema) =>
+ try {
+ SchemaMergingUtils.mergeSchemas(schemaSoFar, schema)
+ } catch {
+ case NonFatal(e) =>
+ throw GraphErrors.unableToInferSchemaError(
+ flows.head.destinationIdentifier,
+ schemaSoFar,
+ schema,
+ cause = Option(e)
+ )
+ }
+ }
+
+ val identifier = flows.head.destinationIdentifier
+ val datasetType = GraphElementTypeUtils.getDatasetTypeForMaterializedViewOrStreamingTable(flows)
+ // We merge the inferred schema with the user-specified schema to pick up any schema metadata
+ // that is provided by the user, e.g., comments or column masks.
+ mergeInferredAndUserSchemasIfNeeded(
+ identifier,
+ datasetType,
+ inferredSchema,
+ userSpecifiedSchema
+ )
+ }
+
+ private def mergeInferredAndUserSchemasIfNeeded(
+ tableIdentifier: TableIdentifier,
+ datasetType: DatasetType,
+ inferredSchema: StructType,
+ userSpecifiedSchema: Option[StructType]): StructType = {
+ userSpecifiedSchema match {
+ case Some(userSpecifiedSchema) =>
+ try {
+ // Merge the inferred schema with the user-provided schema hint
+ SchemaMergingUtils.mergeSchemas(userSpecifiedSchema, inferredSchema)
+ } catch {
+ case NonFatal(e) =>
+ throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError(
+ tableIdentifier,
+ datasetType,
+ userSpecifiedSchema,
+ inferredSchema,
+ cause = Option(e)
+ )
+ }
+ case None => inferredSchema
+ }
+ }
+
+ /**
+ * Determines the column changes needed to transform the current schema into the target schema.
+ *
+ * This function compares the current schema with the target schema and produces a sequence of
+ * TableChange objects representing:
+ * 1. New columns that need to be added
+ * 2. Existing columns that need type updates
+ *
+ * @param currentSchema The current schema of the table
+ * @param targetSchema The target schema that we want the table to have
+ * @return A sequence of TableChange objects representing the necessary changes
+ */
+ def diffSchemas(currentSchema: StructType, targetSchema: StructType): Seq[TableChange] = {
+ val changes = scala.collection.mutable.ArrayBuffer.empty[TableChange]
+
+ // Helper function to get a map of field name to field
+ def getFieldMap(schema: StructType): Map[String, StructField] = {
+ schema.fields.map(field => field.name -> field).toMap
+ }
+
+ val currentFields = getFieldMap(currentSchema)
+ val targetFields = getFieldMap(targetSchema)
+
+ // Find columns to add (in target but not in current)
+ val columnsToAdd = targetFields.keySet.diff(currentFields.keySet)
+ columnsToAdd.foreach { columnName =>
+ val field = targetFields(columnName)
+ changes += TableChange.addColumn(
+ Array(columnName),
+ field.dataType,
+ field.nullable,
+ field.getComment().orNull
+ )
+ }
+
+ // Find columns to delete (in current but not in target)
+ val columnsToDelete = currentFields.keySet.diff(targetFields.keySet)
+ columnsToDelete.foreach { columnName =>
+ changes += TableChange.deleteColumn(Array(columnName), false)
+ }
+
+ // Find columns with type changes (in both but with different types)
+ val commonColumns = currentFields.keySet.intersect(targetFields.keySet)
+ commonColumns.foreach { columnName =>
+ val currentField = currentFields(columnName)
+ val targetField = targetFields(columnName)
+
+ // If data types are different, add a type update change
+ if (currentField.dataType != targetField.dataType) {
+ changes += TableChange.updateColumnType(Array(columnName), targetField.dataType)
+ }
+
+ // If nullability is different, add a nullability update change
+ if (currentField.nullable != targetField.nullable) {
+ changes += TableChange.updateColumnNullability(Array(columnName), targetField.nullable)
+ }
+
+ // If comments are different, add a comment update change
+ val currentComment = currentField.getComment().orNull
+ val targetComment = targetField.getComment().orNull
+ if (currentComment != targetComment) {
+ changes += TableChange.updateColumnComment(Array(columnName), targetComment)
+ }
+ }
+
+ changes.toSeq
+ }
+}
diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala
new file mode 100644
index 0000000000000..d15e7ac6425cc
--- /dev/null
+++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.util
+
+import org.apache.spark.sql.types.StructType
+
+object SchemaMergingUtils {
+ def mergeSchemas(tableSchema: StructType, dataSchema: StructType): StructType = {
+ StructType.merge(tableSchema, dataSchema).asInstanceOf[StructType]
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
new file mode 100644
index 0000000000000..01b2a91bb9329
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+/**
+ * Test suite for resolving the flows in a [[DataflowGraph]]. These
+ * examples are all semantically correct but contain logical errors which should be found
+ * when connect is called and thrown when validate() is called.
+ */
+class ConnectInvalidPipelineSuite extends PipelineTest {
+
+ import originalSpark.implicits._
+ test("Missing source") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("b", query = readFlowFunc("a"))
+ }
+
+ val dfg = new P().resolveToDataflowGraph()
+ assert(!dfg.resolved, "Pipeline should not have resolved properly")
+ val ex = intercept[UnresolvedPipelineException] {
+ dfg.validate()
+ }
+ assert(ex.getMessage.contains("Failed to resolve flows in the pipeline"))
+ assertAnalysisException(
+ ex.directFailures(fullyQualifiedIdentifier("b", isView = true)),
+ "TABLE_OR_VIEW_NOT_FOUND"
+ )
+ }
+
+ test("Correctly differentiate between upstream and downstream errors") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(spark.range(5).toDF()))
+ registerView("b", query = readFlowFunc("nonExistentFlow"))
+ registerView("c", query = readFlowFunc("b"))
+ registerView("d", query = dfFlowFunc(spark.range(5).toDF()))
+ registerView("e", query = sqlFlowFunc(spark, "SELECT nonExistentColumn FROM RANGE(5)"))
+ registerView("f", query = readFlowFunc("e"))
+ }
+
+ val dfg = new P().resolveToDataflowGraph()
+ assert(!dfg.resolved, "Pipeline should not have resolved properly")
+ val ex = intercept[UnresolvedPipelineException] {
+ dfg.validate()
+ }
+ assert(ex.getMessage.contains("Failed to resolve flows in the pipeline"))
+ assert(
+ ex.getMessage.contains(
+ s"Flows with errors: " +
+ s"${fullyQualifiedIdentifier("b", isView = true).unquotedString}," +
+ s" ${fullyQualifiedIdentifier("e", isView = true).unquotedString}"
+ )
+ )
+ assert(
+ ex.getMessage.contains(
+ s"Flows that failed due to upstream errors: " +
+ s"${fullyQualifiedIdentifier("c", isView = true).unquotedString}, " +
+ s"${fullyQualifiedIdentifier("f", isView = true).unquotedString}"
+ )
+ )
+ assert(
+ ex.directFailures.keySet == Set(
+ fullyQualifiedIdentifier("b", isView = true),
+ fullyQualifiedIdentifier("e", isView = true)
+ )
+ )
+ assert(
+ ex.downstreamFailures.keySet == Set(
+ fullyQualifiedIdentifier("c", isView = true),
+ fullyQualifiedIdentifier("f", isView = true)
+ )
+ )
+ assertAnalysisException(
+ ex.directFailures(fullyQualifiedIdentifier("b", isView = true)),
+ "TABLE_OR_VIEW_NOT_FOUND"
+ )
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("e", isView = true))
+ .isInstanceOf[AnalysisException]
+ )
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("e", isView = true))
+ .getMessage
+ .contains("nonExistentColumn")
+ )
+ assert(
+ ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true))
+ .isInstanceOf[UnresolvedDatasetException]
+ )
+ assert(
+ ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true))
+ .getMessage
+ .contains(
+ s"Failed to read dataset " +
+ s"'${fullyQualifiedIdentifier("b", isView = true).unquotedString}'. " +
+ s"Dataset is defined in the pipeline but could not be resolved"
+ )
+ )
+ assert(
+ ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true))
+ .isInstanceOf[UnresolvedDatasetException]
+ )
+ assert(
+ ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true))
+ .getMessage
+ .contains(
+ s"Failed to read dataset " +
+ s"'${fullyQualifiedIdentifier("e", isView = true).unquotedString}'. " +
+ s"Dataset is defined in the pipeline but could not be resolved"
+ )
+ )
+ }
+
+ test("correctly identify direct and downstream errors for multi-flow pipelines") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerTable("a")
+ registerFlow("a", "a", dfFlowFunc(spark.range(5).toDF()))
+ registerFlow("a", "a_2", sqlFlowFunc(spark, "SELECT non_existent_col FROM RANGE(5)"))
+ registerTable("b", query = Option(readFlowFunc("a")))
+ }
+ val ex = intercept[UnresolvedPipelineException] { new P().resolveToDataflowGraph().validate() }
+ assert(ex.directFailures.keySet == Set(fullyQualifiedIdentifier("a_2")))
+ assert(ex.downstreamFailures.keySet == Set(fullyQualifiedIdentifier("b")))
+
+ }
+
+ test("Missing attribute in the schema") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("z")))
+ registerView("b", query = sqlFlowFunc(spark, "SELECT x FROM a"))
+ }
+
+ val dfg = new P().resolveToDataflowGraph()
+ val ex = intercept[UnresolvedPipelineException] {
+ dfg.validate()
+ }.directFailures(fullyQualifiedIdentifier("b", isView = true)).getMessage
+ verifyUnresolveColumnError(ex, "x", Seq("z"))
+ }
+
+ test("Joining on a column with different names") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = dfFlowFunc(Seq("a", "b", "c").toDF("y")))
+ registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)"))
+ }
+
+ val dfg = new P().resolveToDataflowGraph()
+ val ex = intercept[UnresolvedPipelineException] {
+ dfg.validate()
+ }
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("c", isView = true))
+ .getMessage
+ .contains("USING column `x` cannot be resolved on the right side")
+ )
+ }
+
+ test("Writing to one table by unioning flows with different schemas") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = dfFlowFunc(Seq(true, false).toDF("x")))
+ registerView("c", query = sqlFlowFunc(spark, "SELECT x FROM a UNION SELECT x FROM b"))
+ }
+
+ val dfg = new P().resolveToDataflowGraph()
+ assert(!dfg.resolved)
+ val ex = intercept[UnresolvedPipelineException] {
+ dfg.validate()
+ }
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("c", isView = true))
+ .getMessage
+ .contains("compatible column types") ||
+ ex.directFailures(fullyQualifiedIdentifier("c", isView = true))
+ .getMessage
+ .contains("Failed to merge incompatible data types")
+ )
+ }
+
+ test("Self reference") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = readFlowFunc("a"))
+ }
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ assert(e.upstreamDataset == fullyQualifiedIdentifier("a", isView = true))
+ assert(e.downstreamTable == fullyQualifiedIdentifier("a", isView = true))
+ }
+
+ test("Cyclic graph - simple") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = readFlowFunc("b"))
+ registerView("b", query = readFlowFunc("a"))
+ }
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ val cycle = Set(
+ fullyQualifiedIdentifier("a", isView = true),
+ fullyQualifiedIdentifier("b", isView = true)
+ )
+ assert(e.upstreamDataset != e.downstreamTable)
+ assert(cycle.contains(e.upstreamDataset))
+ assert(cycle.contains(e.downstreamTable))
+ }
+
+ test("Cyclic graph") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d"))
+ registerView("c", query = readFlowFunc("b"))
+ registerView("d", query = readFlowFunc("c"))
+ }
+ val cycle =
+ Set(
+ fullyQualifiedIdentifier("b", isView = true),
+ fullyQualifiedIdentifier("c", isView = true),
+ fullyQualifiedIdentifier("d", isView = true)
+ )
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ assert(e.upstreamDataset != e.downstreamTable)
+ assert(cycle.contains(e.upstreamDataset))
+ assert(cycle.contains(e.downstreamTable))
+ }
+
+ test("Cyclic graph with materialized nodes") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))))
+ registerTable(
+ "b",
+ query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d"))
+ )
+ registerTable("c", query = Option(readFlowFunc("b")))
+ registerTable("d", query = Option(readFlowFunc("c")))
+ }
+ val cycle =
+ Set(
+ fullyQualifiedIdentifier("b"),
+ fullyQualifiedIdentifier("c"),
+ fullyQualifiedIdentifier("d")
+ )
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ assert(e.upstreamDataset != e.downstreamTable)
+ assert(cycle.contains(e.upstreamDataset))
+ assert(cycle.contains(e.downstreamTable))
+ }
+
+ test("Cyclic graph - second query makes it cyclic") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))))
+ registerTable("b")
+ registerFlow("b", "b", readFlowFunc("a"))
+ registerFlow("b", "b2", readFlowFunc("d"))
+ registerTable("c", query = Option(readFlowFunc("b")))
+ registerTable("d", query = Option(readFlowFunc("c")))
+ }
+ val cycle =
+ Set(
+ fullyQualifiedIdentifier("b"),
+ fullyQualifiedIdentifier("c"),
+ fullyQualifiedIdentifier("d")
+ )
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ assert(e.upstreamDataset != e.downstreamTable)
+ assert(cycle.contains(e.upstreamDataset))
+ assert(cycle.contains(e.downstreamTable))
+ }
+
+ test("Cyclic graph - all named queries") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))))
+ registerTable("b")
+ registerFlow("b", "`b-name`", sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d"))
+ registerTable("c")
+ registerFlow("c", "`c-name`", readFlowFunc("b"))
+ registerTable("d")
+ registerFlow("d", "`d-name`", readFlowFunc("c"))
+ }
+ val cycle =
+ Set(
+ fullyQualifiedIdentifier("b"),
+ fullyQualifiedIdentifier("c"),
+ fullyQualifiedIdentifier("d")
+ )
+ val e = intercept[CircularDependencyException] {
+ new P().resolveToDataflowGraph().validate()
+ }
+ assert(e.upstreamDataset != e.downstreamTable)
+ assert(cycle.contains(e.upstreamDataset))
+ assert(cycle.contains(e.downstreamTable))
+ }
+
+ test("view-table conf conflict") {
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val"))
+ registerTable("b", query = Option(readFlowFunc("a")), sqlConf = Map("x" -> "b-val"))
+ }
+ val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() }
+ assert(
+ ex.getMessage.contains(
+ s"Found duplicate sql conf for dataset " +
+ s"'${fullyQualifiedIdentifier("b").unquotedString}':"
+ )
+ )
+ assert(
+ ex.getMessage.contains(
+ s"'x' is defined by both " +
+ s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " +
+ s"and '${fullyQualifiedIdentifier("b").unquotedString}'"
+ )
+ )
+ }
+
+ test("view-view conf conflict") {
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val"))
+ registerView("b", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "b-val"))
+ registerTable(
+ "c",
+ query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")),
+ sqlConf = Map("y" -> "c-val")
+ )
+ }
+ val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() }
+ assert(
+ ex.getMessage.contains(
+ s"Found duplicate sql conf for dataset " +
+ s"'${fullyQualifiedIdentifier("c").unquotedString}':"
+ )
+ )
+ assert(
+ ex.getMessage.contains(
+ s"'x' is defined by both " +
+ s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " +
+ s"and '${fullyQualifiedIdentifier("b", isView = true).unquotedString}'"
+ )
+ )
+ }
+
+ test("reading a complete view incrementally") {
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1).toDF()))
+ registerTable("b", query = Option(readStreamFlowFunc("a")))
+ }
+ val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() }
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("b"))
+ .getMessage
+ .contains(
+ s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString}" +
+ s" is a batch view and must be referenced using SparkSession#read."
+ )
+ )
+ }
+
+ test("reading an incremental view completely") {
+ val p = new TestGraphRegistrationContext(spark) {
+ val mem = MemoryStream[Int]
+ mem.addData(1)
+ registerView("a", query = dfFlowFunc(mem.toDF()))
+ registerTable("b", query = Option(readFlowFunc("a")))
+ }
+ val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() }
+ assert(
+ ex.directFailures(fullyQualifiedIdentifier("b"))
+ .getMessage
+ .contains(
+ s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString} " +
+ s"is a streaming view and must be referenced using SparkSession#readStream"
+ )
+ )
+ }
+
+ test("Inferred schema that isn't a subset of user-specified schema") {
+ val graph1 = new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "a",
+ query = Option(dfFlowFunc(Seq(1, 2).toDF("incorrect-col-name"))),
+ specifiedSchema = Option(new StructType().add("x", IntegerType))
+ )
+ }.resolveToDataflowGraph()
+ val ex1 = intercept[AnalysisException] { graph1.validate() }
+ assert(
+ ex1.getMessage.contains(
+ s"'${fullyQualifiedIdentifier("a").unquotedString}' " +
+ s"has a user-specified schema that is incompatible"
+ )
+ )
+ assert(ex1.getMessage.contains("incorrect-col-name"))
+
+ val graph2 = new TestGraphRegistrationContext(spark) {
+ registerTable("a", specifiedSchema = Option(new StructType().add("x", IntegerType)))
+ registerFlow("a", "a", query = dfFlowFunc(Seq(true, false).toDF("x")), once = true)
+ }.resolveToDataflowGraph()
+ val ex2 = intercept[AnalysisException] { graph2.validate() }
+ assert(
+ ex2.getMessage.contains(
+ s"'${fullyQualifiedIdentifier("a").unquotedString}' " +
+ s"has a user-specified schema that is incompatible"
+ )
+ )
+ assert(ex2.getMessage.contains("boolean") && ex2.getMessage.contains("integer"))
+
+ val streamingTableHint = "please full refresh"
+ assert(!ex1.getMessage.contains(streamingTableHint))
+ assert(ex2.getMessage.contains(streamingTableHint))
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
new file mode 100644
index 0000000000000..f8b5133ff167d
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
@@ -0,0 +1,455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical.Union
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Test suite for resolving the flows in a [[DataflowGraph]]. These
+ * examples are all semantically correct and logically correct and connect should not result in any
+ * errors.
+ */
+class ConnectValidPipelineSuite extends PipelineTest {
+
+ import originalSpark.implicits._
+
+ test("Extra simple") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("y")))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val outSchema = new StructType().add("y", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema)
+ }
+
+ test("Simple") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ verifyFlowSchema(
+ p,
+ fullyQualifiedIdentifier("a", isView = true),
+ new StructType().add("x", IntegerType, false)
+ )
+ val outSchema = new StructType().add("y", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema)
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("b", isView = true)).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Dependencies") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("c", query = sqlFlowFunc(spark, "SELECT y as z FROM b"))
+ registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schemaAB = new StructType().add("y", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaAB)
+ val schemaBC = new StructType().add("z", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaBC)
+ }
+
+ test("Multi-hop schema merging") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView(
+ "b",
+ query = sqlFlowFunc(spark, """SELECT * FROM VALUES ((1)) OUTER JOIN d ON false""")
+ )
+ registerView("e", query = readFlowFunc("b"))
+ registerView("d", query = dfFlowFunc(Seq(1).toDF("y")))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schemaE = new StructType().add("col1", IntegerType, false).add("y", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaE)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("e", isView = true), schemaE)
+ }
+
+ test("Cross product join merges schema") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y")))
+ registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a CROSS JOIN b"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schemaC = new StructType().add("x", IntegerType, false).add("y", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC)
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true),
+ fullyQualifiedIdentifier("b", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Real join merges schema") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq((1, "a"), (2, "b"), (3, "c")).toDF("x", "y")))
+ registerView("b", query = dfFlowFunc(Seq((2, "m"), (3, "n"), (4, "o")).toDF("x", "z")))
+ registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schemaC = new StructType()
+ .add("x", IntegerType, false)
+ .add("y", StringType)
+ .add("z", StringType)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC)
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true),
+ fullyQualifiedIdentifier("b", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Union of streaming and batch Dataframes") {
+ class P extends TestGraphRegistrationContext(spark) {
+ val ints = MemoryStream[Int]
+ ints.addData(1, 2, 3, 4)
+ registerView("a", query = dfFlowFunc(ints.toDF()))
+ registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF()))
+ registerView(
+ "c",
+ query = FlowAnalysis.createFlowFunctionFromLogicalPlan(
+ Union(
+ Seq(
+ UnresolvedRelation(
+ TableIdentifier("a"),
+ extraOptions = CaseInsensitiveStringMap.empty(),
+ isStreaming = true
+ ),
+ UnresolvedRelation(TableIdentifier("b"))
+ )
+ )
+ )
+ )
+ }
+
+ val p = new P().resolveToDataflowGraph()
+ verifyFlowSchema(
+ p,
+ fullyQualifiedIdentifier("c", isView = true),
+ new StructType().add("value", IntegerType, false)
+ )
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true),
+ fullyQualifiedIdentifier("b", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Union of two streaming Dataframes") {
+ class P extends TestGraphRegistrationContext(spark) {
+ val ints1 = MemoryStream[Int]
+ ints1.addData(1, 2, 3, 4)
+ val ints2 = MemoryStream[Int]
+ ints2.addData(1, 2, 3, 4)
+ registerView("a", query = dfFlowFunc(ints1.toDF()))
+ registerView("b", query = dfFlowFunc(ints2.toDF()))
+ registerView(
+ "c",
+ query = FlowAnalysis.createFlowFunctionFromLogicalPlan(
+ Union(
+ Seq(
+ UnresolvedRelation(
+ TableIdentifier("a"),
+ extraOptions = CaseInsensitiveStringMap.empty(),
+ isStreaming = true
+ ),
+ UnresolvedRelation(
+ TableIdentifier("b"),
+ extraOptions = CaseInsensitiveStringMap.empty(),
+ isStreaming = true
+ )
+ )
+ )
+ )
+ )
+ }
+
+ val p = new P().resolveToDataflowGraph()
+ verifyFlowSchema(
+ p,
+ fullyQualifiedIdentifier("c", isView = true),
+ new StructType().add("value", IntegerType, false)
+ )
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true),
+ fullyQualifiedIdentifier("b", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("MultipleInputs") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y")))
+ registerView(
+ "c",
+ query = sqlFlowFunc(spark, "SELECT x AS z FROM a UNION SELECT y AS z FROM b")
+ )
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schema = new StructType().add("z", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schema)
+ }
+
+ test("Connect retains and fuses confs") {
+ // a -> b \
+ // d
+ // c /
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val"))
+ registerView("b", query = readFlowFunc("a"), Map("b" -> "b-val"))
+ registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), Map("c" -> "c-val"))
+ registerTable(
+ "d",
+ query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")),
+ Map("d" -> "d-val")
+ )
+ }
+ val graph = p.resolveToDataflowGraph()
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("d", isView = false))
+ .sqlConf == Map("a" -> "a-val", "b" -> "b-val", "c" -> "c-val", "d" -> "d-val")
+ )
+ }
+
+ test("Confs aren't fused past materialization points") {
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val"))
+ registerTable("b", query = Option(readFlowFunc("a")), Map("b" -> "b-val"))
+ registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), sqlConf = Map("c" -> "c-val"))
+ registerTable(
+ "d",
+ query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")),
+ Map("d" -> "d-val")
+ )
+ }
+ val graph = p.resolveToDataflowGraph()
+ assert(graph.flow(fullyQualifiedIdentifier("a", isView = true)).sqlConf == Map("a" -> "a-val"))
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("b"))
+ .sqlConf == Map("a" -> "a-val", "b" -> "b-val")
+ )
+ assert(graph.flow(fullyQualifiedIdentifier("c", isView = true)).sqlConf == Map("c" -> "c-val"))
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("d"))
+ .sqlConf == Map("c" -> "c-val", "d" -> "d-val")
+ )
+ }
+
+ test("Setting the same conf with the same value is totally cool") {
+ val p = new TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val"))
+ registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val"))
+ registerTable(
+ "c",
+ query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")),
+ Map("key" -> "val")
+ )
+ }
+ val graph = p.resolveToDataflowGraph()
+ assert(graph.flow(fullyQualifiedIdentifier("c")).sqlConf == Map("key" -> "val"))
+ }
+
+ test("Named query only") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
+ registerTable("b")
+ registerFlow("b", "`b-query`", readFlowFunc("a"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schema = new StructType().add("x", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b-query"), schema)
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("b-query")).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Default query and named query") {
+ class P extends TestGraphRegistrationContext(spark) {
+ val mem = MemoryStream[Int]
+ registerView("a", query = dfFlowFunc(mem.toDF()))
+ registerTable("b")
+ registerFlow("b", "b", dfFlowFunc(mem.toDF().select($"value" as "y")))
+ registerFlow("b", "b2", readStreamFlowFunc("a"))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schema = new StructType().add("value", IntegerType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("b2"), schema)
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("b2")).inputs == Set(
+ fullyQualifiedIdentifier("a", isView = true)
+ ),
+ "Flow did not have the expected inputs"
+ )
+ verifyFlowSchema(
+ p,
+ fullyQualifiedIdentifier("b"),
+ new StructType().add("y", IntegerType, false)
+ )
+ assert(
+ p.resolvedFlow(fullyQualifiedIdentifier("b")).inputs == Set.empty,
+ "Flow did not have the expected inputs"
+ )
+ }
+
+ test("Multi-query table with 2 complete queries") {
+ class P extends TestGraphRegistrationContext(spark) {
+ registerTable("a")
+ registerFlow("a", "a", query = dfFlowFunc(spark.range(5).toDF()))
+ registerFlow("a", "a2", query = dfFlowFunc(spark.range(6).toDF()))
+ }
+ val p = new P().resolveToDataflowGraph()
+ val schema = new StructType().add("id", LongType, false)
+ verifyFlowSchema(p, fullyQualifiedIdentifier("a"), schema)
+ }
+
+ test("Correct types of flows after connection") {
+ val graph = new TestGraphRegistrationContext(spark) {
+ val mem = MemoryStream[Int]
+ mem.addData(1, 2)
+ registerView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x")))
+ registerView("incremental-view", query = dfFlowFunc(mem.toDF()))
+ registerTable("`complete-table`", query = Option(readFlowFunc("complete-view")))
+ registerTable("`incremental-table`")
+ registerFlow(
+ "`incremental-table`",
+ "`incremental-table`",
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(
+ UnresolvedRelation(
+ TableIdentifier("incremental-view"),
+ extraOptions = CaseInsensitiveStringMap.empty(),
+ isStreaming = true
+ )
+ )
+ )
+ registerFlow(
+ "`incremental-table`",
+ "`append-once`",
+ dfFlowFunc(Seq(1, 2).toDF("x")),
+ once = true
+ )
+ }.resolveToDataflowGraph()
+
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("complete-view", isView = true))
+ .isInstanceOf[CompleteFlow]
+ )
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("incremental-view", isView = true))
+ .isInstanceOf[StreamingFlow]
+ )
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("complete-table"))
+ .isInstanceOf[CompleteFlow]
+ )
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("incremental-table"))
+ .isInstanceOf[StreamingFlow]
+ )
+ assert(
+ graph
+ .flow(fullyQualifiedIdentifier("append-once"))
+ .isInstanceOf[AppendOnceFlow]
+ )
+ }
+
+ test("Pipeline level default spark confs are applied with correct precedence") {
+ val P = new TestGraphRegistrationContext(
+ spark,
+ Map("default.conf" -> "value")
+ ) {
+ registerTable(
+ "a",
+ query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))),
+ sqlConf = Map("other.conf" -> "value")
+ )
+ registerTable(
+ "b",
+ query = Option(sqlFlowFunc(spark, "SELECT x as y FROM a")),
+ sqlConf = Map("default.conf" -> "other-value")
+ )
+ }
+ val p = P.resolveToDataflowGraph()
+
+ assert(
+ p.flow(fullyQualifiedIdentifier("a")).sqlConf == Map(
+ "default.conf" -> "value",
+ "other.conf" -> "value"
+ )
+ )
+
+ assert(
+ p.flow(fullyQualifiedIdentifier("b")).sqlConf == Map(
+ "default.conf" -> "other-value"
+ )
+ )
+ }
+
+ /** Verifies the [[DataflowGraph]] has the specified [[Flow]] with the specified schema. */
+ private def verifyFlowSchema(
+ pipeline: DataflowGraph,
+ identifier: TableIdentifier,
+ expected: StructType): Unit = {
+ assert(
+ pipeline.flow.contains(identifier),
+ s"Flow ${identifier.unquotedString} not found," +
+ s" all flow names: ${pipeline.flow.keys.map(_.unquotedString)}"
+ )
+ assert(
+ pipeline.resolvedFlow.contains(identifier),
+ s"Flow ${identifier.unquotedString} has not been resolved"
+ )
+ assert(
+ pipeline.resolvedFlow(identifier).schema == expected,
+ s"Flow ${identifier.unquotedString} has the wrong schema"
+ )
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala
new file mode 100644
index 0000000000000..41d5bbe14a6b1
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.connector.catalog.TableChange
+import org.apache.spark.sql.types._
+
+class SchemaInferenceUtilsSuite extends SparkFunSuite {
+
+ test("determineColumnChanges - adding new columns") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+
+ val targetSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+ .add("age", IntegerType)
+ .add("email", StringType, nullable = true, "Email address")
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have 2 changes - adding 'age' and 'email' columns
+ assert(changes.length === 2)
+
+ // Verify the changes are of the correct type and have the right properties
+ val ageChange = changes
+ .find {
+ case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("age"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.AddColumn]
+
+ val emailChange = changes
+ .find {
+ case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("email"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.AddColumn]
+
+ // Verify age column properties
+ assert(ageChange.dataType() === IntegerType)
+ assert(ageChange.isNullable() === true) // Default nullable is true
+ assert(ageChange.comment() === null)
+
+ // Verify email column properties
+ assert(emailChange.dataType() === StringType)
+ assert(emailChange.isNullable() === true)
+ assert(emailChange.comment() === "Email address")
+ }
+
+ test("determineColumnChanges - updating column types") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("amount", DoubleType)
+ .add("timestamp", TimestampType)
+
+ val targetSchema = new StructType()
+ .add("id", LongType, nullable = false) // Changed type from Int to Long
+ .add("amount", DecimalType(10, 2)) // Changed type from Double to Decimal
+ .add("timestamp", TimestampType) // No change
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have 2 changes - updating 'id' and 'amount' column types
+ assert(changes.length === 2)
+
+ // Verify the changes are of the correct type
+ val idChange = changes
+ .find {
+ case update: TableChange.UpdateColumnType => update.fieldNames().sameElements(Array("id"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.UpdateColumnType]
+
+ val amountChange = changes
+ .find {
+ case update: TableChange.UpdateColumnType =>
+ update.fieldNames().sameElements(Array("amount"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.UpdateColumnType]
+
+ // Verify the new data types
+ assert(idChange.newDataType() === LongType)
+ assert(amountChange.newDataType() === DecimalType(10, 2))
+ }
+
+ test("determineColumnChanges - updating nullability and comments") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType, nullable = true)
+ .add("description", StringType, nullable = true, "Item description")
+
+ val targetSchema = new StructType()
+ .add("id", IntegerType, nullable = true) // Changed nullability
+ .add("name", StringType, nullable = false) // Changed nullability
+ .add("description", StringType, nullable = true, "Product description") // Changed comment
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have 3 changes - updating nullability for 'id' and 'name', and comment for
+ // 'description'
+ assert(changes.length === 3)
+
+ // Verify the nullability changes
+ val idNullabilityChange = changes
+ .find {
+ case update: TableChange.UpdateColumnNullability =>
+ update.fieldNames().sameElements(Array("id"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.UpdateColumnNullability]
+
+ val nameNullabilityChange = changes
+ .find {
+ case update: TableChange.UpdateColumnNullability =>
+ update.fieldNames().sameElements(Array("name"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.UpdateColumnNullability]
+
+ // Verify the comment change
+ val descriptionCommentChange = changes
+ .find {
+ case update: TableChange.UpdateColumnComment =>
+ update.fieldNames().sameElements(Array("description"))
+ case _ => false
+ }
+ .get
+ .asInstanceOf[TableChange.UpdateColumnComment]
+
+ // Verify the new nullability values
+ assert(idNullabilityChange.nullable() === true)
+ assert(nameNullabilityChange.nullable() === false)
+
+ // Verify the new comment
+ assert(descriptionCommentChange.newComment() === "Product description")
+ }
+
+ test("determineColumnChanges - complex changes") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+ .add("old_field", BooleanType)
+
+ val targetSchema = new StructType()
+ .add("id", LongType, nullable = true) // Changed type and nullability
+ // Added comment and changed nullability
+ .add("name", StringType, nullable = false, "Full name")
+ .add("new_field", StringType) // New field
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have these changes:
+ // 1. Update id type
+ // 2. Update id nullability
+ // 3. Update name nullability
+ // 4. Update name comment
+ // 5. Add new_field
+ // 6. Remove old_field
+ assert(changes.length === 6)
+
+ // Count the types of changes
+ val typeChanges = changes.collect { case _: TableChange.UpdateColumnType => 1 }.size
+ val nullabilityChanges = changes.collect {
+ case _: TableChange.UpdateColumnNullability => 1
+ }.size
+ val commentChanges = changes.collect { case _: TableChange.UpdateColumnComment => 1 }.size
+ val addColumnChanges = changes.collect { case _: TableChange.AddColumn => 1 }.size
+
+ assert(typeChanges === 1)
+ assert(nullabilityChanges === 2)
+ assert(commentChanges === 1)
+ assert(addColumnChanges === 1)
+ }
+
+ test("determineColumnChanges - no changes") {
+ val schema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+ .add("timestamp", TimestampType)
+
+ // Same schema, no changes expected
+ val changes = SchemaInferenceUtils.diffSchemas(schema, schema)
+ assert(changes.isEmpty)
+ }
+
+ test("determineColumnChanges - deleting columns") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+ .add("age", IntegerType)
+ .add("email", StringType)
+ .add("phone", StringType)
+
+ val targetSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType)
+ // age, email, and phone columns are removed
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have 3 changes - deleting 'age', 'email', and 'phone' columns
+ assert(changes.length === 3)
+
+ // Verify all changes are DeleteColumn operations
+ val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc }
+ assert(deleteChanges.length === 3)
+
+ // Verify the specific columns being deleted
+ val columnNames = deleteChanges.map(_.fieldNames()(0)).toSet
+ assert(columnNames === Set("age", "email", "phone"))
+ }
+
+ test("determineColumnChanges - mixed additions and deletions") {
+ val currentSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("first_name", StringType)
+ .add("last_name", StringType)
+ .add("age", IntegerType)
+
+ val targetSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("full_name", StringType) // New column
+ .add("email", StringType) // New column
+ .add("age", IntegerType) // Unchanged
+ // first_name and last_name are removed
+
+ val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema)
+
+ // Should have 4 changes:
+ // - 2 additions (full_name, email)
+ // - 2 deletions (first_name, last_name)
+ assert(changes.length === 4)
+
+ // Count the types of changes
+ val addChanges = changes.collect { case ac: TableChange.AddColumn => ac }
+ val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc }
+
+ assert(addChanges.length === 2)
+ assert(deleteChanges.length === 2)
+
+ // Verify the specific columns being added and deleted
+ val addedColumnNames = addChanges.map(_.fieldNames()(0)).toSet
+ val deletedColumnNames = deleteChanges.map(_.fieldNames()(0)).toSet
+
+ assert(addedColumnNames === Set("full_name", "email"))
+ assert(deletedColumnNames === Set("first_name", "last_name"))
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
new file mode 100644
index 0000000000000..981fa3cdcae85
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.utils
+
+import java.io.{BufferedReader, FileNotFoundException, InputStreamReader}
+import java.nio.file.Files
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Try}
+import scala.util.control.NonFatal
+
+import org.scalactic.source
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag}
+import org.scalatest.matchers.should.Matchers
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Column, QueryTest, Row, TypedColumn}
+import org.apache.spark.sql.SparkSession.{clearActiveSession, setActiveSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession, SQLContext}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, createTempDir}
+
+abstract class PipelineTest
+ extends SparkFunSuite
+ with BeforeAndAfterAll
+ with BeforeAndAfterEach
+ with Matchers
+ with SparkErrorTestMixin
+ with TargetCatalogAndDatabaseMixin
+ with Logging {
+
+ final protected val storageRoot = createTempDir()
+
+ var spark: SparkSession = createAndInitializeSpark()
+ val originalSpark: SparkSession = spark.cloneSession()
+
+ implicit def sqlContext: SQLContext = spark.sqlContext
+ def sql(text: String): DataFrame = spark.sql(text)
+
+ /**
+ * Spark confs for [[originalSpark]]. Spark confs set here will be the default spark confs for
+ * all spark sessions created in tests.
+ */
+ protected def sparkConf: SparkConf = {
+ new SparkConf()
+ .set("spark.sql.shuffle.partitions", "2")
+ .set("spark.sql.session.timeZone", "UTC")
+ }
+
+ /** Returns the dataset name in the event log. */
+ protected def eventLogName(
+ name: String,
+ catalog: Option[String] = catalogInPipelineSpec,
+ database: Option[String] = databaseInPipelineSpec,
+ isView: Boolean = false
+ ): String = {
+ fullyQualifiedIdentifier(name, catalog, database, isView).unquotedString
+ }
+
+ /** Returns the fully qualified identifier. */
+ protected def fullyQualifiedIdentifier(
+ name: String,
+ catalog: Option[String] = catalogInPipelineSpec,
+ database: Option[String] = databaseInPipelineSpec,
+ isView: Boolean = false
+ ): TableIdentifier = {
+ if (isView) {
+ TableIdentifier(name)
+ } else {
+ TableIdentifier(
+ catalog = catalog,
+ database = database,
+ table = name
+ )
+ }
+ }
+
+ /**
+ * This exists temporarily for compatibility with tests that become invalid when multiple
+ * executors are available.
+ */
+ protected def master = "local[*]"
+
+ /** Creates and returns a initialized spark session. */
+ def createAndInitializeSpark(): SparkSession = {
+ val newSparkSession = SparkSession
+ .builder()
+ .config(sparkConf)
+ .master(master)
+ .getOrCreate()
+ newSparkSession
+ }
+
+ /** Set up the spark session before each test. */
+ protected def initializeSparkBeforeEachTest(): Unit = {
+ clearActiveSession()
+ spark = originalSpark.newSession()
+ setActiveSession(spark)
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ initializeSparkBeforeEachTest()
+ cleanupMetastore(spark)
+ (catalogInPipelineSpec, databaseInPipelineSpec) match {
+ case (Some(catalog), Some(schema)) =>
+ sql(s"CREATE DATABASE IF NOT EXISTS `$catalog`.`$schema`")
+ case _ =>
+ databaseInPipelineSpec.foreach(s => sql(s"CREATE DATABASE IF NOT EXISTS `$s`"))
+ }
+ }
+
+ override def afterEach(): Unit = {
+ cleanupMetastore(spark)
+ super.afterEach()
+ }
+
+ override def afterAll(): Unit = {
+ spark.stop()
+ }
+
+ protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])(
+ testFun: A => Unit): Unit = {
+ namedGridTest(testNamePrefix, testTags: _*)(params.map(a => a.toString -> a).toMap)(testFun)
+ }
+
+ override def test(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */ )(
+ implicit pos: source.Position): Unit = super.test(testName, testTags: _*) {
+ runWithInstrumentation(testFun)
+ }
+
+ /**
+ * Adds custom instrumentation for tests.
+ *
+ * This instrumentation runs after `beforeEach` and
+ * before `afterEach` which lets us instrument the state of a test and its environment
+ * after any setup and before any clean-up done for a test.
+ */
+ private def runWithInstrumentation(testFunc: => Any): Any = {
+ testFunc
+ }
+
+ /**
+ * Creates individual tests for all items in [[params]].
+ *
+ * The full test name will be " ( = )" where is one
+ * item in [[params]].
+ *
+ * @param testNamePrefix The test name prefix.
+ * @param paramName A descriptive name for the parameter.
+ * @param testTags Extra tags for the test.
+ * @param params The list of parameters for which to generate tests.
+ * @param testFun The actual test function. This function will be called with one argument of
+ * type [[A]].
+ * @tparam A The type of the params.
+ */
+ protected def gridTest[A](testNamePrefix: String, paramName: String, testTags: Tag*)(
+ params: Seq[A])(testFun: A => Unit): Unit =
+ namedGridTest(testNamePrefix, testTags: _*)(
+ params.map(a => s"$paramName = $a" -> a).toMap
+ )(testFun)
+
+ /**
+ * Specialized version of gridTest where the params are two boolean values - [[true]] and
+ * [[false]].
+ */
+ protected def booleanGridTest(testNamePrefix: String, paramName: String, testTags: Tag*)(
+ testFun: Boolean => Unit): Unit = {
+ gridTest(testNamePrefix, paramName, testTags: _*)(Seq(true, false))(testFun)
+ }
+
+ protected def namedGridTest[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])(
+ testFun: A => Unit): Unit = {
+ for (param <- params) {
+ test(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2))
+ }
+ }
+
+ protected def namedGridIgnore[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])(
+ testFun: A => Unit): Unit = {
+ for (param <- params) {
+ ignore(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2))
+ }
+ }
+
+ /** Loads a package resources as a Seq of lines. */
+ protected def loadResource(path: String): Seq[String] = {
+ val stream = Thread.currentThread.getContextClassLoader.getResourceAsStream(path)
+ if (stream == null) {
+ throw new FileNotFoundException(path)
+ }
+ val reader = new BufferedReader(new InputStreamReader(stream))
+ val data = new ArrayBuffer[String]
+ var line = reader.readLine()
+ while (line != null) {
+ data.append(line)
+ line = reader.readLine()
+ }
+ data.toSeq
+ }
+
+ private def checkAnswerAndPlan(
+ df: => DataFrame,
+ expectedAnswer: Seq[Row],
+ checkPlan: Option[SparkPlan => Unit]): Unit = {
+ QueryTest.checkAnswer(df, expectedAnswer)
+
+ // To help with test development, you can dump the plan to the log by passing
+ // `--test_env=DUMP_PLAN=true` to `bazel test`.
+ if (Option(System.getenv("DUMP_PLAN")).exists(s => java.lang.Boolean.valueOf(s))) {
+ log.info(s"Spark plan:\n${df.queryExecution.executedPlan}")
+ }
+ checkPlan.foreach(_.apply(df.queryExecution.executedPlan))
+ }
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ *
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ checkAnswerAndPlan(df, expectedAnswer, None)
+ }
+
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(df, Seq(expectedAnswer))
+ }
+
+ case class ValidationArgs(
+ ignoreFieldOrder: Boolean = false,
+ ignoreFieldCase: Boolean = false
+ )
+
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ */
+ protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ val result = getResult(ds)
+
+ if (!QueryTest.compare(result.toSeq, expectedAnswer)) {
+ fail(s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ """.stripMargin)
+ }
+ }
+
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer, after sort.
+ */
+ protected def checkDatasetUnorderly[T: Ordering](result: Array[T], expectedAnswer: T*): Unit = {
+ if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) {
+ fail(s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ """.stripMargin)
+ }
+ }
+
+ protected def checkDatasetUnorderly[T: Ordering](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ val result = getResult(ds)
+ if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) {
+ fail(s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ """.stripMargin)
+ }
+ }
+
+ private def getResult[T](ds: => Dataset[T]): Array[T] = {
+ ds
+
+ try ds.collect()
+ catch {
+ case NonFatal(e) =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.queryExecution}
+ """.stripMargin,
+ e
+ )
+ }
+ }
+
+ /** Holds a parsed version along with the original json of a test. */
+ private case class TestSequence(json: Seq[String], rows: Seq[Row]) {
+ require(json.size == rows.size)
+ }
+
+ /**
+ * Helper method to verify unresolved column error message. We expect three elements to be present
+ * in the message: error class, unresolved column name, list of suggested columns. There are three
+ * significant differences between different versions of DBR:
+ * - Error class changed in DBR 11.3 from `MISSING_COLUMN` to `UNRESOLVED_COLUMN.WITH_SUGGESTION`
+ * - Name parts in suggested columns are escaped with backticks starting from DBR 11.3,
+ * e.g. table.column => `table`.`column`
+ * - Starting from DBR 13.1 suggested columns qualification matches unresolved column, i.e. if
+ * unresolved column is a single-part identifier then suggested column will be as well. E.g.
+ * for unresolved column `x` suggested columns will omit catalog/schema or `LIVE` qualifier. For
+ * this reason we verify only last part of suggested column name.
+ */
+ protected def verifyUnresolveColumnError(
+ errorMessage: String,
+ unresolved: String,
+ suggested: Seq[String]): Unit = {
+ assert(errorMessage.contains(unresolved))
+ assert(
+ errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]") ||
+ errorMessage.contains("[MISSING_COLUMN]")
+ )
+ suggested.foreach { x =>
+ if (errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]")) {
+ assert(errorMessage.contains(s"`$x`"))
+ } else {
+ assert(errorMessage.contains(x))
+ }
+ }
+ }
+
+ /** Evaluates the given column and returns the result. */
+ def eval(column: Column): Any = {
+ spark.range(1).select(column).collect().head.get(0)
+ }
+
+ /** Evaluates a column as part of a query and returns the result */
+ def eval[T](col: TypedColumn[Any, T]): T = {
+ spark.range(1).select(col).head()
+ }
+}
+
+/**
+ * A trait that provides a way to specify the target catalog and schema for a test.
+ */
+trait TargetCatalogAndDatabaseMixin {
+
+ protected def catalogInPipelineSpec: Option[String] = Option(
+ TestGraphRegistrationContext.DEFAULT_CATALOG
+ )
+
+ protected def databaseInPipelineSpec: Option[String] = Option(
+ TestGraphRegistrationContext.DEFAULT_DATABASE
+ )
+}
+
+object PipelineTest extends Logging {
+
+ /** System schemas per-catalog that's can't be directly deleted. */
+ protected val systemDatabases: Set[String] = Set("default", "information_schema")
+
+ /** System catalogs that are read-only and cannot be modified/dropped. */
+ private val systemCatalogs: Set[String] = Set("samples")
+
+ /** Catalogs that cannot be dropped but schemas or tables under it can be cleaned up. */
+ private val undroppableCatalogs: Set[String] = Set(
+ "hive_metastore",
+ "spark_catalog",
+ "system",
+ "main"
+ )
+
+ /** Creates a temporary directory. */
+ protected def createTempDir(): String = {
+ Files.createTempDirectory(getClass.getSimpleName).normalize.toString
+ }
+
+ /**
+ * Try to drop the schema in the catalog and return whether it is successfully dropped.
+ */
+ private def dropDatabaseIfPossible(
+ spark: SparkSession,
+ catalogName: String,
+ databaseName: String): Boolean = {
+ try {
+ spark.sql(s"DROP DATABASE IF EXISTS `$catalogName`.`$databaseName` CASCADE")
+ true
+ } catch {
+ case NonFatal(e) =>
+ logInfo(
+ s"Failed to drop database $databaseName in catalog $catalogName, ex:${e.getMessage}"
+ )
+ false
+ }
+ }
+
+ /** Cleanup resources created in the metastore by tests. */
+ def cleanupMetastore(spark: SparkSession): Unit = synchronized {
+ // some tests stop the spark session and managed the cleanup by themself, so no need to
+ // cleanup if no active spark session found
+ if (spark.sparkContext.isStopped) {
+ return
+ }
+ val catalogs =
+ spark.sql(s"SHOW CATALOGS").collect().map(_.getString(0)).filterNot(systemCatalogs.contains)
+ catalogs.foreach { catalog =>
+ if (undroppableCatalogs.contains(catalog)) {
+ val schemas =
+ spark.sql(s"SHOW SCHEMAS IN `$catalog`").collect().map(_.getString(0))
+ schemas.foreach { schema =>
+ if (systemDatabases.contains(schema) || !dropDatabaseIfPossible(spark, catalog, schema)) {
+ spark
+ .sql(s"SHOW tables in `$catalog`.`$schema`")
+ .collect()
+ .map(_.getString(0))
+ .foreach { table =>
+ Try(spark.sql(s"DROP table IF EXISTS `$catalog`.`$schema`.`$table`")) match {
+ case Failure(e) =>
+ logInfo(
+ s"Failed to drop table $table in schema $schema in catalog $catalog, " +
+ s"ex:${e.getMessage}"
+ )
+ case _ =>
+ }
+ }
+ }
+ }
+ } else {
+ Try(spark.sql(s"DROP CATALOG IF EXISTS `$catalog` CASCADE")) match {
+ case Failure(e) =>
+ logInfo(s"Failed to drop catalog $catalog, ex:${e.getMessage}")
+ case _ =>
+ }
+ }
+ }
+ spark.sessionState.catalog.reset()
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala
new file mode 100644
index 0000000000000..d891d6529de96
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.utils
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.AnalysisException
+
+/**
+ * Collection of helper methods to simplify working with exceptions in tests.
+ */
+trait SparkErrorTestMixin {
+ this: SparkFunSuite =>
+
+ /**
+ * Asserts that the given exception is a [[AnalysisException]] with the specific error class.
+ */
+ def assertAnalysisException(ex: Throwable, errorClass: String): Unit = {
+ ex match {
+ case t: AnalysisException =>
+ assert(
+ t.getCondition == errorClass,
+ s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}"
+ )
+ case _ => fail(s"Expected analysis exception but got ${ex.getClass}")
+ }
+ }
+
+ /**
+ * Asserts that the given exception is a [[AnalysisException]]h the specific error class
+ * and metadata.
+ */
+ def assertAnalysisException(
+ ex: Throwable,
+ errorClass: String,
+ metadata: Map[String, String]
+ ): Unit = {
+ ex match {
+ case t: AnalysisException =>
+ assert(
+ t.getCondition == errorClass,
+ s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}"
+ )
+ assert(
+ t.getMessageParameters.asScala.toMap == metadata,
+ s"Expected analysis exception with metadata $metadata, but got " +
+ s"${t.getMessageParameters}"
+ )
+ case _ => fail(s"Expected analysis exception but got ${ex.getClass}")
+ }
+ }
+
+ /**
+ * Asserts that the given exception is a [[SparkException]] with the specific error class.
+ */
+ def assertSparkException(ex: Throwable, errorClass: String): Unit = {
+ ex match {
+ case t: SparkException =>
+ assert(
+ t.getCondition == errorClass,
+ s"Expected spark exception with error class $errorClass, but got ${t.getCondition}"
+ )
+ case _ => fail(s"Expected spark exception but got ${ex.getClass}")
+ }
+ }
+
+ /**
+ * Asserts that the given exception is a [[SparkException]] with the specific error class
+ * and metadata.
+ */
+ def assertSparkException(
+ ex: Throwable,
+ errorClass: String,
+ metadata: Map[String, String]
+ ): Unit = {
+ ex match {
+ case t: SparkException =>
+ assert(
+ t.getCondition == errorClass,
+ s"Expected spark exception with error class $errorClass, but got ${t.getCondition}"
+ )
+ assert(
+ t.getMessageParameters.asScala.toMap == metadata,
+ s"Expected spark exception with metadata $metadata, but got ${t.getMessageParameters}"
+ )
+ case _ => fail(s"Expected spark exception but got ${ex.getClass}")
+ }
+ }
+}
diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
new file mode 100644
index 0000000000000..3449c5155c754
--- /dev/null
+++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.utils
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{LocalTempView, UnresolvedRelation, ViewType}
+import org.apache.spark.sql.classic.{DataFrame, SparkSession}
+import org.apache.spark.sql.pipelines.graph.{
+ DataflowGraph,
+ FlowAnalysis,
+ FlowFunction,
+ GraphIdentifierManager,
+ GraphRegistrationContext,
+ PersistedView,
+ QueryContext,
+ QueryOrigin,
+ Table,
+ TemporaryView,
+ UnresolvedFlow
+}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * A test class to simplify the creation of pipelines and datasets for unit testing.
+ */
+class TestGraphRegistrationContext(
+ val spark: SparkSession,
+ val sqlConf: Map[String, String] = Map.empty)
+ extends GraphRegistrationContext(
+ defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG,
+ defaultDatabase = TestGraphRegistrationContext.DEFAULT_DATABASE,
+ defaultSqlConf = sqlConf
+ ) {
+
+ // scalastyle:off
+ // Disable scalastyle to ignore argument count.
+ def registerTable(
+ name: String,
+ query: Option[FlowFunction] = None,
+ sqlConf: Map[String, String] = Map.empty,
+ comment: Option[String] = None,
+ specifiedSchema: Option[StructType] = None,
+ partitionCols: Option[Seq[String]] = None,
+ properties: Map[String, String] = Map.empty,
+ baseOrigin: QueryOrigin = QueryOrigin.empty,
+ format: Option[String] = None,
+ catalog: Option[String] = None,
+ database: Option[String] = None
+ ): Unit = {
+ // scalastyle:on
+ val tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark)
+ registerTable(
+ Table(
+ identifier = GraphIdentifierManager.parseTableIdentifier(name, spark),
+ comment = comment,
+ specifiedSchema = specifiedSchema,
+ partitionCols = partitionCols,
+ properties = properties,
+ baseOrigin = baseOrigin,
+ format = format.orElse(Some("parquet")),
+ normalizedPath = None,
+ isStreamingTableOpt = None
+ )
+ )
+
+ if (query.isDefined) {
+ registerFlow(
+ new UnresolvedFlow(
+ identifier = tableIdentifier,
+ destinationIdentifier = tableIdentifier,
+ func = query.get,
+ queryContext = QueryContext(
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase))
+ ),
+ sqlConf = sqlConf,
+ once = false,
+ comment = comment,
+ origin = baseOrigin
+ )
+ )
+ }
+ }
+
+ def registerView(
+ name: String,
+ query: FlowFunction,
+ sqlConf: Map[String, String] = Map.empty,
+ comment: Option[String] = None,
+ origin: QueryOrigin = QueryOrigin.empty,
+ viewType: ViewType = LocalTempView,
+ catalog: Option[String] = None,
+ database: Option[String] = None
+ ): Unit = {
+
+ val viewIdentifier = GraphIdentifierManager
+ .parseAndValidateTemporaryViewIdentifier(rawViewIdentifier = TableIdentifier(name))
+
+ registerView(
+ viewType match {
+ case LocalTempView =>
+ TemporaryView(
+ identifier = viewIdentifier,
+ comment = comment,
+ origin = origin,
+ properties = Map.empty
+ )
+ case _ =>
+ PersistedView(
+ identifier = viewIdentifier,
+ comment = comment,
+ origin = origin,
+ properties = Map.empty
+ )
+ }
+ )
+
+ registerFlow(
+ new UnresolvedFlow(
+ identifier = viewIdentifier,
+ destinationIdentifier = viewIdentifier,
+ func = query,
+ queryContext = QueryContext(
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase))
+ ),
+ sqlConf = sqlConf,
+ once = false,
+ comment = comment,
+ origin = origin
+ )
+ )
+ }
+
+ def registerFlow(
+ destinationName: String,
+ name: String,
+ query: FlowFunction,
+ once: Boolean = false,
+ catalog: Option[String] = None,
+ database: Option[String] = None
+ ): Unit = {
+ val flowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark)
+ val flowDestinationIdentifier =
+ GraphIdentifierManager.parseTableIdentifier(destinationName, spark)
+
+ registerFlow(
+ new UnresolvedFlow(
+ identifier = flowIdentifier,
+ destinationIdentifier = flowDestinationIdentifier,
+ func = query,
+ queryContext = QueryContext(
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase))
+ ),
+ sqlConf = Map.empty,
+ once = once,
+ comment = None,
+ origin = QueryOrigin()
+ )
+ )
+ }
+
+ /**
+ * Creates a flow function from a logical plan that reads from a table with the given name.
+ */
+ def readFlowFunc(name: String): FlowFunction = {
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(UnresolvedRelation(TableIdentifier(name)))
+ }
+
+ /**
+ * Creates a flow function from a logical plan that reads a stream from a table with the given
+ * name.
+ */
+ def readStreamFlowFunc(name: String): FlowFunction = {
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(
+ UnresolvedRelation(
+ TableIdentifier(name),
+ extraOptions = CaseInsensitiveStringMap.empty(),
+ isStreaming = true
+ )
+ )
+ }
+
+ /**
+ * Creates a flow function from a logical plan parsed from the given SQL text.
+ */
+ def sqlFlowFunc(spark: SparkSession, sql: String): FlowFunction = {
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(spark.sessionState.sqlParser.parsePlan(sql))
+ }
+
+ /**
+ * Creates a flow function from a logical plan from the given DataFrame. This is meant for
+ * DataFrames that don't read from tables within the pipeline.
+ */
+ def dfFlowFunc(df: DataFrame): FlowFunction = {
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(df.logicalPlan)
+ }
+
+ /**
+ * Generates a dataflow graph from this pipeline definition and resolves it.
+ * @return
+ */
+ def resolveToDataflowGraph(): DataflowGraph = toDataflowGraph.resolve()
+}
+
+object TestGraphRegistrationContext {
+ val DEFAULT_CATALOG = "spark_catalog"
+ val DEFAULT_DATABASE = "test_db"
+}