Skip to content

Commit a0f7865

Browse files
HeartSaVioRyhuang-db
authored andcommitted
[SPARK-52333][SS][PYTHON] Squeeze the protocol of retrieving timers for transformWithState in PySpark
### What changes were proposed in this pull request? This PR proposes to squeeze the protocol of retrieving timers for transformWithState in PySpark, which will help a lot on dealing with not-to-be-huge number of timers. Here are the changes: * StatefulProcessorHandleImpl.listTimers(), StatefulProcessorHandleImpl.getExpiredTimers() no longer requires additional request to notice there is no further data to read. * We inline the data into proto message, to ease of determine whether the iterator has fully consumed or not. This change is the same mechanism we applied for ListState & MapState. We got performance improvement in the prior case, and we also see this change to be helpful on our internal benchmark. ### Why are the changes needed? To optimize further on some timer operations. We benchmarked the change with listing 100 timers (PR for benchmarking: apache#50952), and we saw overall performance improvements. > Before the fix ``` ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 78.250 141.583 184.375 635.792 962743.500 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 65.375 126.125 162.792 565.833 60809.333 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 68.500 130.000 170.292 610.083 156733.125 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 486.833 714.961 998.625 2695.417 167039.959 ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 77.916 139.000 182.375 671.792 521809.958 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 65.000 124.333 160.875 596.667 30860.208 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 67.125 127.916 170.250 740.051 64404.416 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 482.041 710.333 1050.333 2685.500 76762.583 ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 78.208 139.959 181.459 722.459 713788.250 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 65.209 125.125 159.625 636.666 27963.167 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 67.417 129.000 168.875 764.602 12991.667 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 479.000 709.584 1045.543 2776.541 92247.542 ``` > After the fix ``` ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 31.250 47.250 75.875 150.000 551557.750 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 26.958 39.208 65.208 122.667 78609.292 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 23.500 41.125 64.542 125.958 52641.042 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 93.125 118.542 156.500 284.625 19910.000 ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 30.875 44.083 70.417 128.875 628912.209 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 26.917 36.416 61.292 109.917 164584.666 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 23.333 38.375 59.542 113.839 114350.250 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 94.125 115.208 148.917 246.292 36924.292 ==================== SET IMPLICIT KEY latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 31.375 58.375 93.041 243.750 719545.583 ==================== REGISTER latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 26.959 50.167 81.833 194.375 67609.583 ==================== DELETE latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 24.208 50.834 83.000 211.018 20611.959 ==================== LIST latency (micros) ====================== perc:50 perc:95 perc:99 perc:99.9 perc:100 95.291 132.375 183.875 427.584 36971.792 ``` Worth noting that it is not only impacting the LIST operation - it also impacts other operations as well. It's not clear why it happens, but the direction of reducing round-trips is proven to be the right direction. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51036 from HeartSaVioR/SPARK-52333. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent c6e15b2 commit a0f7865

File tree

7 files changed

+375
-182
lines changed

7 files changed

+375
-182
lines changed

python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Lines changed: 81 additions & 77 deletions
Large diffs are not rendered by default.

python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,74 @@ class StateResponseWithMapIterator(google.protobuf.message.Message):
366366

367367
global___StateResponseWithMapIterator = StateResponseWithMapIterator
368368

369+
class TimerInfo(google.protobuf.message.Message):
370+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
371+
372+
KEY_FIELD_NUMBER: builtins.int
373+
TIMESTAMPMS_FIELD_NUMBER: builtins.int
374+
key: builtins.bytes
375+
timestampMs: builtins.int
376+
def __init__(
377+
self,
378+
*,
379+
key: builtins.bytes | None = ...,
380+
timestampMs: builtins.int = ...,
381+
) -> None: ...
382+
def HasField(
383+
self, field_name: typing_extensions.Literal["_key", b"_key", "key", b"key"]
384+
) -> builtins.bool: ...
385+
def ClearField(
386+
self,
387+
field_name: typing_extensions.Literal[
388+
"_key", b"_key", "key", b"key", "timestampMs", b"timestampMs"
389+
],
390+
) -> None: ...
391+
def WhichOneof(
392+
self, oneof_group: typing_extensions.Literal["_key", b"_key"]
393+
) -> typing_extensions.Literal["key"] | None: ...
394+
395+
global___TimerInfo = TimerInfo
396+
397+
class StateResponseWithTimer(google.protobuf.message.Message):
398+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
399+
400+
STATUSCODE_FIELD_NUMBER: builtins.int
401+
ERRORMESSAGE_FIELD_NUMBER: builtins.int
402+
TIMER_FIELD_NUMBER: builtins.int
403+
REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int
404+
statusCode: builtins.int
405+
errorMessage: builtins.str
406+
@property
407+
def timer(
408+
self,
409+
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
410+
global___TimerInfo
411+
]: ...
412+
requireNextFetch: builtins.bool
413+
def __init__(
414+
self,
415+
*,
416+
statusCode: builtins.int = ...,
417+
errorMessage: builtins.str = ...,
418+
timer: collections.abc.Iterable[global___TimerInfo] | None = ...,
419+
requireNextFetch: builtins.bool = ...,
420+
) -> None: ...
421+
def ClearField(
422+
self,
423+
field_name: typing_extensions.Literal[
424+
"errorMessage",
425+
b"errorMessage",
426+
"requireNextFetch",
427+
b"requireNextFetch",
428+
"statusCode",
429+
b"statusCode",
430+
"timer",
431+
b"timer",
432+
],
433+
) -> None: ...
434+
435+
global___StateResponseWithTimer = StateResponseWithTimer
436+
369437
class StatefulProcessorCall(google.protobuf.message.Message):
370438
DESCRIPTOR: google.protobuf.descriptor.Descriptor
371439

