Skip to content

Commit 41e8cae

Browse files
committed
Optional tolerance feature to isin per issue #5587
1 parent 07de257 commit 41e8cae

File tree

4 files changed

+71
-1
lines changed

4 files changed

+71
-1
lines changed

xarray/core/common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,12 +1384,13 @@ def notnull(self, keep_attrs: bool = None):
13841384
keep_attrs=keep_attrs,
13851385
)
13861386

1387-
def isin(self, test_elements):
1387+
def isin(self, test_elements, tolerance=None):
13881388
"""Tests each value in the array for whether it is in test elements.
13891389
13901390
Parameters
13911391
----------
13921392
test_elements : array_like
1393+
tolerance : dtype
13931394
The values against which to test each value of `element`.
13941395
This argument is flattened if an array or array_like.
13951396
See numpy notes for behavior with non-array-like parameters.
@@ -1407,6 +1408,12 @@ def isin(self, test_elements):
14071408
array([ True, False, True])
14081409
Dimensions without coordinates: x
14091410
1411+
>>> array = xr.DataArray([1, 2, 3], dims="x")
1412+
>>> array.isin([1.1, 2.9], tolerance = 0.2)
1413+
<xarray.DataArray (x: 3)>
1414+
array([ True, False, True])
1415+
Dimensions without coordinates: x
1416+
14101417
See Also
14111418
--------
14121419
numpy.isin
@@ -1427,6 +1434,15 @@ def isin(self, test_elements):
14271434
# second argument
14281435
test_elements = test_elements.data
14291436

1437+
if tolerance:
1438+
# non-zero & None arguments
1439+
return apply_ufunc(
1440+
duck_array_ops.isin_tolerance,
1441+
self,
1442+
kwargs=dict(test_elements=test_elements, tolerance=tolerance),
1443+
dask="allowed",
1444+
)
1445+
14301446
return apply_ufunc(
14311447
duck_array_ops.isin,
14321448
self,

xarray/core/duck_array_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,38 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
467467
return np.where(isnull(array), np.nan, array.astype(dtype))
468468

469469

470+
def isin_tolerance(self, test_elements, tolerance):
471+
"""Compare self.values to test_elements elementwise.
472+
Parameters
473+
----------
474+
self : numpy.array_like
475+
test_elements : numpy.array_like
476+
tolerance : dtype
477+
Absolute value of accemptable range between self and test_elements.
478+
479+
Returns
480+
-------
481+
array_like : Same shape as self, but contains bool values.
482+
483+
Notes
484+
-----
485+
Vectorized comparasons elementwise require immense memory for larger datasets
486+
because it generates np.array with shape (self.shape, test_elements.shape)
487+
488+
"""
489+
merge_axis = (
490+
*[
491+
mergeaxis
492+
for mergeaxis in range(
493+
len(self.shape), len(self.shape) + len(test_elements.shape)
494+
)
495+
],
496+
)
497+
return (np.abs(np.subtract.outer(self, test_elements)) < abs(tolerance)).any(
498+
merge_axis
499+
)
500+
501+
470502
def timedelta_to_numeric(value, datetime_unit="ns", dtype=float):
471503
"""Convert a timedelta-like object to numerical values.
472504

xarray/tests/test_duck_array_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
push,
2424
py_timedelta_to_float,
2525
stack,
26+
isin_tolerance,
2627
timedelta_to_numeric,
2728
where,
2829
)
@@ -892,3 +893,19 @@ def test_push_dask():
892893
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=None
893894
)
894895
np.testing.assert_equal(actual, expected)
896+
897+
898+
@pytest.mark.parametrize("shape", [(200, 1), (10, 10, 2), (4, 50)])
899+
@pytest.mark.parametrize("tolerance", [1e-2, 1e-4, 1e-6])
900+
def test_isin_tolerance(shape, tolerance):
901+
in_margin = tolerance / 2 # in margin
902+
arrayA = np.arange(-10.0, 10.0, 0.1).reshape(shape)
903+
expected = np.array([item % 2 == 0 for item in range(0, arrayA.size)]).reshape(
904+
shape
905+
)
906+
for c in range(1, 5):
907+
# generate test set
908+
arrayB = -99 * (~expected.flatten()) + (in_margin + arrayA * expected).flatten()
909+
with raise_if_dask_computes():
910+
actual = isin_tolerance(arrayA, arrayB, tolerance)
911+
np.testing.assert_equal(actual, expected)

xarray/tests/test_sparse.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,11 @@ def test_dataarray_property(prop):
454454
False,
455455
marks=xfail(reason="Missing implementation for np.isin"),
456456
),
457+
param(
458+
do("isin_tolerance", [1 - 1e-7, 2, 3 + 1e-7], tolerance=1e-6),
459+
False,
460+
marks=xfail(reason="Missing implementation for isin_tolerance"),
461+
),
457462
param(
458463
do("item", (1, 1)),
459464
False,

0 commit comments

Comments
 (0)