Skip to content

Commit 021dfb4

Browse files
committed
second attempt
1 parent c17c6b9 commit 021dfb4

File tree

10 files changed

+123
-106
lines changed

10 files changed

+123
-106
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package org.apache.spark.sql.connector.metric;
2+
3+
public interface V2ExecMetric {
4+
/**
5+
* Returns the type of V2 exec metric (ie, firstScan, secondScan, merge, etc).
6+
*/
7+
String metricType();
8+
/**
9+
* Returns the name of V2 exec metric.
10+
*/
11+
String name();
12+
13+
/**
14+
* Returns the value of V2 exec metric.
15+
*/
16+
String value();
17+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.spark.annotation.Evolving;
2121
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
22+
import org.apache.spark.sql.connector.metric.V2ExecMetric;
2223

2324
/**
2425
* An interface that defines how to write the data to data source for batch processing.
@@ -108,16 +109,15 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
108109

109110

110111
/**
111-
* Whether this batch write requests execution metrics. Returns a row level operation command this batch write
112-
* is part of, if requested. Return null if not requested.
112+
* Whether this batch write requests execution metrics.
113113
*/
114-
default RowLevelOperation.Command requestExecMetrics() {
115-
return null;
114+
default boolean requestExecMetrics() {
115+
return false;
116116
}
117117

118118
/**
119119
* Provides an array of query execution metrics to the batch write prior to commit.
120120
* @param metrics an array of execution metrics
121121
*/
122-
default void execMetrics(CustomTaskMetric[] metrics) {}
122+
default void execMetrics(V2ExecMetric[] metrics) {}
123123
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,6 @@ abstract class InMemoryBaseTable(
501501
options: CaseInsensitiveStringMap)
502502
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {
503503

504-
var setFilters = Array.empty[Filter]
505-
506504
override def reportDriverMetrics(): Array[CustomTaskMetric] =
507505
Array(new CustomTaskMetric{
508506
override def name(): String = "numSplits"
@@ -526,7 +524,6 @@ abstract class InMemoryBaseTable(
526524
}
527525

528526
override def filter(filters: Array[Filter]): Unit = {
529-
this.setFilters = filters
530527
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
531528
val ref = partitioning.head.references().head
532529
filters.foreach {
@@ -598,7 +595,7 @@ abstract class InMemoryBaseTable(
598595

599596
protected abstract class TestBatchWrite extends BatchWrite {
600597

601-
var commitProperties: mutable.Map[String, String] = mutable.Map.empty[String, String]
598+
val commitProperties: mutable.Map[String, String] = mutable.Map.empty[String, String]
602599

603600
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
604601
BufferedRowsWriterFactory

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@ package org.apache.spark.sql.connector.catalog
2020
import java.time.Instant
2121
import java.util
2222

23-
import scala.collection.mutable.ListBuffer
24-
2523
import org.apache.spark.sql.catalyst.InternalRow
2624
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
2725
import org.apache.spark.sql.connector.catalog.constraints.Constraint
2826
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
2927
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
30-
import org.apache.spark.sql.connector.metric.CustomTaskMetric
28+
import org.apache.spark.sql.connector.metric.V2ExecMetric
3129
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
3230
import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage}
3331
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
@@ -50,8 +48,6 @@ class InMemoryRowLevelOperationTable(
5048
constraints)
5149
with SupportsRowLevelOperations {
5250

53-
private val _scans = ListBuffer.empty[Scan]
54-
5551
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
5652
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
5753
private final val SUPPORTS_DELTAS = "supports-deltas"
@@ -74,16 +70,6 @@ class InMemoryRowLevelOperationTable(
7470
}
7571
}
7672

77-
class InMemoryRowLevelOperationScanBuilder(tableSchema: StructType,
78-
options: CaseInsensitiveStringMap)
79-
extends InMemoryScanBuilder(tableSchema, options) {
80-
override def build: Scan = {
81-
val scan = super.build
82-
_scans += scan
83-
scan
84-
}
85-
}
86-
8773
case class PartitionBasedOperation(command: Command) extends RowLevelOperation {
8874
var configuredScan: InMemoryBatchScan = _
8975

@@ -117,7 +103,7 @@ class InMemoryRowLevelOperationTable(
117103
SortDirection.ASCENDING.defaultNullOrdering()))
118104
}
119105

120-
override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan, command)
106+
override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)
121107

122108
override def description: String = "InMemoryWrite"
123109
}
@@ -127,33 +113,14 @@ class InMemoryRowLevelOperationTable(
127113
override def description(): String = "InMemoryPartitionReplaceOperation"
128114
}
129115

130-
abstract class RowLevelOperationBatchWrite(command: Command) extends TestBatchWrite {
131-
override def requestExecMetrics(): Command = command
116+
abstract class RowLevelOperationBatchWrite extends TestBatchWrite {
117+
override def requestExecMetrics(): Boolean = true
132118

133-
override def execMetrics(metrics: Array[CustomTaskMetric]): Unit = {
134-
metrics.foreach(m => commitProperties += (m.name() -> m.value().toString))
119+
override def execMetrics(metrics: Array[V2ExecMetric]): Unit = {
120+
metrics.foreach(m => commitProperties += s"${m.metricType()}.${m.name()}" -> m.value())
135121
}
136122

137123
override def commit(messages: Array[WriterCommitMessage]): Unit = {
138-
assert(_scans.size <= 2, "Expected at most two scans in row-level operations")
139-
assert(_scans.count{ case s: InMemoryBatchScan => s.setFilters.nonEmpty } <= 1,
140-
"Expected at most one scan with runtime filters in row-level operations")
141-
assert(_scans.count{ case s: InMemoryBatchScan => s.setFilters.isEmpty } <= 1,
142-
"Expected at most one scan without runtime filters in row-level operations")
143-
144-
_scans.foreach{
145-
case s: InMemoryBatchScan =>
146-
val prefix = if (s.setFilters.isEmpty) {
147-
""
148-
} else {
149-
"secondScan."
150-
}
151-
s.reportDriverMetrics().foreach { metric =>
152-
commitProperties += (prefix + metric.name() -> metric.value().toString)
153-
}
154-
case _ =>
155-
}
156-
_scans.clear()
157124
doCommit(messages)
158125
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
159126
commitProperties.clear()
@@ -162,9 +129,8 @@ class InMemoryRowLevelOperationTable(
162129
def doCommit(messages: Array[WriterCommitMessage]): Unit
163130
}
164131

165-
private case class PartitionBasedReplaceData(scan: InMemoryBatchScan,
166-
command: RowLevelOperation.Command)
167-
extends RowLevelOperationBatchWrite(command) {
132+
private case class PartitionBasedReplaceData(scan: InMemoryBatchScan)
133+
extends RowLevelOperationBatchWrite {
168134

169135
override def doCommit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
170136
val newData = messages.map(_.asInstanceOf[BufferedRows])
@@ -187,7 +153,7 @@ class InMemoryRowLevelOperationTable(
187153
override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)
188154

189155
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
190-
new InMemoryRowLevelOperationScanBuilder(schema, options)
156+
new InMemoryScanBuilder(schema, options)
191157
}
192158

193159
override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
@@ -208,7 +174,7 @@ class InMemoryRowLevelOperationTable(
208174
)
209175
}
210176

211-
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite(command)
177+
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
212178
}
213179
}
214180
}
@@ -218,8 +184,8 @@ class InMemoryRowLevelOperationTable(
218184
}
219185
}
220186

221-
private case class TestDeltaBatchWrite(command: Command)
222-
extends RowLevelOperationBatchWrite(command) with DeltaBatchWrite{
187+
private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite
188+
with DeltaBatchWrite{
223189

224190
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
225191
DeltaBufferedRowsWriterFactory

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
182182
case Some(r) => session.sharedState.cacheManager.uncacheQuery(session, r, cascade = true)
183183
case None => ()
184184
}
185-
WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics) :: Nil
185+
WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics,
186+
relationOpt.map(_.table)) :: Nil
186187

187188
case c @ CreateTable(ResolvedIdentifier(catalog, ident), columns, partitioning,
188189
tableSpec: TableSpec, ifNotExists) =>
@@ -275,7 +276,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
275276
}
276277

277278
case AppendData(r: DataSourceV2Relation, query, _, _, Some(write), _) =>
278-
AppendDataExec(planLater(query), refreshCache(r), write) :: Nil
279+
AppendDataExec(planLater(query), refreshCache(r), write, Some(r.table)) :: Nil
279280

280281
case OverwriteByExpression(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _), _, _,
281282
_, _, Some(write), analyzedQuery) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
@@ -290,10 +291,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
290291

291292
case OverwriteByExpression(
292293
r: DataSourceV2Relation, _, query, _, _, Some(write), _) =>
293-
OverwriteByExpressionExec(planLater(query), refreshCache(r), write) :: Nil
294+
OverwriteByExpressionExec(planLater(query), refreshCache(r), write, Some(r.table)) :: Nil
294295

295296
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, _, Some(write)) =>
296-
OverwritePartitionsDynamicExec(planLater(query), refreshCache(r), write) :: Nil
297+
OverwritePartitionsDynamicExec(planLater(query), refreshCache(r), write, Some(r.table)) :: Nil
297298

298299
case DeleteFromTableWithFilters(r: DataSourceV2Relation, filters) =>
299300
DeleteFromTableExec(r.table.asDeletable, filters.toArray, refreshCache(r)) :: Nil
@@ -332,15 +333,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
332333
throw SparkException.internalError("Unexpected table relation: " + other)
333334
}
334335

335-
case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _,
336+
case ReplaceData(o: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _,
336337
Some(write)) =>
337338
// use the original relation to refresh the cache
338-
ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil
339+
ReplaceDataExec(planLater(query), refreshCache(r), projections, write, Some(o.table)) :: Nil
339340

340-
case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections,
341+
case WriteDelta(o: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections,
341342
Some(write)) =>
342343
// use the original relation to refresh the cache
343-
WriteDeltaExec(planLater(query), refreshCache(r), projections, write) :: Nil
344+
WriteDeltaExec(planLater(query), refreshCache(r), projections, write, Some(o.table)) :: Nil
344345

345346
case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions,
346347
notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
7373
case _ =>
7474
throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
7575
}
76-
7776
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
7877
o.copy(write = Some(write), query = newQuery)
7978

0 commit comments

Comments
 (0)