Skip to content

Commit 9f4156d

Browse files
huanliwang-dbyhuang-db
authored andcommitted
[SPARK-52391][SS] Refactor TransformWithStateExec to extract shared functions and variables into an abstract base class for Scala and Python
### What changes were proposed in this pull request? Refactor the TWS Exec code to extract the common functions/variables and move them to a base abstract class such that it can be shared by both scala exec and python exec. ### Why are the changes needed? code elegant - less duplicate code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? no functionalities change - existing UTs should be able to provide test coverage ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51077 from huanliwang-db/refactor-tws. Authored-by: huanliwang-db <huanli.wang@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 67ee620 commit 9f4156d

File tree

3 files changed

+264
-292
lines changed

3 files changed

+264
-292
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala

Lines changed: 13 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,16 @@ import org.apache.spark.broadcast.Broadcast
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
31-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder}
32-
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime
31+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF}
3332
import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
34-
import org.apache.spark.sql.catalyst.plans.physical.Distribution
3533
import org.apache.spark.sql.catalyst.types.DataTypeUtils
36-
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
34+
import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
3735
import org.apache.spark.sql.execution.metric.SQLMetric
3836
import org.apache.spark.sql.execution.python.ArrowPythonRunner
3937
import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets}
40-
import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorCustomMetric, StatefulOperatorCustomSumMetric, StatefulOperatorPartitioning, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, StateStoreWriter, TransformWithStateMetadataUtils, TransformWithStateVariableInfo, WatermarkSupport}
38+
import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo}
4139
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
42-
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, OperatorStateMetadata, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId}
40+
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId}
4341
import org.apache.spark.sql.internal.SQLConf
4442
import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
4543
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
@@ -83,10 +81,15 @@ case class TransformWithStateInPySparkExec(
8381
initialState: SparkPlan,
8482
initialStateGroupingAttrs: Seq[Attribute],
8583
initialStateSchema: StructType)
86-
extends BinaryExecNode
87-
with StateStoreWriter
88-
with WatermarkSupport
89-
with TransformWithStateMetadataUtils {
84+
extends TransformWithStateExecBase(
85+
groupingAttributes,
86+
timeMode,
87+
outputMode,
88+
batchTimestampMs,
89+
eventTimeWatermarkForEviction,
90+
child,
91+
initialStateGroupingAttrs,
92+
initialState) {
9093

9194
// NOTE: This is needed to comply with existing release of transformWithStateInPandas.
9295
override def shortName: String = if (
@@ -115,17 +118,12 @@ case class TransformWithStateInPySparkExec(
115118

116119
private val numOutputRows: SQLMetric = longMetric("numOutputRows")
117120

118-
// The keys that may have a watermark attribute.
119-
override def keyExpressions: Seq[Attribute] = groupingAttributes
120-
121121
// Each state variable has its own schema, this is a dummy one.
122122
protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
123123

124124
// Each state variable has its own schema, this is a dummy one.
125125
protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
126126

127-
override def operatorStateMetadataVersion: Int = 2
128-
129127
override def getColFamilySchemas(
130128
shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
131129
// For Python, the user can explicitly set nullability on schema, so
@@ -146,37 +144,6 @@ case class TransformWithStateInPySparkExec(
146144
private val driverProcessorHandle: DriverStatefulProcessorHandleImpl =
147145
new DriverStatefulProcessorHandleImpl(timeMode, groupingKeyExprEncoder)
148146

149-
/**
150-
* Distribute by grouping attributes - We need the underlying data and the initial state data
151-
* to have the same grouping so that the data are co-located on the same task.
152-
*/
153-
override def requiredChildDistribution: Seq[Distribution] = {
154-
StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes,
155-
getStateInfo, conf) ::
156-
StatefulOperatorPartitioning.getCompatibleDistribution(
157-
initialStateGroupingAttrs, getStateInfo, conf) ::
158-
Nil
159-
}
160-
161-
/**
162-
* We need the initial state to also use the ordering as the data so that we can co-locate the
163-
* keys from the underlying data and the initial state.
164-
*/
165-
override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
166-
groupingAttributes.map(SortOrder(_, Ascending)),
167-
initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
168-
169-
override def operatorStateMetadata(
170-
stateSchemaPaths: List[List[String]]): OperatorStateMetadata = {
171-
getOperatorStateMetadata(stateSchemaPaths, getStateInfo, shortName, timeMode, outputMode)
172-
}
173-
174-
override def validateNewMetadata(
175-
oldOperatorMetadata: OperatorStateMetadata,
176-
newOperatorMetadata: OperatorStateMetadata): Unit = {
177-
validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata)
178-
}
179-
180147
override def validateAndMaybeEvolveStateSchema(
181148
hadoopConf: Configuration,
182149
batchId: Long,
@@ -208,60 +175,6 @@ case class TransformWithStateInPySparkExec(
208175
conf.stateStoreEncodingFormat)
209176
}
210177

211-
override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
212-
if (timeMode == ProcessingTime) {
213-
// TODO SPARK-50180: check if we can return true only if actual timers are registered,
214-
// or there is expired state
215-
true
216-
} else if (outputMode == OutputMode.Append || outputMode == OutputMode.Update) {
217-
eventTimeWatermarkForEviction.isDefined &&
218-
newInputWatermark > eventTimeWatermarkForEviction.get
219-
} else {
220-
false
221-
}
222-
}
223-
224-
/**
225-
* Controls watermark propagation to downstream modes. If timeMode is
226-
* ProcessingTime, the output rows cannot be interpreted in eventTime, hence
227-
* this node will not propagate watermark in this timeMode.
228-
*
229-
* For timeMode EventTime, output watermark is same as input Watermark because
230-
* transformWithState does not allow users to set the event time column to be
231-
* earlier than the watermark.
232-
*/
233-
override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
234-
timeMode match {
235-
case ProcessingTime =>
236-
None
237-
case _ =>
238-
Some(inputWatermarkMs)
239-
}
240-
}
241-
242-
override def customStatefulOperatorMetrics: Seq[StatefulOperatorCustomMetric] = {
243-
Seq(
244-
// metrics around state variables
245-
StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value state variables"),
246-
StatefulOperatorCustomSumMetric("numListStateVars", "Number of list state variables"),
247-
StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state variables"),
248-
StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of deleted state variables"),
249-
// metrics around timers
250-
StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of registered timers"),
251-
StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted timers"),
252-
StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired timers"),
253-
// metrics around TTL
254-
StatefulOperatorCustomSumMetric("numValueStateWithTTLVars",
255-
"Number of value state variables with TTL"),
256-
StatefulOperatorCustomSumMetric("numListStateWithTTLVars",
257-
"Number of list state variables with TTL"),
258-
StatefulOperatorCustomSumMetric("numMapStateWithTTLVars",
259-
"Number of map state variables with TTL"),
260-
StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry",
261-
"Number of values removed due to TTL expiry")
262-
)
263-
}
264-
265178
/**
266179
* Produces the result of the query as an `RDD[InternalRow]`
267180
*/
@@ -376,8 +289,6 @@ case class TransformWithStateInPySparkExec(
376289
}
377290
}
378291

379-
override def supportsSchemaEvolution: Boolean = true
380-
381292
private def processDataWithPartition(
382293
store: StateStore,
383294
dataIterator: Iterator[InternalRow],
@@ -491,10 +402,6 @@ case class TransformWithStateInPySparkExec(
491402
} else {
492403
copy(child = newLeft)
493404
}
494-
495-
override def left: SparkPlan = child
496-
497-
override def right: SparkPlan = initialState
498405
}
499406

500407
// scalastyle:off argcount

0 commit comments

Comments
 (0)