@@ -634,15 +702,21 @@ global___TimerValueRequest = TimerValueRequest
634702
class ExpiryTimerRequest(google.protobuf.message.Message):
635703
DESCRIPTOR: google.protobuf.descriptor.Descriptor
636704

705+
ITERATORID_FIELD_NUMBER: builtins.int
637706
EXPIRYTIMESTAMPMS_FIELD_NUMBER: builtins.int
707+
iteratorId: builtins.str
638708
expiryTimestampMs: builtins.int
639709
def __init__(
640710
self,
641711
*,
712+
iteratorId: builtins.str = ...,
642713
expiryTimestampMs: builtins.int = ...,
643714
) -> None: ...
644715
def ClearField(
645-
self, field_name: typing_extensions.Literal["expiryTimestampMs", b"expiryTimestampMs"]
716+
self,
717+
field_name: typing_extensions.Literal[
718+
"expiryTimestampMs", b"expiryTimestampMs", "iteratorId", b"iteratorId"
719+
],
646720
) -> None: ...
647721

648722
global___ExpiryTimerRequest = ExpiryTimerRequest

python/pyspark/sql/streaming/stateful_processor_api_client.py

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,13 @@
2424
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
2525
from pyspark.sql.types import (
2626
StructType,
27-
TYPE_CHECKING,
2827
Row,
2928
)
3029
from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
3130
from pyspark.serializers import CPickleSerializer
3231
from pyspark.errors import PySparkRuntimeError
3332
import uuid
3433

35-
if TYPE_CHECKING:
36-
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
37-
3834
__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]
3935

4036

@@ -80,9 +76,11 @@ def __init__(
8076
self.utf8_deserializer = UTF8Deserializer()
8177
self.pickleSer = CPickleSerializer()
8278
self.serializer = ArrowStreamSerializer()
83-
# Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame
79+
# Dictionaries to store the mapping between iterator id and a tuple of data batch
8480
# and the index of the last row that was read.
85-
self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
81+
self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}
82+
self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}
83+
8684
# statefulProcessorApiClient is initialized per batch per partition,
8785
# so we will have new timestamps for a new batch
8886
self._batch_timestamp = -1
@@ -222,76 +220,93 @@ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
222220
# TODO(SPARK-49233): Classify user facing errors.
223221
raise PySparkRuntimeError(f"Error deleting timer: " f"{response_message[1]}")
224222

225-
def get_list_timer_row(self, iterator_id: str) -> int:
223+
def get_list_timer_row(self, iterator_id: str) -> Tuple[int, bool]:
226224
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
227225

228226
if iterator_id in self.list_timer_iterator_cursors:
229227
# if the iterator is already in the dictionary, return the next row
230-
pandas_df, index = self.list_timer_iterator_cursors[iterator_id]
228+
data_batch, index, require_next_fetch = self.list_timer_iterator_cursors[iterator_id]
231229
else:
232230
list_call = stateMessage.ListTimers(iteratorId=iterator_id)
233231
state_call_command = stateMessage.TimerStateCallCommand(list=list_call)
234232
call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
235233
message = stateMessage.StateRequest(statefulProcessorCall=call)
236234

