Skip to content

Commit ce2f8b3

Browse files
committed
[JOIN] Refactor _as_stateful to accept sdf
1 parent 6f2e0be commit ce2f8b3

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,7 @@ def func(d: dict, state: State):
285285
cast(ApplyCallbackStateful, func)
286286
)
287287

288-
stateful_func = _as_stateful(
289-
func=with_metadata_func,
290-
processing_context=self._processing_context,
291-
stream_id=self.stream_id,
292-
)
288+
stateful_func = _as_stateful(with_metadata_func, self)
293289
stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload]
294290
else:
295291
stream = self.stream.add_apply(
@@ -394,11 +390,7 @@ def func(values: list, state: State):
394390
cast(UpdateCallbackStateful, func)
395391
)
396392

397-
stateful_func = _as_stateful(
398-
func=with_metadata_func,
399-
processing_context=self._processing_context,
400-
stream_id=self.stream_id,
401-
)
393+
stateful_func = _as_stateful(with_metadata_func, self)
402394
return self._add_update(stateful_func, metadata=True)
403395
else:
404396
return self._add_update(
@@ -496,11 +488,7 @@ def func(d: dict, state: State):
496488
cast(FilterCallbackStateful, func)
497489
)
498490

499-
stateful_func = _as_stateful(
500-
func=with_metadata_func,
501-
processing_context=self._processing_context,
502-
stream_id=self.stream_id,
503-
)
491+
stateful_func = _as_stateful(with_metadata_func, self)
504492
stream = self.stream.add_filter(stateful_func, metadata=True)
505493
else:
506494
stream = self.stream.add_filter( # type: ignore[call-overload]
@@ -1848,24 +1836,20 @@ def wrapper(
18481836

18491837
def _as_stateful(
18501838
func: Callable[[Any, Any, int, Any, State], T],
1851-
processing_context: ProcessingContext,
1852-
stream_id: str,
1839+
sdf: StreamingDataFrame,
18531840
) -> Callable[[Any, Any, int, Any], T]:
18541841
@functools.wraps(func)
18551842
def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
18561843
# Pass a State object with an interface limited to the key updates only
18571844
# and prefix all the state keys by the message key
1858-
transaction = _get_transaction(processing_context, stream_id)
1859-
state = transaction.as_state(prefix=key)
1845+
state = _get_transaction(sdf).as_state(prefix=key)
18601846
return func(value, key, timestamp, headers, state)
18611847

18621848
return wrapper
18631849

18641850

1865-
def _get_transaction(
1866-
processing_context: ProcessingContext, stream_id: str
1867-
) -> PartitionTransaction:
1868-
return processing_context.checkpoint.get_store_transaction(
1869-
stream_id=stream_id,
1851+
def _get_transaction(sdf: StreamingDataFrame) -> PartitionTransaction:
1852+
return sdf.processing_context.checkpoint.get_store_transaction(
1853+
stream_id=sdf.stream_id,
18701854
partition=message_context().partition,
18711855
)

0 commit comments

Comments
 (0)