Skip to content

Commit 5d9c872

Browse files
committed
Ensure both sides are copartitioned
1 parent be3c196 commit 5d9c872

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
16481648
)
16491649

16501650
def join(self, right: "StreamingDataFrame") -> "StreamingDataFrame":
1651-
# TODO: ensure copartitioning of left and right?
1651+
self.ensure_topics_copartitioned(*self.topics, *right.topics)
16521652
right.register_store(store_type=TimestampedStore)
16531653

16541654
def left_func(value, key, timestamp, headers):
@@ -1664,12 +1664,13 @@ def right_func(value, key, timestamp, headers):
16641664
right = right.update(right_func, metadata=True).filter(lambda value: False)
16651665
return left.concat(right)
16661666

1667-
def ensure_topics_copartitioned(self):
1668-
partitions_counts = set(t.broker_config.num_partitions for t in self._topics)
1667+
def ensure_topics_copartitioned(self, *topics: Topic):
1668+
topics = topics or self._topics
1669+
partitions_counts = set(t.broker_config.num_partitions for t in topics)
16691670
if len(partitions_counts) > 1:
16701671
msg = ", ".join(
16711672
f'"{t.name}" ({t.broker_config.num_partitions} partitions)'
1672-
for t in self._topics
1673+
for t in topics
16731674
)
16741675
raise TopicPartitionsMismatch(
16751676
f"The underlying topics must have the same number of partitions to use State; got {msg}"

tests/test_quixstreams/test_dataframe/test_dataframe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2703,3 +2703,10 @@ def test_join(self, create_topic, create_sdf, assign_partition, publish):
27032703
joined_sdf, left_topic, value={"left": 2}, key=b"key", timestamp=2
27042704
)
27052705
assert joined_value == [({"left": 2, "right": 1}, b"key", 2, None)]
2706+
2707+
def test_mismatching_partitions_fails(self, create_topic, create_sdf):
2708+
left_topic, right_topic = create_topic(), create_topic(num_partitions=2)
2709+
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2710+
2711+
with pytest.raises(TopicPartitionsMismatch):
2712+
left_sdf.join(right_sdf)

0 commit comments

Comments
 (0)