237235
self._send_proto_message(message.SerializeToString())
238-
response_message = self._receive_proto_message()
236+
response_message = self._receive_proto_message_with_timers()
239237
status = response_message[0]
240238
if status == 0:
241-
iterator = self._read_arrow_state()
242-
# We need to exhaust the iterator here to make sure all the arrow batches are read,
243-
# even though there is only one batch in the iterator. Otherwise, the stream might
244-
# block further reads since it thinks there might still be some arrow batches left.
245-
# We only need to read the first batch in the iterator because it's guaranteed that
246-
# there would only be one batch sent from the JVM side.
247-
data_batch = None
248-
for batch in iterator:
249-
if data_batch is None:
250-
data_batch = batch
251-
if data_batch is None:
252-
# TODO(SPARK-49233): Classify user facing errors.
253-
raise PySparkRuntimeError("Error getting map state entry.")
254-
pandas_df = data_batch.to_pandas()
239+
data_batch = list(map(lambda x: x.timestampMs, response_message[2]))
240+
require_next_fetch = response_message[3]
255241
index = 0
256242
else:
257243
raise StopIteration()
244+
245+
is_last_row = False
258246
new_index = index + 1
259-
if new_index < len(pandas_df):
247+
if new_index < len(data_batch):
260248
# Update the index in the dictionary.
261-
self.list_timer_iterator_cursors[iterator_id] = (pandas_df, new_index)
249+
self.list_timer_iterator_cursors[iterator_id] = (
250+
data_batch,
251+
new_index,
252+
require_next_fetch,
253+
)
262254
else:
263-
# If the index is at the end of the DataFrame, remove the state from the dictionary.
255+
# If the index is at the end of the data batch, remove the state from the dictionary.
264256
self.list_timer_iterator_cursors.pop(iterator_id, None)
265-
return pandas_df.at[index, "timestamp"].item()
257+
is_last_row = True
258+
259+
is_last_row_from_iterator = is_last_row and not require_next_fetch
260+
timestamp = data_batch[index]
261+
return (timestamp, is_last_row_from_iterator)
266262

267263
def get_expiry_timers_iterator(
268-
self, expiry_timestamp: int
269-
) -> Iterator[list[Tuple[Tuple, int]]]:
264+
self, iterator_id: str, expiry_timestamp: int
265+
) -> Tuple[Tuple, int, bool]:
270266
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
271267

272-
while True:
273-
expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
268+
if iterator_id in self.expiry_timer_iterator_cursors:
269+
# If the state is already in the dictionary, return the next row.
270+
data_batch, index, require_next_fetch = self.expiry_timer_iterator_cursors[iterator_id]
271+
else:
272+
expiry_timer_call = stateMessage.ExpiryTimerRequest(
273+
expiryTimestampMs=expiry_timestamp, iteratorId=iterator_id
274+
)
274275
timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
275276
message = stateMessage.StateRequest(timerRequest=timer_request)
276277

277278
self._send_proto_message(message.SerializeToString())
278-
response_message = self._receive_proto_message()
279+
response_message = self._receive_proto_message_with_timers()
279280
status = response_message[0]
280-
if status == 1:
281-
break
282-
elif status == 0:
283-
result_list = []
284-
iterator = self._read_arrow_state()
285-
for batch in iterator:
286-
batch_df = batch.to_pandas()
287-
for i in range(batch.num_rows):
288-
deserialized_key = self.pickleSer.loads(batch_df.at[i, "key"])
289-
timestamp = batch_df.at[i, "timestamp"].item()
290-
result_list.append((tuple(deserialized_key), timestamp))
291-
yield result_list
281+
if status == 0:
282+
data_batch = list(
283+
map(
284+
lambda x: (self._deserialize_from_bytes(x.key), x.timestampMs),
285+
response_message[2],
286+
)
287+
)
288+
require_next_fetch = response_message[3]
289+
index = 0
292290
else:
293-
# TODO(SPARK-49233): Classify user facing errors.
294-
raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}")
291+
raise StopIteration()
292+
293+
is_last_row = False
294+
new_index = index + 1
295+
if new_index < len(data_batch):
296+
# Update the index in the dictionary.
297+
self.expiry_timer_iterator_cursors[iterator_id] = (
298+
data_batch,
299+
new_index,
300+
require_next_fetch,
301+
)
302+
else:
303+
# If the index is at the end of the data batch, remove the state from the dictionary.
304+
self.expiry_timer_iterator_cursors.pop(iterator_id, None)
305+
is_last_row = True
306+
307+
is_last_row_from_iterator = is_last_row and not require_next_fetch
308+
key, timestamp = data_batch[index]
309+
return (key, timestamp, is_last_row_from_iterator)
295310

