|
24 | 24 | from pyspark.sql.pandas.serializers import ArrowStreamSerializer
|
25 | 25 | from pyspark.sql.types import (
|
26 | 26 | StructType,
|
27 |
| - TYPE_CHECKING, |
28 | 27 | Row,
|
29 | 28 | )
|
30 | 29 | from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
|
31 | 30 | from pyspark.serializers import CPickleSerializer
|
32 | 31 | from pyspark.errors import PySparkRuntimeError
|
33 | 32 | import uuid
|
34 | 33 |
|
35 |
| -if TYPE_CHECKING: |
36 |
| - from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike |
37 |
| - |
38 | 34 | __all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]
|
39 | 35 |
|
40 | 36 |
|
@@ -80,9 +76,11 @@ def __init__(
|
80 | 76 | self.utf8_deserializer = UTF8Deserializer()
|
81 | 77 | self.pickleSer = CPickleSerializer()
|
82 | 78 | self.serializer = ArrowStreamSerializer()
|
83 |
| - # Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame |
| 79 | + # Dictionaries to store the mapping between iterator id and a tuple of data batch |
84 | 80 | # and the index of the last row that was read.
|
85 |
| - self.list_timer_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} |
| 81 | + self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} |
| 82 | + self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} |
| 83 | + |
86 | 84 | # statefulProcessorApiClient is initialized per batch per partition,
|
87 | 85 | # so we will have new timestamps for a new batch
|
88 | 86 | self._batch_timestamp = -1
|
@@ -222,76 +220,93 @@ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
|
222 | 220 | # TODO(SPARK-49233): Classify user facing errors.
|
223 | 221 | raise PySparkRuntimeError(f"Error deleting timer: " f"{response_message[1]}")
|
224 | 222 |
|
225 |
| - def get_list_timer_row(self, iterator_id: str) -> int: |
| 223 | + def get_list_timer_row(self, iterator_id: str) -> Tuple[int, bool]: |
226 | 224 | import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
|
227 | 225 |
|
228 | 226 | if iterator_id in self.list_timer_iterator_cursors:
|
229 | 227 | # if the iterator is already in the dictionary, return the next row
|
230 |
| - pandas_df, index = self.list_timer_iterator_cursors[iterator_id] |
| 228 | + data_batch, index, require_next_fetch = self.list_timer_iterator_cursors[iterator_id] |
231 | 229 | else:
|
232 | 230 | list_call = stateMessage.ListTimers(iteratorId=iterator_id)
|
233 | 231 | state_call_command = stateMessage.TimerStateCallCommand(list=list_call)
|
234 | 232 | call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
|
235 | 233 | message = stateMessage.StateRequest(statefulProcessorCall=call)
|
236 | 234 |
|
237 | 235 | self._send_proto_message(message.SerializeToString())
|
238 |
| - response_message = self._receive_proto_message() |
| 236 | + response_message = self._receive_proto_message_with_timers() |
239 | 237 | status = response_message[0]
|
240 | 238 | if status == 0:
|
241 |
| - iterator = self._read_arrow_state() |
242 |
| - # We need to exhaust the iterator here to make sure all the arrow batches are read, |
243 |
| - # even though there is only one batch in the iterator. Otherwise, the stream might |
244 |
| - # block further reads since it thinks there might still be some arrow batches left. |
245 |
| - # We only need to read the first batch in the iterator because it's guaranteed that |
246 |
| - # there would only be one batch sent from the JVM side. |
247 |
| - data_batch = None |
248 |
| - for batch in iterator: |
249 |
| - if data_batch is None: |
250 |
| - data_batch = batch |
251 |
| - if data_batch is None: |
252 |
| - # TODO(SPARK-49233): Classify user facing errors. |
253 |
| - raise PySparkRuntimeError("Error getting map state entry.") |
254 |
| - pandas_df = data_batch.to_pandas() |
| 239 | + data_batch = list(map(lambda x: x.timestampMs, response_message[2])) |
| 240 | + require_next_fetch = response_message[3] |
255 | 241 | index = 0
|
256 | 242 | else:
|
257 | 243 | raise StopIteration()
|
| 244 | + |
| 245 | + is_last_row = False |
258 | 246 | new_index = index + 1
|
259 |
| - if new_index < len(pandas_df): |
| 247 | + if new_index < len(data_batch): |
260 | 248 | # Update the index in the dictionary.
|
261 |
| - self.list_timer_iterator_cursors[iterator_id] = (pandas_df, new_index) |
| 249 | + self.list_timer_iterator_cursors[iterator_id] = ( |
| 250 | + data_batch, |
| 251 | + new_index, |
| 252 | + require_next_fetch, |
| 253 | + ) |
262 | 254 | else:
|
263 |
| - # If the index is at the end of the DataFrame, remove the state from the dictionary. |
| 255 | + # If the index is at the end of the data batch, remove the state from the dictionary. |
264 | 256 | self.list_timer_iterator_cursors.pop(iterator_id, None)
|
265 |
| - return pandas_df.at[index, "timestamp"].item() |
| 257 | + is_last_row = True |
| 258 | + |
| 259 | + is_last_row_from_iterator = is_last_row and not require_next_fetch |
| 260 | + timestamp = data_batch[index] |
| 261 | + return (timestamp, is_last_row_from_iterator) |
266 | 262 |
|
267 | 263 | def get_expiry_timers_iterator(
|
268 |
| - self, expiry_timestamp: int |
269 |
| - ) -> Iterator[list[Tuple[Tuple, int]]]: |
| 264 | + self, iterator_id: str, expiry_timestamp: int |
| 265 | + ) -> Tuple[Tuple, int, bool]: |
270 | 266 | import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
|
271 | 267 |
|
272 |
| - while True: |
273 |
| - expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp) |
| 268 | + if iterator_id in self.expiry_timer_iterator_cursors: |
| 269 | + # If the state is already in the dictionary, return the next row. |
| 270 | + data_batch, index, require_next_fetch = self.expiry_timer_iterator_cursors[iterator_id] |
| 271 | + else: |
| 272 | + expiry_timer_call = stateMessage.ExpiryTimerRequest( |
| 273 | + expiryTimestampMs=expiry_timestamp, iteratorId=iterator_id |
| 274 | + ) |
274 | 275 | timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
|
275 | 276 | message = stateMessage.StateRequest(timerRequest=timer_request)
|
276 | 277 |
|
277 | 278 | self._send_proto_message(message.SerializeToString())
|
278 |
| - response_message = self._receive_proto_message() |
| 279 | + response_message = self._receive_proto_message_with_timers() |
279 | 280 | status = response_message[0]
|
280 |
| - if status == 1: |
281 |
| - break |
282 |
| - elif status == 0: |
283 |
| - result_list = [] |
284 |
| - iterator = self._read_arrow_state() |
285 |
| - for batch in iterator: |
286 |
| - batch_df = batch.to_pandas() |
287 |
| - for i in range(batch.num_rows): |
288 |
| - deserialized_key = self.pickleSer.loads(batch_df.at[i, "key"]) |
289 |
| - timestamp = batch_df.at[i, "timestamp"].item() |
290 |
| - result_list.append((tuple(deserialized_key), timestamp)) |
291 |
| - yield result_list |
| 281 | + if status == 0: |
| 282 | + data_batch = list( |
| 283 | + map( |
| 284 | + lambda x: (self._deserialize_from_bytes(x.key), x.timestampMs), |
| 285 | + response_message[2], |
| 286 | + ) |
| 287 | + ) |
| 288 | + require_next_fetch = response_message[3] |
| 289 | + index = 0 |
292 | 290 | else:
|
293 |
| - # TODO(SPARK-49233): Classify user facing errors. |
294 |
| - raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") |
| 291 | + raise StopIteration() |
| 292 | + |
| 293 | + is_last_row = False |
| 294 | + new_index = index + 1 |
| 295 | + if new_index < len(data_batch): |
| 296 | + # Update the index in the dictionary. |
| 297 | + self.expiry_timer_iterator_cursors[iterator_id] = ( |
| 298 | + data_batch, |
| 299 | + new_index, |
| 300 | + require_next_fetch, |
| 301 | + ) |
| 302 | + else: |
| 303 | + # If the index is at the end of the data batch, remove the state from the dictionary. |
| 304 | + self.expiry_timer_iterator_cursors.pop(iterator_id, None) |
| 305 | + is_last_row = True |
| 306 | + |
| 307 | + is_last_row_from_iterator = is_last_row and not require_next_fetch |
| 308 | + key, timestamp = data_batch[index] |
| 309 | + return (key, timestamp, is_last_row_from_iterator) |
295 | 310 |
|
296 | 311 | def get_timestamps(self, time_mode: str) -> Tuple[int, int]:
|
297 | 312 | if time_mode.lower() == "none":
|
@@ -461,6 +476,18 @@ def _receive_proto_message_with_map_pairs(self) -> Tuple[int, str, Any, bool]:
|
461 | 476 |
|
462 | 477 | return message.statusCode, message.errorMessage, message.kvPair, message.requireNextFetch
|
463 | 478 |
|
| 479 | + # The third return type is RepeatedScalarFieldContainer[TimerInfo], which is protobuf's |
| 480 | + # container type. We simplify it to Any here to avoid unnecessary complexity. |
| 481 | + def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]: |
| 482 | + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage |
| 483 | + |
| 484 | + length = read_int(self.sockfile) |
| 485 | + bytes = self.sockfile.read(length) |
| 486 | + message = stateMessage.StateResponseWithTimer() |
| 487 | + message.ParseFromString(bytes) |
| 488 | + |
| 489 | + return message.statusCode, message.errorMessage, message.timer, message.requireNextFetch |
| 490 | + |
464 | 491 | def _receive_str(self) -> str:
|
465 | 492 | return self.utf8_deserializer.loads(self.sockfile)
|
466 | 493 |
|
@@ -552,9 +579,44 @@ def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient):
|
552 | 579 | # same partition won't interfere with each other
|
553 | 580 | self.iterator_id = str(uuid.uuid4())
|
554 | 581 | self.stateful_processor_api_client = stateful_processor_api_client
|
| 582 | + self.iterator_fully_consumed = False |
555 | 583 |
|
556 | 584 | def __iter__(self) -> Iterator[int]:
|
557 | 585 | return self
|
558 | 586 |
|
559 | 587 | def __next__(self) -> int:
|
560 |
| - return self.stateful_processor_api_client.get_list_timer_row(self.iterator_id) |
| 588 | + if self.iterator_fully_consumed: |
| 589 | + raise StopIteration() |
| 590 | + |
| 591 | + ts, is_last_row = self.stateful_processor_api_client.get_list_timer_row(self.iterator_id) |
| 592 | + if is_last_row: |
| 593 | + self.iterator_fully_consumed = True |
| 594 | + |
| 595 | + return ts |
| 596 | + |
| 597 | + |
| 598 | +class ExpiredTimerIterator: |
| 599 | + def __init__( |
| 600 | + self, stateful_processor_api_client: StatefulProcessorApiClient, expiry_timestamp: int |
| 601 | + ): |
| 602 | + # Generate a unique identifier for the iterator to make sure iterators on the |
| 603 | + # same partition won't interfere with each other |
| 604 | + self.iterator_id = str(uuid.uuid4()) |
| 605 | + self.stateful_processor_api_client = stateful_processor_api_client |
| 606 | + self.expiry_timestamp = expiry_timestamp |
| 607 | + self.iterator_fully_consumed = False |
| 608 | + |
| 609 | + def __iter__(self) -> Iterator[Tuple[Tuple, int]]: |
| 610 | + return self |
| 611 | + |
| 612 | + def __next__(self) -> Tuple[Tuple, int]: |
| 613 | + if self.iterator_fully_consumed: |
| 614 | + raise StopIteration() |
| 615 | + |
| 616 | + key, ts, is_last_row = self.stateful_processor_api_client.get_expiry_timers_iterator( |
| 617 | + self.iterator_id, self.expiry_timestamp |
| 618 | + ) |
| 619 | + if is_last_row: |
| 620 | + self.iterator_fully_consumed = True |
| 621 | + |
| 622 | + return (key, ts) |
0 commit comments