diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 079c5f51c95..d77422df5b4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ Bug fixes By `Michael Niklas `_. - Harmonize returned multi-indexed indexes when applying ``concat`` along new dimension (:issue:`6881`, :pull:`6889`) By `Fabian Hofmann `_. +- Fix step plots with ``hue`` arg. (:pull:`6944`) + By `András Gunyhó `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f004a2645c9..f106d56689c 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -564,13 +564,16 @@ def _resolve_intervals_1dplot( if kwargs.get("drawstyle", "").startswith("steps-"): remove_drawstyle = False + # Convert intervals to double points - if _valid_other_type(np.array([xval, yval]), [pd.Interval]): + x_is_interval = _valid_other_type(xval, [pd.Interval]) + y_is_interval = _valid_other_type(yval, [pd.Interval]) + if x_is_interval and y_is_interval: raise TypeError("Can't step plot intervals against intervals.") - if _valid_other_type(xval, [pd.Interval]): + elif x_is_interval: xval, yval = _interval_to_double_bound_points(xval, yval) remove_drawstyle = True - if _valid_other_type(yval, [pd.Interval]): + elif y_is_interval: yval, xval = _interval_to_double_bound_points(yval, xval) remove_drawstyle = True diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 08bf6af8a66..f37c2fd7508 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -796,6 +796,24 @@ def test_step_with_where(self, where): hdl = self.darray[0, 0].plot.step(where=where) assert hdl[0].get_drawstyle() == f"steps-{where}" + def test_step_with_hue(self): + hdl = self.darray[0].plot.step(hue="dim_2") + assert hdl[0].get_drawstyle() == "steps-pre" + + @pytest.mark.parametrize("where", ["pre", "post", "mid"]) + def test_step_with_hue_and_where(self, where): + hdl = self.darray[0].plot.step(hue="dim_2", where=where) + assert hdl[0].get_drawstyle() == f"steps-{where}" + + def test_drawstyle_steps(self): + hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps") + assert hdl[0].get_drawstyle() == "steps" + + @pytest.mark.parametrize("where", ["pre", "post", "mid"]) + def test_drawstyle_steps_with_where(self, where): + hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}") + assert hdl[0].get_drawstyle() == f"steps-{where}" + def test_coord_with_interval_step(self): """Test step plot with intervals.""" bins = [-1, 0, 1, 2] @@ -814,6 +832,15 @@ def test_coord_with_interval_step_y(self): self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + def test_coord_with_interval_step_x_and_y_raises_valueeerror(self): + """Test that step plot with intervals both on x and y axes raises an error.""" + arr = xr.DataArray( + [pd.Interval(0, 1), pd.Interval(1, 2)], + coords=[("x", [pd.Interval(0, 1), pd.Interval(1, 2)])], + ) + with pytest.raises(TypeError, match="intervals against intervals"): + arr.plot.step() + class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True)