Skip to content

Commit c18b1f5

Browse files
committed
Add how param with inner and left options
1 parent 4e5ab65 commit c18b1f5

File tree

2 files changed

+86
-8
lines changed

2 files changed

+86
-8
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
FilterCallbackStateful = Callable[[Any, State], bool]
8686
FilterWithMetadataCallbackStateful = Callable[[Any, Any, int, Any, State], bool]
8787

88+
# Constants related to join
89+
DISCARDED = object()
90+
JoinHow = Literal["inner", "left"]
8891
JoinOnOverlap = Literal["keep-left", "keep-right", "raise"]
8992

9093

@@ -1653,17 +1656,26 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
16531656
def join(
16541657
self,
16551658
right: "StreamingDataFrame",
1659+
how: JoinHow = "inner",
16561660
on_overlap: JoinOnOverlap = "raise",
16571661
merger: Optional[Callable[[Any, Any], Any]] = None,
16581662
) -> "StreamingDataFrame":
1663+
if how not in get_args(JoinHow):
1664+
raise ValueError(
1665+
f"Invalid how value: {how}. "
1666+
f"Valid values are: {', '.join(get_args(JoinHow))}."
1667+
)
16591668
if on_overlap not in get_args(JoinOnOverlap):
16601669
raise ValueError(
16611670
f"Invalid on_overlap value: {on_overlap}. "
16621671
f"Valid values are: {', '.join(get_args(JoinOnOverlap))}."
16631672
)
1673+
16641674
self.ensure_topics_copartitioned(*self.topics, *right.topics)
16651675
right.register_store(store_type=TimestampedStore)
16661676

1677+
is_inner_join = how == "inner"
1678+
16671679
if merger is None:
16681680
if on_overlap == "keep-left":
16691681

@@ -1699,13 +1711,17 @@ def merger(left_value, right_value):
16991711
def left_func(value, key, timestamp, headers):
17001712
right_tx = _get_transaction(right)
17011713
right_value = right_tx.get_last(timestamp=timestamp, prefix=key)
1714+
if is_inner_join and not right_value:
1715+
return DISCARDED
17021716
return merger(value, right_value)
17031717

17041718
def right_func(value, key, timestamp, headers):
17051719
right_tx = _get_transaction(right)
17061720
right_tx.set(timestamp=timestamp, value=value, prefix=key)
17071721

1708-
left = self.apply(left_func, metadata=True)
1722+
left = self.apply(left_func, metadata=True).filter(
1723+
lambda value: value is not DISCARDED
1724+
)
17091725
right = right.update(right_func, metadata=True).filter(lambda value: False)
17101726
return left.concat(right)
17111727

tests/test_quixstreams/test_dataframe/test_dataframe.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
from quixstreams import State
13-
from quixstreams.dataframe.dataframe import JoinOnOverlap
13+
from quixstreams.dataframe.dataframe import JoinHow, JoinOnOverlap
1414
from quixstreams.dataframe.exceptions import (
1515
GroupByDuplicate,
1616
GroupByNestingLimit,
@@ -2693,17 +2693,79 @@ def _publish(sdf, topic, value, key, timestamp):
26932693

26942694
return _publish
26952695

2696-
def test_join(self, create_topic, create_sdf, assign_partition, publish):
2696+
@pytest.mark.parametrize(
2697+
"how, left, right, expected",
2698+
[
2699+
(
2700+
"inner",
2701+
{"left": 1},
2702+
{"right": 2},
2703+
[({"left": 1, "right": 2}, b"key", 2, None)],
2704+
),
2705+
(
2706+
"inner",
2707+
{"left": 1},
2708+
None,
2709+
[],
2710+
),
2711+
(
2712+
"inner",
2713+
{"left": 1},
2714+
{},
2715+
[],
2716+
),
2717+
(
2718+
"left",
2719+
{"left": 1},
2720+
{"right": 2},
2721+
[({"left": 1, "right": 2}, b"key", 2, None)],
2722+
),
2723+
(
2724+
"left",
2725+
{"left": 1},
2726+
None,
2727+
[({"left": 1}, b"key", 2, None)],
2728+
),
2729+
(
2730+
"left",
2731+
{"left": 1},
2732+
{},
2733+
[({"left": 1}, b"key", 2, None)],
2734+
),
2735+
],
2736+
)
2737+
def test_how(
2738+
self,
2739+
create_topic,
2740+
create_sdf,
2741+
assign_partition,
2742+
publish,
2743+
how,
2744+
left,
2745+
right,
2746+
expected,
2747+
):
26972748
left_topic, right_topic = create_topic(), create_topic()
26982749
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2699-
joined_sdf = left_sdf.join(right_sdf)
2750+
joined_sdf = left_sdf.join(right_sdf, how=how)
27002751
assign_partition(right_sdf)
27012752

2702-
publish(joined_sdf, right_topic, value={"right": 1}, key=b"key", timestamp=1)
2753+
publish(joined_sdf, right_topic, value=right, key=b"key", timestamp=1)
27032754
joined_value = publish(
2704-
joined_sdf, left_topic, value={"left": 2}, key=b"key", timestamp=2
2755+
joined_sdf, left_topic, value=left, key=b"key", timestamp=2
27052756
)
2706-
assert joined_value == [({"left": 2, "right": 1}, b"key", 2, None)]
2757+
assert joined_value == expected
2758+
2759+
def test_how_invalid_value(self, create_topic, create_sdf):
2760+
left_topic, right_topic = create_topic(), create_topic()
2761+
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2762+
2763+
match = (
2764+
"Invalid how value: invalid. "
2765+
f"Valid values are: {', '.join(get_args(JoinHow))}."
2766+
)
2767+
with pytest.raises(ValueError, match=match):
2768+
left_sdf.join(right_sdf, how="invalid")
27072769

27082770
def test_mismatching_partitions_fails(self, create_topic, create_sdf):
27092771
left_topic, right_topic = create_topic(), create_topic(num_partitions=2)
@@ -2772,7 +2834,7 @@ def test_on_overlap(
27722834
):
27732835
left_topic, right_topic = create_topic(), create_topic()
27742836
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2775-
joined_sdf = left_sdf.join(right_sdf, on_overlap=on_overlap)
2837+
joined_sdf = left_sdf.join(right_sdf, how="left", on_overlap=on_overlap)
27762838
assign_partition(right_sdf)
27772839

27782840
publish(joined_sdf, right_topic, value=right, key=b"key", timestamp=1)

0 commit comments

Comments
 (0)