Skip to content

[SPARK-52689][SQL] Send DML Metrics to V2Write #51377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.metric;

import java.util.OptionalLong;

/**
* Execution metrics for a Merge Operation for a Connector that supports RowLevelOperations
* of this type.
*/
public interface MergeMetrics {

/**
* Returns the number of target rows copied unmodified because they did not match any action.
*/
OptionalLong numTargetRowsCopied();

/**
* Returns the number of target rows inserted.
*/
OptionalLong numTargetRowsInserted();

/**
* Returns the number of target rows deleted.
*/
OptionalLong numTargetRowsDeleted();

/**
* Returns the number of target rows updated.
*/
OptionalLong numTargetRowsUpdated();

/**
* Returns the number of target rows matched and updated by a matched clause.
*/
OptionalLong numTargetRowsMatchedUpdated();

/**
* Returns the number of target rows matched and deleted by a matched clause.
*/
OptionalLong numTargetRowsMatchedDeleted();

/**
* Returns the number of target rows updated by a not matched by source clause.
*/
OptionalLong numTargetRowsNotMatchedBySourceUpdated();

/**
* Returns the number of target rows deleted by a not matched by source clause.
*/
OptionalLong numTargetRowsNotMatchedBySourceDeleted();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.write;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.metric.MergeMetrics;

/**
* An interface that defines how to write the data to data source for batch processing.
Expand Down Expand Up @@ -104,4 +105,13 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
* clean up the data left by data writers.
*/
void abort(WriterCommitMessage[] messages);

/**
* Similar to {@link #commit(WriterCommitMessage[])}, but providing merge exec metrics to
* this batch write.
* @param metrics merge execution metrics
*/
default void commitMerge(WriterCommitMessage[] messages, MergeMetrics metrics) {
commit(messages);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util
import java.util.OptionalLong

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import com.google.common.base.Objects

Expand Down Expand Up @@ -144,6 +145,8 @@ abstract class InMemoryBaseTable(
// The key `Seq[Any]` is the partition values, value is a set of splits, each with a set of rows.
val dataMap: mutable.Map[Seq[Any], Seq[BufferedRows]] = mutable.Map.empty

val commits: ListBuffer[Commit] = ListBuffer[Commit]()

def data: Array[BufferedRows] = dataMap.values.flatten.toArray

def rows: Seq[InternalRow] = dataMap.values.flatten.flatMap(_.rows).toSeq
Expand Down Expand Up @@ -575,6 +578,9 @@ abstract class InMemoryBaseTable(
}

protected abstract class TestBatchWrite extends BatchWrite {

var commitProperties: mutable.Map[String, String] = mutable.Map.empty[String, String]

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
BufferedRowsWriterFactory
}
Expand All @@ -583,8 +589,11 @@ abstract class InMemoryBaseTable(
}

class Append(val info: LogicalWriteInfo) extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

Expand All @@ -593,13 +602,17 @@ abstract class InMemoryBaseTable(
val newData = messages.map(_.asInstanceOf[BufferedRows])
dataMap --= newData.flatMap(_.rows.map(getKey))
withData(newData)
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear()
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

Expand Down Expand Up @@ -841,6 +854,8 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
override def value(): Long = value
}

case class Commit(id: Long, properties: Map[String, String])

sealed trait Operation
case object Write extends Operation
case object Delete extends Operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.connector.catalog

import java.time.Instant
import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.metric.MergeMetrics
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
Expand Down Expand Up @@ -111,7 +113,35 @@ class InMemoryRowLevelOperationTable(
override def description(): String = "InMemoryPartitionReplaceOperation"
}

private case class PartitionBasedReplaceData(scan: InMemoryBatchScan) extends TestBatchWrite {
abstract class RowLevelOperationBatchWrite extends TestBatchWrite {

override def commitMerge(messages: Array[WriterCommitMessage], metrics: MergeMetrics):
Unit = {
commitProperties += "numTargetRowsCopied" -> metrics.numTargetRowsCopied().orElse(-1).toString
commitProperties += "numTargetRowsInserted" ->
metrics.numTargetRowsInserted().orElse(-1).toString
commitProperties += "numTargetRowsDeleted" ->
metrics.numTargetRowsDeleted().orElse(-1).toString
commitProperties += "numTargetRowsUpdated" ->
metrics.numTargetRowsUpdated().orElse(-1).toString
commitProperties += "numTargetRowsInserted" ->
metrics.numTargetRowsInserted().orElse(-1).toString
commitProperties += ("numTargetRowsMatchedDeleted"
-> metrics.numTargetRowsMatchedDeleted().orElse(-1).toString)
commitProperties += ("numTargetRowsMatchedUpdated"
-> metrics.numTargetRowsMatchedUpdated().orElse(-1).toString)
commitProperties += ("numTargetRowsNotMatchedBySourceUpdated"
-> metrics.numTargetRowsNotMatchedBySourceUpdated().orElse(-1).toString)
commitProperties += ("numTargetRowsNotMatchedBySourceDeleted"
-> metrics.numTargetRowsNotMatchedBySourceDeleted().orElse(-1).toString)
commit(messages)
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

private case class PartitionBasedReplaceData(scan: InMemoryBatchScan)
extends RowLevelOperationBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
Expand Down Expand Up @@ -165,7 +195,7 @@ class InMemoryRowLevelOperationTable(
}
}

private object TestDeltaBatchWrite extends DeltaBatchWrite {
private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite with DeltaBatchWrite{
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
DeltaBufferedRowsWriterFactory
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources.v2

import java.util.OptionalLong

import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
Expand All @@ -30,10 +32,11 @@ import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUt
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.metric.{CustomMetric, MergeMetrics}
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{LongAccumulator, Utils}
Expand Down Expand Up @@ -398,7 +401,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
/**
* The base physical plan for writing data into data source v2.
*/
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper {
def query: SparkPlan
def writingTask: WritingSparkTask[_] = DataWritingSparkTask

Expand Down Expand Up @@ -451,8 +454,12 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
}
)

val mergeMetricsOpt = getMergeMetrics(query)
logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} is committing.")
batchWrite.commit(messages)
mergeMetricsOpt match {
case Some(metrics) => batchWrite.commitMerge(messages, metrics)
case None => batchWrite.commit(messages)
}
logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} committed.")
commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value))
} catch {
Expand All @@ -474,6 +481,30 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {

Nil
}

private def getMergeMetrics(query: SparkPlan): Option[MergeMetrics] = {
collectFirst(query) { case m: MergeRowsExec => m }.map{ n =>
MergeMetricsImpl(
numTargetRowsCopied = metric(n.metrics, "numTargetRowsCopied"),
numTargetRowsDeleted = metric(n.metrics, "numTargetRowsDeleted"),
numTargetRowsUpdated = metric(n.metrics, "numTargetRowsUpdated"),
numTargetRowsInserted = metric(n.metrics, "numTargetRowsInserted"),
numTargetRowsMatchedDeleted = metric(n.metrics, "numTargetRowsMatchedDeleted"),
numTargetRowsMatchedUpdated = metric(n.metrics, "numTargetRowsMatchedUpdated"),
numTargetRowsNotMatchedBySourceDeleted =
metric(n.metrics, "numTargetRowsNotMatchedBySourceDeleted"),
numTargetRowsNotMatchedBySourceUpdated =
metric(n.metrics, "numTargetRowsNotMatchedBySourceUpdated")
)
}
}

private def metric(metrics: Map[String, SQLMetric], metric: String): OptionalLong = {
metrics.get(metric) match {
case Some(m) => OptionalLong.of(m.value)
case None => OptionalLong.empty()
}
}
}

trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable {
Expand Down Expand Up @@ -729,3 +760,12 @@ private[v2] case class DataWritingSparkTaskResult(
*/
private[sql] case class StreamWriterCommitProgress(numOutputRows: Long)

private case class MergeMetricsImpl(
override val numTargetRowsCopied: OptionalLong,
override val numTargetRowsDeleted: OptionalLong,
override val numTargetRowsUpdated: OptionalLong,
override val numTargetRowsInserted: OptionalLong,
override val numTargetRowsMatchedUpdated: OptionalLong,
override val numTargetRowsMatchedDeleted: OptionalLong,
override val numTargetRowsNotMatchedBySourceUpdated: OptionalLong,
override val numTargetRowsNotMatchedBySourceDeleted: OptionalLong) extends MergeMetrics
Loading