diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8dddca5077a67..993ffd888e0e7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -82,6 +82,14 @@ ], "sqlState" : "XX000" }, + "APPEND_ONCE_FROM_BATCH_QUERY" : { + "message" : [ + "Creating a streaming table from a batch query prevents incremental loading of new data from source. Offending table: ''.", + "Please use the stream() operator. Example usage:", + "CREATE STREAMING TABLE ... AS SELECT ... FROM stream() ..." + ], + "sqlState" : "42000" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." @@ -1372,6 +1380,12 @@ }, "sqlState" : "42734" }, + "DUPLICATE_FLOW_SQL_CONF" : { + "message" : [ + "Found duplicate sql conf for dataset '': '' is defined by both '' and ''" + ], + "sqlState" : "42710" + }, "DUPLICATE_KEY" : { "message" : [ "Found duplicate keys ." @@ -1943,6 +1957,12 @@ ], "sqlState" : "42818" }, + "INCOMPATIBLE_BATCH_VIEW_READ" : { + "message" : [ + "View is a batch view and must be referenced using SparkSession#read. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + ], + "sqlState" : "42000" + }, "INCOMPATIBLE_COLUMN_TYPE" : { "message" : [ " can only be performed on tables with compatible column types. The column of the table is type which is not compatible with at the same column of the first table.." @@ -2019,6 +2039,12 @@ ], "sqlState" : "42613" }, + "INCOMPATIBLE_STREAMING_VIEW_READ" : { + "message" : [ + "View is a streaming view and must be referenced using SparkSession#readStream. This check can be disabled by setting Spark conf pipelines.incompatibleViewCheck.enabled = false." + ], + "sqlState" : "42000" + }, "INCOMPATIBLE_VIEW_SCHEMA_CHANGE" : { "message" : [ "The SQL query of view has an incompatible schema change and column cannot be resolved. Expected columns named but got .", @@ -3119,6 +3145,12 @@ }, "sqlState" : "KD002" }, + "INVALID_NAME_IN_USE_COMMAND" : { + "message" : [ + "Invalid name '' in command. Reason: " + ], + "sqlState" : "42000" + }, "INVALID_NON_DETERMINISTIC_EXPRESSIONS" : { "message" : [ "The operator expects a deterministic expression, but the actual expression is ." @@ -3384,6 +3416,12 @@ ], "sqlState" : "22023" }, + "INVALID_RESETTABLE_DEPENDENCY" : { + "message" : [ + "Tables are resettable but have a non-resettable downstream dependency ''. `reset` will fail as Spark Streaming does not support deleted source data. You can either remove the =false property from '' or add it to its upstream dependencies." + ], + "sqlState" : "42000" + }, "INVALID_RESET_COMMAND_FORMAT" : { "message" : [ "Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`." @@ -5419,6 +5457,19 @@ ], "sqlState" : "58030" }, + "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA" : { + "message" : [ + "Failed to infer the schema for table from its upstream flows.", + "Please modify the flows that write to this table to make their schemas compatible.", + "", + "Inferred schema so far:", + "", + "", + "Incompatible schema:", + "" + ], + "sqlState" : "42KD9" + }, "UNABLE_TO_INFER_SCHEMA" : { "message" : [ "Unable to infer schema for . It must be specified manually." @@ -5590,6 +5641,12 @@ ], "sqlState" : "42883" }, + "UNRESOLVED_TABLE_PATH" : { + "message" : [ + "Storage path for table cannot be resolved." + ], + "sqlState" : "22KD1" + }, "UNRESOLVED_USING_COLUMN_FOR_JOIN" : { "message" : [ "USING column cannot be resolved on the side of the join. The -side columns: []." @@ -6571,6 +6628,20 @@ ], "sqlState" : "P0001" }, + "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE" : { + "message" : [ + "Table '' has a user-specified schema that is incompatible with the schema", + "inferred from its query.", + "", + "", + "Declared schema:", + "", + "", + "Inferred schema:", + "" + ], + "sqlState" : "42000" + }, "VARIABLE_ALREADY_EXISTS" : { "message" : [ "Cannot create the variable because it already exists.", diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index 7d796a83af69d..a04993299ce7c 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -16,7 +16,8 @@ ~ limitations under the License. --> - + 4.0.0 org.apache.spark @@ -44,11 +45,76 @@ org.apache.spark - spark-core_${scala.binary.version} + spark-sql_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} ${project.version} test-jar test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + + + org.scala-lang.modules + scala-parallel-collections_${scala.binary.version} + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala new file mode 100644 index 0000000000000..35b8185c255e1 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines + +/** Represents a warning generated as part of graph analysis. */ +sealed trait AnalysisWarning + +object AnalysisWarning { + + /** + * Warning that some streaming reader options are being dropped + * + * @param sourceName Source for which reader options are being dropped. + * @param droppedOptions Set of reader options that are being dropped for a specific source. + */ + case class StreamingReaderOptionsDropped(sourceName: String, droppedOptions: Seq[String]) + extends AnalysisWarning +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala new file mode 100644 index 0000000000000..c627850b667be --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/Language.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines + +sealed trait Language {} + +object Language { + case class Python() extends Language {} + case class Sql() extends Language {} +} + diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala new file mode 100644 index 0000000000000..d33924c2e1c37 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ + TransformNodeFailedException, + TransformNodeRetryableException +} + +/** + * Processor that is responsible for analyzing each flow and sort the nodes in + * topological order + */ +class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { + + private val flowResolver = new FlowResolver(rawGraph) + + // Map of input identifier to resolved [[Input]]. + private val resolvedInputs = new ConcurrentHashMap[TableIdentifier, Input]() + // Map & queue of resolved flows identifiers + // queue is there to track the topological order while map is used to store the id -> flow + // mapping + private val resolvedFlowNodesMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() + private val resolvedFlowNodesQueue = new ConcurrentLinkedQueue[ResolvedFlow]() + + private def processUnresolvedFlow(flow: UnresolvedFlow): ResolvedFlow = { + val resolvedFlow = flowResolver.attemptResolveFlow( + flow, + rawGraph.inputIdentifiers, + resolvedInputs.asScala.toMap + ) + resolvedFlowNodesQueue.add(resolvedFlow) + resolvedFlowNodesMap.put(flow.identifier, resolvedFlow) + resolvedFlow + } + + /** + * Processes the node of the graph, re-arranging them if they are not topologically sorted. + * Takes care of resolving the flows and virtualizing tables (i.e. removing tables to + * ensure resolution is internally consistent) if needed for the nodes. + * @param node The node to process + * @param upstreamNodes Upstream nodes for the node + * @return The resolved nodes generated by processing this element. + */ + def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = { + node match { + case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow)) + case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow)) + case table: Table => + // Ensure all upstreamNodes for a table are flows + val flowsToTable = upstreamNodes.map { + case flow: Flow => flow + case _ => + throw new IllegalArgumentException( + s"Unsupported upstream node type for table ${table.displayName}: " + + s"${upstreamNodes.getClass}" + ) + } + val resolvedFlowsToTable = flowsToTable.map { flow => + resolvedFlowNodesMap.get(flow.identifier) + } + + // Assign isStreamingTable (MV or ST) to the table based on the resolvedFlowsToTable + val tableWithType = table.copy( + isStreamingTableOpt = Option(resolvedFlowsToTable.exists(f => f.df.isStreaming)) + ) + + // We mark all tables as virtual to ensure resolution uses incoming flows + // rather than previously materialized tables. + val virtualTableInput = VirtualTableInput( + identifier = table.identifier, + specifiedSchema = table.specifiedSchema, + incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet, + availableFlows = resolvedFlowsToTable + ) + resolvedInputs.put(table.identifier, virtualTableInput) + Seq(tableWithType) + case view: View => + // For view, add the flow to resolvedInputs and return empty. + require(upstreamNodes.size == 1, "Found multiple flows to view") + upstreamNodes.head match { + case f: Flow => + resolvedInputs.put(view.identifier, resolvedFlowNodesMap.get(f.destinationIdentifier)) + Seq(view) + case _ => + throw new IllegalArgumentException( + s"Unsupported upstream node type for view ${view.displayName}: " + + s"${upstreamNodes.getClass}" + ) + } + case _ => + throw new IllegalArgumentException(s"Unsupported node type: ${node.getClass}") + } + } +} + +private class FlowResolver(rawGraph: DataflowGraph) { + + /** Helper used to track which confs were set by which flows. */ + private case class FlowConf(key: String, value: String, flowIdentifier: TableIdentifier) + + /** Attempts resolving a single flow using the map of resolved inputs. */ + def attemptResolveFlow( + flowToResolve: UnresolvedFlow, + allInputs: Set[TableIdentifier], + availableResolvedInputs: Map[TableIdentifier, Input]): ResolvedFlow = { + val flowFunctionResult = flowToResolve.func.call( + allInputs = allInputs, + availableInputs = availableResolvedInputs.values.toList, + configuration = flowToResolve.sqlConf, + queryContext = flowToResolve.queryContext + ) + val result = + flowFunctionResult match { + case f if f.dataFrame.isSuccess => + // Merge confs from any upstream views into confs for this flow. + val allFConfs = + (flowToResolve +: + f.inputs.toSeq + .map(availableResolvedInputs(_)) + .filter { + // Input is a flow implies that the upstream table is a View. + case _: Flow => true + // We stop in all other cases. + case _ => false + }).collect { + case g: Flow => + g.sqlConf.toSeq.map { case (k, v) => FlowConf(k, v, g.identifier) } + }.flatten + + allFConfs + .groupBy(_.key) // Key name -> Seq[FlowConf] + .filter(_._2.length > 1) // Entries where key was set more than once + .find(_._2.map(_.value).toSet.size > 1) // Entry where key was set with diff vals + .foreach { + case (key, confs) => + val sortedByVal = confs.sortBy(_.value) + throw new AnalysisException( + "DUPLICATE_FLOW_SQL_CONF", + Map( + "key" -> key, + "datasetName" -> flowToResolve.displayName, + "flowName1" -> sortedByVal.head.flowIdentifier.unquotedString, + "flowName2" -> sortedByVal.last.flowIdentifier.unquotedString + ) + ) + } + + val newSqlConf = allFConfs.map(fc => fc.key -> fc.value).toMap + // if the new sql confs are different from the original sql confs the flow was resolved + // with, resolve again. + val maybeNewFuncResult = if (newSqlConf != flowToResolve.sqlConf) { + flowToResolve.func.call( + allInputs = allInputs, + availableInputs = availableResolvedInputs.values.toList, + configuration = newSqlConf, + queryContext = flowToResolve.queryContext + ) + } else { + f + } + convertResolvedToTypedFlow(flowToResolve, maybeNewFuncResult) + + // If the flow failed due to an UnresolvedDatasetException, it means that one of the + // flow's inputs wasn't available. After other flows are resolved, these inputs + // may become available, so throw a retryable exception in this case. + case f => + f.dataFrame.failed.toOption.collectFirst { + case e: UnresolvedDatasetException => e + case _ => None + } match { + case Some(e: UnresolvedDatasetException) => + throw TransformNodeRetryableException( + e.identifier, + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + case _ => + throw TransformNodeFailedException( + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + } + } + result + } + + private def convertResolvedToTypedFlow( + flow: UnresolvedFlow, + funcResult: FlowFunctionResult): ResolvedFlow = { + val typedFlow = flow match { + case f: UnresolvedFlow if f.once => new AppendOnceFlow(flow, funcResult) + case f: UnresolvedFlow if funcResult.dataFrame.get.isStreaming => + // If there's more than 1 flow to this flow's destination, we should not allow it + // to be planned with an output mode other than Append, as the other flows will + // then get their results overwritten. + val mustBeAppend = rawGraph.flowsTo(f.destinationIdentifier).size > 1 + new StreamingFlow(flow, funcResult, mustBeAppend = mustBeAppend) + case _: UnresolvedFlow => new CompleteFlow(flow, funcResult) + } + typedFlow + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala new file mode 100644 index 0000000000000..585ba6295f239 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique +import org.apache.spark.sql.pipelines.util.SchemaMergingUtils +import org.apache.spark.sql.types.StructType + +/** + * DataflowGraph represents the core graph structure for Spark declarative pipelines. + * It manages the relationships between logical flows, tables, and views, providing + * operations for graph traversal, validation, and transformation. + */ +case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) + extends GraphOperations + with GraphValidations { + + /** Map of [[Output]]s by their identifiers */ + lazy val output: Map[TableIdentifier, Output] = mapUnique(tables, "output")(_.identifier) + + /** + * [[Flow]]s in this graph that need to get planned and potentially executed when + * executing the graph. Flows that write to logical views are excluded. + */ + lazy val materializedFlows: Seq[ResolvedFlow] = { + resolvedFlows.filter( + f => output.contains(f.destinationIdentifier) + ) + } + + /** The identifiers of [[materializedFlows]]. */ + val materializedFlowIdentifiers: Set[TableIdentifier] = materializedFlows.map(_.identifier).toSet + + /** Map of [[Table]]s by their identifiers */ + lazy val table: Map[TableIdentifier, Table] = + mapUnique(tables, "table")(_.identifier) + + /** Map of [[Flow]]s by their identifier */ + lazy val flow: Map[TableIdentifier, Flow] = { + // Better error message than using mapUnique. + val flowsByIdentifier = flows.groupBy(_.identifier) + flowsByIdentifier + .find(_._2.size > 1) + .foreach { + case (flowIdentifier, flows) => + // We don't expect this to ever actually be hit, graph registration should validate for + // unique flow names. + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW", + messageParameters = Map( + "flowName" -> flowIdentifier.unquotedString, + "datasetNames" -> flows.map(_.destinationIdentifier).mkString(",") + ) + ) + } + // Flows with non-default names shouldn't conflict with table names + flows + .filterNot(f => f.identifier == f.destinationIdentifier) + .filter(f => table.contains(f.identifier)) + .foreach { f => + throw new AnalysisException( + "FLOW_NAME_CONFLICTS_WITH_TABLE", + Map( + "flowName" -> f.identifier.toString(), + "target" -> f.destinationIdentifier.toString(), + "tableName" -> f.identifier.toString() + ) + ) + } + flowsByIdentifier.view.mapValues(_.head).toMap + } + + /** Map of [[View]]s by their identifiers */ + lazy val view: Map[TableIdentifier, View] = mapUnique(views, "view")(_.identifier) + + /** The [[PersistedView]]s of the graph */ + lazy val persistedViews: Seq[PersistedView] = views.collect { + case v: PersistedView => v + } + + /** All the [[Input]]s in the current DataflowGraph. */ + lazy val inputIdentifiers: Set[TableIdentifier] = { + (flows ++ tables).map(_.identifier).toSet + } + + /** The [[Flow]]s that write to a given destination. */ + lazy val flowsTo: Map[TableIdentifier, Seq[Flow]] = flows.groupBy(_.destinationIdentifier) + + lazy val resolvedFlows: Seq[ResolvedFlow] = { + flows.collect { case f: ResolvedFlow => f } + } + + lazy val resolvedFlow: Map[TableIdentifier, ResolvedFlow] = { + resolvedFlows.map { f => + f.identifier -> f + }.toMap + } + + lazy val resolutionFailedFlows: Seq[ResolutionFailedFlow] = { + flows.collect { case f: ResolutionFailedFlow => f } + } + + lazy val resolutionFailedFlow: Map[TableIdentifier, ResolutionFailedFlow] = { + resolutionFailedFlows.map { f => + f.identifier -> f + }.toMap + } + + /** + * Used to reanalyze the flow's DF for a given table. This is done by finding all upstream + * flows (until a table is reached) for the specified source and reanalyzing all upstream + * flows. + * + * @param srcFlow The flow that writes into the table that we will start from when finding + * upstream flows + * @return The reanalyzed flow + */ + protected[graph] def reanalyzeFlow(srcFlow: Flow): ResolvedFlow = { + val upstreamDatasetIdentifiers = dfsInternal( + flowNodes(srcFlow.identifier).output, + downstream = false, + stopAtMaterializationPoints = true + ) + val upstreamFlows = + resolvedFlows + .filter(f => upstreamDatasetIdentifiers.contains(f.destinationIdentifier)) + .map(_.flow) + val upstreamViews = upstreamDatasetIdentifiers.flatMap(identifier => view.get(identifier)).toSeq + + val subgraph = new DataflowGraph( + flows = upstreamFlows, + views = upstreamViews, + tables = Seq(table(srcFlow.destinationIdentifier)) + ) + subgraph.resolve().resolvedFlow(srcFlow.identifier) + } + + /** + * A map of the inferred schema of each table, computed by merging the analyzed schemas + * of all flows writing to that table. + */ + lazy val inferredSchema: Map[TableIdentifier, StructType] = { + flowsTo.view.mapValues { flows => + flows + .map { flow => + resolvedFlow(flow.identifier).schema + } + .reduce(SchemaMergingUtils.mergeSchemas) + }.toMap + } + + /** Ensure that the [[DataflowGraph]] is valid and throws errors if not. */ + def validate(): DataflowGraph = { + validationFailure.toOption match { + case Some(exception) => throw exception + case None => this + } + } + + /** + * Validate the current [[DataflowGraph]] and cache the validation failure. + * + * To add more validations, add them in a helper function that throws an exception if the + * validation fails, and invoke the helper function here. + */ + private lazy val validationFailure: Try[Throwable] = Try { + validateSuccessfulFlowAnalysis() + validateUserSpecifiedSchemas() + // Connecting the graph sorts it topologically + validateGraphIsTopologicallySorted() + validateMultiQueryTables() + validatePersistedViewSources() + validateEveryDatasetHasFlow() + validateTablesAreResettable() + inferredSchema + }.failed + + /** + * Enforce every dataset has at least one input flow. For example its possible to define + * streaming tables without a query; such tables should still have at least one flow + * writing to it. + */ + def validateEveryDatasetHasFlow(): Unit = { + (tables.map(_.identifier) ++ views.map(_.identifier)).foreach { identifier => + if (!flows.exists(_.destinationIdentifier == identifier)) { + throw new AnalysisException( + "PIPELINE_DATASET_WITHOUT_FLOW", + Map("identifier" -> identifier.quotedString) + ) + } + } + } + + /** Returns true iff all [[Flow]]s are successfully analyzed. */ + def resolved: Boolean = + flows.forall(f => resolvedFlow.contains(f.identifier)) + + def resolve(): DataflowGraph = + DataflowGraphTransformer.withDataflowGraphTransformer(this) { transformer => + val coreDataflowNodeProcessor = + new CoreDataflowNodeProcessor(rawGraph = this) + transformer + .transformDownNodes(coreDataflowNodeProcessor.processNode) + .getDataflowGraph + } +} + +object DataflowGraph { + protected[graph] def mapUnique[K, A](input: Seq[A], tpe: String)(f: A => K): Map[K, A] = { + val grouped = input.groupBy(f) + grouped.filter(_._2.length > 1).foreach { + case (name, _) => + throw new AnalysisException( + errorClass = "DUPLICATE_GRAPH_ELEMENT", + messageParameters = Map("graphElementType" -> tpe, "graphElementName" -> name.toString) + ) + } + grouped.view.mapValues(_.head).toMap + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala new file mode 100644 index 0000000000000..8448ed5f10d21 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ + ConcurrentHashMap, + ConcurrentLinkedDeque, + ConcurrentLinkedQueue, + ExecutionException, + Future +} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.control.NoStackTrace + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.util.ThreadUtils + +/** + * Resolves the [[DataflowGraph]] by processing each node in the graph. This class exposes visitor + * functionality to resolve/analyze graph nodes. + * We only expose simple visitor abilities to transform different entities of the + * graph. + * For advanced transformations we also expose a mechanism to walk the graph over entity by entity. + * + * Assumptions: + * 1. Each output will have at-least 1 flow to it. + * 2. Each flow may or may not have a destination table. If a flow does not have a destination + * table, the destination is a temporary view. + * + * The way graph is structured is that flows, tables and sinks all are graph elements or nodes. + * While we expose transformation functions for each of these entities, we also expose a way to + * process to walk over the graph. + * + * Constructor is private as all usages should be via + * DataflowGraphTransformer.withDataflowGraphTransformer. + * @param graph: Any Dataflow Graph + */ +class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { + import DataflowGraphTransformer._ + + private var tables: Seq[Table] = graph.tables + private var tableMap: Map[TableIdentifier, Table] = computeTableMap() + private var flows: Seq[Flow] = graph.flows + private var flowsTo: Map[TableIdentifier, Seq[Flow]] = computeFlowsTo() + private var views: Seq[View] = graph.views + private var viewMap: Map[TableIdentifier, View] = computeViewMap() + + // Fail analysis nodes + // Failed flows are flows that are failed to resolve or its inputs are not available or its + // destination failed to resolve. + private var failedFlows: Seq[ResolutionCompletedFlow] = Seq.empty + // We define a dataset is failed to resolve if it is a destination of a flow that is unresolved. + private var failedTables: Seq[Table] = Seq.empty + + private val parallelism = 10 + + // Executor used to resolve nodes in parallel. It is lazily initialized to avoid creating it + // for scenarios its not required. To track if the lazy val was evaluated or not we use a + // separate variable so we know if we need to shutdown the executor or not. + private var fixedPoolExecutorInitialized = false + lazy private val fixedPoolExecutor = { + fixedPoolExecutorInitialized = true + ThreadUtils.newDaemonFixedThreadPool( + parallelism, + prefix = "data-flow-graph-transformer-" + ) + } + private val selfExecutor = ThreadUtils.sameThreadExecutorService() + + private def computeTableMap(): Map[TableIdentifier, Table] = synchronized { + tables.map(table => table.identifier -> table).toMap + } + + private def computeViewMap(): Map[TableIdentifier, View] = synchronized { + views.map(view => view.identifier -> view).toMap + } + + private def computeFlowsTo(): Map[TableIdentifier, Seq[Flow]] = synchronized { + flows.groupBy(_.destinationIdentifier) + } + + def transformTables(transformer: Table => Table): DataflowGraphTransformer = synchronized { + tables = tables.map(transformer) + tableMap = computeTableMap() + this + } + + private def defaultOnFailedDependentTables( + failedTableDependencies: Map[TableIdentifier, Seq[Table]]): Unit = { + require( + failedTableDependencies.isEmpty, + "Dependency failure happened and some tables were not resolved" + ) + } + + /** + * Example graph: [Flow1, Flow 2] -> ST -> Flow3 -> MV + * Order of processing: Flow1, Flow2, ST, Flow3, MV. + * @param transformer function that transforms any graph entity. + * transformer( + * nodeToTransform: GraphElement, upstreamNodes: Seq[GraphElement] + * ) => transformedNodes: Seq[GraphElement] + * @return this + */ + def transformDownNodes( + transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + disableParallelism: Boolean = false): DataflowGraphTransformer = { + val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor + val batchSize = if (disableParallelism) 1 else parallelism + // List of resolved tables, sinks and flows + val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]() + val resolvedTables = new ConcurrentLinkedQueue[Table]() + val resolvedViews = new ConcurrentLinkedQueue[View]() + // Flow identifier to a list of transformed flows mapping to track resolved flows + val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]() + val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() + val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() + val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() + + var futures = ArrayBuffer[Future[Unit]]() + val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() + toBeResolvedFlows.addAll(flows.asJava) + + while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) { + val (done, notDone) = futures.partition(_.isDone) + // Explicitly call future.get() to propagate exceptions one by one if any + try { + done.foreach(_.get()) + } catch { + case exn: ExecutionException => + // Computation threw the exception that is the cause of exn + throw exn.getCause + } + futures = notDone + val flowOpt = { + // We only schedule [[batchSize]] number of flows in parallel. + if (futures.size < batchSize) { + Option(toBeResolvedFlows.pollFirst()) + } else { + None + } + } + if (flowOpt.isDefined) { + val flow = flowOpt.get + futures.append( + executor.submit( + () => + try { + try { + // Note: Flow don't need their inputs passed, so for now we send empty Seq. + val result = transformer(flow, Seq.empty) + require( + result.forall(_.isInstanceOf[ResolvedFlow]), + "transformer must return a Seq[Flow]" + ) + + val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow]) + resolvedFlowsMap.put(flow.identifier, transformedFlows) + resolvedFlows.addAll(transformedFlows.asJava) + } catch { + case e: TransformNodeRetryableException => + val datasetIdentifier = e.datasetIdentifier + failedDependentFlows.compute( + datasetIdentifier, + (_, flows) => { + // Don't add the input flow back but the failed flow object + // back which has relevant failure information. + val failedFlow = e.failedNode + if (flows == null) { + Seq(failedFlow) + } else { + flows :+ failedFlow + } + } + ) + // Between the time the flow started and finished resolving, perhaps the + // dependent dataset was resolved + resolvedFlowDestinationsMap.computeIfPresent( + datasetIdentifier, + (_, resolved) => { + if (resolved) { + // Check if the dataset that the flow is dependent on has been resolved + // and if so, remove all dependent flows from the failedDependentFlows and + // add them to the toBeResolvedFlows queue for retry. + failedDependentFlows.computeIfPresent( + datasetIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) + null + } + ) + } + resolved + } + ) + case other: Throwable => throw other + } + // If all flows to this particular destination are resolved, move to the destination + // node transformer + if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => + resolvedFlowsMap.containsKey(flowToDestination.identifier) + })) { + // If multiple flows completed in parallel, ensure we resolve the destination only + // once by electing a leader via computeIfAbsent + var isCurrentThreadLeader = false + resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { + isCurrentThreadLeader = true + // Set initial value as false as flow destination is not resolved yet. + false + }) + if (isCurrentThreadLeader) { + if (tableMap.contains(flow.destinationIdentifier)) { + val transformed = + transformer( + tableMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + resolvedTables.addAll( + transformed.collect { case t: Table => t }.asJava + ) + resolvedFlows.addAll( + transformed.collect { case f: ResolvedFlow => f }.asJava + ) + } else { + if (viewMap.contains(flow.destinationIdentifier)) { + resolvedViews.addAll { + val transformed = + transformer( + viewMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + transformed.map(_.asInstanceOf[View]).asJava + } + } else { + throw new IllegalArgumentException( + s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + + s" in flow: ${flow.displayName} at transformDownNodes" + ) + } + } + // Set flow destination as resolved now. + resolvedFlowDestinationsMap.computeIfPresent( + flow.destinationIdentifier, + (_, _) => { + // If there are any other node failures dependent on this destination, retry + // them + failedDependentFlows.computeIfPresent( + flow.destinationIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) + null + } + ) + true + } + ) + } + } + } catch { + case ex: TransformNodeFailedException => failedFlowsQueue.add(ex.failedNode) + } + ) + ) + } + } + + // Mutate the fail analysis entities + // A table is failed to analyze if: + // - It does not exist in the resolvedFlowDestinationsMap + failedTables = tables.filterNot { table => + resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) + } + + // We maintain the topological sort order of successful flows always + val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) = + resolvedFlows.asScala.toSeq.partition(flow => { + resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) + }) + + // A flow is failed to analyze if: + // - It is non-retryable + // - It is retryable but could not be retried, i.e. the dependent dataset is still unresolved + // - It might be resolvable but it writes into a destination that is failed to analyze + // To note: because we are transform down nodes, all downstream nodes of any pruned nodes + // will also be pruned + failedFlows = + // All transformed flows that write to a destination that is failed to analyze. + resolvedFlowsWithFailedDest ++ + // All failed flows thrown by TransformNodeFailedException + failedFlowsQueue.asScala.toSeq ++ + // All flows that have not been transformed and resolved yet + failedDependentFlows.values().asScala.flatten.toSeq + + // Mutate the resolved entities + flows = resolvedFlowsWithResolvedDest + flowsTo = computeFlowsTo() + tables = resolvedTables.asScala.toSeq + views = resolvedViews.asScala.toSeq + tableMap = computeTableMap() + viewMap = computeViewMap() + this + } + + def getDataflowGraph: DataflowGraph = { + graph.copy( + // Returns all flows (resolved and failed) in topological order. + // The relative order between flows and failed flows doesn't matter here. + // For failed flows that were resolved but were marked failed due to destination failure, + // they will be front of the list in failedFlows and thus by definition topologically sorted + // in the combined sequence too. + flows = flows ++ failedFlows, + tables = tables ++ failedTables + ) + } + + override def close(): Unit = { + if (fixedPoolExecutorInitialized) { + fixedPoolExecutor.shutdown() + } + } +} + +object DataflowGraphTransformer { + + /** + * Exception thrown when transforming a node in the graph fails because at least one of its + * dependencies weren't yet transformed. + * + * @param datasetIdentifier The identifier for an untransformed dependency table identifier in the + * dataflow graph. + */ + case class TransformNodeRetryableException( + datasetIdentifier: TableIdentifier, + failedNode: ResolutionFailedFlow) + extends Exception + with NoStackTrace + + /** + * Exception thrown when transforming a node in the graph fails with a non-retryable error. + * + * @param failedNode The failed node that could not be transformed. + */ + case class TransformNodeFailedException(failedNode: ResolutionFailedFlow) + extends Exception + with NoStackTrace + + /** + * Autocloseable wrapper around DataflowGraphTransformer to ensure that the transformer is closed + * without clients needing to remember to close it. It takes in the same arguments as + * [[DataflowGraphTransformer]] constructor. It exposes the DataflowGraphTransformer instance + * within the callable scope. + */ + def withDataflowGraphTransformer[T](graph: DataflowGraph)(f: DataflowGraphTransformer => T): T = { + val dataflowGraphTransformer = new DataflowGraphTransformer(graph) + try { + f(dataflowGraphTransformer) + } finally { + dataflowGraphTransformer.close() + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala new file mode 100644 index 0000000000000..2378b6f8d96a6 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.classic.DataFrame +import org.apache.spark.sql.pipelines.AnalysisWarning +import org.apache.spark.sql.pipelines.util.InputReadOptions +import org.apache.spark.sql.types.StructType + +/** + * Contains the catalog and database context information for query execution. + */ +case class QueryContext(currentCatalog: Option[String], currentDatabase: Option[String]) + +/** + * A [[Flow]] is a node of data transformation in a dataflow graph. It describes the movement + * of data into a particular dataset. + */ +trait Flow extends GraphElement with Logging { + + /** The [[FlowFunction]] containing the user's query. */ + def func: FlowFunction + + val identifier: TableIdentifier + + /** + * The dataset that this Flow represents a write to. + */ + val destinationIdentifier: TableIdentifier + + /** + * Whether this is a ONCE flow. ONCE flows should run only once per full refresh. + */ + def once: Boolean = false + + /** The current query context (catalog and database) when the query is defined. */ + def queryContext: QueryContext + + /** The comment associated with this flow */ + def comment: Option[String] + + def sqlConf: Map[String, String] +} + +/** A wrapper for a resolved internal input that includes the alias provided by the user. */ +case class ResolvedInput(input: Input, aliasIdentifier: AliasIdentifier) + +/** A wrapper for the lambda function that defines a [[Flow]]. */ +trait FlowFunction extends Logging { + + /** + * This function defines the transformations performed by a flow, expressed as a DataFrame. + * + * @param allInputs the set of identifiers for all the [[Input]]s defined in the + * [[DataflowGraph]]. + * @param availableInputs the list of all [[Input]]s available to this flow + * @param configuration the spark configurations that apply to this flow. + * @param queryContext The context of the query being evaluated. + * @return the inputs actually used, and the DataFrame expression for the flow + */ + def call( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + configuration: Map[String, String], + queryContext: QueryContext + ): FlowFunctionResult +} + +/** + * Holds the DataFrame returned by a [[FlowFunction]] along with the inputs used to + * construct it. + * @param batchInputs the complete inputs read by the flow + * @param streamingInputs the incremental inputs read by the flow + * @param usedExternalInputs the identifiers of the external inputs read by the flow + * @param dataFrame the DataFrame expression executed by the flow if the flow can be resolved + */ +case class FlowFunctionResult( + requestedInputs: Set[TableIdentifier], + batchInputs: Set[ResolvedInput], + streamingInputs: Set[ResolvedInput], + usedExternalInputs: Set[TableIdentifier], + dataFrame: Try[DataFrame], + sqlConf: Map[String, String], + analysisWarnings: Seq[AnalysisWarning] = Nil) { + + /** + * Returns the names of all of the [[Input]]s used when resolving this [[Flow]]. If the + * flow failed to resolve, we return the all the datasets that were requested when evaluating the + * flow. + */ + def inputs: Set[TableIdentifier] = { + (batchInputs ++ streamingInputs).map(_.input.identifier) + } + + /** Returns errors that occurred when attempting to analyze this [[Flow]]. */ + def failure: Seq[Throwable] = { + dataFrame.failed.toOption.toSeq + } + + /** Whether this [[Flow]] is successfully analyzed. */ + final def resolved: Boolean = failure.isEmpty // don't override this, override failure +} + +/** A [[Flow]] whose output schema and dependencies aren't known. */ +case class UnresolvedFlow( + identifier: TableIdentifier, + destinationIdentifier: TableIdentifier, + func: FlowFunction, + queryContext: QueryContext, + sqlConf: Map[String, String], + comment: Option[String] = None, + override val once: Boolean, + override val origin: QueryOrigin +) extends Flow + +/** + * A [[Flow]] whose flow function has been invoked, meaning either: + * - Its output schema and dependencies are known. + * - It failed to resolve. + */ +trait ResolutionCompletedFlow extends Flow { + def flow: UnresolvedFlow + def funcResult: FlowFunctionResult + + val identifier: TableIdentifier = flow.identifier + val destinationIdentifier: TableIdentifier = flow.destinationIdentifier + def func: FlowFunction = flow.func + def queryContext: QueryContext = flow.queryContext + def comment: Option[String] = flow.comment + def sqlConf: Map[String, String] = funcResult.sqlConf + def origin: QueryOrigin = flow.origin +} + +/** A [[Flow]] whose flow function has failed to resolve. */ +class ResolutionFailedFlow(val flow: UnresolvedFlow, val funcResult: FlowFunctionResult) + extends ResolutionCompletedFlow { + assert(!funcResult.resolved) + + def failure: Seq[Throwable] = funcResult.failure +} + +/** A [[Flow]] whose flow function has successfully resolved. */ +trait ResolvedFlow extends ResolutionCompletedFlow with Input { + assert(funcResult.resolved) + + /** The logical plan for this flow's query. */ + def df: DataFrame = funcResult.dataFrame.get + + /** Returns the schema of the output of this [[Flow]]. */ + def schema: StructType = df.schema + override def load(readOptions: InputReadOptions): DataFrame = df + def inputs: Set[TableIdentifier] = funcResult.inputs +} + +/** A [[Flow]] that represents stateful movement of data to some target. */ +class StreamingFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult, + val mustBeAppend: Boolean = false +) extends ResolvedFlow {} + +/** A [[Flow]] that declares exactly what data should be in the target table. */ +class CompleteFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult, + val mustBeAppend: Boolean = false +) extends ResolvedFlow {} + +/** A [[Flow]] that reads source[s] completely and appends data to the target, just once. + */ +class AppendOnceFlow( + val flow: UnresolvedFlow, + val funcResult: FlowFunctionResult +) extends ResolvedFlow { + + override val once = true +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala new file mode 100644 index 0000000000000..7e2e97f2b5d74 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.Try + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} +import org.apache.spark.sql.pipelines.AnalysisWarning +import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} +import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} + + +object FlowAnalysis { + /** + * Creates a [[FlowFunction]] that attempts to analyze the provided LogicalPlan + * using the existing resolved inputs. + * - If all upstream inputs have been resolved, then analysis succeeds and the + * function returns a [[FlowFunctionResult]] containing the dataframe. + * - If any upstream inputs are unresolved, then the function throws an exception. + * + * @param plan The user-supplied LogicalPlan defining a flow. + * @return A FlowFunction that attempts to analyze the provided LogicalPlan. + */ + def createFlowFunctionFromLogicalPlan(plan: LogicalPlan): FlowFunction = { + new FlowFunction { + override def call( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + confs: Map[String, String], + queryContext: QueryContext + ): FlowFunctionResult = { + val ctx = FlowAnalysisContext( + allInputs = allInputs, + availableInputs = availableInputs, + queryContext = queryContext, + spark = SparkSession.active + ) + val df = try { + confs.foreach { case (k, v) => ctx.setConf(k, v) } + Try(FlowAnalysis.analyze(ctx, plan)) + } finally { + ctx.restoreOriginalConf() + } + FlowFunctionResult( + requestedInputs = ctx.requestedInputs.toSet, + batchInputs = ctx.batchInputs.toSet, + streamingInputs = ctx.streamingInputs.toSet, + usedExternalInputs = ctx.externalInputs.toSet, + dataFrame = df, + sqlConf = confs, + analysisWarnings = ctx.analysisWarnings.toList + ) + } + } + } + + /** + * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific + * TVFs and datasets that cannot be resolved directly by Catalyst. + * + * This function shouldn't call any singleton as it will break concurrent access to graph + * analysis; or any thread local variables as graph analysis and this function will use + * different threads in python repl. + * + * @param plan The [[LogicalPlan]] defining a flow. + * @return An analyzed [[DataFrame]]. + */ + private def analyze( + context: FlowAnalysisContext, + plan: LogicalPlan + ): DataFrame = { + // Users can define CTEs within their CREATE statements. For example, + // + // CREATE STREAMING TABLE a + // WITH b AS ( + // SELECT * FROM STREAM upstream + // ) + // SELECT * FROM b + // + // The relation defined using the WITH keyword is not included in the children of the main + // plan so the specific analysis we do below will not be applied to those relations. + // Instead, we call an analyzer rule to inline all of the CTE relations in the main plan before + // we do analysis. This rule would be called during analysis anyways, but we just call it + // earlier so we only need to apply analysis to a single logical plan. + val planWithInlinedCTEs = CTESubstitution(plan) + + val spark = context.spark + // Traverse the user's query plan and recursively resolve nodes that reference Pipelines + // features that the Spark analyzer is unable to resolve + val resolvedPlan = planWithInlinedCTEs transformWithSubqueries { + // Streaming read on another dataset + // This branch will be hit for the following kinds of queries: + // - SELECT ... FROM STREAM(t1) + // - SELECT ... FROM STREAM t1 + case u: UnresolvedRelation if u.isStreaming => + readStreamInput( + context, + name = IdentifierHelper.toQuotedString(u.multipartIdentifier), + spark.readStream, + streamingReadOptions = StreamingReadOptions() + ).queryExecution.analyzed + + // Batch read on another dataset in the pipeline + case u: UnresolvedRelation => + readBatchInput( + context, + name = IdentifierHelper.toQuotedString(u.multipartIdentifier), + batchReadOptions = BatchReadOptions() + ).queryExecution.analyzed + } + Dataset.ofRows(spark, resolvedPlan) + + } + + /** + * Internal helper to reference the batch dataset (i.e., non-streaming dataset) with the given + * name. + * 1. The dataset can be a table, view, or a named flow. + * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external + * catalog. + * All the public APIs that read from a dataset should call this function to read the dataset. + * + * @param name the name of the Dataset to be read. + * @param batchReadOptions Options for this batch read + * @return batch DataFrame that represents data from the specified Dataset. + */ + final private def readBatchInput( + context: FlowAnalysisContext, + name: String, + batchReadOptions: BatchReadOptions + ): DataFrame = { + GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { + case inputIdentifier: InternalDatasetIdentifier => + readGraphInput(context, inputIdentifier, batchReadOptions) + + case inputIdentifier: ExternalDatasetIdentifier => + readExternalBatchInput( + context, + inputIdentifier = inputIdentifier, + name = name + ) + } + } + + /** + * Internal helper to reference the streaming dataset with the given name. + * 1. The dataset can be a table, view, or a named flow. + * 2. The dataset can be a dataset defined in the same DataflowGraph or a table in the external + * catalog. + * All the public APIs that read from a dataset should call this function to read the dataset. + * + * @param name the name of the Dataset to be read. + * @param streamReader The [[DataStreamReader]] that may hold read options specified by the user. + * @param streamingReadOptions Options for this streaming read. + * @return streaming DataFrame that represents data from the specified Dataset. + */ + final private def readStreamInput( + context: FlowAnalysisContext, + name: String, + streamReader: DataStreamReader, + streamingReadOptions: StreamingReadOptions + ): DataFrame = { + GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { + case inputIdentifier: InternalDatasetIdentifier => + readGraphInput( + context, + inputIdentifier, + streamingReadOptions + ) + + case inputIdentifier: ExternalDatasetIdentifier => + readExternalStreamInput( + context, + inputIdentifier = inputIdentifier, + streamReader = streamReader, + name = name + ) + } + } + + /** + * Internal helper to reference dataset defined in the same [[DataflowGraph]]. + * + * @param inputIdentifier The identifier of the Dataset to be read. + * @param readOptions Options for this read (may be either streaming or batch options) + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readGraphInput( + ctx: FlowAnalysisContext, + inputIdentifier: InternalDatasetIdentifier, + readOptions: InputReadOptions + ): DataFrame = { + val datasetIdentifier = inputIdentifier.identifier + + ctx.requestedInputs += datasetIdentifier + + val i = if (!ctx.allInputs.contains(datasetIdentifier)) { + // Dataset not defined in the dataflow graph + throw GraphErrors.pipelineLocalDatasetNotDefinedError(datasetIdentifier.unquotedString) + } else if (!ctx.availableInput.contains(datasetIdentifier)) { + // Dataset defined in the dataflow graph but not yet resolved + throw UnresolvedDatasetException(datasetIdentifier) + } else { + // Dataset is resolved, so we can read from it + ctx.availableInput(datasetIdentifier) + } + + val inputDF = i.load(readOptions) + i match { + // If the referenced input is a [[Flow]], because the query plans will be fused + // together, we also need to fuse their confs. + case f: Flow => f.sqlConf.foreach { case (k, v) => ctx.setConf(k, v) } + case _ => + } + + val incompatibleViewReadCheck = + ctx.spark.conf.get("pipelines.incompatibleViewCheck.enabled", "true").toBoolean + + // Wrap the DF in an alias so that columns in the DF can be referenced with + // the following in the query: + // - ... + // - .. + // - . + val aliasIdentifier = AliasIdentifier( + name = datasetIdentifier.table, + qualifier = Seq(datasetIdentifier.catalog, datasetIdentifier.database).flatten + ) + + readOptions match { + case sro: StreamingReadOptions => + if (!inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_BATCH_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + + if (sro.droppedUserOptions.nonEmpty) { + ctx.analysisWarnings += AnalysisWarning.StreamingReaderOptionsDropped( + sourceName = datasetIdentifier.unquotedString, + droppedOptions = sro.droppedUserOptions.keys.toSeq + ) + } + ctx.streamingInputs += ResolvedInput(i, aliasIdentifier) + case _ => + if (inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_STREAMING_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + ctx.batchInputs += ResolvedInput(i, aliasIdentifier) + } + Dataset.ofRows( + ctx.spark, + SubqueryAlias(identifier = aliasIdentifier, child = inputDF.queryExecution.logical) + ) + } + + /** + * Internal helper to reference batch dataset (i.e., non-streaming dataset) defined in an external + * catalog or as a path. + * + * @param inputIdentifier The identifier of the dataset to be read. + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readExternalBatchInput( + context: FlowAnalysisContext, + inputIdentifier: ExternalDatasetIdentifier, + name: String): DataFrame = { + + val spark = context.spark + context.externalInputs += inputIdentifier.identifier + spark.read.table(inputIdentifier.identifier.quotedString) + } + + /** + * Internal helper to reference dataset defined in an external catalog or as a path. + * + * @param inputIdentifier The identifier of the dataset to be read. + * @param streamReader The [[DataStreamReader]] that may hold additional read options specified by + * the user. + * @return streaming or batch DataFrame that represents data from the specified Dataset. + */ + final private def readExternalStreamInput( + context: FlowAnalysisContext, + inputIdentifier: ExternalDatasetIdentifier, + streamReader: DataStreamReader, + name: String): DataFrame = { + + context.externalInputs += inputIdentifier.identifier + streamReader.table(inputIdentifier.identifier.quotedString) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala new file mode 100644 index 0000000000000..fb96c6cb5bb1d --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.AnalysisWarning + +/** + * A context used when evaluating a [[Flow]]'s query into a concrete DataFrame. + * + * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph. + * @param availableInputs Inputs available to be referenced with `read` or `readStream`. + * @param queryContext The context of the query being evaluated. + * @param requestedInputs A mutable buffer populated with names of all inputs that were + * requested. + * @param spark the spark session to be used. + * @param externalInputs The names of external inputs that were used to evaluate + * the flow's query. + */ +private[pipelines] case class FlowAnalysisContext( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + queryContext: QueryContext, + batchInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, + streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, + requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty, + shouldLowerCaseNames: Boolean = false, + analysisWarnings: mutable.Buffer[AnalysisWarning] = new ListBuffer[AnalysisWarning], + spark: SparkSession, + externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty +) { + + /** Map from [[Input]] name to the actual [[Input]] */ + val availableInput: Map[TableIdentifier, Input] = + availableInputs.map(i => i.identifier -> i).toMap + + /** The confs set in this context that should be undone when exiting this context. */ + private val confsToRestore = mutable.HashMap[String, Option[String]]() + + /** Sets a Spark conf within this context that will be undone by `restoreOriginalConf`. */ + def setConf(key: String, value: String): Unit = { + if (!confsToRestore.contains(key)) { + confsToRestore.put(key, spark.conf.getOption(key)) + } + spark.conf.set(key, value) + } + + /** Restores the Spark conf to its state when this context was creating by undoing confs set. */ + def restoreOriginalConf(): Unit = confsToRestore.foreach { + case (k, Some(v)) => spark.conf.set(k, v) + case (k, None) => spark.conf.unset(k) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala new file mode 100644 index 0000000000000..bb3d290c3d3e1 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphElementTypeUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.pipelines.common.DatasetType + +object GraphElementTypeUtils { + + /** + * Helper function to obtain the DatasetType. This function should be called with all + * the flows that are writing into a table and all the flows should be resolved. + */ + def getDatasetTypeForMaterializedViewOrStreamingTable( + flowsToTable: Seq[ResolvedFlow]): DatasetType = { + val isStreamingTable = flowsToTable.exists(f => f.df.isStreaming || f.once) + if (isStreamingTable) { + DatasetType.STREAMING_TABLE + } else { + DatasetType.MATERIALIZED_VIEW + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala new file mode 100644 index 0000000000000..53db669e687d2 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphErrors.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.types.StructType + +/** Collection of errors that can be thrown during graph resolution / analysis. */ +object GraphErrors { + + /** + * Throws when a dataset is marked as internal but is not defined in the graph. + * + * @param datasetName the name of the dataset that is not defined + */ + def pipelineLocalDatasetNotDefinedError(datasetName: String): SparkException = { + SparkException.internalError( + s"Failed to read dataset '$datasetName'. This dataset was expected to be " + + s"defined and created by the pipeline." + ) + } + + /** + * Throws when the catalog or schema name in the "USE CATALOG | SCHEMA" command is invalid + * + * @param command string "USE CATALOG" or "USE SCHEMA" + * @param name the invalid catalog or schema name + * @param reason the reason why the name is invalid + */ + def invalidNameInUseCommandError( + command: String, + name: String, + reason: String + ): SparkException = { + new SparkException( + errorClass = "INVALID_NAME_IN_USE_COMMAND", + messageParameters = Map("command" -> command, "name" -> name, "reason" -> reason), + cause = null + ) + } + + /** + * Throws when a table path is unresolved, i.e. the table identifier + * does not exist in the catalog. + * + * @param identifier the unresolved table identifier + */ + def unresolvedTablePath(identifier: TableIdentifier): SparkException = { + new SparkException( + errorClass = "UNRESOLVED_TABLE_PATH", + messageParameters = Map("identifier" -> identifier.toString), + cause = null + ) + } + + /** + * Throws an error if the user-specified schema and the inferred schema are not compatible. + * + * @param tableIdentifier the identifier of the table that was not found + */ + def incompatibleUserSpecifiedAndInferredSchemasError( + tableIdentifier: TableIdentifier, + datasetType: DatasetType, + specifiedSchema: StructType, + inferredSchema: StructType, + cause: Option[Throwable] = None + ): AnalysisException = { + val streamingTableHint = + if (datasetType == DatasetType.STREAMING_TABLE) { + s"""" + |Streaming tables are stateful and remember data that has already been + |processed. If you want to recompute the table from scratch, please full refresh + |the table. + """.stripMargin + } else { + "" + } + + new AnalysisException( + errorClass = "USER_SPECIFIED_AND_INFERRED_SCHEMA_NOT_COMPATIBLE", + messageParameters = Map( + "tableName" -> tableIdentifier.unquotedString, + "streamingTableHint" -> streamingTableHint, + "specifiedSchema" -> specifiedSchema.treeString, + "inferredDataSchema" -> inferredSchema.treeString + ), + cause = Option(cause.orNull) + ) + } + + /** + * Throws if the latest inferred schema for a pipeline table is not compatible with + * the table's existing schema. + * + * @param tableIdentifier the identifier of the table that was not found + */ + def unableToInferSchemaError( + tableIdentifier: TableIdentifier, + inferredSchema: StructType, + incompatibleSchema: StructType, + cause: Option[Throwable] = None + ): AnalysisException = { + new AnalysisException( + errorClass = "UNABLE_TO_INFER_PIPELINE_TABLE_SCHEMA", + messageParameters = Map( + "tableName" -> tableIdentifier.unquotedString, + "inferredDataSchema" -> inferredSchema.treeString, + "incompatibleDataSchema" -> incompatibleSchema.treeString + ), + cause = Option(cause.orNull) + ) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala new file mode 100644 index 0000000000000..414d9d0effea4 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.datasources.DataSource + +/** + * Responsible for properly qualify the identifiers for datasets inside or referenced by the + * dataflow graph. + */ +object GraphIdentifierManager { + + import IdentifierHelper._ + + def parseTableIdentifier(name: String, spark: SparkSession): TableIdentifier = { + toTableIdentifier(spark.sessionState.sqlParser.parseMultipartIdentifier(name)) + } + + /** + * Fully qualify (if needed) the user-specified identifier used to reference datasets, and + * categorizing the dataset we're referencing (i.e. dataset from this pipeline or dataset that is + * external to this pipeline). + * + * Returns whether the input dataset should be read as a dataset and also the qualified + * identifier. + * + * @param rawInputName the user-specified name when referencing datasets. + */ + def parseAndQualifyInputIdentifier( + context: FlowAnalysisContext, + rawInputName: String): DatasetIdentifier = { + resolveDatasetReadInsideQueryDefinition(context = context, rawInputName = rawInputName) + } + + /** + * Resolve dataset reads that happens inside the dataset query definition (i.e., inside + * the @materialized_view() annotation in Python). + */ + private def resolveDatasetReadInsideQueryDefinition( + context: FlowAnalysisContext, + rawInputName: String + ): DatasetIdentifier = { + // After identifier is pre-processed, we first check whether we're referencing a + // single-part-name dataset (e.g., temp view). If so, don't fully qualified the identifier + // and directly read from it, because single-part-name datasets always out-mask other + // fully-qualified-datasets that have the same name. For example, if there's a view named + // "a" and also a table named "catalog.schema.a" defined in the graph. "SELECT * FROM a" + // would always read the view "a". To read table "a", user would need to read it using + // fully/partially qualified name (e.g., "SELECT * FROM catalog.schema.a" or "SELECT * FROM + // schema.a"). + val inputIdentifier = parseTableIdentifier(rawInputName, context.spark) + + /** Return whether we're referencing a dataset that is part of the pipeline. */ + def isInternalDataset(identifier: TableIdentifier): Boolean = { + context.allInputs.contains(identifier) + } + + if (isSinglePartIdentifier(inputIdentifier) && + isInternalDataset(inputIdentifier)) { + // reading a single-part-name dataset defined in the dataflow graph (e.g., a view) + InternalDatasetIdentifier(identifier = inputIdentifier) + } else if (isPathIdentifier(context.spark, inputIdentifier)) { + // path-based reference, always read as external dataset + ExternalDatasetIdentifier(identifier = inputIdentifier) + } else { + val fullyQualifiedInputIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = inputIdentifier, + currentCatalog = context.queryContext.currentCatalog, + currentDatabase = context.queryContext.currentDatabase + ) + assertIsFullyQualifiedForRead(identifier = fullyQualifiedInputIdentifier) + + if (isInternalDataset(fullyQualifiedInputIdentifier)) { + InternalDatasetIdentifier(identifier = fullyQualifiedInputIdentifier) + } else { + ExternalDatasetIdentifier(fullyQualifiedInputIdentifier) + } + } + } + + /** + * @param rawDatasetIdentifier the dataset identifier specified by the user. + */ + @throws[AnalysisException] + private def parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier: TableIdentifier): InternalDatasetIdentifier = { + InternalDatasetIdentifier(identifier = rawDatasetIdentifier) + } + + /** + * Parses the table identifier from the raw table identifier and fully qualifies it. + * + * @param rawTableIdentifier the raw table identifier + * @return the parsed table identifier + */ + @throws[AnalysisException] + def parseAndQualifyTableIdentifier( + rawTableIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): InternalDatasetIdentifier = { + val pipelineDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawTableIdentifier + ) + val fullyQualifiedTableIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = pipelineDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedTableIdentifier) + InternalDatasetIdentifier(identifier = fullyQualifiedTableIdentifier) + } + + /** + * Parses and validates the view identifier from the raw view identifier for temporary views. + * + * @param rawViewIdentifier the raw view identifier + * @return the parsed view identifier + */ + @throws[AnalysisException] + def parseAndValidateTemporaryViewIdentifier( + rawViewIdentifier: TableIdentifier): TableIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawViewIdentifier + ) + // Temporary views are not persisted to the catalog in use, therefore should not be qualified. + if (!isSinglePartIdentifier(internalDatasetIdentifier.identifier)) { + throw new AnalysisException( + "MULTIPART_TEMPORARY_VIEW_NAME_NOT_SUPPORTED", + Map("viewName" -> rawViewIdentifier.unquotedString) + ) + } + internalDatasetIdentifier.identifier + } + + /** + * Parses and validates the view identifier from the raw view identifier for persisted views. + * + * @param rawViewIdentifier the raw view identifier + * @param currentCatalog the catalog + * @param currentDatabase the schema + * @return the parsed view identifier + */ + def parseAndValidatePersistedViewIdentifier( + rawViewIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String]): TableIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawViewIdentifier + ) + // Persisted views have fully qualified names + val fullyQualifiedViewIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedViewIdentifier) + fullyQualifiedViewIdentifier + } + + /** + * Parses the flow identifier from the raw flow identifier and fully qualify it. + * + * @param rawFlowIdentifier the raw flow identifier + * @return the parsed flow identifier + */ + @throws[AnalysisException] + def parseAndQualifyFlowIdentifier( + rawFlowIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): InternalDatasetIdentifier = { + val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier( + rawDatasetIdentifier = rawFlowIdentifier + ) + + val fullyQualifiedFlowIdentifier = fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier = internalDatasetIdentifier.identifier, + currentCatalog = currentCatalog, + currentDatabase = currentDatabase + ) + + // assert the identifier is properly fully qualified + assertIsFullyQualifiedForCreate(fullyQualifiedFlowIdentifier) + InternalDatasetIdentifier(identifier = fullyQualifiedFlowIdentifier) + } + + /** Represents the identifier for a dataset that is defined or referenced in a pipeline. */ + sealed trait DatasetIdentifier + + /** Represents the identifier for a dataset that is defined by the current pipeline. */ + case class InternalDatasetIdentifier private ( + identifier: TableIdentifier + ) extends DatasetIdentifier + + /** Represents the identifier for a dataset that is external to the current pipeline. */ + case class ExternalDatasetIdentifier(identifier: TableIdentifier) extends DatasetIdentifier +} + +object IdentifierHelper { + + /** + * Returns the quoted string for the name parts. + * + * @param nameParts the dataset name parts. + * @return the quoted string for the name parts. + */ + def toQuotedString(nameParts: Seq[String]): String = { + toTableIdentifier(nameParts).quotedString + } + + /** + * Returns the table identifier constructed from the name parts. + * + * @param nameParts the dataset name parts. + * @return the table identifier constructed from the name parts. + */ + @throws[UnsupportedOperationException] + def toTableIdentifier(nameParts: Seq[String]): TableIdentifier = { + nameParts.length match { + case 1 => TableIdentifier(tableName = nameParts.head) + case 2 => TableIdentifier(table = nameParts(1), database = Option(nameParts.head)) + case 3 => + TableIdentifier( + table = nameParts(2), + database = Option(nameParts(1)), + catalog = Option(nameParts.head) + ) + case _ => + throw new UnsupportedOperationException( + s"4+ part table identifier ${nameParts.mkString(".")} is not supported." + ) + } + } + + /** + * Returns the table identifier constructed from the logical plan. + * + * @param table the logical plan. + * @return the table identifier constructed from the logical plan. + */ + def toTableIdentifier(table: LogicalPlan): TableIdentifier = { + val parts = table match { + case r: ResolvedIdentifier => r.identifier.namespace.toSeq :+ r.identifier.name + case u: UnresolvedIdentifier => u.nameParts + case u: UnresolvedRelation => u.multipartIdentifier + case _ => + throw new UnsupportedOperationException(s"Unable to resolve name for $table.") + } + toTableIdentifier(parts) + } + + /** Return whether the input identifier is a single-part identifier. */ + def isSinglePartIdentifier(identifier: TableIdentifier): Boolean = { + identifier.database.isEmpty && identifier.catalog.isEmpty + } + + /** + * Return true if the identifier should be resolved as a path-based reference + * (i.e., `datasource`.`path`). + */ + def isPathIdentifier(spark: SparkSession, identifier: TableIdentifier): Boolean = { + if (identifier.nameParts.length != 2) { + return false + } + val Seq(datasource, path) = identifier.nameParts + val sqlConf = spark.sessionState.conf + + def isDatasourceValid = { + try { + DataSource.lookupDataSource(datasource, sqlConf) + true + } catch { + case _: ClassNotFoundException => false + } + } + + // Whether the provided datasource is valid. + isDatasourceValid + } + + /** Fully qualifies provided identifier with provided catalog & schema. */ + def fullyQualifyIdentifier( + maybeFullyQualifiedIdentifier: TableIdentifier, + currentCatalog: Option[String], + currentDatabase: Option[String] + ): TableIdentifier = { + maybeFullyQualifiedIdentifier.copy( + database = maybeFullyQualifiedIdentifier.database.orElse(currentDatabase), + catalog = maybeFullyQualifiedIdentifier.catalog.orElse(currentCatalog) + ) + } + + /** Assert whether the identifier is properly fully qualified when creating a dataset. */ + def assertIsFullyQualifiedForCreate(identifier: TableIdentifier): Unit = { + assert( + identifier.catalog.isDefined && identifier.database.isDefined, + s"Dataset identifier $identifier is not properly fully qualified, expect a " + + s"three-part-name ..
" + ) + } + + /** Assert whether the identifier is properly qualified when reading a dataset in a pipeline. */ + def assertIsFullyQualifiedForRead(identifier: TableIdentifier): Unit = { + assert( + identifier.catalog.isDefined && identifier.database.isDefined, + s"Failed to reference dataset $identifier, expect a " + + s"three-part-name ..
" + ) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala new file mode 100644 index 0000000000000..c0b5a360afea7 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * @param identifier The identifier of the flow. + * @param inputs The identifiers of nodes used as inputs to this flow. + * @param output The identifier of the output that this flow writes to. + */ +case class FlowNode( + identifier: TableIdentifier, + inputs: Set[TableIdentifier], + output: TableIdentifier) + +trait GraphOperations { + this: DataflowGraph => + + /** The set of all outputs in the graph. */ + private lazy val destinationSet: Set[TableIdentifier] = + flows.map(_.destinationIdentifier).toSet + + /** A map from flow identifier to [[FlowNode]], which contains the input/output nodes. */ + lazy val flowNodes: Map[TableIdentifier, FlowNode] = { + flows.map { f => + val identifier = f.identifier + val n = + FlowNode( + identifier = identifier, + inputs = resolvedFlow(f.identifier).inputs, + output = f.destinationIdentifier + ) + identifier -> n + }.toMap + } + + /** Map from dataset identifier to all reachable upstream destinations, including itself. */ + private lazy val upstreamDestinations = + mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefault(key => dfsInternal(startDestination = key, downstream = false)) + + /** Map from dataset identifier to all reachable downstream destinations, including itself. */ + private lazy val downstreamDestinations = + mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefault(key => dfsInternal(startDestination = key, downstream = true)) + + /** + * Performs a DFS starting from `startNode` and returns the set of nodes (datasets) reached. + * @param startDestination The identifier of the node to start from. + * @param downstream if true, traverse output edges (search downstream) + * if false, traverse input edges (search upstream). + * @param stopAtMaterializationPoints If true, stop when we reach a materialization point (table). + * If false, keep going until the end. + */ + protected def dfsInternal( + startDestination: TableIdentifier, + downstream: Boolean, + stopAtMaterializationPoints: Boolean = false): Set[TableIdentifier] = { + assert( + destinationSet.contains(startDestination), + s"$startDestination is not a valid start node" + ) + val visited = new mutable.HashSet[TableIdentifier] + + // Same semantics as a stack. Need to be able to push/pop items. + var nextNodes = List[TableIdentifier]() + nextNodes = startDestination :: nextNodes + + while (nextNodes.nonEmpty) { + val currNode = nextNodes.head + nextNodes = nextNodes.tail + + if (!visited.contains(currNode) + // When we stop at materialization points, skip non-start nodes that are materialized. + && !(stopAtMaterializationPoints && table.contains(currNode) + && currNode != startDestination)) { + visited.add(currNode) + val neighbors = if (downstream) { + flowNodes.values.filter(_.inputs.contains(currNode)).map(_.output) + } else { + flowNodes.values.filter(_.output == currNode).flatMap(_.inputs) + } + nextNodes = neighbors.toList ++ nextNodes + } + } + visited.toSet + } + + /** + * An implementation of DFS that takes in a sequence of start nodes and returns the + * "reachability set" of nodes from the start nodes. + * + * @param downstream Walks the graph via the input edges if true, otherwise via the output + * edges. + * @return A map from visited nodes to its origin[s] in `datasetIdentifiers`, e.g. + * Let graph = a -> b c -> d (partitioned graph) + * + * reachabilitySet(Seq("a", "c"), downstream = true) + * -> ["a" -> ["a"], "b" -> ["a"], "c" -> ["c"], "d" -> ["c"]] + */ + private def reachabilitySet( + datasetIdentifiers: Seq[TableIdentifier], + downstream: Boolean): Map[TableIdentifier, Set[TableIdentifier]] = { + // Seq of the form "node identifier" -> Set("dependency1, "dependency2") + val deps = datasetIdentifiers.map(n => n -> reachabilitySet(n, downstream)) + + // Invert deps so that we get a map from each dependency to node identifier + val finalMap = mutable.HashMap + .empty[TableIdentifier, Set[TableIdentifier]] + .withDefaultValue(Set.empty[TableIdentifier]) + deps.foreach { + case (start, reachableNodes) => + reachableNodes.foreach(n => finalMap.put(n, finalMap(n) + start)) + } + finalMap.toMap + } + + /** + * Returns all datasets that can be reached from `destinationIdentifier`. + */ + private def reachabilitySet( + destinationIdentifier: TableIdentifier, + downstream: Boolean): Set[TableIdentifier] = { + if (downstream) downstreamDestinations(destinationIdentifier) + else upstreamDestinations(destinationIdentifier) + } + + /** Returns the set of flows reachable from `flowIdentifier` via output (child) edges. */ + def downstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = { + assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow") + val downstreamDatasets = reachabilitySet(flowNodes(flowIdentifier).output, downstream = true) + flowNodes.values.filter(_.inputs.exists(downstreamDatasets.contains)).map(_.identifier).toSet + } + + /** Returns the set of flows reachable from `flowIdentifier` via input (parent) edges. */ + def upstreamFlows(flowIdentifier: TableIdentifier): Set[TableIdentifier] = { + assert(flowNodes.contains(flowIdentifier), s"$flowIdentifier is not a valid start flow") + val upstreamDatasets = + flowNodes(flowIdentifier).inputs.flatMap(reachabilitySet(_, downstream = false)) + flowNodes.values.filter(e => upstreamDatasets.contains(e.output)).map(_.identifier).toSet + } + + /** Returns the set of datasets reachable from `datasetIdentifier` via input (parent) edges. */ + def upstreamDatasets(datasetIdentifier: TableIdentifier): Set[TableIdentifier] = + reachabilitySet(datasetIdentifier, downstream = false) - datasetIdentifier + + /** + * Traverses the graph upstream starting from the specified `datasetIdentifiers` to return the + * reachable nodes. The return map's keyset consists of all datasets reachable from + * `datasetIdentifiers`. For each entry in the response map, the value of that element refers + * to which of `datasetIdentifiers` was able to reach the key. If multiple of `datasetIdentifiers` + * could reach that key, one is picked arbitrarily. + */ + def upstreamDatasets( + datasetIdentifiers: Seq[TableIdentifier]): Map[TableIdentifier, Set[TableIdentifier]] = + reachabilitySet(datasetIdentifiers, downstream = false) -- datasetIdentifiers +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala new file mode 100644 index 0000000000000..0e2ba42b15e59 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * A mutable context for registering tables, views, and flows in a dataflow graph. + * + * @param defaultCatalog The pipeline's default catalog. + * @param defaultDatabase The pipeline's default schema. + */ +class GraphRegistrationContext( + val defaultCatalog: String, + val defaultDatabase: String, + val defaultSqlConf: Map[String, String]) { + import GraphRegistrationContext._ + + protected val tables = new mutable.ListBuffer[Table] + protected val views = new mutable.ListBuffer[View] + protected val flows = new mutable.ListBuffer[UnresolvedFlow] + + def registerTable(tableDef: Table): Unit = { + tables += tableDef + } + + def registerView(viewDef: View): Unit = { + views += viewDef + } + + def registerFlow(flowDef: UnresolvedFlow): Unit = { + flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf) + } + + def toDataflowGraph: DataflowGraph = { + val qualifiedTables = tables.toSeq.map { t => + t.copy( + identifier = GraphIdentifierManager + .parseAndQualifyTableIdentifier( + rawTableIdentifier = t.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier + ) + } + + val validatedViews = views.toSeq.collect { + case v: TemporaryView => + v.copy( + identifier = GraphIdentifierManager + .parseAndValidateTemporaryViewIdentifier( + rawViewIdentifier = v.identifier + ) + ) + case v: PersistedView => + v.copy( + identifier = GraphIdentifierManager + .parseAndValidatePersistedViewIdentifier( + rawViewIdentifier = v.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + ) + } + + val qualifiedFlows = flows.toSeq.map { f => + val isImplicitFlow = f.identifier == f.destinationIdentifier + val flowWritesToView = + validatedViews + .filter(_.isInstanceOf[TemporaryView]) + .exists(_.identifier == f.destinationIdentifier) + + // If the flow is created implicitly as part of defining a view, then we do not + // qualify the flow identifier and the flow destination. This is because views are + // not permitted to have multipart + if (isImplicitFlow && flowWritesToView) { + f + } else { + f.copy( + identifier = GraphIdentifierManager + .parseAndQualifyFlowIdentifier( + rawFlowIdentifier = f.identifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier, + destinationIdentifier = GraphIdentifierManager + .parseAndQualifyFlowIdentifier( + rawFlowIdentifier = f.destinationIdentifier, + currentCatalog = Some(defaultCatalog), + currentDatabase = Some(defaultDatabase) + ) + .identifier + ) + } + } + + assertNoDuplicates( + qualifiedTables = qualifiedTables, + validatedViews = validatedViews, + qualifiedFlows = qualifiedFlows + ) + + new DataflowGraph( + tables = qualifiedTables, + views = validatedViews, + flows = qualifiedFlows + ) + } + + private def assertNoDuplicates( + qualifiedTables: Seq[Table], + validatedViews: Seq[View], + qualifiedFlows: Seq[UnresolvedFlow]): Unit = { + + (qualifiedTables.map(_.identifier) ++ validatedViews.map(_.identifier)) + .foreach { identifier => + assertDatasetIdentifierIsUnique( + identifier = identifier, + tables = qualifiedTables, + views = validatedViews + ) + } + + qualifiedFlows.foreach { flow => + assertFlowIdentifierIsUnique( + flow = flow, + datasetType = TableType, + flows = qualifiedFlows + ) + } + } + + private def assertDatasetIdentifierIsUnique( + identifier: TableIdentifier, + tables: Seq[Table], + views: Seq[View]): Unit = { + + // We need to check for duplicates in both tables and views, as they can have the same name. + val allDatasets = tables.map(t => t.identifier -> TableType) ++ views.map( + v => v.identifier -> ViewType + ) + + val grouped = allDatasets.groupBy { case (id, _) => id } + + grouped(identifier).toList match { + case (_, firstType) :: (_, secondType) :: _ => + // Sort the types in lexicographic order to ensure consistent error messages. + val sortedTypes = Seq(firstType.toString, secondType.toString).sorted + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET", + messageParameters = Map( + "datasetName" -> identifier.quotedString, + "datasetType1" -> sortedTypes.head, + "datasetType2" -> sortedTypes.last + ) + ) + case _ => // No duplicates found. + } + } + + private def assertFlowIdentifierIsUnique( + flow: UnresolvedFlow, + datasetType: DatasetType, + flows: Seq[UnresolvedFlow]): Unit = { + flows.groupBy(i => i.identifier).get(flow.identifier).filter(_.size > 1).foreach { + duplicateFlows => + val duplicateFlow = duplicateFlows.filter(_ != flow).head + throw new AnalysisException( + errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW", + messageParameters = Map( + "flowName" -> flow.identifier.unquotedString, + "datasetNames" -> Set( + flow.destinationIdentifier.quotedString, + duplicateFlow.destinationIdentifier.quotedString + ).mkString(",") + ) + ) + } + } +} + +object GraphRegistrationContext { + sealed trait DatasetType + + private object TableType extends DatasetType { + override def toString: String = "TABLE" + } + + private object ViewType extends DatasetType { + override def toString: String = "VIEW" + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala new file mode 100644 index 0000000000000..99142432f9cec --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique +import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils + +/** Validations performed on a [[DataflowGraph]]. */ +trait GraphValidations extends Logging { + this: DataflowGraph => + + /** + * Validate multi query table correctness. + */ + protected[pipelines] def validateMultiQueryTables(): Map[TableIdentifier, Seq[Flow]] = { + val multiQueryTables = flowsTo.filter(_._2.size > 1) + // Non-streaming tables do not support multiflow. + multiQueryTables + .find { + case (dest, flows) => + flows.exists(f => !resolvedFlow(f.identifier).df.isStreaming) && + table.contains(dest) + } + .foreach { + case (dest, flows) => + throw new AnalysisException( + "MATERIALIZED_VIEW_WITH_MULTIPLE_QUERIES", + Map( + "tableName" -> dest.unquotedString, + "queries" -> flows.map(_.identifier).mkString(",") + ) + ) + } + + multiQueryTables + } + + /** Throws an exception if the flows in this graph are not topologically sorted. */ + protected[graph] def validateGraphIsTopologicallySorted(): Unit = { + val visitedNodes = mutable.Set.empty[TableIdentifier] // Set of visited nodes + val visitedEdges = mutable.Set.empty[TableIdentifier] // Set of visited edges + flows.foreach { f => + // Unvisited inputs of the current flow + val unvisitedInputNodes = + resolvedFlow(f.identifier).inputs -- visitedNodes + unvisitedInputNodes.headOption match { + case None => + visitedEdges.add(f.identifier) + if (flowsTo(f.destinationIdentifier).map(_.identifier).forall(visitedEdges.contains)) { + // A node is marked visited if all its inputs are visited + visitedNodes.add(f.destinationIdentifier) + } + case Some(unvisitedInput) => + throw new AnalysisException( + "PIPELINE_GRAPH_NOT_TOPOLOGICALLY_SORTED", + Map( + "flowName" -> f.identifier.unquotedString, + "inputName" -> unvisitedInput.unquotedString + ) + ) + } + } + } + + /** + * Validate that all tables are resettable. This is a best-effort check that will only catch + * upstream tables that are resettable but have a non-resettable downstream dependency. + */ + protected def validateTablesAreResettable(): Unit = { + validateTablesAreResettable(tables) + } + + /** Validate that all specified tables are resettable. */ + protected def validateTablesAreResettable(tables: Seq[Table]): Unit = { + val tableLookup = mapUnique(tables, "table")(_.identifier) + val nonResettableTables = + tables.filter(t => !PipelinesTableProperties.resetAllowed.fromMap(t.properties)) + val upstreamResettableTables = upstreamDatasets(nonResettableTables.map(_.identifier)) + .collect { + // Filter for upstream datasets that are tables with downstream streaming tables + case (upstreamDataset, nonResettableDownstreams) if table.contains(upstreamDataset) => + nonResettableDownstreams + .filter( + t => flowsTo(t).exists(f => resolvedFlow(f.identifier).df.isStreaming) + ) + .map(id => (tableLookup(upstreamDataset), tableLookup(id).displayName)) + } + .flatten + .toSeq + .filter { + case (t, _) => PipelinesTableProperties.resetAllowed.fromMap(t.properties) + } // Filter for resettable + + upstreamResettableTables + .groupBy(_._2) // Group-by non-resettable downstream tables + .view + .mapValues(_.map(_._1)) + .toSeq + .sortBy(_._2.size) // Output errors from largest to smallest + .reverse + .map { + case (nameForEvent, tables) => + throw new AnalysisException( + "INVALID_RESETTABLE_DEPENDENCY", + Map( + "downstreamTable" -> nameForEvent, + "upstreamResettableTables" -> tables + .map(_.displayName) + .sorted + .map(t => s"'$t'") + .mkString(", "), + "resetAllowedKey" -> PipelinesTableProperties.resetAllowed.key + ) + ) + } + } + + protected def validateUserSpecifiedSchemas(): Unit = { + flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput => + // The output inferred schema of a table is the declared schema merged with the + // schema of all incoming flows. This must be equivalent to the declared schema. + val inferredSchema = SchemaInferenceUtils + .inferSchemaFromFlows( + flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)), + userSpecifiedSchema = t.specifiedSchema + ) + + t.specifiedSchema.foreach { ss => + // Check the inferred schema matches the specified schema. Used to catch errors where the + // inferred user-facing schema has columns that are not in the specified one. + if (inferredSchema != ss) { + val datasetType = GraphElementTypeUtils + .getDatasetTypeForMaterializedViewOrStreamingTable( + flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)) + ) + throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError( + t.identifier, + datasetType, + ss, + inferredSchema + ) + } + } + } + } + + /** + * Validates that all flows are resolved. If there are unresolved flows, + * detects a possible cyclic dependency and throw the appropriate execption. + */ + protected def validateSuccessfulFlowAnalysis(): Unit = { + // all failed flows with their errors + val flowAnalysisFailures = resolutionFailedFlows.flatMap( + f => f.failure.headOption.map(err => (f.identifier, err)) + ) + // only proceed if there are unresolved flows + if (flowAnalysisFailures.nonEmpty) { + val failedFlowIdentifiers = flowAnalysisFailures.map(_._1).toSet + // used to collect the subgraph of only the unresolved flows + // maps every unresolved flow to the set of unresolved flows writing to one if its inputs + val failedFlowsSubgraph = mutable.Map[TableIdentifier, Seq[TableIdentifier]]() + val (downstreamFailures, directFailures) = flowAnalysisFailures.partition { + case (flowIdentifier, _) => + // If a failed flow writes to any of the requested datasets, we mark this flow as a + // downstream failure + val failedFlowsWritingToRequestedDatasets = + resolutionFailedFlow(flowIdentifier).funcResult.requestedInputs + .flatMap(d => flowsTo.getOrElse(d, Seq())) + .map(_.identifier) + .intersect(failedFlowIdentifiers) + .toSeq + failedFlowsSubgraph += (flowIdentifier -> failedFlowsWritingToRequestedDatasets) + failedFlowsWritingToRequestedDatasets.nonEmpty + } + // if there are flow that failed due to unresolved upstream flows, check for a cycle + if (failedFlowsSubgraph.nonEmpty) { + detectCycle(failedFlowsSubgraph.toMap).foreach { + case (upstream, downstream) => + val upstreamDataset = flow(upstream).destinationIdentifier + val downstreamDataset = flow(downstream).destinationIdentifier + throw CircularDependencyException( + downstreamDataset, + upstreamDataset + ) + } + } + // otherwise report what flows failed directly vs. depending on a failed flow + throw UnresolvedPipelineException( + this, + directFailures.map { case (id, value) => (id, value) }.toMap, + downstreamFailures.map { case (id, value) => (id, value) }.toMap + ) + } + } + + /** + * Generic method to detect a cycle in directed graph via DFS traversal. + * The graph is given as a reverse adjacency map, that is, a map from + * each node to its ancestors. + * @return the start and end node of a cycle if found, None otherwise + */ + private def detectCycle(ancestors: Map[TableIdentifier, Seq[TableIdentifier]]) + : Option[(TableIdentifier, TableIdentifier)] = { + var cycle: Option[(TableIdentifier, TableIdentifier)] = None + val visited = mutable.Set[TableIdentifier]() + def visit(f: TableIdentifier, currentPath: List[TableIdentifier]): Unit = { + if (cycle.isEmpty && !visited.contains(f)) { + if (currentPath.contains(f)) { + cycle = Option((currentPath.head, f)) + } else { + ancestors(f).foreach(visit(_, f :: currentPath)) + visited += f + } + } + } + ancestors.keys.foreach(visit(_, Nil)) + cycle + } + + /** Validates that persisted views don't read from invalid sources */ + protected[graph] def validatePersistedViewSources(): Unit = { + val viewToFlowMap = ViewHelpers.persistedViewIdentifierToFlow(graph = this) + + persistedViews + .foreach { persistedView => + val flow = viewToFlowMap(persistedView.identifier) + val funcResult = resolvedFlow(flow.identifier).funcResult + val inputIdentifiers = (funcResult.batchInputs ++ funcResult.streamingInputs) + .map(_.input.identifier) + + inputIdentifiers + .flatMap(view.get) + .foreach { + case tempView: TemporaryView => + throw new AnalysisException( + errorClass = "INVALID_TEMP_OBJ_REFERENCE", + messageParameters = Map( + "persistedViewName" -> persistedView.identifier.toString, + "temporaryViewName" -> tempView.identifier.toString + ), + cause = null + ) + case _ => + } + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala new file mode 100644 index 0000000000000..4bed25f2aa1c7 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * Exception raised when a flow tries to read from a dataset that exists but is unresolved + * + * @param identifier The identifier of the dataset + */ +case class UnresolvedDatasetException(identifier: TableIdentifier) + extends AnalysisException( + s"Failed to read dataset '${identifier.unquotedString}'. Dataset is defined in the " + + s"pipeline but could not be resolved." + ) + +/** + * Exception raised when a flow fails to read from a table defined within the pipeline + * + * @param name The name of the table + * @param cause The cause of the failure + */ +case class LoadTableException(name: String, cause: Option[Throwable]) + extends SparkException( + errorClass = "INTERNAL_ERROR", + messageParameters = Map("message" -> s"Failed to load table '$name'"), + cause = cause.orNull + ) + +/** + * Exception raised when a pipeline has one or more flows that cannot be resolved + * + * @param directFailures Mapping between the name of flows that failed to resolve (due to an + * error in that flow) and the error that occurred when attempting to + * resolve them + * @param downstreamFailures Mapping between the name of flows that failed to resolve (because they + * failed to read from other unresolved flows) and the error that occurred + * when attempting to resolve them + */ +case class UnresolvedPipelineException( + graph: DataflowGraph, + directFailures: Map[TableIdentifier, Throwable], + downstreamFailures: Map[TableIdentifier, Throwable], + additionalHint: Option[String] = None) + extends AnalysisException( + s""" + |Failed to resolve flows in the pipeline. + | + |A flow can fail to resolve because the flow itself contains errors or because it reads + |from an upstream flow which failed to resolve. + |${additionalHint.getOrElse("")} + |Flows with errors: ${directFailures.keys.map(_.unquotedString).toSeq.sorted.mkString(", ")} + |Flows that failed due to upstream errors: ${downstreamFailures.keys + .map(_.unquotedString) + .toSeq + .sorted + .mkString(", ")} + | + |To view the exceptions that were raised while resolving these flows, look for flow + |failures that precede this log.""".stripMargin + ) + +/** + * Raised when there's a circular dependency in the current pipeline. That is, a downstream + * table is referenced while creating a upstream table. + */ +case class CircularDependencyException( + downstreamTable: TableIdentifier, + upstreamDataset: TableIdentifier) + extends AnalysisException( + s"The downstream table '${downstreamTable.unquotedString}' is referenced when " + + s"creating the upstream table or view '${upstreamDataset.unquotedString}'. " + + s"Circular dependencies are not supported in a pipeline. Please remove the dependency " + + s"between '${upstreamDataset.unquotedString}' and '${downstreamTable.unquotedString}'." + ) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala new file mode 100644 index 0000000000000..c8627a9a0a2e4 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesTableProperties.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.Locale + +import scala.collection.mutable +import scala.util.control.NonFatal + +/** + * Interface for validating and accessing Pipeline-specific table properties. + */ +object PipelinesTableProperties { + + /** Prefix used for all table properties. */ + val pipelinesPrefix = "pipelines." + + /** Map of keys to property entries. */ + private val entries: mutable.HashMap[String, PipelineTableProperty[_]] = mutable.HashMap.empty + + /** Whether the table should be reset when a Reset is triggered. */ + val resetAllowed: PipelineTableProperty[Boolean] = + buildProp("pipelines.reset.allowed", default = true) + + /** + * Validates that all known pipeline properties are valid and can be parsed. + * Also warn the user if they try to set an unknown pipelines property, or any + * property that looks suspiciously similar to a known pipeline property. + * + * @param rawProps Raw table properties to validate and canonicalize + * @param warnFunction Function to warn users of potential issues. Fatal errors are still thrown. + * @return a map of table properties, with canonical-case keys for valid pipelines properties, + * and excluding invalid pipelines properties. + */ + def validateAndCanonicalize( + rawProps: Map[String, String], + warnFunction: String => Unit): Map[String, String] = { + rawProps.flatMap { + case (originalCaseKey, value) => + val k = originalCaseKey.toLowerCase(Locale.ROOT) + + // Make sure all pipelines properties are valid + if (k.startsWith(pipelinesPrefix)) { + entries.get(k) match { + case Some(prop) => + prop.fromMap(rawProps) // Make sure the value can be retrieved + Option(prop.key -> value) // Canonicalize the case + case None => + warnFunction(s"Unknown pipelines table property '$originalCaseKey'.") + // exclude this property - the pipeline won't recognize it anyway, + // and setting it would allow users to use `pipelines.xyz` for their custom + // purposes - which we don't want to encourage. + None + } + } else { + // Make sure they weren't trying to set a pipelines property + val similarProp = entries.get(s"${pipelinesPrefix}$k") + similarProp.foreach { c => + warnFunction( + s"You are trying to set table property '$originalCaseKey', which has a similar name" + + s" to Pipelines table property '${c.key}'. If you are trying to set the Pipelines " + + s"table property, please include the correct prefix." + ) + } + Option(originalCaseKey -> value) + } + } + } + + /** Registers a pipelines table property with the specified key and default value. */ + private def buildProp[T]( + key: String, + default: String, + fromString: String => T): PipelineTableProperty[T] = { + val prop = PipelineTableProperty(key, default, fromString) + entries.put(key.toLowerCase(Locale.ROOT), prop) + prop + } + + private def buildProp(key: String, default: Boolean): PipelineTableProperty[Boolean] = + buildProp(key, default.toString, _.toBoolean) +} + +case class PipelineTableProperty[T](key: String, default: String, fromString: String => T) { + def fromMap(rawProps: Map[String, String]): T = parseFromString(rawProps.getOrElse(key, default)) + + private def parseFromString(value: String): T = { + try { + fromString(value) + } catch { + case NonFatal(_) => + throw new IllegalArgumentException( + s"Could not parse value '$value' for table property '$key'" + ) + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala new file mode 100644 index 0000000000000..042b4d9626fd6 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.util.control.{NonFatal, NoStackTrace} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.pipelines.Language +import org.apache.spark.sql.pipelines.logging.SourceCodeLocation + +/** + * Records information used to track the provenance of a given query to user code. + * + * @param language The language used by the user to define the query. + * @param fileName The file name of the user code that defines the query. + * @param sqlText The SQL text of the query. + * @param line The line number of the query in the user code. + * Line numbers are 1-indexed. + * @param startPosition The start position of the query in the user code. + * @param objectType The type of the object that the query is associated with. (Table, View, etc) + * @param objectName The name of the object that the query is associated with. + */ +case class QueryOrigin( + language: Option[Language] = None, + fileName: Option[String] = None, + sqlText: Option[String] = None, + line: Option[Int] = None, + startPosition: Option[Int] = None, + objectType: Option[String] = None, + objectName: Option[String] = None +) { + + /** + * Merges this origin with another one. + * + * The result has fields set to the value in the other origin if it is defined, or if not, then + * the value in this origin. + */ + def merge(other: QueryOrigin): QueryOrigin = { + QueryOrigin( + language = other.language.orElse(language), + fileName = other.fileName.orElse(fileName), + sqlText = other.sqlText.orElse(sqlText), + line = other.line.orElse(line), + startPosition = other.startPosition.orElse(startPosition), + objectType = other.objectType.orElse(objectType), + objectName = other.objectName.orElse(objectName) + ) + } + + /** + * Merge values from the catalyst origin. + * + * The result has fields set to the value in the other origin if it is defined, or if not, then + * the value in this origin. + */ + def merge(other: Origin): QueryOrigin = { + merge( + QueryOrigin( + sqlText = other.sqlText, + line = other.line, + startPosition = other.startPosition + ) + ) + } + + /** Generates a SourceCodeLocation using the details present in the query origin. */ + def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation( + path = fileName, + // QueryOrigin tracks line numbers using a 1-indexed numbering scheme whereas SourceCodeLocation + // tracks them using a 0-indexed numbering scheme. + lineNumber = line.map(_ - 1), + columnNumber = startPosition, + endingLineNumber = None, + endingColumnNumber = None + ) +} + +object QueryOrigin extends Logging { + + /** An empty QueryOrigin without any provenance information. */ + val empty: QueryOrigin = QueryOrigin() + + /** + * An exception that wraps [[QueryOrigin]] and lets us store it in errors as suppressed + * exceptions. + */ + private case class QueryOriginWrapper(origin: QueryOrigin) extends Exception with NoStackTrace + + implicit class ExceptionHelpers(t: Throwable) { + + /** + * Stores `origin` inside the given throwable using suppressed exceptions. + * + * We rely on suppressed exceptions since that lets us preserve the original exception class + * and type. + */ + def addOrigin(origin: QueryOrigin): Throwable = { + // Only try to add the query context if one has not already been added. + try { + // Do not add the origin again if one is already present. + // This also handles the case where the throwable is `null`. + if (getOrigin(t).isEmpty) { + t.addSuppressed(QueryOriginWrapper(origin)) + } + } catch { + case NonFatal(e) => logError("Failed to add pipeline context", e) + } + t + } + } + + /** Returns the [[QueryOrigin]] stored as a suppressed exception in the given throwable. + * + * @return Some(origin) if the origin is recorded as part of the given throwable, `None` + * otherwise. + */ + def getOrigin(t: Throwable): Option[QueryOrigin] = { + try { + // Wrap in an `Option(_)` first to handle `null` throwables. + Option(t).flatMap { ex => + ex.getSuppressed.collectFirst { + case QueryOriginWrapper(context) => context + } + } + } catch { + case NonFatal(e) => + logError("Failed to get pipeline context", e) + None + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala new file mode 100644 index 0000000000000..9f05219c383dd --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/ViewHelpers.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier + +object ViewHelpers { + + /** Map of view identifier to corresponding unresolved flow */ + def persistedViewIdentifierToFlow(graph: DataflowGraph): Map[TableIdentifier, Flow] = { + graph.persistedViews.map { v => + require( + graph.flowsTo.get(v.identifier).isDefined, + s"No flows to view ${v.identifier} were found" + ) + val flowsToView = graph.flowsTo(v.identifier) + require( + flowsToView.size == 1, + s"Expected a single flow to the view, found ${flowsToView.size} flows to ${v.identifier}" + ) + (v.identifier, flowsToView.head) + }.toMap + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala new file mode 100644 index 0000000000000..770776b29cf08 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.pipelines.util.{ + BatchReadOptions, + InputReadOptions, + SchemaInferenceUtils, + StreamingReadOptions +} +import org.apache.spark.sql.types.StructType + +/** An element in a [[DataflowGraph]]. */ +trait GraphElement { + + /** + * Contains provenance to tie back this GraphElement to the user code that defined it. + * + * This must be set when a [[GraphElement]] is directly created by some user code. + * Subsequently, this initial origin must be propagated as is without modification. + * If this [[GraphElement]] is copied or converted to a different type, then this origin must be + * copied as is. + */ + def origin: QueryOrigin + + protected def spark: SparkSession = SparkSession.getActiveSession.get + + /** Returns the unique identifier for this [[GraphElement]]. */ + def identifier: TableIdentifier + + /** + * Returns a user-visible name for the element. + */ + def displayName: String = identifier.unquotedString +} + +/** + * Specifies an input that can be referenced by another Dataset's query. + */ +trait Input extends GraphElement { + + /** + * Returns a DataFrame that is a result of loading data from this [[Input]]. + * @param readOptions Type of input. Used to determine streaming/batch + * @return Streaming or batch DataFrame of this Input's data. + */ + def load(readOptions: InputReadOptions): DataFrame +} + +/** + * Represents a node in a [[DataflowGraph]] that can be written to by a [[Flow]]. + * Must be backed by a file source. + */ +sealed trait Output { + + /** + * Normalized storage location used for storing materializations for this [[Output]]. + * If None, it means this [[Output]] has not been normalized yet. + */ + def normalizedPath: Option[String] + + /** Return whether the storage location for this [[Output]] has been normalized. */ + final def normalized: Boolean = normalizedPath.isDefined + + /** + * Return the normalized storage location for this [[Output]] and throw if the + * storage location has not been normalized. + */ + @throws[SparkException] + def path: String +} + +/** A type of [[Input]] where data is loaded from a table. */ +sealed trait TableInput extends Input { + + /** The user-specified schema for this table. */ + def specifiedSchema: Option[StructType] +} + +/** + * A table representing a materialized dataset in a [[DataflowGraph]]. + * + * @param identifier The identifier of this table within the graph. + * @param specifiedSchema The user-specified schema for this table. + * @param partitionCols What columns the table should be partitioned by when materialized. + * @param normalizedPath Normalized storage location for the table based on the user-specified table + * path (if not defined, we will normalize a managed storage path for it). + * @param properties Table Properties to set in table metadata. + * @param comment User-specified comment that can be placed on the table. + * @param isStreamingTableOpt if the table is a streaming table, will be None until we have resolved + * flows into table + */ +case class Table( + identifier: TableIdentifier, + specifiedSchema: Option[StructType], + partitionCols: Option[Seq[String]], + normalizedPath: Option[String], + properties: Map[String, String] = Map.empty, + comment: Option[String], + baseOrigin: QueryOrigin, + isStreamingTableOpt: Option[Boolean], + format: Option[String] +) extends TableInput + with Output { + + override val origin: QueryOrigin = baseOrigin.copy( + objectType = Some("table"), + objectName = Some(identifier.unquotedString) + ) + + // Load this table's data from underlying storage. + override def load(readOptions: InputReadOptions): DataFrame = { + try { + lazy val tableName = identifier.quotedString + + val df = readOptions match { + case sro: StreamingReadOptions => + spark.readStream.options(sro.userOptions).table(tableName) + case _: BatchReadOptions => + spark.read.table(tableName) + case _ => + throw new IllegalArgumentException("Unhandled `InputReadOptions` type when loading table") + } + + df + } catch { + case NonFatal(e) => throw LoadTableException(displayName, Option(e)) + } + } + + /** Returns the normalized storage location to this [[Table]]. */ + override def path: String = { + if (!normalized) { + throw GraphErrors.unresolvedTablePath(identifier) + } + normalizedPath.get + } + + /** + * Tell if a table is a streaming table or not. This property is not set until we have resolved + * the flows into the table. The exception reminds engineers that they cant call at random time. + */ + def isStreamingTable: Boolean = isStreamingTableOpt.getOrElse { + throw new IllegalStateException( + "Cannot identify whether the table is streaming table or not. You may need to resolve the " + + "flows into table." + ) + } + + /** + * Get the DatasetType of the table + */ + def datasetType: DatasetType = { + if (isStreamingTable) { + DatasetType.STREAMING_TABLE + } else { + DatasetType.MATERIALIZED_VIEW + } + } +} + +/** + * A type of [[TableInput]] that returns data from a specified schema or from the inferred + * [[Flow]]s that write to the table. + */ +case class VirtualTableInput( + identifier: TableIdentifier, + specifiedSchema: Option[StructType], + incomingFlowIdentifiers: Set[TableIdentifier], + availableFlows: Seq[ResolvedFlow] = Nil +) extends TableInput + with Logging { + override def origin: QueryOrigin = QueryOrigin() + + assert(availableFlows.forall(_.destinationIdentifier == identifier)) + override def load(readOptions: InputReadOptions): DataFrame = { + // Infer the schema for this virtual table + def getFinalSchema: StructType = { + specifiedSchema match { + // This is not a backing table, and we have a user-specified schema, so use it directly. + case Some(ss) => ss + // Otherwise infer the schema from a combination of the incoming flows and the + // user-specified schema, if provided. + case _ => + SchemaInferenceUtils.inferSchemaFromFlows(availableFlows, specifiedSchema) + } + } + + // create empty streaming/batch df based on input type. + def createEmptyDF(schema: StructType): DataFrame = readOptions match { + case _: StreamingReadOptions => + MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark.sqlContext) + .toDF() + case _ => spark.createDataFrame(new util.ArrayList[Row](), schema) + } + + val df = createEmptyDF(getFinalSchema) + df + } +} + +/** + * Representing a view in the [[DataflowGraph]]. + */ +trait View extends GraphElement { + + /** Returns the unique identifier for this [[View]]. */ + val identifier: TableIdentifier + + /** Properties of this view */ + val properties: Map[String, String] + + /** User-specified comment that can be placed on the [[View]]. */ + val comment: Option[String] +} + +/** + * Representing a temporary [[View]] in a [[DataflowGraph]]. + * + * @param identifier The identifier of this view within the graph. + * @param properties Properties of the view + * @param comment when defining a view + */ +case class TemporaryView( + identifier: TableIdentifier, + properties: Map[String, String], + comment: Option[String], + origin: QueryOrigin +) extends View {} + +/** + * Representing a persisted [[View]] in a [[DataflowGraph]]. + * + * @param identifier The identifier of this view within the graph. + * @param properties Properties of the view + * @param comment when defining a view + */ +case class PersistedView( + identifier: TableIdentifier, + properties: Map[String, String], + comment: Option[String], + origin: QueryOrigin +) extends View {} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala new file mode 100644 index 0000000000000..070927aea295f --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions + +/** + * Generic options for a read of an input. + */ +sealed trait InputReadOptions + +/** + * Options for a batch read of an input. + */ +final case class BatchReadOptions() extends InputReadOptions + +/** + * Options for a streaming read of an input. + * + * @param userOptions Holds the user defined read options. + * @param droppedUserOptions Holds the options that were specified by the user but + * not actually used. This is a bug but we are preserving this behavior + * for now to avoid making a backwards incompatible change. + */ +final case class StreamingReadOptions( + userOptions: CaseInsensitiveMap[String] = EmptyUserOptions, + droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions +) extends InputReadOptions + +object StreamingReadOptions { + val EmptyUserOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map()) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala new file mode 100644 index 0000000000000..4777772342d7d --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtils.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.pipelines.common.DatasetType +import org.apache.spark.sql.pipelines.graph.{GraphElementTypeUtils, GraphErrors, ResolvedFlow} +import org.apache.spark.sql.types.{StructField, StructType} + + +object SchemaInferenceUtils { + + /** + * Given a set of flows that write to the same destination and possibly a user-specified schema, + * we infer the schema of the destination dataset. The logic is as follows: + * 1. If there are no incoming flows, return the user-specified schema (if provided) + * or an empty schema. + * 2. If there are incoming flows, we merge the schemas of all flows that write to + * the same destination. + * 3. If a user-specified schema is provided, we merge it with the inferred schema. + * The user-specified schema will take precedence over the inferred schema. + * Returns an error if encountered during schema inference or merging the inferred schema with + * the user-specified one. + */ + def inferSchemaFromFlows( + flows: Seq[ResolvedFlow], + userSpecifiedSchema: Option[StructType]): StructType = { + if (flows.isEmpty) { + return userSpecifiedSchema.getOrElse(new StructType()) + } + + require( + flows.forall(_.destinationIdentifier == flows.head.destinationIdentifier), + "Expected all flows to have the same destination" + ) + + val inferredSchema = flows.map(_.schema).fold(new StructType()) { (schemaSoFar, schema) => + try { + SchemaMergingUtils.mergeSchemas(schemaSoFar, schema) + } catch { + case NonFatal(e) => + throw GraphErrors.unableToInferSchemaError( + flows.head.destinationIdentifier, + schemaSoFar, + schema, + cause = Option(e) + ) + } + } + + val identifier = flows.head.destinationIdentifier + val datasetType = GraphElementTypeUtils.getDatasetTypeForMaterializedViewOrStreamingTable(flows) + // We merge the inferred schema with the user-specified schema to pick up any schema metadata + // that is provided by the user, e.g., comments or column masks. + mergeInferredAndUserSchemasIfNeeded( + identifier, + datasetType, + inferredSchema, + userSpecifiedSchema + ) + } + + private def mergeInferredAndUserSchemasIfNeeded( + tableIdentifier: TableIdentifier, + datasetType: DatasetType, + inferredSchema: StructType, + userSpecifiedSchema: Option[StructType]): StructType = { + userSpecifiedSchema match { + case Some(userSpecifiedSchema) => + try { + // Merge the inferred schema with the user-provided schema hint + SchemaMergingUtils.mergeSchemas(userSpecifiedSchema, inferredSchema) + } catch { + case NonFatal(e) => + throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError( + tableIdentifier, + datasetType, + userSpecifiedSchema, + inferredSchema, + cause = Option(e) + ) + } + case None => inferredSchema + } + } + + /** + * Determines the column changes needed to transform the current schema into the target schema. + * + * This function compares the current schema with the target schema and produces a sequence of + * TableChange objects representing: + * 1. New columns that need to be added + * 2. Existing columns that need type updates + * + * @param currentSchema The current schema of the table + * @param targetSchema The target schema that we want the table to have + * @return A sequence of TableChange objects representing the necessary changes + */ + def diffSchemas(currentSchema: StructType, targetSchema: StructType): Seq[TableChange] = { + val changes = scala.collection.mutable.ArrayBuffer.empty[TableChange] + + // Helper function to get a map of field name to field + def getFieldMap(schema: StructType): Map[String, StructField] = { + schema.fields.map(field => field.name -> field).toMap + } + + val currentFields = getFieldMap(currentSchema) + val targetFields = getFieldMap(targetSchema) + + // Find columns to add (in target but not in current) + val columnsToAdd = targetFields.keySet.diff(currentFields.keySet) + columnsToAdd.foreach { columnName => + val field = targetFields(columnName) + changes += TableChange.addColumn( + Array(columnName), + field.dataType, + field.nullable, + field.getComment().orNull + ) + } + + // Find columns to delete (in current but not in target) + val columnsToDelete = currentFields.keySet.diff(targetFields.keySet) + columnsToDelete.foreach { columnName => + changes += TableChange.deleteColumn(Array(columnName), false) + } + + // Find columns with type changes (in both but with different types) + val commonColumns = currentFields.keySet.intersect(targetFields.keySet) + commonColumns.foreach { columnName => + val currentField = currentFields(columnName) + val targetField = targetFields(columnName) + + // If data types are different, add a type update change + if (currentField.dataType != targetField.dataType) { + changes += TableChange.updateColumnType(Array(columnName), targetField.dataType) + } + + // If nullability is different, add a nullability update change + if (currentField.nullable != targetField.nullable) { + changes += TableChange.updateColumnNullability(Array(columnName), targetField.nullable) + } + + // If comments are different, add a comment update change + val currentComment = currentField.getComment().orNull + val targetComment = targetField.getComment().orNull + if (currentComment != targetComment) { + changes += TableChange.updateColumnComment(Array(columnName), targetComment) + } + } + + changes.toSeq + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala new file mode 100644 index 0000000000000..d15e7ac6425cc --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SchemaMergingUtils.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.sql.types.StructType + +object SchemaMergingUtils { + def mergeSchemas(tableSchema: StructType, dataSchema: StructType): StructType = { + StructType.merge(tableSchema, dataSchema).asInstanceOf[StructType] + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala new file mode 100644 index 0000000000000..01b2a91bb9329 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types.{IntegerType, StructType} + +/** + * Test suite for resolving the flows in a [[DataflowGraph]]. These + * examples are all semantically correct but contain logical errors which should be found + * when connect is called and thrown when validate() is called. + */ +class ConnectInvalidPipelineSuite extends PipelineTest { + + import originalSpark.implicits._ + test("Missing source") { + class P extends TestGraphRegistrationContext(spark) { + registerView("b", query = readFlowFunc("a")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved, "Pipeline should not have resolved properly") + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) + assertAnalysisException( + ex.directFailures(fullyQualifiedIdentifier("b", isView = true)), + "TABLE_OR_VIEW_NOT_FOUND" + ) + } + + test("Correctly differentiate between upstream and downstream errors") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(spark.range(5).toDF())) + registerView("b", query = readFlowFunc("nonExistentFlow")) + registerView("c", query = readFlowFunc("b")) + registerView("d", query = dfFlowFunc(spark.range(5).toDF())) + registerView("e", query = sqlFlowFunc(spark, "SELECT nonExistentColumn FROM RANGE(5)")) + registerView("f", query = readFlowFunc("e")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved, "Pipeline should not have resolved properly") + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert(ex.getMessage.contains("Failed to resolve flows in the pipeline")) + assert( + ex.getMessage.contains( + s"Flows with errors: " + + s"${fullyQualifiedIdentifier("b", isView = true).unquotedString}," + + s" ${fullyQualifiedIdentifier("e", isView = true).unquotedString}" + ) + ) + assert( + ex.getMessage.contains( + s"Flows that failed due to upstream errors: " + + s"${fullyQualifiedIdentifier("c", isView = true).unquotedString}, " + + s"${fullyQualifiedIdentifier("f", isView = true).unquotedString}" + ) + ) + assert( + ex.directFailures.keySet == Set( + fullyQualifiedIdentifier("b", isView = true), + fullyQualifiedIdentifier("e", isView = true) + ) + ) + assert( + ex.downstreamFailures.keySet == Set( + fullyQualifiedIdentifier("c", isView = true), + fullyQualifiedIdentifier("f", isView = true) + ) + ) + assertAnalysisException( + ex.directFailures(fullyQualifiedIdentifier("b", isView = true)), + "TABLE_OR_VIEW_NOT_FOUND" + ) + assert( + ex.directFailures(fullyQualifiedIdentifier("e", isView = true)) + .isInstanceOf[AnalysisException] + ) + assert( + ex.directFailures(fullyQualifiedIdentifier("e", isView = true)) + .getMessage + .contains("nonExistentColumn") + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true)) + .isInstanceOf[UnresolvedDatasetException] + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains( + s"Failed to read dataset " + + s"'${fullyQualifiedIdentifier("b", isView = true).unquotedString}'. " + + s"Dataset is defined in the pipeline but could not be resolved" + ) + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true)) + .isInstanceOf[UnresolvedDatasetException] + ) + assert( + ex.downstreamFailures(fullyQualifiedIdentifier("f", isView = true)) + .getMessage + .contains( + s"Failed to read dataset " + + s"'${fullyQualifiedIdentifier("e", isView = true).unquotedString}'. " + + s"Dataset is defined in the pipeline but could not be resolved" + ) + ) + } + + test("correctly identify direct and downstream errors for multi-flow pipelines") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a") + registerFlow("a", "a", dfFlowFunc(spark.range(5).toDF())) + registerFlow("a", "a_2", sqlFlowFunc(spark, "SELECT non_existent_col FROM RANGE(5)")) + registerTable("b", query = Option(readFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { new P().resolveToDataflowGraph().validate() } + assert(ex.directFailures.keySet == Set(fullyQualifiedIdentifier("a_2"))) + assert(ex.downstreamFailures.keySet == Set(fullyQualifiedIdentifier("b"))) + + } + + test("Missing attribute in the schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("z"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT x FROM a")) + } + + val dfg = new P().resolveToDataflowGraph() + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + }.directFailures(fullyQualifiedIdentifier("b", isView = true)).getMessage + verifyUnresolveColumnError(ex, "x", Seq("z")) + } + + test("Joining on a column with different names") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq("a", "b", "c").toDF("y"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)")) + } + + val dfg = new P().resolveToDataflowGraph() + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert( + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("USING column `x` cannot be resolved on the right side") + ) + } + + test("Writing to one table by unioning flows with different schemas") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(true, false).toDF("x"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT x FROM a UNION SELECT x FROM b")) + } + + val dfg = new P().resolveToDataflowGraph() + assert(!dfg.resolved) + val ex = intercept[UnresolvedPipelineException] { + dfg.validate() + } + assert( + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("compatible column types") || + ex.directFailures(fullyQualifiedIdentifier("c", isView = true)) + .getMessage + .contains("Failed to merge incompatible data types") + ) + } + + test("Self reference") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = readFlowFunc("a")) + } + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset == fullyQualifiedIdentifier("a", isView = true)) + assert(e.downstreamTable == fullyQualifiedIdentifier("a", isView = true)) + } + + test("Cyclic graph - simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = readFlowFunc("b")) + registerView("b", query = readFlowFunc("a")) + } + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + val cycle = Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ) + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + registerView("c", query = readFlowFunc("b")) + registerView("d", query = readFlowFunc("c")) + } + val cycle = + Set( + fullyQualifiedIdentifier("b", isView = true), + fullyQualifiedIdentifier("c", isView = true), + fullyQualifiedIdentifier("d", isView = true) + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph with materialized nodes") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + ) + registerTable("c", query = Option(readFlowFunc("b"))) + registerTable("d", query = Option(readFlowFunc("c"))) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph - second query makes it cyclic") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable("b") + registerFlow("b", "b", readFlowFunc("a")) + registerFlow("b", "b2", readFlowFunc("d")) + registerTable("c", query = Option(readFlowFunc("b"))) + registerTable("d", query = Option(readFlowFunc("c"))) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("Cyclic graph - all named queries") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x")))) + registerTable("b") + registerFlow("b", "`b-name`", sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM d")) + registerTable("c") + registerFlow("c", "`c-name`", readFlowFunc("b")) + registerTable("d") + registerFlow("d", "`d-name`", readFlowFunc("c")) + } + val cycle = + Set( + fullyQualifiedIdentifier("b"), + fullyQualifiedIdentifier("c"), + fullyQualifiedIdentifier("d") + ) + val e = intercept[CircularDependencyException] { + new P().resolveToDataflowGraph().validate() + } + assert(e.upstreamDataset != e.downstreamTable) + assert(cycle.contains(e.upstreamDataset)) + assert(cycle.contains(e.downstreamTable)) + } + + test("view-table conf conflict") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val")) + registerTable("b", query = Option(readFlowFunc("a")), sqlConf = Map("x" -> "b-val")) + } + val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() } + assert( + ex.getMessage.contains( + s"Found duplicate sql conf for dataset " + + s"'${fullyQualifiedIdentifier("b").unquotedString}':" + ) + ) + assert( + ex.getMessage.contains( + s"'x' is defined by both " + + s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " + + s"and '${fullyQualifiedIdentifier("b").unquotedString}'" + ) + ) + } + + test("view-view conf conflict") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "a-val")) + registerView("b", query = dfFlowFunc(Seq(1).toDF()), sqlConf = Map("x" -> "b-val")) + registerTable( + "c", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")), + sqlConf = Map("y" -> "c-val") + ) + } + val ex = intercept[AnalysisException] { p.resolveToDataflowGraph() } + assert( + ex.getMessage.contains( + s"Found duplicate sql conf for dataset " + + s"'${fullyQualifiedIdentifier("c").unquotedString}':" + ) + ) + assert( + ex.getMessage.contains( + s"'x' is defined by both " + + s"'${fullyQualifiedIdentifier("a", isView = true).unquotedString}' " + + s"and '${fullyQualifiedIdentifier("b", isView = true).unquotedString}'" + ) + ) + } + + test("reading a complete view incrementally") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF())) + registerTable("b", query = Option(readStreamFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() } + assert( + ex.directFailures(fullyQualifiedIdentifier("b")) + .getMessage + .contains( + s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString}" + + s" is a batch view and must be referenced using SparkSession#read." + ) + ) + } + + test("reading an incremental view completely") { + val p = new TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + mem.addData(1) + registerView("a", query = dfFlowFunc(mem.toDF())) + registerTable("b", query = Option(readFlowFunc("a"))) + } + val ex = intercept[UnresolvedPipelineException] { p.resolveToDataflowGraph().validate() } + assert( + ex.directFailures(fullyQualifiedIdentifier("b")) + .getMessage + .contains( + s"View ${fullyQualifiedIdentifier("a", isView = true).quotedString} " + + s"is a streaming view and must be referenced using SparkSession#readStream" + ) + ) + } + + test("Inferred schema that isn't a subset of user-specified schema") { + val graph1 = new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq(1, 2).toDF("incorrect-col-name"))), + specifiedSchema = Option(new StructType().add("x", IntegerType)) + ) + }.resolveToDataflowGraph() + val ex1 = intercept[AnalysisException] { graph1.validate() } + assert( + ex1.getMessage.contains( + s"'${fullyQualifiedIdentifier("a").unquotedString}' " + + s"has a user-specified schema that is incompatible" + ) + ) + assert(ex1.getMessage.contains("incorrect-col-name")) + + val graph2 = new TestGraphRegistrationContext(spark) { + registerTable("a", specifiedSchema = Option(new StructType().add("x", IntegerType))) + registerFlow("a", "a", query = dfFlowFunc(Seq(true, false).toDF("x")), once = true) + }.resolveToDataflowGraph() + val ex2 = intercept[AnalysisException] { graph2.validate() } + assert( + ex2.getMessage.contains( + s"'${fullyQualifiedIdentifier("a").unquotedString}' " + + s"has a user-specified schema that is incompatible" + ) + ) + assert(ex2.getMessage.contains("boolean") && ex2.getMessage.contains("integer")) + + val streamingTableHint = "please full refresh" + assert(!ex1.getMessage.contains(streamingTableHint)) + assert(ex2.getMessage.contains(streamingTableHint)) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala new file mode 100644 index 0000000000000..f8b5133ff167d --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Test suite for resolving the flows in a [[DataflowGraph]]. These + * examples are all semantically correct and logically correct and connect should not result in any + * errors. + */ +class ConnectValidPipelineSuite extends PipelineTest { + + import originalSpark.implicits._ + + test("Extra simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("y"))) + } + val p = new P().resolveToDataflowGraph() + val outSchema = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema) + } + + test("Simple") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a")) + } + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("a", isView = true), + new StructType().add("x", IntegerType, false) + ) + val outSchema = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), outSchema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Dependencies") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT y as z FROM b")) + registerView("b", query = sqlFlowFunc(spark, "SELECT x as y FROM a")) + } + val p = new P().resolveToDataflowGraph() + val schemaAB = new StructType().add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaAB) + val schemaBC = new StructType().add("z", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaBC) + } + + test("Multi-hop schema merging") { + class P extends TestGraphRegistrationContext(spark) { + registerView( + "b", + query = sqlFlowFunc(spark, """SELECT * FROM VALUES ((1)) OUTER JOIN d ON false""") + ) + registerView("e", query = readFlowFunc("b")) + registerView("d", query = dfFlowFunc(Seq(1).toDF("y"))) + } + val p = new P().resolveToDataflowGraph() + val schemaE = new StructType().add("col1", IntegerType, false).add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("b", isView = true), schemaE) + verifyFlowSchema(p, fullyQualifiedIdentifier("e", isView = true), schemaE) + } + + test("Cross product join merges schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a CROSS JOIN b")) + } + val p = new P().resolveToDataflowGraph() + val schemaC = new StructType().add("x", IntegerType, false).add("y", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Real join merges schema") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq((1, "a"), (2, "b"), (3, "c")).toDF("x", "y"))) + registerView("b", query = dfFlowFunc(Seq((2, "m"), (3, "n"), (4, "o")).toDF("x", "z"))) + registerView("c", query = sqlFlowFunc(spark, "SELECT * FROM a JOIN b USING (x)")) + } + val p = new P().resolveToDataflowGraph() + val schemaC = new StructType() + .add("x", IntegerType, false) + .add("y", StringType) + .add("z", StringType) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schemaC) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Union of streaming and batch Dataframes") { + class P extends TestGraphRegistrationContext(spark) { + val ints = MemoryStream[Int] + ints.addData(1, 2, 3, 4) + registerView("a", query = dfFlowFunc(ints.toDF())) + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF())) + registerView( + "c", + query = FlowAnalysis.createFlowFunctionFromLogicalPlan( + Union( + Seq( + UnresolvedRelation( + TableIdentifier("a"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ), + UnresolvedRelation(TableIdentifier("b")) + ) + ) + ) + ) + } + + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("c", isView = true), + new StructType().add("value", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Union of two streaming Dataframes") { + class P extends TestGraphRegistrationContext(spark) { + val ints1 = MemoryStream[Int] + ints1.addData(1, 2, 3, 4) + val ints2 = MemoryStream[Int] + ints2.addData(1, 2, 3, 4) + registerView("a", query = dfFlowFunc(ints1.toDF())) + registerView("b", query = dfFlowFunc(ints2.toDF())) + registerView( + "c", + query = FlowAnalysis.createFlowFunctionFromLogicalPlan( + Union( + Seq( + UnresolvedRelation( + TableIdentifier("a"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ), + UnresolvedRelation( + TableIdentifier("b"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + ) + ) + ) + } + + val p = new P().resolveToDataflowGraph() + verifyFlowSchema( + p, + fullyQualifiedIdentifier("c", isView = true), + new StructType().add("value", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("c", isView = true)).inputs == Set( + fullyQualifiedIdentifier("a", isView = true), + fullyQualifiedIdentifier("b", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("MultipleInputs") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerView("b", query = dfFlowFunc(Seq(4, 5, 6).toDF("y"))) + registerView( + "c", + query = sqlFlowFunc(spark, "SELECT x AS z FROM a UNION SELECT y AS z FROM b") + ) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("z", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("c", isView = true), schema) + } + + test("Connect retains and fuses confs") { + // a -> b \ + // d + // c / + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val")) + registerView("b", query = readFlowFunc("a"), Map("b" -> "b-val")) + registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), Map("c" -> "c-val")) + registerTable( + "d", + query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")), + Map("d" -> "d-val") + ) + } + val graph = p.resolveToDataflowGraph() + assert( + graph + .flow(fullyQualifiedIdentifier("d", isView = false)) + .sqlConf == Map("a" -> "a-val", "b" -> "b-val", "c" -> "c-val", "d" -> "d-val") + ) + } + + test("Confs aren't fused past materialization points") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1).toDF("x")), Map("a" -> "a-val")) + registerTable("b", query = Option(readFlowFunc("a")), Map("b" -> "b-val")) + registerView("c", query = dfFlowFunc(Seq(2).toDF("x")), sqlConf = Map("c" -> "c-val")) + registerTable( + "d", + query = Option(sqlFlowFunc(spark, "SELECT * FROM b UNION SELECT * FROM c")), + Map("d" -> "d-val") + ) + } + val graph = p.resolveToDataflowGraph() + assert(graph.flow(fullyQualifiedIdentifier("a", isView = true)).sqlConf == Map("a" -> "a-val")) + assert( + graph + .flow(fullyQualifiedIdentifier("b")) + .sqlConf == Map("a" -> "a-val", "b" -> "b-val") + ) + assert(graph.flow(fullyQualifiedIdentifier("c", isView = true)).sqlConf == Map("c" -> "c-val")) + assert( + graph + .flow(fullyQualifiedIdentifier("d")) + .sqlConf == Map("c" -> "c-val", "d" -> "d-val") + ) + } + + test("Setting the same conf with the same value is totally cool") { + val p = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val")) + registerView("b", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")), Map("key" -> "val")) + registerTable( + "c", + query = Option(sqlFlowFunc(spark, "SELECT * FROM a UNION SELECT * FROM b")), + Map("key" -> "val") + ) + } + val graph = p.resolveToDataflowGraph() + assert(graph.flow(fullyQualifiedIdentifier("c")).sqlConf == Map("key" -> "val")) + } + + test("Named query only") { + class P extends TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerTable("b") + registerFlow("b", "`b-query`", readFlowFunc("a")) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("x", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema) + verifyFlowSchema(p, fullyQualifiedIdentifier("b-query"), schema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b-query")).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + } + + test("Default query and named query") { + class P extends TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + registerView("a", query = dfFlowFunc(mem.toDF())) + registerTable("b") + registerFlow("b", "b", dfFlowFunc(mem.toDF().select($"value" as "y"))) + registerFlow("b", "b2", readStreamFlowFunc("a")) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("value", IntegerType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a", isView = true), schema) + verifyFlowSchema(p, fullyQualifiedIdentifier("b2"), schema) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b2")).inputs == Set( + fullyQualifiedIdentifier("a", isView = true) + ), + "Flow did not have the expected inputs" + ) + verifyFlowSchema( + p, + fullyQualifiedIdentifier("b"), + new StructType().add("y", IntegerType, false) + ) + assert( + p.resolvedFlow(fullyQualifiedIdentifier("b")).inputs == Set.empty, + "Flow did not have the expected inputs" + ) + } + + test("Multi-query table with 2 complete queries") { + class P extends TestGraphRegistrationContext(spark) { + registerTable("a") + registerFlow("a", "a", query = dfFlowFunc(spark.range(5).toDF())) + registerFlow("a", "a2", query = dfFlowFunc(spark.range(6).toDF())) + } + val p = new P().resolveToDataflowGraph() + val schema = new StructType().add("id", LongType, false) + verifyFlowSchema(p, fullyQualifiedIdentifier("a"), schema) + } + + test("Correct types of flows after connection") { + val graph = new TestGraphRegistrationContext(spark) { + val mem = MemoryStream[Int] + mem.addData(1, 2) + registerView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) + registerView("incremental-view", query = dfFlowFunc(mem.toDF())) + registerTable("`complete-table`", query = Option(readFlowFunc("complete-view"))) + registerTable("`incremental-table`") + registerFlow( + "`incremental-table`", + "`incremental-table`", + FlowAnalysis.createFlowFunctionFromLogicalPlan( + UnresolvedRelation( + TableIdentifier("incremental-view"), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + ) + registerFlow( + "`incremental-table`", + "`append-once`", + dfFlowFunc(Seq(1, 2).toDF("x")), + once = true + ) + }.resolveToDataflowGraph() + + assert( + graph + .flow(fullyQualifiedIdentifier("complete-view", isView = true)) + .isInstanceOf[CompleteFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("incremental-view", isView = true)) + .isInstanceOf[StreamingFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("complete-table")) + .isInstanceOf[CompleteFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("incremental-table")) + .isInstanceOf[StreamingFlow] + ) + assert( + graph + .flow(fullyQualifiedIdentifier("append-once")) + .isInstanceOf[AppendOnceFlow] + ) + } + + test("Pipeline level default spark confs are applied with correct precedence") { + val P = new TestGraphRegistrationContext( + spark, + Map("default.conf" -> "value") + ) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq(1, 2, 3).toDF("x"))), + sqlConf = Map("other.conf" -> "value") + ) + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT x as y FROM a")), + sqlConf = Map("default.conf" -> "other-value") + ) + } + val p = P.resolveToDataflowGraph() + + assert( + p.flow(fullyQualifiedIdentifier("a")).sqlConf == Map( + "default.conf" -> "value", + "other.conf" -> "value" + ) + ) + + assert( + p.flow(fullyQualifiedIdentifier("b")).sqlConf == Map( + "default.conf" -> "other-value" + ) + ) + } + + /** Verifies the [[DataflowGraph]] has the specified [[Flow]] with the specified schema. */ + private def verifyFlowSchema( + pipeline: DataflowGraph, + identifier: TableIdentifier, + expected: StructType): Unit = { + assert( + pipeline.flow.contains(identifier), + s"Flow ${identifier.unquotedString} not found," + + s" all flow names: ${pipeline.flow.keys.map(_.unquotedString)}" + ) + assert( + pipeline.resolvedFlow.contains(identifier), + s"Flow ${identifier.unquotedString} has not been resolved" + ) + assert( + pipeline.resolvedFlow(identifier).schema == expected, + s"Flow ${identifier.unquotedString} has the wrong schema" + ) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala new file mode 100644 index 0000000000000..41d5bbe14a6b1 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/util/SchemaInferenceUtilsSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.types._ + +class SchemaInferenceUtilsSuite extends SparkFunSuite { + + test("determineColumnChanges - adding new columns") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("age", IntegerType) + .add("email", StringType, nullable = true, "Email address") + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 2 changes - adding 'age' and 'email' columns + assert(changes.length === 2) + + // Verify the changes are of the correct type and have the right properties + val ageChange = changes + .find { + case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("age")) + case _ => false + } + .get + .asInstanceOf[TableChange.AddColumn] + + val emailChange = changes + .find { + case addCol: TableChange.AddColumn => addCol.fieldNames().sameElements(Array("email")) + case _ => false + } + .get + .asInstanceOf[TableChange.AddColumn] + + // Verify age column properties + assert(ageChange.dataType() === IntegerType) + assert(ageChange.isNullable() === true) // Default nullable is true + assert(ageChange.comment() === null) + + // Verify email column properties + assert(emailChange.dataType() === StringType) + assert(emailChange.isNullable() === true) + assert(emailChange.comment() === "Email address") + } + + test("determineColumnChanges - updating column types") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("amount", DoubleType) + .add("timestamp", TimestampType) + + val targetSchema = new StructType() + .add("id", LongType, nullable = false) // Changed type from Int to Long + .add("amount", DecimalType(10, 2)) // Changed type from Double to Decimal + .add("timestamp", TimestampType) // No change + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 2 changes - updating 'id' and 'amount' column types + assert(changes.length === 2) + + // Verify the changes are of the correct type + val idChange = changes + .find { + case update: TableChange.UpdateColumnType => update.fieldNames().sameElements(Array("id")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnType] + + val amountChange = changes + .find { + case update: TableChange.UpdateColumnType => + update.fieldNames().sameElements(Array("amount")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnType] + + // Verify the new data types + assert(idChange.newDataType() === LongType) + assert(amountChange.newDataType() === DecimalType(10, 2)) + } + + test("determineColumnChanges - updating nullability and comments") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType, nullable = true) + .add("description", StringType, nullable = true, "Item description") + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = true) // Changed nullability + .add("name", StringType, nullable = false) // Changed nullability + .add("description", StringType, nullable = true, "Product description") // Changed comment + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 3 changes - updating nullability for 'id' and 'name', and comment for + // 'description' + assert(changes.length === 3) + + // Verify the nullability changes + val idNullabilityChange = changes + .find { + case update: TableChange.UpdateColumnNullability => + update.fieldNames().sameElements(Array("id")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnNullability] + + val nameNullabilityChange = changes + .find { + case update: TableChange.UpdateColumnNullability => + update.fieldNames().sameElements(Array("name")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnNullability] + + // Verify the comment change + val descriptionCommentChange = changes + .find { + case update: TableChange.UpdateColumnComment => + update.fieldNames().sameElements(Array("description")) + case _ => false + } + .get + .asInstanceOf[TableChange.UpdateColumnComment] + + // Verify the new nullability values + assert(idNullabilityChange.nullable() === true) + assert(nameNullabilityChange.nullable() === false) + + // Verify the new comment + assert(descriptionCommentChange.newComment() === "Product description") + } + + test("determineColumnChanges - complex changes") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("old_field", BooleanType) + + val targetSchema = new StructType() + .add("id", LongType, nullable = true) // Changed type and nullability + // Added comment and changed nullability + .add("name", StringType, nullable = false, "Full name") + .add("new_field", StringType) // New field + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have these changes: + // 1. Update id type + // 2. Update id nullability + // 3. Update name nullability + // 4. Update name comment + // 5. Add new_field + // 6. Remove old_field + assert(changes.length === 6) + + // Count the types of changes + val typeChanges = changes.collect { case _: TableChange.UpdateColumnType => 1 }.size + val nullabilityChanges = changes.collect { + case _: TableChange.UpdateColumnNullability => 1 + }.size + val commentChanges = changes.collect { case _: TableChange.UpdateColumnComment => 1 }.size + val addColumnChanges = changes.collect { case _: TableChange.AddColumn => 1 }.size + + assert(typeChanges === 1) + assert(nullabilityChanges === 2) + assert(commentChanges === 1) + assert(addColumnChanges === 1) + } + + test("determineColumnChanges - no changes") { + val schema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("timestamp", TimestampType) + + // Same schema, no changes expected + val changes = SchemaInferenceUtils.diffSchemas(schema, schema) + assert(changes.isEmpty) + } + + test("determineColumnChanges - deleting columns") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + .add("age", IntegerType) + .add("email", StringType) + .add("phone", StringType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType) + // age, email, and phone columns are removed + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 3 changes - deleting 'age', 'email', and 'phone' columns + assert(changes.length === 3) + + // Verify all changes are DeleteColumn operations + val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc } + assert(deleteChanges.length === 3) + + // Verify the specific columns being deleted + val columnNames = deleteChanges.map(_.fieldNames()(0)).toSet + assert(columnNames === Set("age", "email", "phone")) + } + + test("determineColumnChanges - mixed additions and deletions") { + val currentSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("first_name", StringType) + .add("last_name", StringType) + .add("age", IntegerType) + + val targetSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("full_name", StringType) // New column + .add("email", StringType) // New column + .add("age", IntegerType) // Unchanged + // first_name and last_name are removed + + val changes = SchemaInferenceUtils.diffSchemas(currentSchema, targetSchema) + + // Should have 4 changes: + // - 2 additions (full_name, email) + // - 2 deletions (first_name, last_name) + assert(changes.length === 4) + + // Count the types of changes + val addChanges = changes.collect { case ac: TableChange.AddColumn => ac } + val deleteChanges = changes.collect { case dc: TableChange.DeleteColumn => dc } + + assert(addChanges.length === 2) + assert(deleteChanges.length === 2) + + // Verify the specific columns being added and deleted + val addedColumnNames = addChanges.map(_.fieldNames()(0)).toSet + val deletedColumnNames = deleteChanges.map(_.fieldNames()(0)).toSet + + assert(addedColumnNames === Set("full_name", "email")) + assert(deletedColumnNames === Set("first_name", "last_name")) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala new file mode 100644 index 0000000000000..981fa3cdcae85 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import java.io.{BufferedReader, FileNotFoundException, InputStreamReader} +import java.nio.file.Files + +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Try} +import scala.util.control.NonFatal + +import org.scalactic.source +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, QueryTest, Row, TypedColumn} +import org.apache.spark.sql.SparkSession.{clearActiveSession, setActiveSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession, SQLContext} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, createTempDir} + +abstract class PipelineTest + extends SparkFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach + with Matchers + with SparkErrorTestMixin + with TargetCatalogAndDatabaseMixin + with Logging { + + final protected val storageRoot = createTempDir() + + var spark: SparkSession = createAndInitializeSpark() + val originalSpark: SparkSession = spark.cloneSession() + + implicit def sqlContext: SQLContext = spark.sqlContext + def sql(text: String): DataFrame = spark.sql(text) + + /** + * Spark confs for [[originalSpark]]. Spark confs set here will be the default spark confs for + * all spark sessions created in tests. + */ + protected def sparkConf: SparkConf = { + new SparkConf() + .set("spark.sql.shuffle.partitions", "2") + .set("spark.sql.session.timeZone", "UTC") + } + + /** Returns the dataset name in the event log. */ + protected def eventLogName( + name: String, + catalog: Option[String] = catalogInPipelineSpec, + database: Option[String] = databaseInPipelineSpec, + isView: Boolean = false + ): String = { + fullyQualifiedIdentifier(name, catalog, database, isView).unquotedString + } + + /** Returns the fully qualified identifier. */ + protected def fullyQualifiedIdentifier( + name: String, + catalog: Option[String] = catalogInPipelineSpec, + database: Option[String] = databaseInPipelineSpec, + isView: Boolean = false + ): TableIdentifier = { + if (isView) { + TableIdentifier(name) + } else { + TableIdentifier( + catalog = catalog, + database = database, + table = name + ) + } + } + + /** + * This exists temporarily for compatibility with tests that become invalid when multiple + * executors are available. + */ + protected def master = "local[*]" + + /** Creates and returns a initialized spark session. */ + def createAndInitializeSpark(): SparkSession = { + val newSparkSession = SparkSession + .builder() + .config(sparkConf) + .master(master) + .getOrCreate() + newSparkSession + } + + /** Set up the spark session before each test. */ + protected def initializeSparkBeforeEachTest(): Unit = { + clearActiveSession() + spark = originalSpark.newSession() + setActiveSession(spark) + } + + override def beforeEach(): Unit = { + super.beforeEach() + initializeSparkBeforeEachTest() + cleanupMetastore(spark) + (catalogInPipelineSpec, databaseInPipelineSpec) match { + case (Some(catalog), Some(schema)) => + sql(s"CREATE DATABASE IF NOT EXISTS `$catalog`.`$schema`") + case _ => + databaseInPipelineSpec.foreach(s => sql(s"CREATE DATABASE IF NOT EXISTS `$s`")) + } + } + + override def afterEach(): Unit = { + cleanupMetastore(spark) + super.afterEach() + } + + override def afterAll(): Unit = { + spark.stop() + } + + protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + namedGridTest(testNamePrefix, testTags: _*)(params.map(a => a.toString -> a).toMap)(testFun) + } + + override def test(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */ )( + implicit pos: source.Position): Unit = super.test(testName, testTags: _*) { + runWithInstrumentation(testFun) + } + + /** + * Adds custom instrumentation for tests. + * + * This instrumentation runs after `beforeEach` and + * before `afterEach` which lets us instrument the state of a test and its environment + * after any setup and before any clean-up done for a test. + */ + private def runWithInstrumentation(testFunc: => Any): Any = { + testFunc + } + + /** + * Creates individual tests for all items in [[params]]. + * + * The full test name will be " ( = )" where is one + * item in [[params]]. + * + * @param testNamePrefix The test name prefix. + * @param paramName A descriptive name for the parameter. + * @param testTags Extra tags for the test. + * @param params The list of parameters for which to generate tests. + * @param testFun The actual test function. This function will be called with one argument of + * type [[A]]. + * @tparam A The type of the params. + */ + protected def gridTest[A](testNamePrefix: String, paramName: String, testTags: Tag*)( + params: Seq[A])(testFun: A => Unit): Unit = + namedGridTest(testNamePrefix, testTags: _*)( + params.map(a => s"$paramName = $a" -> a).toMap + )(testFun) + + /** + * Specialized version of gridTest where the params are two boolean values - [[true]] and + * [[false]]. + */ + protected def booleanGridTest(testNamePrefix: String, paramName: String, testTags: Tag*)( + testFun: Boolean => Unit): Unit = { + gridTest(testNamePrefix, paramName, testTags: _*)(Seq(true, false))(testFun) + } + + protected def namedGridTest[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2)) + } + } + + protected def namedGridIgnore[A](testNamePrefix: String, testTags: Tag*)(params: Map[String, A])( + testFun: A => Unit): Unit = { + for (param <- params) { + ignore(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2)) + } + } + + /** Loads a package resources as a Seq of lines. */ + protected def loadResource(path: String): Seq[String] = { + val stream = Thread.currentThread.getContextClassLoader.getResourceAsStream(path) + if (stream == null) { + throw new FileNotFoundException(path) + } + val reader = new BufferedReader(new InputStreamReader(stream)) + val data = new ArrayBuffer[String] + var line = reader.readLine() + while (line != null) { + data.append(line) + line = reader.readLine() + } + data.toSeq + } + + private def checkAnswerAndPlan( + df: => DataFrame, + expectedAnswer: Seq[Row], + checkPlan: Option[SparkPlan => Unit]): Unit = { + QueryTest.checkAnswer(df, expectedAnswer) + + // To help with test development, you can dump the plan to the log by passing + // `--test_env=DUMP_PLAN=true` to `bazel test`. + if (Option(System.getenv("DUMP_PLAN")).exists(s => java.lang.Boolean.valueOf(s))) { + log.info(s"Spark plan:\n${df.queryExecution.executedPlan}") + } + checkPlan.foreach(_.apply(df.queryExecution.executedPlan)) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + checkAnswerAndPlan(df, expectedAnswer, None) + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + case class ValidationArgs( + ignoreFieldOrder: Boolean = false, + ignoreFieldCase: Boolean = false + ) + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + */ + protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = { + val result = getResult(ds) + + if (!QueryTest.compare(result.toSeq, expectedAnswer)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T: Ordering](result: Array[T], expectedAnswer: T*): Unit = { + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + protected def checkDatasetUnorderly[T: Ordering](ds: => Dataset[T], expectedAnswer: T*): Unit = { + val result = getResult(ds) + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + ds + + try ds.collect() + catch { + case NonFatal(e) => + fail( + s""" + |Exception collecting dataset as objects + |${ds.queryExecution} + """.stripMargin, + e + ) + } + } + + /** Holds a parsed version along with the original json of a test. */ + private case class TestSequence(json: Seq[String], rows: Seq[Row]) { + require(json.size == rows.size) + } + + /** + * Helper method to verify unresolved column error message. We expect three elements to be present + * in the message: error class, unresolved column name, list of suggested columns. There are three + * significant differences between different versions of DBR: + * - Error class changed in DBR 11.3 from `MISSING_COLUMN` to `UNRESOLVED_COLUMN.WITH_SUGGESTION` + * - Name parts in suggested columns are escaped with backticks starting from DBR 11.3, + * e.g. table.column => `table`.`column` + * - Starting from DBR 13.1 suggested columns qualification matches unresolved column, i.e. if + * unresolved column is a single-part identifier then suggested column will be as well. E.g. + * for unresolved column `x` suggested columns will omit catalog/schema or `LIVE` qualifier. For + * this reason we verify only last part of suggested column name. + */ + protected def verifyUnresolveColumnError( + errorMessage: String, + unresolved: String, + suggested: Seq[String]): Unit = { + assert(errorMessage.contains(unresolved)) + assert( + errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]") || + errorMessage.contains("[MISSING_COLUMN]") + ) + suggested.foreach { x => + if (errorMessage.contains("[UNRESOLVED_COLUMN.WITH_SUGGESTION]")) { + assert(errorMessage.contains(s"`$x`")) + } else { + assert(errorMessage.contains(x)) + } + } + } + + /** Evaluates the given column and returns the result. */ + def eval(column: Column): Any = { + spark.range(1).select(column).collect().head.get(0) + } + + /** Evaluates a column as part of a query and returns the result */ + def eval[T](col: TypedColumn[Any, T]): T = { + spark.range(1).select(col).head() + } +} + +/** + * A trait that provides a way to specify the target catalog and schema for a test. + */ +trait TargetCatalogAndDatabaseMixin { + + protected def catalogInPipelineSpec: Option[String] = Option( + TestGraphRegistrationContext.DEFAULT_CATALOG + ) + + protected def databaseInPipelineSpec: Option[String] = Option( + TestGraphRegistrationContext.DEFAULT_DATABASE + ) +} + +object PipelineTest extends Logging { + + /** System schemas per-catalog that's can't be directly deleted. */ + protected val systemDatabases: Set[String] = Set("default", "information_schema") + + /** System catalogs that are read-only and cannot be modified/dropped. */ + private val systemCatalogs: Set[String] = Set("samples") + + /** Catalogs that cannot be dropped but schemas or tables under it can be cleaned up. */ + private val undroppableCatalogs: Set[String] = Set( + "hive_metastore", + "spark_catalog", + "system", + "main" + ) + + /** Creates a temporary directory. */ + protected def createTempDir(): String = { + Files.createTempDirectory(getClass.getSimpleName).normalize.toString + } + + /** + * Try to drop the schema in the catalog and return whether it is successfully dropped. + */ + private def dropDatabaseIfPossible( + spark: SparkSession, + catalogName: String, + databaseName: String): Boolean = { + try { + spark.sql(s"DROP DATABASE IF EXISTS `$catalogName`.`$databaseName` CASCADE") + true + } catch { + case NonFatal(e) => + logInfo( + s"Failed to drop database $databaseName in catalog $catalogName, ex:${e.getMessage}" + ) + false + } + } + + /** Cleanup resources created in the metastore by tests. */ + def cleanupMetastore(spark: SparkSession): Unit = synchronized { + // some tests stop the spark session and managed the cleanup by themself, so no need to + // cleanup if no active spark session found + if (spark.sparkContext.isStopped) { + return + } + val catalogs = + spark.sql(s"SHOW CATALOGS").collect().map(_.getString(0)).filterNot(systemCatalogs.contains) + catalogs.foreach { catalog => + if (undroppableCatalogs.contains(catalog)) { + val schemas = + spark.sql(s"SHOW SCHEMAS IN `$catalog`").collect().map(_.getString(0)) + schemas.foreach { schema => + if (systemDatabases.contains(schema) || !dropDatabaseIfPossible(spark, catalog, schema)) { + spark + .sql(s"SHOW tables in `$catalog`.`$schema`") + .collect() + .map(_.getString(0)) + .foreach { table => + Try(spark.sql(s"DROP table IF EXISTS `$catalog`.`$schema`.`$table`")) match { + case Failure(e) => + logInfo( + s"Failed to drop table $table in schema $schema in catalog $catalog, " + + s"ex:${e.getMessage}" + ) + case _ => + } + } + } + } + } else { + Try(spark.sql(s"DROP CATALOG IF EXISTS `$catalog` CASCADE")) match { + case Failure(e) => + logInfo(s"Failed to drop catalog $catalog, ex:${e.getMessage}") + case _ => + } + } + } + spark.sessionState.catalog.reset() + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala new file mode 100644 index 0000000000000..d891d6529de96 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/SparkErrorTestMixin.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.AnalysisException + +/** + * Collection of helper methods to simplify working with exceptions in tests. + */ +trait SparkErrorTestMixin { + this: SparkFunSuite => + + /** + * Asserts that the given exception is a [[AnalysisException]] with the specific error class. + */ + def assertAnalysisException(ex: Throwable, errorClass: String): Unit = { + ex match { + case t: AnalysisException => + assert( + t.getCondition == errorClass, + s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}" + ) + case _ => fail(s"Expected analysis exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[AnalysisException]]h the specific error class + * and metadata. + */ + def assertAnalysisException( + ex: Throwable, + errorClass: String, + metadata: Map[String, String] + ): Unit = { + ex match { + case t: AnalysisException => + assert( + t.getCondition == errorClass, + s"Expected analysis exception with error class $errorClass, but got ${t.getCondition}" + ) + assert( + t.getMessageParameters.asScala.toMap == metadata, + s"Expected analysis exception with metadata $metadata, but got " + + s"${t.getMessageParameters}" + ) + case _ => fail(s"Expected analysis exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[SparkException]] with the specific error class. + */ + def assertSparkException(ex: Throwable, errorClass: String): Unit = { + ex match { + case t: SparkException => + assert( + t.getCondition == errorClass, + s"Expected spark exception with error class $errorClass, but got ${t.getCondition}" + ) + case _ => fail(s"Expected spark exception but got ${ex.getClass}") + } + } + + /** + * Asserts that the given exception is a [[SparkException]] with the specific error class + * and metadata. + */ + def assertSparkException( + ex: Throwable, + errorClass: String, + metadata: Map[String, String] + ): Unit = { + ex match { + case t: SparkException => + assert( + t.getCondition == errorClass, + s"Expected spark exception with error class $errorClass, but got ${t.getCondition}" + ) + assert( + t.getMessageParameters.asScala.toMap == metadata, + s"Expected spark exception with metadata $metadata, but got ${t.getMessageParameters}" + ) + case _ => fail(s"Expected spark exception but got ${ex.getClass}") + } + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala new file mode 100644 index 0000000000000..3449c5155c754 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{LocalTempView, UnresolvedRelation, ViewType} +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.pipelines.graph.{ + DataflowGraph, + FlowAnalysis, + FlowFunction, + GraphIdentifierManager, + GraphRegistrationContext, + PersistedView, + QueryContext, + QueryOrigin, + Table, + TemporaryView, + UnresolvedFlow +} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A test class to simplify the creation of pipelines and datasets for unit testing. + */ +class TestGraphRegistrationContext( + val spark: SparkSession, + val sqlConf: Map[String, String] = Map.empty) + extends GraphRegistrationContext( + defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG, + defaultDatabase = TestGraphRegistrationContext.DEFAULT_DATABASE, + defaultSqlConf = sqlConf + ) { + + // scalastyle:off + // Disable scalastyle to ignore argument count. + def registerTable( + name: String, + query: Option[FlowFunction] = None, + sqlConf: Map[String, String] = Map.empty, + comment: Option[String] = None, + specifiedSchema: Option[StructType] = None, + partitionCols: Option[Seq[String]] = None, + properties: Map[String, String] = Map.empty, + baseOrigin: QueryOrigin = QueryOrigin.empty, + format: Option[String] = None, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + // scalastyle:on + val tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + registerTable( + Table( + identifier = GraphIdentifierManager.parseTableIdentifier(name, spark), + comment = comment, + specifiedSchema = specifiedSchema, + partitionCols = partitionCols, + properties = properties, + baseOrigin = baseOrigin, + format = format.orElse(Some("parquet")), + normalizedPath = None, + isStreamingTableOpt = None + ) + ) + + if (query.isDefined) { + registerFlow( + new UnresolvedFlow( + identifier = tableIdentifier, + destinationIdentifier = tableIdentifier, + func = query.get, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), + sqlConf = sqlConf, + once = false, + comment = comment, + origin = baseOrigin + ) + ) + } + } + + def registerView( + name: String, + query: FlowFunction, + sqlConf: Map[String, String] = Map.empty, + comment: Option[String] = None, + origin: QueryOrigin = QueryOrigin.empty, + viewType: ViewType = LocalTempView, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + + val viewIdentifier = GraphIdentifierManager + .parseAndValidateTemporaryViewIdentifier(rawViewIdentifier = TableIdentifier(name)) + + registerView( + viewType match { + case LocalTempView => + TemporaryView( + identifier = viewIdentifier, + comment = comment, + origin = origin, + properties = Map.empty + ) + case _ => + PersistedView( + identifier = viewIdentifier, + comment = comment, + origin = origin, + properties = Map.empty + ) + } + ) + + registerFlow( + new UnresolvedFlow( + identifier = viewIdentifier, + destinationIdentifier = viewIdentifier, + func = query, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), + sqlConf = sqlConf, + once = false, + comment = comment, + origin = origin + ) + ) + } + + def registerFlow( + destinationName: String, + name: String, + query: FlowFunction, + once: Boolean = false, + catalog: Option[String] = None, + database: Option[String] = None + ): Unit = { + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + val flowDestinationIdentifier = + GraphIdentifierManager.parseTableIdentifier(destinationName, spark) + + registerFlow( + new UnresolvedFlow( + identifier = flowIdentifier, + destinationIdentifier = flowDestinationIdentifier, + func = query, + queryContext = QueryContext( + currentCatalog = catalog.orElse(Some(defaultCatalog)), + currentDatabase = database.orElse(Some(defaultDatabase)) + ), + sqlConf = Map.empty, + once = once, + comment = None, + origin = QueryOrigin() + ) + ) + } + + /** + * Creates a flow function from a logical plan that reads from a table with the given name. + */ + def readFlowFunc(name: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(UnresolvedRelation(TableIdentifier(name))) + } + + /** + * Creates a flow function from a logical plan that reads a stream from a table with the given + * name. + */ + def readStreamFlowFunc(name: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan( + UnresolvedRelation( + TableIdentifier(name), + extraOptions = CaseInsensitiveStringMap.empty(), + isStreaming = true + ) + ) + } + + /** + * Creates a flow function from a logical plan parsed from the given SQL text. + */ + def sqlFlowFunc(spark: SparkSession, sql: String): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(spark.sessionState.sqlParser.parsePlan(sql)) + } + + /** + * Creates a flow function from a logical plan from the given DataFrame. This is meant for + * DataFrames that don't read from tables within the pipeline. + */ + def dfFlowFunc(df: DataFrame): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan(df.logicalPlan) + } + + /** + * Generates a dataflow graph from this pipeline definition and resolves it. + * @return + */ + def resolveToDataflowGraph(): DataflowGraph = toDataflowGraph.resolve() +} + +object TestGraphRegistrationContext { + val DEFAULT_CATALOG = "spark_catalog" + val DEFAULT_DATABASE = "test_db" +}