Skip to content

Commit afce18f

Browse files
Avoid in-place multiplication of a large value to an array with small integer dtype (#8867)
* Avoid inplace multiplication * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_plot.py * Update test_plot.py * Update dataarray_plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ffb30a8 commit afce18f

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

xarray/plot/dataarray_plot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,9 +1848,10 @@ def _center_pixels(x):
18481848
# missing data transparent. We therefore add an alpha channel if
18491849
# there isn't one, and set it to transparent where data is masked.
18501850
if z.shape[-1] == 3:
1851-
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype)
1851+
safe_dtype = np.promote_types(z.dtype, np.uint8)
1852+
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype)
18521853
if np.issubdtype(z.dtype, np.integer):
1853-
alpha *= 255
1854+
alpha[:] = 255
18541855
z = np.ma.concatenate((z, alpha), axis=2)
18551856
else:
18561857
z = z.copy()

xarray/tests/test_plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,15 +2040,17 @@ def test_normalize_rgb_one_arg_error(self) -> None:
20402040
for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)):
20412041
da.plot.imshow(vmin=vmin2, vmax=vmax2)
20422042

2043-
def test_imshow_rgb_values_in_valid_range(self) -> None:
2044-
da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3)))
2043+
@pytest.mark.parametrize("dtype", [np.uint8, np.int8, np.int16])
2044+
def test_imshow_rgb_values_in_valid_range(self, dtype) -> None:
2045+
da = DataArray(np.arange(75, dtype=dtype).reshape((5, 5, 3)))
20452046
_, ax = plt.subplots()
20462047
out = da.plot.imshow(ax=ax).get_array()
20472048
assert out is not None
2048-
dtype = out.dtype
2049-
assert dtype is not None
2050-
assert dtype == np.uint8
2049+
actual_dtype = out.dtype
2050+
assert actual_dtype is not None
2051+
assert actual_dtype == np.uint8
20512052
assert (out[..., :3] == da.values).all() # Compare without added alpha
2053+
assert (out[..., -1] == 255).all() # Compare alpha
20522054

20532055
@pytest.mark.filterwarnings("ignore:Several dimensions of this array")
20542056
def test_regression_rgb_imshow_dim_size_one(self) -> None:

0 commit comments

Comments
 (0)