Skip to content

Commit c17c6b9

Browse files
Szehon Hoszehon-ho
authored andcommitted
Send DML metrics from job to V2Write
1 parent 5e6e8f1 commit c17c6b9

File tree

7 files changed

+181
-12
lines changed

7 files changed

+181
-12
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.connector.write;
1919

2020
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
2122

2223
/**
2324
* An interface that defines how to write the data to data source for batch processing.
@@ -104,4 +105,19 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
104105
* clean up the data left by data writers.
105106
*/
106107
void abort(WriterCommitMessage[] messages);
108+
109+
110+
/**
111+
* Whether this batch write requests execution metrics. Returns a row level operation command this batch write
112+
* is part of, if requested. Return null if not requested.
113+
*/
114+
default RowLevelOperation.Command requestExecMetrics() {
115+
return null;
116+
}
117+
118+
/**
119+
* Provides an array of query execution metrics to the batch write prior to commit.
120+
* @param metrics an array of execution metrics
121+
*/
122+
default void execMetrics(CustomTaskMetric[] metrics) {}
107123
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,4 @@ default CustomMetric[] supportedCustomMetrics() {
8686
default CustomTaskMetric[] reportDriverMetrics() {
8787
return new CustomTaskMetric[]{};
8888
}
89-
9089
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util
2323
import java.util.OptionalLong
2424

2525
import scala.collection.mutable
26+
import scala.collection.mutable.ListBuffer
2627

2728
import com.google.common.base.Objects
2829

@@ -144,6 +145,8 @@ abstract class InMemoryBaseTable(
144145
// The key `Seq[Any]` is the partition values, value is a set of splits, each with a set of rows.
145146
val dataMap: mutable.Map[Seq[Any], Seq[BufferedRows]] = mutable.Map.empty
146147

148+
val commits: ListBuffer[Commit] = ListBuffer[Commit]()
149+
147150
def data: Array[BufferedRows] = dataMap.values.flatten.toArray
148151

149152
def rows: Seq[InternalRow] = dataMap.values.flatten.flatMap(_.rows).toSeq
@@ -498,13 +501,32 @@ abstract class InMemoryBaseTable(
498501
options: CaseInsensitiveStringMap)
499502
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {
500503

504+
var setFilters = Array.empty[Filter]
505+
506+
override def reportDriverMetrics(): Array[CustomTaskMetric] =
507+
Array(new CustomTaskMetric{
508+
override def name(): String = "numSplits"
509+
override def value(): Long = 1L
510+
})
511+
512+
override def supportedCustomMetrics(): Array[CustomMetric] = {
513+
Array(new CustomMetric {
514+
override def name(): String = "numSplits"
515+
override def description(): String = "number of splits in the scan"
516+
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
517+
taskMetrics.sum.toString
518+
}
519+
})
520+
}
521+
501522
override def filterAttributes(): Array[NamedReference] = {
502523
val scanFields = readSchema.fields.map(_.name).toSet
503524
partitioning.flatMap(_.references)
504525
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
505526
}
506527

507528
override def filter(filters: Array[Filter]): Unit = {
529+
this.setFilters = filters
508530
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
509531
val ref = partitioning.head.references().head
510532
filters.foreach {
@@ -575,6 +597,9 @@ abstract class InMemoryBaseTable(
575597
}
576598

577599
protected abstract class TestBatchWrite extends BatchWrite {
600+
601+
var commitProperties: mutable.Map[String, String] = mutable.Map.empty[String, String]
602+
578603
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
579604
BufferedRowsWriterFactory
580605
}
@@ -583,8 +608,11 @@ abstract class InMemoryBaseTable(
583608
}
584609

585610
class Append(val info: LogicalWriteInfo) extends TestBatchWrite {
611+
586612
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
587613
withData(messages.map(_.asInstanceOf[BufferedRows]))
614+
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
615+
commitProperties.clear()
588616
}
589617
}
590618

@@ -593,13 +621,17 @@ abstract class InMemoryBaseTable(
593621
val newData = messages.map(_.asInstanceOf[BufferedRows])
594622
dataMap --= newData.flatMap(_.rows.map(getKey))
595623
withData(newData)
624+
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
625+
commitProperties.clear()
596626
}
597627
}
598628

599629
class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
600630
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
601631
dataMap.clear()
602632
withData(messages.map(_.asInstanceOf[BufferedRows]))
633+
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
634+
commitProperties.clear()
603635
}
604636
}
605637

