Skip to content

Commit 63d7eb9

Browse files
Generalize set_(x, y, z)labels in facetgrids (#6918)
* Generalize set_xlabels * Update facetgrid.py * Add some typing and docstring fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fbaf815 commit 63d7eb9

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

xarray/plot/facetgrid.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import itertools
55
import warnings
6+
from typing import Iterable
67

78
import numpy as np
89

@@ -470,39 +471,39 @@ def add_quiverkey(self, u, v, **kwargs):
470471
# self._adjust_fig_for_guide(self.quiverkey.text)
471472
return self
472473

473-
def set_axis_labels(self, x_var=None, y_var=None):
474+
def set_axis_labels(self, *axlabels):
474475
"""Set axis labels on the left column and bottom row of the grid."""
475-
if x_var is not None:
476-
if x_var in self.data.coords:
477-
self._x_var = x_var
478-
self.set_xlabels(label_from_attrs(self.data[x_var]))
479-
else:
480-
# x_var is a string
481-
self.set_xlabels(x_var)
482-
483-
if y_var is not None:
484-
if y_var in self.data.coords:
485-
self._y_var = y_var
486-
self.set_ylabels(label_from_attrs(self.data[y_var]))
487-
else:
488-
self.set_ylabels(y_var)
476+
from ..core.dataarray import DataArray
477+
478+
for var, axis in zip(axlabels, ["x", "y", "z"]):
479+
if var is not None:
480+
if isinstance(var, DataArray):
481+
getattr(self, f"set_{axis}labels")(label_from_attrs(var))
482+
else:
483+
getattr(self, f"set_{axis}labels")(var)
484+
489485
return self
490486

491-
def set_xlabels(self, label=None, **kwargs):
492-
"""Label the x axis on the bottom row of the grid."""
487+
def _set_labels(
488+
self, axis: str, axes: Iterable, label: None | str = None, **kwargs
489+
):
493490
if label is None:
494-
label = label_from_attrs(self.data[self._x_var])
495-
for ax in self._bottom_axes:
496-
ax.set_xlabel(label, **kwargs)
491+
label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")])
492+
for ax in axes:
493+
getattr(ax, f"set_{axis}label")(label, **kwargs)
497494
return self
498495

499-
def set_ylabels(self, label=None, **kwargs):
496+
def set_xlabels(self, label: None | str = None, **kwargs) -> None:
497+
"""Label the x axis on the bottom row of the grid."""
498+
self._set_labels("x", self._bottom_axes, label, **kwargs)
499+
500+
def set_ylabels(self, label: None | str = None, **kwargs) -> None:
500501
"""Label the y axis on the left column of the grid."""
501-
if label is None:
502-
label = label_from_attrs(self.data[self._y_var])
503-
for ax in self._left_axes:
504-
ax.set_ylabel(label, **kwargs)
505-
return self
502+
self._set_labels("y", self._left_axes, label, **kwargs)
503+
504+
def set_zlabels(self, label: None | str = None, **kwargs) -> None:
505+
"""Label the z axis."""
506+
self._set_labels("z", self._left_axes, label, **kwargs)
506507

507508
def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs):
508509
"""

0 commit comments

Comments
 (0)