Skip to content

Commit dafcde2

Browse files
authored
Use compute instead of load in plot (#9818)
* Only compute this array, load computes in place * Add test * Update whats-new.rst * Update whats-new.rst
1 parent caf62d3 commit dafcde2

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ Deprecations
3636

3737
Bug fixes
3838
~~~~~~~~~
39-
39+
- Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`).
40+
By `Jimmy Westling <https://github.com/illviljan>`_.
4041

4142
Documentation
4243
~~~~~~~~~~~~~

xarray/plot/dataarray_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def newplotfunc(
946946
# Remove any nulls, .where(m, drop=True) doesn't work when m is
947947
# a dask array, so load the array to memory.
948948
# It will have to be loaded to memory at some point anyway:
949-
darray = darray.load()
949+
darray = darray.compute()
950950
darray = darray.where(darray.notnull(), drop=True)
951951
else:
952952
size_ = kwargs.pop("_size", linewidth)

xarray/tests/test_plot.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
assert_no_warnings,
3434
requires_cartopy,
3535
requires_cftime,
36+
requires_dask,
3637
requires_matplotlib,
3738
requires_seaborn,
3839
)
@@ -3326,6 +3327,24 @@ def test_datarray_scatter(
33263327
)
33273328

33283329

3330+
@requires_dask
3331+
@requires_matplotlib
3332+
@pytest.mark.parametrize(
3333+
"plotfunc",
3334+
["scatter"],
3335+
)
3336+
def test_dataarray_not_loading_inplace(plotfunc: str) -> None:
3337+
ds = xr.tutorial.scatter_example_dataset()
3338+
ds = ds.chunk()
3339+
3340+
with figure_context():
3341+
getattr(ds.A.plot, plotfunc)(x="x")
3342+
3343+
from dask.array import Array
3344+
3345+
assert isinstance(ds.A.data, Array)
3346+
3347+
33293348
@requires_matplotlib
33303349
def test_assert_valid_xy() -> None:
33313350
ds = xr.tutorial.scatter_example_dataset()

0 commit comments

Comments
 (0)