296311
def get_timestamps(self, time_mode: str) -> Tuple[int, int]:
297312
if time_mode.lower() == "none":
@@ -461,6 +476,18 @@ def _receive_proto_message_with_map_pairs(self) -> Tuple[int, str, Any, bool]:
461476

462477
return message.statusCode, message.errorMessage, message.kvPair, message.requireNextFetch
463478

479+
# The third return type is RepeatedScalarFieldContainer[TimerInfo], which is protobuf's
480+
# container type. We simplify it to Any here to avoid unnecessary complexity.
481+
def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]:
482+
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
483+
484+
length = read_int(self.sockfile)
485+
bytes = self.sockfile.read(length)
486+
message = stateMessage.StateResponseWithTimer()
487+
message.ParseFromString(bytes)
488+
489+
return message.statusCode, message.errorMessage, message.timer, message.requireNextFetch
490+
464491
def _receive_str(self) -> str:
465492
return self.utf8_deserializer.loads(self.sockfile)
466493

@@ -552,9 +579,44 @@ def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient):
552579
# same partition won't interfere with each other
553580
self.iterator_id = str(uuid.uuid4())
554581
self.stateful_processor_api_client = stateful_processor_api_client
582+
self.iterator_fully_consumed = False
555583

556584
def __iter__(self) -> Iterator[int]:
557585
return self
558586

559587
def __next__(self) -> int:
560-
return self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
588+
if self.iterator_fully_consumed:
589+
raise StopIteration()
590+
591+
ts, is_last_row = self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
592+
if is_last_row:
593+
self.iterator_fully_consumed = True
594+
595+
return ts
596+
597+
598+
class ExpiredTimerIterator:
599+
def __init__(
600+
self, stateful_processor_api_client: StatefulProcessorApiClient, expiry_timestamp: int
601+
):
602+
# Generate a unique identifier for the iterator to make sure iterators on the
603+
# same partition won't interfere with each other
604+
self.iterator_id = str(uuid.uuid4())
605+
self.stateful_processor_api_client = stateful_processor_api_client
606+
self.expiry_timestamp = expiry_timestamp
607+
self.iterator_fully_consumed = False
608+
609+
def __iter__(self) -> Iterator[Tuple[Tuple, int]]:
610+
return self
611+
612+
def __next__(self) -> Tuple[Tuple, int]:
613+
if self.iterator_fully_consumed:
614+
raise StopIteration()
615+
616+
key, ts, is_last_row = self.stateful_processor_api_client.get_expiry_timers_iterator(
617+
self.iterator_id, self.expiry_timestamp
618+
)
619+
if is_last_row:
620+
self.iterator_fully_consumed = True
621+
622+
return (key, ts)

python/pyspark/sql/streaming/stateful_processor_util.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
StatefulProcessorHandle,
2929
TimerValues,
3030
)
31+
from pyspark.sql.streaming.stateful_processor_api_client import ExpiredTimerIterator
3132
from pyspark.sql.types import Row
3233

3334
if TYPE_CHECKING:
@@ -218,24 +219,19 @@ def _handle_expired_timers(
218219
)
219220

220221
if self._time_mode.lower() == "processingtime":
221-
expiry_list_iter = stateful_processor_api_client.get_expiry_timers_iterator(
222-
batch_timestamp
223-
)
222+
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, batch_timestamp)
224223
elif self._time_mode.lower() == "eventtime":
225-
expiry_list_iter = stateful_processor_api_client.get_expiry_timers_iterator(
226-
watermark_timestamp
227-
)
224+
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, watermark_timestamp)
228225
else:
229-
expiry_list_iter = iter([[]])
226+
expiry_iter = iter([]) # type: ignore[assignment]
230227

231228
# process with expiry timers, only timer related rows will be emitted
232-
for expiry_list in expiry_list_iter:
233-
for key_obj, expiry_timestamp in expiry_list:
234-
stateful_processor_api_client.set_implicit_key(key_obj)
235-
for pd in self._stateful_processor.handleExpiredTimer(
236-
key=key_obj,
237-
timerValues=TimerValues(batch_timestamp, watermark_timestamp),
238-
expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp),
239-
):
240-
yield pd
241-
stateful_processor_api_client.delete_timer(expiry_timestamp)
229+
for key_obj, expiry_timestamp in expiry_iter:
230+
stateful_processor_api_client.set_implicit_key(key_obj)
231+
for pd in self._stateful_processor.handleExpiredTimer(
232+
key=key_obj,
233+
timerValues=TimerValues(batch_timestamp, watermark_timestamp),
234+
expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp),
235+
):
236+
yield pd
237+
stateful_processor_api_client.delete_timer(expiry_timestamp)

0 commit comments

Comments
 (0)