Skip to content

Commit 4e5ab65

Browse files
committed
Add on_overlap and merger params
1 parent 5d9c872 commit 4e5ab65

File tree

2 files changed

+148
-3
lines changed

2 files changed

+148
-3
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TypeVar,
2121
Union,
2222
cast,
23+
get_args,
2324
overload,
2425
)
2526

@@ -84,6 +85,8 @@
8485
FilterCallbackStateful = Callable[[Any, State], bool]
8586
FilterWithMetadataCallbackStateful = Callable[[Any, Any, int, Any, State], bool]
8687

88+
JoinOnOverlap = Literal["keep-left", "keep-right", "raise"]
89+
8790

8891
class StreamingDataFrame:
8992
"""
@@ -1647,14 +1650,56 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
16471650
*self.topics, *other.topics, stream=merged_stream
16481651
)
16491652

1650-
def join(self, right: "StreamingDataFrame") -> "StreamingDataFrame":
1653+
def join(
1654+
self,
1655+
right: "StreamingDataFrame",
1656+
on_overlap: JoinOnOverlap = "raise",
1657+
merger: Optional[Callable[[Any, Any], Any]] = None,
1658+
) -> "StreamingDataFrame":
1659+
if on_overlap not in get_args(JoinOnOverlap):
1660+
raise ValueError(
1661+
f"Invalid on_overlap value: {on_overlap}. "
1662+
f"Valid values are: {', '.join(get_args(JoinOnOverlap))}."
1663+
)
16511664
self.ensure_topics_copartitioned(*self.topics, *right.topics)
16521665
right.register_store(store_type=TimestampedStore)
16531666

1667+
if merger is None:
1668+
if on_overlap == "keep-left":
1669+
1670+
def merger(left_value, right_value):
1671+
"""
1672+
Merge two dictionaries, preferring values from the left dictionary
1673+
and preserving order with left columns coming first.
1674+
"""
1675+
right_value = {
1676+
key: value
1677+
for key, value in (right_value or {}).items()
1678+
if key not in left_value
1679+
}
1680+
return {**left_value, **right_value}
1681+
1682+
elif on_overlap == "keep-right":
1683+
1684+
def merger(left_value, right_value):
1685+
return {**left_value, **(right_value or {})}
1686+
elif on_overlap == "raise":
1687+
1688+
def merger(left_value, right_value):
1689+
right_value = right_value or {}
1690+
if overlapping_columns := left_value.keys() & right_value.keys():
1691+
overlapping_columns_str = ", ".join(sorted(overlapping_columns))
1692+
raise ValueError(
1693+
f"Overlapping columns: {overlapping_columns_str}."
1694+
"You need to provide either an on_overlap value of "
1695+
"'keep-left' or 'keep-right' or a custom merger function."
1696+
)
1697+
return {**left_value, **right_value}
1698+
16541699
def left_func(value, key, timestamp, headers):
16551700
right_tx = _get_transaction(right)
16561701
right_value = right_tx.get_last(timestamp=timestamp, prefix=key)
1657-
return {**value, **(right_value or {})}
1702+
return merger(value, right_value)
16581703

16591704
def right_func(value, key, timestamp, headers):
16601705
right_tx = _get_transaction(right)

tests/test_quixstreams/test_dataframe/test_dataframe.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import warnings
55
from collections import namedtuple
66
from datetime import timedelta
7-
from typing import Any
7+
from typing import Any, get_args
88
from unittest import mock
99

1010
import pytest
1111

1212
from quixstreams import State
13+
from quixstreams.dataframe.dataframe import JoinOnOverlap
1314
from quixstreams.dataframe.exceptions import (
1415
GroupByDuplicate,
1516
GroupByNestingLimit,
@@ -2710,3 +2711,102 @@ def test_mismatching_partitions_fails(self, create_topic, create_sdf):
27102711

27112712
with pytest.raises(TopicPartitionsMismatch):
27122713
left_sdf.join(right_sdf)
2714+
2715+
@pytest.mark.parametrize(
2716+
"on_overlap, left, right, expected",
2717+
[
2718+
(
2719+
"keep-left",
2720+
{"A": 1},
2721+
None,
2722+
{"A": 1},
2723+
),
2724+
(
2725+
"keep-left",
2726+
{"A": 1, "B": "left"},
2727+
{"B": "right", "C": 2},
2728+
{"A": 1, "B": "left", "C": 2},
2729+
),
2730+
(
2731+
"keep-right",
2732+
{"A": 1},
2733+
None,
2734+
{"A": 1},
2735+
),
2736+
(
2737+
"keep-right",
2738+
{"A": 1, "B": "left"},
2739+
{"B": "right", "C": 2},
2740+
{"A": 1, "B": "right", "C": 2},
2741+
),
2742+
(
2743+
"raise",
2744+
{"A": 1},
2745+
None,
2746+
{"A": 1},
2747+
),
2748+
(
2749+
"raise",
2750+
{"A": 1},
2751+
{"B": 2},
2752+
{"A": 1, "B": 2},
2753+
),
2754+
(
2755+
"raise",
2756+
{"A": 1, "B": "left B", "C": "left C"},
2757+
{"B": "right B", "C": "right C"},
2758+
ValueError("Overlapping columns: B, C."),
2759+
),
2760+
],
2761+
)
2762+
def test_on_overlap(
2763+
self,
2764+
create_topic,
2765+
create_sdf,
2766+
assign_partition,
2767+
publish,
2768+
on_overlap,
2769+
left,
2770+
right,
2771+
expected,
2772+
):
2773+
left_topic, right_topic = create_topic(), create_topic()
2774+
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2775+
joined_sdf = left_sdf.join(right_sdf, on_overlap=on_overlap)
2776+
assign_partition(right_sdf)
2777+
2778+
publish(joined_sdf, right_topic, value=right, key=b"key", timestamp=1)
2779+
2780+
if isinstance(expected, Exception):
2781+
with pytest.raises(expected.__class__, match=expected.args[0]):
2782+
publish(joined_sdf, left_topic, value=left, key=b"key", timestamp=2)
2783+
else:
2784+
joined_value = publish(
2785+
joined_sdf, left_topic, value=left, key=b"key", timestamp=2
2786+
)
2787+
assert joined_value == [(expected, b"key", 2, None)]
2788+
2789+
def test_on_overlap_invalid_value(self, create_topic, create_sdf):
2790+
left_topic, right_topic = create_topic(), create_topic()
2791+
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2792+
2793+
match = (
2794+
"Invalid on_overlap value: invalid. "
2795+
f"Valid values are: {', '.join(get_args(JoinOnOverlap))}."
2796+
)
2797+
with pytest.raises(ValueError, match=match):
2798+
left_sdf.join(right_sdf, on_overlap="invalid")
2799+
2800+
def test_custom_merger(self, create_topic, create_sdf, assign_partition, publish):
2801+
left_topic, right_topic = create_topic(), create_topic()
2802+
left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic)
2803+
2804+
def merger(left, right):
2805+
return {"left": left, "right": right}
2806+
2807+
joined_sdf = left_sdf.join(right_sdf, merger=merger)
2808+
assign_partition(right_sdf)
2809+
2810+
publish(joined_sdf, right_topic, value=1, key=b"key", timestamp=1)
2811+
joined_value = publish(joined_sdf, left_topic, value=2, key=b"key", timestamp=2)
2812+
assert joined_value == [({"left": 2, "right": 1}, b"key", 2, None)]

0 commit comments

Comments
 (0)