Skip to content

Commit 55563d1

Browse files
committed
Self joins
1 parent 275e2c0 commit 55563d1

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,7 +1623,11 @@ def _sink_callback(
16231623
# uses apply without returning to make this operation terminal
16241624
self.apply(_sink_callback, metadata=True)
16251625

1626-
def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
1626+
def concat(
1627+
self,
1628+
other: "StreamingDataFrame",
1629+
stream_id: Optional[str] = None,
1630+
) -> "StreamingDataFrame":
16271631
"""
16281632
Concatenate two StreamingDataFrames together and return a new one.
16291633
The transformations applied on this new StreamingDataFrame will update data
@@ -1650,7 +1654,7 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
16501654
self._registry.require_time_alignment()
16511655

16521656
return self.__dataframe_clone__(
1653-
*self.topics, *other.topics, stream=merged_stream
1657+
*self.topics, *other.topics, stream=merged_stream, stream_id=stream_id
16541658
)
16551659

16561660
def join_latest(
@@ -1726,11 +1730,11 @@ def right_func(value, key, timestamp, headers):
17261730
retention_ms=retention_ms,
17271731
)
17281732

1733+
right = right.update(right_func, metadata=True).filter(lambda value: False)
17291734
left = self.apply(left_func, metadata=True).filter(
17301735
lambda value: value is not DISCARDED
17311736
)
1732-
right = right.update(right_func, metadata=True).filter(lambda value: False)
1733-
return left.concat(right)
1737+
return left.concat(right, stream_id=f"{self.stream_id}-{right.stream_id}")
17341738

17351739
def ensure_topics_copartitioned(self, *topics: Topic):
17361740
topics = topics or self._topics
@@ -1766,7 +1770,11 @@ def _add_update(
17661770
self._stream = self._stream.add_update(func, metadata=metadata) # type: ignore[call-overload]
17671771
return self
17681772

1769-
def register_store(self, store_type: Optional[StoreTypes] = None):
1773+
def register_store(
1774+
self,
1775+
stream_id: Optional[str] = None,
1776+
store_type: Optional[StoreTypes] = None,
1777+
):
17701778
"""
17711779
Register the default store for the current stream_id in StateStoreManager.
17721780
"""
@@ -1776,7 +1784,7 @@ def register_store(self, store_type: Optional[StoreTypes] = None):
17761784
changelog_topic_config = self._topic_manager.derive_topic_config(self._topics)
17771785

17781786
self._processing_context.state_manager.register_store(
1779-
stream_id=self.stream_id,
1787+
stream_id=stream_id or self.stream_id,
17801788
store_type=store_type,
17811789
changelog_config=changelog_topic_config,
17821790
)
@@ -1938,8 +1946,10 @@ def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
19381946
return wrapper
19391947

19401948

1941-
def _get_transaction(sdf: StreamingDataFrame) -> PartitionTransaction:
1949+
def _get_transaction(
1950+
sdf: StreamingDataFrame, stream_id: Optional[str] = None
1951+
) -> PartitionTransaction:
19421952
return sdf.processing_context.checkpoint.get_store_transaction(
1943-
stream_id=sdf.stream_id,
1953+
stream_id=stream_id or sdf.stream_id,
19441954
partition=message_context().partition,
19451955
)

tests/test_quixstreams/test_dataframe/test_dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,3 +2904,19 @@ def test_retention_ms(
29042904

29052905
assert publish_left(timestamp=4) == []
29062906
assert publish_left(timestamp=5) == [({"left": 4, "right": 2}, b"key", 5, None)]
2907+
2908+
def test_self_join(self, create_topic, create_sdf, assign_partition, publish):
2909+
topic = create_topic()
2910+
sdf = create_sdf(topic)
2911+
merger = lambda left, right: {"left": left, "right": right}
2912+
joined_sdf = sdf.join_latest(sdf, merger=merger)
2913+
assign_partition(sdf)
2914+
2915+
_publish = partial(publish, joined_sdf, topic, key=b"key")
2916+
2917+
assert _publish(value=1, timestamp=1) == [
2918+
({"left": 1, "right": 1}, b"key", 1, None)
2919+
]
2920+
assert _publish(value=2, timestamp=2) == [
2921+
({"left": 2, "right": 2}, b"key", 2, None)
2922+
]

0 commit comments

Comments
 (0)