@@ -285,11 +285,7 @@ def func(d: dict, state: State):
285
285
cast (ApplyCallbackStateful , func )
286
286
)
287
287
288
- stateful_func = _as_stateful (
289
- func = with_metadata_func ,
290
- processing_context = self ._processing_context ,
291
- stream_id = self .stream_id ,
292
- )
288
+ stateful_func = _as_stateful (with_metadata_func , self )
293
289
stream = self .stream .add_apply (stateful_func , expand = expand , metadata = True ) # type: ignore[call-overload]
294
290
else :
295
291
stream = self .stream .add_apply (
@@ -394,11 +390,7 @@ def func(values: list, state: State):
394
390
cast (UpdateCallbackStateful , func )
395
391
)
396
392
397
- stateful_func = _as_stateful (
398
- func = with_metadata_func ,
399
- processing_context = self ._processing_context ,
400
- stream_id = self .stream_id ,
401
- )
393
+ stateful_func = _as_stateful (with_metadata_func , self )
402
394
return self ._add_update (stateful_func , metadata = True )
403
395
else :
404
396
return self ._add_update (
@@ -496,11 +488,7 @@ def func(d: dict, state: State):
496
488
cast (FilterCallbackStateful , func )
497
489
)
498
490
499
- stateful_func = _as_stateful (
500
- func = with_metadata_func ,
501
- processing_context = self ._processing_context ,
502
- stream_id = self .stream_id ,
503
- )
491
+ stateful_func = _as_stateful (with_metadata_func , self )
504
492
stream = self .stream .add_filter (stateful_func , metadata = True )
505
493
else :
506
494
stream = self .stream .add_filter ( # type: ignore[call-overload]
@@ -1848,24 +1836,20 @@ def wrapper(
1848
1836
1849
1837
def _as_stateful (
1850
1838
func : Callable [[Any , Any , int , Any , State ], T ],
1851
- processing_context : ProcessingContext ,
1852
- stream_id : str ,
1839
+ sdf : StreamingDataFrame ,
1853
1840
) -> Callable [[Any , Any , int , Any ], T ]:
1854
1841
@functools .wraps (func )
1855
1842
def wrapper (value : Any , key : Any , timestamp : int , headers : Any ) -> Any :
1856
1843
# Pass a State object with an interface limited to the key updates only
1857
1844
# and prefix all the state keys by the message key
1858
- transaction = _get_transaction (processing_context , stream_id )
1859
- state = transaction .as_state (prefix = key )
1845
+ state = _get_transaction (sdf ).as_state (prefix = key )
1860
1846
return func (value , key , timestamp , headers , state )
1861
1847
1862
1848
return wrapper
1863
1849
1864
1850
1865
- def _get_transaction (
1866
- processing_context : ProcessingContext , stream_id : str
1867
- ) -> PartitionTransaction :
1868
- return processing_context .checkpoint .get_store_transaction (
1869
- stream_id = stream_id ,
1851
+ def _get_transaction (sdf : StreamingDataFrame ) -> PartitionTransaction :
1852
+ return sdf .processing_context .checkpoint .get_store_transaction (
1853
+ stream_id = sdf .stream_id ,
1870
1854
partition = message_context ().partition ,
1871
1855
)
0 commit comments