Skip to content

Fix step plots with hue #6944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ Bug fixes
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Harmonize returned multi-indexed indexes when applying ``concat`` along new dimension (:issue:`6881`, :pull:`6889`)
By `Fabian Hofmann <https://github.com/FabianHofmann>`_.
- Fix step plots with ``hue`` arg. (:pull:`6944`)
By `András Gunyhó <https://github.com/mgunyho>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
9 changes: 6 additions & 3 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,16 @@ def _resolve_intervals_1dplot(
if kwargs.get("drawstyle", "").startswith("steps-"):

remove_drawstyle = False

# Convert intervals to double points
if _valid_other_type(np.array([xval, yval]), [pd.Interval]):
x_is_interval = _valid_other_type(xval, [pd.Interval])
y_is_interval = _valid_other_type(yval, [pd.Interval])
if x_is_interval and y_is_interval:
raise TypeError("Can't step plot intervals against intervals.")
if _valid_other_type(xval, [pd.Interval]):
elif x_is_interval:
xval, yval = _interval_to_double_bound_points(xval, yval)
remove_drawstyle = True
if _valid_other_type(yval, [pd.Interval]):
elif y_is_interval:
yval, xval = _interval_to_double_bound_points(yval, xval)
remove_drawstyle = True

Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,24 @@ def test_step_with_where(self, where):
hdl = self.darray[0, 0].plot.step(where=where)
assert hdl[0].get_drawstyle() == f"steps-{where}"

def test_step_with_hue(self):
hdl = self.darray[0].plot.step(hue="dim_2")
assert hdl[0].get_drawstyle() == "steps-pre"

@pytest.mark.parametrize("where", ["pre", "post", "mid"])
def test_step_with_hue_and_where(self, where):
hdl = self.darray[0].plot.step(hue="dim_2", where=where)
assert hdl[0].get_drawstyle() == f"steps-{where}"

def test_drawstyle_steps(self):
hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps")
assert hdl[0].get_drawstyle() == "steps"

@pytest.mark.parametrize("where", ["pre", "post", "mid"])
def test_drawstyle_steps_with_where(self, where):
hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}")
assert hdl[0].get_drawstyle() == f"steps-{where}"

def test_coord_with_interval_step(self):
"""Test step plot with intervals."""
bins = [-1, 0, 1, 2]
Expand All @@ -814,6 +832,15 @@ def test_coord_with_interval_step_y(self):
self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins")
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2)

def test_coord_with_interval_step_x_and_y_raises_valueeerror(self):
"""Test that step plot with intervals both on x and y axes raises an error."""
arr = xr.DataArray(
[pd.Interval(0, 1), pd.Interval(1, 2)],
coords=[("x", [pd.Interval(0, 1), pd.Interval(1, 2)])],
)
with pytest.raises(TypeError, match="intervals against intervals"):
arr.plot.step()


class TestPlotHistogram(PlotTestCase):
@pytest.fixture(autouse=True)
Expand Down