@@ -747,6 +779,14 @@ private class BufferedRowsReader(
747779

748780
override def close(): Unit = {}
749781

782+
override def currentMetricsValues(): Array[CustomTaskMetric] =
783+
Array[CustomTaskMetric](
784+
new CustomTaskMetric {
785+
override def name(): String = "numSplits"
786+
override def value(): Long = 1
787+
}
788+
)
789+
750790
private def extractFieldValue(
751791
field: StructField,
752792
schema: StructType,
@@ -841,6 +881,8 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
841881
override def value(): Long = value
842882
}
843883

884+
case class Commit(id: Long, properties: Map[String, String])
885+
844886
sealed trait Operation
845887
case object Write extends Operation
846888
case object Delete extends Operation

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717

1818
package org.apache.spark.sql.connector.catalog
1919

20+
import java.time.Instant
2021
import java.util
2122

23+
import scala.collection.mutable.ListBuffer
24+
2225
import org.apache.spark.sql.catalyst.InternalRow
2326
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
2427
import org.apache.spark.sql.connector.catalog.constraints.Constraint
2528
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
2629
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
30+
import org.apache.spark.sql.connector.metric.CustomTaskMetric
2731
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
2832
import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage}
2933
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
@@ -46,6 +50,8 @@ class InMemoryRowLevelOperationTable(
4650
constraints)
4751
with SupportsRowLevelOperations {
4852

53+
private val _scans = ListBuffer.empty[Scan]
54+
4955
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
5056
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
5157
private final val SUPPORTS_DELTAS = "supports-deltas"
@@ -68,6 +74,16 @@ class InMemoryRowLevelOperationTable(
6874
}
6975
}
7076

77+
class InMemoryRowLevelOperationScanBuilder(tableSchema: StructType,
78+
options: CaseInsensitiveStringMap)
79+
extends InMemoryScanBuilder(tableSchema, options) {
80+
override def build: Scan = {
81+
val scan = super.build
82+
_scans += scan
83+
scan
84+
}
85+
}
86+
7187
case class PartitionBasedOperation(command: Command) extends RowLevelOperation {
7288
var configuredScan: InMemoryBatchScan = _
7389

@@ -101,7 +117,7 @@ class InMemoryRowLevelOperationTable(
101117
SortDirection.ASCENDING.defaultNullOrdering()))
102118
}
103119

104-
override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)
120+
override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan, command)
105121

106122
override def description: String = "InMemoryWrite"
107123
}
@@ -111,9 +127,46 @@ class InMemoryRowLevelOperationTable(
111127
override def description(): String = "InMemoryPartitionReplaceOperation"
112128
}
113129

114-
private case class PartitionBasedReplaceData(scan: InMemoryBatchScan) extends TestBatchWrite {
130+
abstract class RowLevelOperationBatchWrite(command: Command) extends TestBatchWrite {
131+
override def requestExecMetrics(): Command = command
132+
133+
override def execMetrics(metrics: Array[CustomTaskMetric]): Unit = {
134+
metrics.foreach(m => commitProperties += (m.name() -> m.value().toString))
135+
}
136+
137+
override def commit(messages: Array[WriterCommitMessage]): Unit = {
138+
assert(_scans.size <= 2, "Expected at most two scans in row-level operations")
139+
assert(_scans.count{ case s: InMemoryBatchScan => s.setFilters.nonEmpty } <= 1,
140+
"Expected at most one scan with runtime filters in row-level operations")
141+
assert(_scans.count{ case s: InMemoryBatchScan => s.setFilters.isEmpty } <= 1,
142+
"Expected at most one scan without runtime filters in row-level operations")
143+
144+
_scans.foreach{
145+
case s: InMemoryBatchScan =>
146+
val prefix = if (s.setFilters.isEmpty) {
147+
""
148+
} else {
149+
"secondScan."
150+
}
151+
s.reportDriverMetrics().foreach { metric =>
152+
commitProperties += (prefix + metric.name() -> metric.value().toString)
153+
}
154+
case _ =>
155+
}
156+
_scans.clear()
157+
doCommit(messages)
158+
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
159+
commitProperties.clear()
160+
}
161+
162+
def doCommit(messages: Array[WriterCommitMessage]): Unit
163+
}
164+
165+
private case class PartitionBasedReplaceData(scan: InMemoryBatchScan,
166+
command: RowLevelOperation.Command)
167+
extends RowLevelOperationBatchWrite(command) {
115168

116-
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
169+
override def doCommit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
117170
val newData = messages.map(_.asInstanceOf[BufferedRows])
118171
val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
119172
val readPartitions = readRows.map(r => getKey(r, schema)).distinct
@@ -134,7 +187,7 @@ class InMemoryRowLevelOperationTable(
134187
override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)
135188

136189
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
137-
new InMemoryScanBuilder(schema, options)
190+
new InMemoryRowLevelOperationScanBuilder(schema, options)
138191
}
139192

140193
override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
@@ -155,7 +208,7 @@ class InMemoryRowLevelOperationTable(
155208
)
156209
}
157210

