Skip to content

Commit 62a4d0f

Browse files
Allow callables to .drop_vars (#8511)
* Allow callables to `.drop_vars` This can be used as a nice more general alternative to `.drop_indexes` or `.reset_coords(drop=True)` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * . --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b8b7857 commit 62a4d0f

File tree

5 files changed

+61
-21
lines changed

5 files changed

+61
-21
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ New Features
3030
- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming
3131
the core dim has exactly one chunk. (:pull:`8475`).
3232
By `Maximilian Roos <https://github.com/max-sixty>`_.
33+
- :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` allow passing a
34+
callable, similar to :py:meth:`Dataset.where` & :py:meth:`Dataset.sortby` & others.
35+
(:pull:`8511`).
36+
By `Maximilian Roos <https://github.com/max-sixty>`_.
3337

3438
Breaking changes
3539
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,16 +3041,17 @@ def T(self) -> Self:
30413041

30423042
def drop_vars(
30433043
self,
3044-
names: Hashable | Iterable[Hashable],
3044+
names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
30453045
*,
30463046
errors: ErrorOptions = "raise",
30473047
) -> Self:
30483048
"""Returns an array with dropped variables.
30493049
30503050
Parameters
30513051
----------
3052-
names : Hashable or iterable of Hashable
3053-
Name(s) of variables to drop.
3052+
names : Hashable or iterable of Hashable or Callable
3053+
Name(s) of variables to drop. If a Callable, this object is passed as its
3054+
only argument and its result is used.
30543055
errors : {"raise", "ignore"}, default: "raise"
30553056
If 'raise', raises a ValueError error if any of the variable
30563057
passed are not in the dataset. If 'ignore', any given names that are in the
@@ -3100,7 +3101,17 @@ def drop_vars(
31003101
[ 6, 7, 8],
31013102
[ 9, 10, 11]])
31023103
Dimensions without coordinates: x, y
3104+
3105+
>>> da.drop_vars(lambda x: x.coords)
3106+
<xarray.DataArray (x: 4, y: 3)>
3107+
array([[ 0, 1, 2],
3108+
[ 3, 4, 5],
3109+
[ 6, 7, 8],
3110+
[ 9, 10, 11]])
3111+
Dimensions without coordinates: x, y
31033112
"""
3113+
if callable(names):
3114+
names = names(self)
31043115
ds = self._to_temp_dataset().drop_vars(names, errors=errors)
31053116
return self._from_temp_dataset(ds)
31063117

xarray/core/dataset.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5741,16 +5741,17 @@ def _assert_all_in_dataset(
57415741

57425742
def drop_vars(
57435743
self,
5744-
names: Hashable | Iterable[Hashable],
5744+
names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
57455745
*,
57465746
errors: ErrorOptions = "raise",
57475747
) -> Self:
57485748
"""Drop variables from this dataset.
57495749
57505750
Parameters
57515751
----------
5752-
names : hashable or iterable of hashable
5753-
Name(s) of variables to drop.
5752+
names : Hashable or iterable of Hashable or Callable
5753+
Name(s) of variables to drop. If a Callable, this object is passed as its
5754+
only argument and its result is used.
57545755
errors : {"raise", "ignore"}, default: "raise"
57555756
If 'raise', raises a ValueError error if any of the variable
57565757
passed are not in the dataset. If 'ignore', any given names that are in the
@@ -5792,7 +5793,7 @@ def drop_vars(
57925793
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
57935794
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
57945795
5795-
# Drop the 'humidity' variable
5796+
Drop the 'humidity' variable
57965797
57975798
>>> dataset.drop_vars(["humidity"])
57985799
<xarray.Dataset>
@@ -5805,7 +5806,7 @@ def drop_vars(
58055806
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
58065807
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
58075808
5808-
# Drop the 'humidity', 'temperature' variables
5809+
Drop the 'humidity', 'temperature' variables
58095810
58105811
>>> dataset.drop_vars(["humidity", "temperature"])
58115812
<xarray.Dataset>
@@ -5817,7 +5818,18 @@ def drop_vars(
58175818
Data variables:
58185819
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
58195820
5820-
# Attempt to drop non-existent variable with errors="ignore"
5821+
Drop all indexes
5822+
5823+
>>> dataset.drop_vars(lambda x: x.indexes)
5824+
<xarray.Dataset>
5825+
Dimensions: (time: 1, latitude: 2, longitude: 2)
5826+
Dimensions without coordinates: time, latitude, longitude
5827+
Data variables:
5828+
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
5829+
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
5830+
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
5831+
5832+
Attempt to drop non-existent variable with errors="ignore"
58215833
58225834
>>> dataset.drop_vars(["pressure"], errors="ignore")
58235835
<xarray.Dataset>
@@ -5831,7 +5843,7 @@ def drop_vars(
58315843
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
58325844
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
58335845
5834-
# Attempt to drop non-existent variable with errors="raise"
5846+
Attempt to drop non-existent variable with errors="raise"
58355847
58365848
>>> dataset.drop_vars(["pressure"], errors="raise")
58375849
Traceback (most recent call last):
@@ -5851,36 +5863,38 @@ def drop_vars(
58515863
DataArray.drop_vars
58525864
58535865
"""
5866+
if callable(names):
5867+
names = names(self)
58545868
# the Iterable check is required for mypy
58555869
if is_scalar(names) or not isinstance(names, Iterable):
5856-
names = {names}
5870+
names_set = {names}
58575871
else:
5858-
names = set(names)
5872+
names_set = set(names)
58595873
if errors == "raise":
5860-
self._assert_all_in_dataset(names)
5874+
self._assert_all_in_dataset(names_set)
58615875

58625876
# GH6505
58635877
other_names = set()
5864-
for var in names:
5878+
for var in names_set:
58655879
maybe_midx = self._indexes.get(var, None)
58665880
if isinstance(maybe_midx, PandasMultiIndex):
58675881
idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim])
5868-
idx_other_names = idx_coord_names - set(names)
5882+
idx_other_names = idx_coord_names - set(names_set)
58695883
other_names.update(idx_other_names)
58705884
if other_names:
5871-
names |= set(other_names)
5885+
names_set |= set(other_names)
58725886
warnings.warn(
58735887
f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. "
58745888
f"Please also drop the following variables: {other_names!r} to avoid an error in the future.",
58755889
DeprecationWarning,
58765890
stacklevel=2,
58775891
)
58785892

