|
3 | 3 | import functools
|
4 | 4 | import itertools
|
5 | 5 | import warnings
|
| 6 | +from typing import Iterable |
6 | 7 |
|
7 | 8 | import numpy as np
|
8 | 9 |
|
@@ -470,39 +471,39 @@ def add_quiverkey(self, u, v, **kwargs):
|
470 | 471 | # self._adjust_fig_for_guide(self.quiverkey.text)
|
471 | 472 | return self
|
472 | 473 |
|
473 |
| - def set_axis_labels(self, x_var=None, y_var=None): |
| 474 | + def set_axis_labels(self, *axlabels): |
474 | 475 | """Set axis labels on the left column and bottom row of the grid."""
|
475 |
| - if x_var is not None: |
476 |
| - if x_var in self.data.coords: |
477 |
| - self._x_var = x_var |
478 |
| - self.set_xlabels(label_from_attrs(self.data[x_var])) |
479 |
| - else: |
480 |
| - # x_var is a string |
481 |
| - self.set_xlabels(x_var) |
482 |
| - |
483 |
| - if y_var is not None: |
484 |
| - if y_var in self.data.coords: |
485 |
| - self._y_var = y_var |
486 |
| - self.set_ylabels(label_from_attrs(self.data[y_var])) |
487 |
| - else: |
488 |
| - self.set_ylabels(y_var) |
| 476 | + from ..core.dataarray import DataArray |
| 477 | + |
| 478 | + for var, axis in zip(axlabels, ["x", "y", "z"]): |
| 479 | + if var is not None: |
| 480 | + if isinstance(var, DataArray): |
| 481 | + getattr(self, f"set_{axis}labels")(label_from_attrs(var)) |
| 482 | + else: |
| 483 | + getattr(self, f"set_{axis}labels")(var) |
| 484 | + |
489 | 485 | return self
|
490 | 486 |
|
491 |
| - def set_xlabels(self, label=None, **kwargs): |
492 |
| - """Label the x axis on the bottom row of the grid.""" |
| 487 | + def _set_labels( |
| 488 | + self, axis: str, axes: Iterable, label: None | str = None, **kwargs |
| 489 | + ): |
493 | 490 | if label is None:
|
494 |
| - label = label_from_attrs(self.data[self._x_var]) |
495 |
| - for ax in self._bottom_axes: |
496 |
| - ax.set_xlabel(label, **kwargs) |
| 491 | + label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")]) |
| 492 | + for ax in axes: |
| 493 | + getattr(ax, f"set_{axis}label")(label, **kwargs) |
497 | 494 | return self
|
498 | 495 |
|
499 |
| - def set_ylabels(self, label=None, **kwargs): |
| 496 | + def set_xlabels(self, label: None | str = None, **kwargs) -> None: |
| 497 | + """Label the x axis on the bottom row of the grid.""" |
| 498 | + self._set_labels("x", self._bottom_axes, label, **kwargs) |
| 499 | + |
| 500 | + def set_ylabels(self, label: None | str = None, **kwargs) -> None: |
500 | 501 | """Label the y axis on the left column of the grid."""
|
501 |
| - if label is None: |
502 |
| - label = label_from_attrs(self.data[self._y_var]) |
503 |
| - for ax in self._left_axes: |
504 |
| - ax.set_ylabel(label, **kwargs) |
505 |
| - return self |
| 502 | + self._set_labels("y", self._left_axes, label, **kwargs) |
| 503 | + |
| 504 | + def set_zlabels(self, label: None | str = None, **kwargs) -> None: |
| 505 | + """Label the z axis.""" |
| 506 | + self._set_labels("z", self._left_axes, label, **kwargs) |
506 | 507 |
|
507 | 508 | def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs):
|
508 | 509 | """
|
|
0 commit comments