158-
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
211+
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite(command)
159212
}
160213
}
161214
}
@@ -165,12 +218,14 @@ class InMemoryRowLevelOperationTable(
165218
}
166219
}
167220

168-
private object TestDeltaBatchWrite extends DeltaBatchWrite {
221+
private case class TestDeltaBatchWrite(command: Command)
222+
extends RowLevelOperationBatchWrite(command) with DeltaBatchWrite{
223+
169224
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
170225
DeltaBufferedRowsWriterFactory
171226
}
172227

173-
override def commit(messages: Array[WriterCommitMessage]): Unit = {
228+
override def doCommit(messages: Array[WriterCommitMessage]): Unit = {
174229
val newData = messages.map(_.asInstanceOf[BufferedRows])
175230
withDeletes(newData)
176231
withData(newData)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUt
3030
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
3131
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege}
3232
import org.apache.spark.sql.connector.expressions.Transform
33-
import org.apache.spark.sql.connector.metric.CustomMetric
34-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
33+
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
34+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, RowLevelOperation, Write, WriterCommitMessage}
3535
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3636
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
37+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3738
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
3839
import org.apache.spark.sql.types.StructType
3940
import org.apache.spark.util.{LongAccumulator, Utils}
@@ -398,7 +399,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
398399
/**
399400
* The base physical plan for writing data into data source v2.
400401
*/
401-
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
402+
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper {
402403
def query: SparkPlan
403404
def writingTask: WritingSparkTask[_] = DataWritingSparkTask
404405

@@ -422,6 +423,22 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
422423
tempRdd
423424
}
424425
}
426+
427+
val metricsOpt = batchWrite.requestExecMetrics() match {
428+
case RowLevelOperation.Command.MERGE =>
429+
collectFirst(query) {
430+
case m: MergeRowsExec => m.metrics
431+
}
432+
case _ => None
433+
}
434+
metricsOpt.foreach { metrics =>
435+
batchWrite.execMetrics(
436+
metrics.map {
437+
case (k, v) => V2ExecMetric(k, v.value)
438+
}.toArray
439+
)
440+
}
441+
425442
// introduce a local var to avoid serializing the whole class
426443
val task = writingTask
427444
val writerFactory = batchWrite.createBatchWriterFactory(
@@ -729,3 +746,4 @@ private[v2] case class DataWritingSparkTaskResult(
729746
*/
730747
private[sql] case class StreamWriterCommitProgress(numOutputRows: Long)
731748

749+
private [v2] case class V2ExecMetric(name: String, value: Long) extends CustomTaskMetric

sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,6 +2045,45 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
20452045
}
20462046
}
20472047

2048+
test("V2 write metrics for merge") {
2049+
2050+
Seq("true", "false").foreach { aqeEnabled: String =>
2051+
withTempView("source") {
2052+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
2053+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
2054+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
2055+
|{ "pk": 2, "salary": 200, "dep": "software" }
2056+
|{ "pk": 3, "salary": 300, "dep": "hr" }
2057+
|{ "pk": 4, "salary": 400, "dep": "marketing" }
2058+
|{ "pk": 5, "salary": 500, "dep": "executive" }
2059+
|""".stripMargin)
2060+
2061+
val sourceDF = Seq(1, 2, 6, 10).toDF("pk")
2062+
sourceDF.createOrReplaceTempView("source")
2063+
2064+
sql(
2065+
s"""MERGE INTO $tableNameAsString t
2066+
|USING source s
2067+
|ON t.pk = s.pk
2068+
|WHEN MATCHED AND salary < 200 THEN
2069+
| DELETE
2070+
|WHEN NOT MATCHED AND s.pk < 10 THEN
2071+
| INSERT (pk, salary, dep) VALUES (s.pk, -1, "dummy")
2072+
|WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
2073+
| DELETE
2074+
|""".stripMargin
2075+
)
2076+
2077+
val table = catalog.loadTable(ident)
2078+
// scalastyle:off println
2079+
println(table)
2080+
// scalastyle:on println
2081+
sql(s"DROP TABLE $tableNameAsString")
2082+
}
2083+
}
2084+
}
2085+
}
2086+
20482087
private def findMergeExec(query: String): MergeRowsExec = {
20492088
val plan = executeAndKeepPlan {
20502089
sql(query)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
3636
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
3737
import org.apache.spark.sql.catalyst.catalog._
3838
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
39-
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Identifier, TableChange, TableInfo}
39+
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Identifier, TableCatalog, TableChange, TableInfo}
4040
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
4141
import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER
4242
import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils}

0 commit comments

Comments
 (0)