|
49 | 49 | from quixstreams.models.serializers import DeserializerType, SerializerType
|
50 | 50 | from quixstreams.sinks import BaseSink
|
51 | 51 | from quixstreams.state.base import State
|
| 52 | +from quixstreams.state.base.transaction import PartitionTransaction |
52 | 53 | from quixstreams.utils.printing import (
|
53 | 54 | DEFAULT_COLUMN_NAME,
|
54 | 55 | DEFAULT_LIVE,
|
@@ -1852,14 +1853,19 @@ def _as_stateful(
|
1852 | 1853 | ) -> Callable[[Any, Any, int, Any], T]:
|
1853 | 1854 | @functools.wraps(func)
|
1854 | 1855 | def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
|
1855 |
| - ctx = message_context() |
1856 |
| - transaction = processing_context.checkpoint.get_store_transaction( |
1857 |
| - stream_id=stream_id, |
1858 |
| - partition=ctx.partition, |
1859 |
| - ) |
1860 | 1856 | # Pass a State object with an interface limited to the key updates only
|
1861 | 1857 | # and prefix all the state keys by the message key
|
| 1858 | + transaction = _get_transaction(processing_context, stream_id) |
1862 | 1859 | state = transaction.as_state(prefix=key)
|
1863 | 1860 | return func(value, key, timestamp, headers, state)
|
1864 | 1861 |
|
1865 | 1862 | return wrapper
|
| 1863 | + |
| 1864 | + |
| 1865 | +def _get_transaction( |
| 1866 | + processing_context: ProcessingContext, stream_id: str |
| 1867 | +) -> PartitionTransaction: |
| 1868 | + return processing_context.checkpoint.get_store_transaction( |
| 1869 | + stream_id=stream_id, |
| 1870 | + partition=message_context().partition, |
| 1871 | + ) |
0 commit comments