@@ -28,18 +28,16 @@ import org.apache.spark.broadcast.Broadcast
28
28
import org .apache .spark .rdd .RDD
29
29
import org .apache .spark .sql .catalyst .InternalRow
30
30
import org .apache .spark .sql .catalyst .encoders .ExpressionEncoder
31
- import org .apache .spark .sql .catalyst .expressions .{Ascending , Attribute , Expression , PythonUDF , SortOrder }
32
- import org .apache .spark .sql .catalyst .plans .logical .ProcessingTime
31
+ import org .apache .spark .sql .catalyst .expressions .{Attribute , Expression , PythonUDF }
33
32
import org .apache .spark .sql .catalyst .plans .logical .TransformWithStateInPySpark
34
- import org .apache .spark .sql .catalyst .plans .physical .Distribution
35
33
import org .apache .spark .sql .catalyst .types .DataTypeUtils
36
- import org .apache .spark .sql .execution .{BinaryExecNode , CoGroupedIterator , SparkPlan }
34
+ import org .apache .spark .sql .execution .{CoGroupedIterator , SparkPlan }
37
35
import org .apache .spark .sql .execution .metric .SQLMetric
38
36
import org .apache .spark .sql .execution .python .ArrowPythonRunner
39
37
import org .apache .spark .sql .execution .python .PandasGroupUtils .{executePython , groupAndProject , resolveArgOffsets }
40
- import org .apache .spark .sql .execution .streaming .{DriverStatefulProcessorHandleImpl , StatefulOperatorCustomMetric , StatefulOperatorCustomSumMetric , StatefulOperatorPartitioning , StatefulOperatorStateInfo , StatefulProcessorHandleImpl , StateStoreWriter , TransformWithStateMetadataUtils , TransformWithStateVariableInfo , WatermarkSupport }
38
+ import org .apache .spark .sql .execution .streaming .{DriverStatefulProcessorHandleImpl , StatefulOperatorStateInfo , StatefulProcessorHandleImpl , TransformWithStateExecBase , TransformWithStateVariableInfo }
41
39
import org .apache .spark .sql .execution .streaming .StreamingSymmetricHashJoinHelper .StateStoreAwareZipPartitionsHelper
42
- import org .apache .spark .sql .execution .streaming .state .{NoPrefixKeyStateEncoderSpec , OperatorStateMetadata , RocksDBStateStoreProvider , StateSchemaValidationResult , StateStore , StateStoreColFamilySchema , StateStoreConf , StateStoreId , StateStoreOps , StateStoreProvider , StateStoreProviderId }
40
+ import org .apache .spark .sql .execution .streaming .state .{NoPrefixKeyStateEncoderSpec , RocksDBStateStoreProvider , StateSchemaValidationResult , StateStore , StateStoreColFamilySchema , StateStoreConf , StateStoreId , StateStoreOps , StateStoreProvider , StateStoreProviderId }
43
41
import org .apache .spark .sql .internal .SQLConf
44
42
import org .apache .spark .sql .streaming .{OutputMode , TimeMode }
45
43
import org .apache .spark .sql .types .{BinaryType , StructField , StructType }
@@ -83,10 +81,15 @@ case class TransformWithStateInPySparkExec(
83
81
initialState : SparkPlan ,
84
82
initialStateGroupingAttrs : Seq [Attribute ],
85
83
initialStateSchema : StructType )
86
- extends BinaryExecNode
87
- with StateStoreWriter
88
- with WatermarkSupport
89
- with TransformWithStateMetadataUtils {
84
+ extends TransformWithStateExecBase (
85
+ groupingAttributes,
86
+ timeMode,
87
+ outputMode,
88
+ batchTimestampMs,
89
+ eventTimeWatermarkForEviction,
90
+ child,
91
+ initialStateGroupingAttrs,
92
+ initialState) {
90
93
91
94
// NOTE: This is needed to comply with existing release of transformWithStateInPandas.
92
95
override def shortName : String = if (
@@ -115,17 +118,12 @@ case class TransformWithStateInPySparkExec(
115
118
116
119
private val numOutputRows : SQLMetric = longMetric(" numOutputRows" )
117
120
118
- // The keys that may have a watermark attribute.
119
- override def keyExpressions : Seq [Attribute ] = groupingAttributes
120
-
121
121
// Each state variable has its own schema, this is a dummy one.
122
122
protected val schemaForKeyRow : StructType = new StructType ().add(" key" , BinaryType )
123
123
124
124
// Each state variable has its own schema, this is a dummy one.
125
125
protected val schemaForValueRow : StructType = new StructType ().add(" value" , BinaryType )
126
126
127
- override def operatorStateMetadataVersion : Int = 2
128
-
129
127
override def getColFamilySchemas (
130
128
shouldBeNullable : Boolean ): Map [String , StateStoreColFamilySchema ] = {
131
129
// For Python, the user can explicitly set nullability on schema, so
@@ -146,37 +144,6 @@ case class TransformWithStateInPySparkExec(
146
144
private val driverProcessorHandle : DriverStatefulProcessorHandleImpl =
147
145
new DriverStatefulProcessorHandleImpl (timeMode, groupingKeyExprEncoder)
148
146
149
- /**
150
- * Distribute by grouping attributes - We need the underlying data and the initial state data
151
- * to have the same grouping so that the data are co-located on the same task.
152
- */
153
- override def requiredChildDistribution : Seq [Distribution ] = {
154
- StatefulOperatorPartitioning .getCompatibleDistribution(groupingAttributes,
155
- getStateInfo, conf) ::
156
- StatefulOperatorPartitioning .getCompatibleDistribution(
157
- initialStateGroupingAttrs, getStateInfo, conf) ::
158
- Nil
159
- }
160
-
161
- /**
162
- * We need the initial state to also use the ordering as the data so that we can co-locate the
163
- * keys from the underlying data and the initial state.
164
- */
165
- override def requiredChildOrdering : Seq [Seq [SortOrder ]] = Seq (
166
- groupingAttributes.map(SortOrder (_, Ascending )),
167
- initialStateGroupingAttrs.map(SortOrder (_, Ascending )))
168
-
169
- override def operatorStateMetadata (
170
- stateSchemaPaths : List [List [String ]]): OperatorStateMetadata = {
171
- getOperatorStateMetadata(stateSchemaPaths, getStateInfo, shortName, timeMode, outputMode)
172
- }
173
-
174
- override def validateNewMetadata (
175
- oldOperatorMetadata : OperatorStateMetadata ,
176
- newOperatorMetadata : OperatorStateMetadata ): Unit = {
177
- validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata)
178
- }
179
-
180
147
override def validateAndMaybeEvolveStateSchema (
181
148
hadoopConf : Configuration ,
182
149
batchId : Long ,
@@ -208,60 +175,6 @@ case class TransformWithStateInPySparkExec(
208
175
conf.stateStoreEncodingFormat)
209
176
}
210
177
211
- override def shouldRunAnotherBatch (newInputWatermark : Long ): Boolean = {
212
- if (timeMode == ProcessingTime ) {
213
- // TODO SPARK-50180: check if we can return true only if actual timers are registered,
214
- // or there is expired state
215
- true
216
- } else if (outputMode == OutputMode .Append || outputMode == OutputMode .Update ) {
217
- eventTimeWatermarkForEviction.isDefined &&
218
- newInputWatermark > eventTimeWatermarkForEviction.get
219
- } else {
220
- false
221
- }
222
- }
223
-
224
- /**
225
- * Controls watermark propagation to downstream modes. If timeMode is
226
- * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
227
- * this node will not propagate watermark in this timeMode.
228
- *
229
- * For timeMode EventTime, output watermark is same as input Watermark because
230
- * transformWithState does not allow users to set the event time column to be
231
- * earlier than the watermark.
232
- */
233
- override def produceOutputWatermark (inputWatermarkMs : Long ): Option [Long ] = {
234
- timeMode match {
235
- case ProcessingTime =>
236
- None
237
- case _ =>
238
- Some (inputWatermarkMs)
239
- }
240
- }
241
-
242
- override def customStatefulOperatorMetrics : Seq [StatefulOperatorCustomMetric ] = {
243
- Seq (
244
- // metrics around state variables
245
- StatefulOperatorCustomSumMetric (" numValueStateVars" , " Number of value state variables" ),
246
- StatefulOperatorCustomSumMetric (" numListStateVars" , " Number of list state variables" ),
247
- StatefulOperatorCustomSumMetric (" numMapStateVars" , " Number of map state variables" ),
248
- StatefulOperatorCustomSumMetric (" numDeletedStateVars" , " Number of deleted state variables" ),
249
- // metrics around timers
250
- StatefulOperatorCustomSumMetric (" numRegisteredTimers" , " Number of registered timers" ),
251
- StatefulOperatorCustomSumMetric (" numDeletedTimers" , " Number of deleted timers" ),
252
- StatefulOperatorCustomSumMetric (" numExpiredTimers" , " Number of expired timers" ),
253
- // metrics around TTL
254
- StatefulOperatorCustomSumMetric (" numValueStateWithTTLVars" ,
255
- " Number of value state variables with TTL" ),
256
- StatefulOperatorCustomSumMetric (" numListStateWithTTLVars" ,
257
- " Number of list state variables with TTL" ),
258
- StatefulOperatorCustomSumMetric (" numMapStateWithTTLVars" ,
259
- " Number of map state variables with TTL" ),
260
- StatefulOperatorCustomSumMetric (" numValuesRemovedDueToTTLExpiry" ,
261
- " Number of values removed due to TTL expiry" )
262
- )
263
- }
264
-
265
178
/**
266
179
* Produces the result of the query as an `RDD[InternalRow]`
267
180
*/
@@ -376,8 +289,6 @@ case class TransformWithStateInPySparkExec(
376
289
}
377
290
}
378
291
379
- override def supportsSchemaEvolution : Boolean = true
380
-
381
292
private def processDataWithPartition (
382
293
store : StateStore ,
383
294
dataIterator : Iterator [InternalRow ],
@@ -491,10 +402,6 @@ case class TransformWithStateInPySparkExec(
491
402
} else {
492
403
copy(child = newLeft)
493
404
}
494
-
495
- override def left : SparkPlan = child
496
-
497
- override def right : SparkPlan = initialState
498
405
}
499
406
500
407
// scalastyle:off argcount
0 commit comments