5879-
assert_no_index_corrupted(self.xindexes, names)
5893+
assert_no_index_corrupted(self.xindexes, names_set)
58805894

5881-
variables = {k: v for k, v in self._variables.items() if k not in names}
5895+
variables = {k: v for k, v in self._variables.items() if k not in names_set}
58825896
coord_names = {k for k in self._coord_names if k in variables}
5883-
indexes = {k: v for k, v in self._indexes.items() if k not in names}
5897+
indexes = {k: v for k, v in self._indexes.items() if k not in names_set}
58845898
return self._replace_with_new_dims(
58855899
variables, coord_names=coord_names, indexes=indexes
58865900
)
@@ -5978,6 +5992,9 @@ def drop(
59785992
"dropping variables using `drop` is deprecated; use drop_vars.",
59795993
DeprecationWarning,
59805994
)
5995+
# for mypy
5996+
if is_scalar(labels):
5997+
labels = [labels]
59815998
return self.drop_vars(labels, errors=errors)
59825999
if dim is not None:
59836000
warnings.warn(

xarray/core/resample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _drop_coords(self) -> T_Xarray:
6363
obj = self._obj
6464
for k, v in obj.coords.items():
6565
if k != self._dim and self._dim in v.dims:
66-
obj = obj.drop_vars(k)
66+
obj = obj.drop_vars([k])
6767
return obj
6868

6969
def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray:
@@ -244,7 +244,7 @@ def map(
244244
# dimension, then we need to do so before we can rename the proxy
245245
# dimension we used.
246246
if self._dim in combined.coords:
247-
combined = combined.drop_vars(self._dim)
247+
combined = combined.drop_vars([self._dim])
248248

249249
if RESAMPLE_DIM in combined.dims:
250250
combined = combined.rename({RESAMPLE_DIM: self._dim})

xarray/tests/test_dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,6 +2652,14 @@ def test_drop_coordinates(self) -> None:
26522652
actual = renamed.drop_vars("foo", errors="ignore")
26532653
assert_identical(actual, renamed)
26542654

2655+
def test_drop_vars_callable(self) -> None:
2656+
A = DataArray(
2657+
np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]}
2658+
)
2659+
expected = A.drop_vars(["x", "y"])
2660+
actual = A.drop_vars(lambda x: x.indexes)
2661+
assert_identical(expected, actual)
2662+
26552663
def test_drop_multiindex_level(self) -> None:
26562664
# GH6505
26572665
expected = self.mda.drop_vars(["x", "level_1", "level_2"])

0 commit comments

Comments
 (0)