diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 8f41e4fb..afce09b5 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1587,7 +1587,7 @@ def shared(paxs): return [pax for pax in paxs if not pax._panel_hidden and pax._panel_share] # Internal axis sharing, share stacks of panels and main axes with each other - # NOTE: This is called on the main axes whenver a panel is created. + # NOTE: This is called on the main axes whenever a panel is created. # NOTE: This block is why, even though we have figure-wide share[xy], we # still need the axes-specific _share[xy]_override attribute. if not self._panel_side: # this is a main axes @@ -3254,7 +3254,6 @@ def _is_panel_group_member(self, other: "Axes") -> bool: and self._panel_parent is other._panel_parent ): return True - # Not in the same panel group return False diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 8fc9742e..912facea 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -9,6 +9,8 @@ import matplotlib.ticker as mticker import numpy as np +from packaging import version + from .. import constructor from .. import scale as pscale from .. import ticker as pticker @@ -16,6 +18,7 @@ from ..internals import ic # noqa: F401 from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings from . import plot, shared +import matplotlib.axis as maxis __all__ = ["CartesianAxes"] @@ -373,7 +376,6 @@ def _apply_axis_sharing(self): Enforce the "shared" axis labels and axis tick labels. If this is not called at drawtime, "shared" labels can be inadvertantly turned off. """ - # X axis # NOTE: Critical to apply labels to *shared* axes attributes rather # than testing extents or we end up sharing labels with twin axes. # NOTE: Similar to how _align_super_labels() calls _apply_title_above() this @@ -381,27 +383,166 @@ def _apply_axis_sharing(self): # NOTE: The "panel sharing group" refers to axes and panels *above* the # bottommost or to the *right* of the leftmost panel. But the sharing level # used for the leftmost and bottommost is the *figure* sharing level. - axis = self.xaxis - if self._sharex is not None and axis.get_visible(): - level = 3 if self._panel_sharex_group else self.figure._sharex - if level > 0: - labels._transfer_label(axis.label, self._sharex.xaxis.label) - axis.label.set_visible(False) - if level > 2: - # WARNING: Cannot set NullFormatter because shared axes share the - # same Ticker(). Instead use approach copied from mpl subplots(). - axis.set_tick_params(which="both", labelbottom=False, labeltop=False) - # Y axis - axis = self.yaxis - if self._sharey is not None and axis.get_visible(): - level = 3 if self._panel_sharey_group else self.figure._sharey - if level > 0: - labels._transfer_label(axis.label, self._sharey.yaxis.label) - axis.label.set_visible(False) - if level > 2: - axis.set_tick_params(which="both", labelleft=False, labelright=False) + + # Get border axes once for efficiency + border_axes = self.figure._get_border_axes() + + # Apply X axis sharing + self._apply_axis_sharing_for_axis("x", border_axes) + + # Apply Y axis sharing + self._apply_axis_sharing_for_axis("y", border_axes) + + def _apply_axis_sharing_for_axis( + self, + axis_name: str, + border_axes: dict[str, plot.PlotAxes], + ) -> None: + """ + Apply axis sharing for a specific axis (x or y). + + Parameters + ---------- + axis_name : str + Either 'x' or 'y' + border_axes : dict + Dictionary from _get_border_axes() containing border information + """ + if axis_name == "x": + axis = self.xaxis + shared_axis = self._sharex + panel_group = self._panel_sharex_group + sharing_level = self.figure._sharex + label_params = ["labeltop", "labelbottom"] + border_sides = ["top", "bottom"] + else: # axis_name == 'y' + axis = self.yaxis + shared_axis = self._sharey + panel_group = self._panel_sharey_group + sharing_level = self.figure._sharey + label_params = ["labelleft", "labelright"] + border_sides = ["left", "right"] + + if shared_axis is None or not axis.get_visible(): + return + + level = 3 if panel_group else sharing_level + + # Handle axis label sharing (level > 0) + if level > 0: + shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") + labels._transfer_label(axis.label, shared_axis_obj.label) + axis.label.set_visible(False) + + # Handle tick label sharing (level > 2) + if level > 2: + label_visibility = self._determine_tick_label_visibility( + axis, + shared_axis, + axis_name, + label_params, + border_sides, + border_axes, + ) + axis.set_tick_params(which="both", **label_visibility) + # Turn minor ticks off axis.set_minor_formatter(mticker.NullFormatter()) + def _determine_tick_label_visibility( + self, + axis: maxis.Axis, + shared_axis: maxis.Axis, + axis_name: str, + label_params: list[str], + border_sides: list[str], + border_axes: dict[str, list[plot.PlotAxes]], + ) -> dict[str, bool]: + """ + Determine which tick labels should be visible based on sharing rules and borders. + + Parameters + ---------- + axis : matplotlib axis + The current axis object + shared_axis : Axes + The axes this one shares with + axis_name : str + Either 'x' or 'y' + label_params : list + List of label parameter names (e.g., ['labeltop', 'labelbottom']) + border_sides : list + List of border side names (e.g., ['top', 'bottom']) + border_axes : dict + Dictionary from _get_border_axes() + + Returns + ------- + dict + Dictionary of label visibility parameters + """ + ticks = axis.get_tick_params() + shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") + sharing_ticks = shared_axis_obj.get_tick_params() + + label_visibility = {} + + def _convert_label_param(label_param: str) -> str: + # Deal with logic not being consistent + # in prior mpl versions + if version.parse(str(_version_mpl)) <= version.parse("3.9"): + if label_param == "labeltop" and axis_name == "x": + label_param = "labelright" + elif label_param == "labelbottom" and axis_name == "x": + label_param = "labelleft" + return label_param + + for label_param, border_side in zip(label_params, border_sides): + # Check if user has explicitly set label location via format() + label_visibility[label_param] = False + has_panel = False + for panel in self._panel_dict[border_side]: + # Check if the panel is a colorbar + colorbars = [ + values + for key, values in self._colorbar_dict.items() + if border_side in key # key is tuple (side, top | center | lower) + ] + if not panel in colorbars: + # Skip colorbar as their + # yaxis is not shared + has_panel = True + break + # When we have a panel, let the panel have + # the labels and turn-off for this axis + side. + if has_panel: + continue + is_border = self in border_axes.get(border_side, []) + is_panel = ( + self in shared_axis._panel_dict[border_side] + and self == shared_axis._panel_dict[border_side][-1] + ) + + # Use automatic border detection logic + # if we are a panel we "push" the labels outwards + if is_border or is_panel: + # Deal with mpl version for label_param + label_param = _convert_label_param(label_param) + is_this_tick_on = ticks[label_param] + is_parent_tick_on = sharing_ticks[label_param] + # Only turn on the labels for the current axis + # if the axis it is sharing with is a main + # and we are not panel + # For shared axes we turn them on if either or are on, but turn off the parent + if is_this_tick_on or is_parent_tick_on: + # Note: we set the current axis to visible + # as we are dealing with borders + # or panels + getattr(shared_axis, f"{axis_name}axis").set_tick_params( + **{label_param: False} + ) + label_visibility[label_param] = True + return label_visibility + def _add_alt(self, sx, **kwargs): """ Add an alternate axes. diff --git a/ultraplot/axes/shared.py b/ultraplot/axes/shared.py index 8fd252d0..57d5abe0 100644 --- a/ultraplot/axes/shared.py +++ b/ultraplot/axes/shared.py @@ -12,6 +12,13 @@ from ..utils import _fontsize_to_pt, _not_none, units from ..axes import Axes +try: + # From python 3.12 + from typing import override +except ImportError: + # From Python 3.5 + from typing_extensions import override + class _SharedAxes(object): """ @@ -186,10 +193,11 @@ def _update_ticks( for lab in obj.get_ticklabels(): lab.update(kwtext_extra) - # Override matplotlib defaults to handle multiple axis sharing + @override def sharex(self, other): return self._share_axis_with(other, which="x") + @override def sharey(self, other): self._share_axis_with(other, which="y") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 645ac79a..b319a82e 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -35,7 +35,7 @@ labels, warnings, ) -from .utils import units +from .utils import units, _get_subplot_layout, _Crawler __all__ = [ "Figure", @@ -905,7 +905,9 @@ def _get_align_axes(self, side): axs = [ax for ax in axs if ax.get_visible()] return axs - def _get_border_axes(self) -> dict[str, list[paxes.Axes]]: + def _get_border_axes( + self, *, same_type=False, force_recalculate=False + ) -> dict[str, list[paxes.Axes]]: """ Identifies axes located on the outer boundaries of the GridSpec layout. @@ -913,7 +915,15 @@ def _get_border_axes(self) -> dict[str, list[paxes.Axes]]: containing a list of axes on that border. """ - border_axes = dict(top=[], bottom=[], left=[], right=[]) + if hasattr(self, "_cached_border_axes") and not force_recalculate: + return self._cached_border_axes + + border_axes = dict( + left=[], + right=[], + top=[], + bottom=[], + ) gs = self.gridspec if gs is None: return border_axes @@ -931,68 +941,28 @@ def _get_border_axes(self) -> dict[str, list[paxes.Axes]]: # Reconstruct the grid based on axis locations. Note that # spanning axes will fit into one of the boxes. Check # this with unittest to see how empty axes are handles - grid = np.zeros((gs.nrows, gs.ncols)) - for axi in all_axes: - # Infer coordinate from grdispec - spec = axi.get_subplotspec() - spans = spec._get_rows_columns() - rowspans = spans[:2] - colspans = spans[-2:] - - grid[ - rowspans[0] : rowspans[1] + 1, - colspans[0] : colspans[1] + 1, - ] = axi.number - directions = { - "left": (0, -1), - "right": (0, 1), - "top": (-1, 0), - "bottom": (1, 0), - } - - def is_border(pos, grid, target, direction): - x, y = pos - # Check if we are at an edge of the grid (out-of-bounds). - if x < 0: - return True - elif x > grid.shape[0] - 1: - return True - - if y < 0: - return True - elif y > grid.shape[1] - 1: - return True - - # Check if we reached a plot or an internal edge - if grid[x, y] != target and grid[x, y] > 0: - return False - if grid[x, y] == 0: - return True - dx, dy = direction - new_pos = (x + dx, y + dy) - return is_border(new_pos, grid, target, direction) - - from itertools import product - + grid, grid_axis_type, seen_axis_type = _get_subplot_layout( + gs, + all_axes, + same_type=same_type, + ) + # We check for all axes is they are a border or not + # Note we could also write the crawler in a way where + # it find the borders by moving around in the grid, without spawning on each axis point. We may change + # this in the future for axi in all_axes: - spec = axi.get_subplotspec() - spans = spec._get_rows_columns() - rowspan = spans[:2] - colspan = spans[-2:] - # Check all cardinal directions. When we find a - # border for any starting conditions we break and - # consider it a border. This could mean that for some - # partial overlaps we consider borders that should - # not be borders -- we are conservative in this - # regard - for direction, d in directions.items(): - xs = range(rowspan[0], rowspan[1] + 1) - ys = range(colspan[0], colspan[1] + 1) - for x, y in product(xs, ys): - pos = (x, y) - if is_border(pos=pos, grid=grid, target=axi.number, direction=d): - border_axes[direction].append(axi) - break + axis_type = seen_axis_type.get(type(axi), 1) + crawler = _Crawler( + ax=axi, + grid=grid, + target=axi.number, + axis_type=axis_type, + grid_axis_type=grid_axis_type, + ) + for direction, is_border in crawler.find_edges(): + if is_border: + border_axes[direction].append(axi) + self._cached_border_axes = border_axes return border_axes def _get_align_coord(self, side, axs, includepanels=False): @@ -1257,6 +1227,9 @@ def _add_subplot(self, *args, **kwargs): if ax.number: self._subplot_dict[ax.number] = ax + # Invalidate border axes cache + if hasattr(self, "_cached_border_axes"): + delattr(self, "_cached_border_axes") return ax def _unshare_axes(self): @@ -1284,6 +1257,10 @@ def _share_labels_with_others(self, *, which="both"): # Note: this action performs it for all the axes in # the figure. We use the stale here to only perform # it once as it is an expensive action. + # The axis will be a border if it is either + # (a) on the edge + # (b) not next to a subplot + # (c) not next to a subplot of the same kind border_axes = self._get_border_axes() # Recode: recoded = {} @@ -1302,10 +1279,7 @@ def _share_labels_with_others(self, *, which="both"): # Turn the ticks on or off depending on the position sides = recoded.get(axi, []) turn_on_or_off = default.copy() - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind + for side in sides: sidelabel = f"label{side}" is_label_on = axi._is_ticklabel_on(sidelabel) @@ -1318,18 +1292,7 @@ def _share_labels_with_others(self, *, which="both"): if isinstance(axi, paxes.GeoAxes): axi._toggle_gridliner_labels(**turn_on_or_off) else: - # TODO: we need to replace the - # _apply_axis_sharing with something that is - # more profound. Currently, it removes the - # ticklabels in all directions independent - # of the position of the subplot. This means - # that for top right subplots, the labels - # will always be off. Furthermore, - # this is handled in the draw sequence - # which is not necessary, and we should - # add it to _add_subplot of the figure class - continue - # axi.tick_params(which=which, **turn_on_or_off) + axi._apply_axis_sharing() def _toggle_axis_sharing( self, @@ -1971,7 +1934,8 @@ def format( ax.number = store_old_number # When we apply formatting to all axes, we need # to potentially adjust the labels. - if len(axs) == len(self.axes): + + if len(axs) == len(self.axes) and self._get_sharing_level() > 0: self._share_labels_with_others() # Warn unused keyword argument(s) @@ -1985,6 +1949,53 @@ def format( f"Ignoring unused projection-specific format() keyword argument(s): {kw}" # noqa: E501 ) + def _share_labels_with_others(self, *, which="both"): + """ + Helpers function to ensure the labels + are shared for rectilinear GeoAxes. + """ + # Turn all labels off + # Note: this action performs it for all the axes in + # the figure. We use the stale here to only perform + # it once as it is an expensive action. + border_axes = self._get_border_axes(same_type=False) + # Recode: + recoded = {} + for direction, axes in border_axes.items(): + for axi in axes: + recoded[axi] = recoded.get(axi, []) + [direction] + + # We turn off the tick labels when the scale and + # ticks are shared (level > 0) + are_ticks_on = False + default = dict( + labelleft=are_ticks_on, + labelright=are_ticks_on, + labeltop=are_ticks_on, + labelbottom=are_ticks_on, + ) + for axi in self._iter_axes(hidden=False, panels=False, children=False): + # Turn the ticks on or off depending on the position + sides = recoded.get(axi, []) + turn_on_or_off = default.copy() + # The axis will be a border if it is either + # (a) on the edge + # (b) not next to a subplot + # (c) not next to a subplot of the same kind + for side in sides: + sidelabel = f"label{side}" + is_label_on = axi._is_ticklabel_on(sidelabel) + if is_label_on: + # When we are a border an the labels are on + # we keep them on + assert sidelabel in turn_on_or_off + turn_on_or_off[sidelabel] = True + + if isinstance(axi, paxes.GeoAxes): + axi._toggle_gridliner_labels(**turn_on_or_off) + else: + axi.tick_params(which=which, **turn_on_or_off) + @docstring._concatenate_inherited @docstring._snippet_manager def colorbar( diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 8a191be5..0d642505 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -195,6 +195,25 @@ def _get_rows_columns(self, ncols=None): row2, col2 = divmod(self.num2, ncols) return row1, row2, col1, col2 + def _get_grid_span(self, hidden=False) -> (int, int, int, int): + """ + Retrieve the location of the subplot within the + gridspec. When hidden is False we only consider + the main plots, not the panels or colorbars. + """ + gs = self.get_gridspec() + nrows, ncols = gs.nrows_total, gs.ncols_total + if not hidden: + nrows, ncols = gs.nrows, gs.ncols + # Use num1 or num2 + decoded = gs._decode_indices(self.num1) + x, y = np.unravel_index(decoded, (nrows, ncols)) + span = self._get_rows_columns() + + xspan = span[1] - span[0] + 1 # inclusive + yspan = span[3] - span[2] + 1 # inclusive + return (x, x + xspan, y, y + yspan) + def get_position(self, figure, return_all=False): # Silent override. Older matplotlib versions can create subplots # with negative heights and widths that crash on instantiation. diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index d8521e61..fdd02157 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -259,51 +259,108 @@ def test_sharing_labels_top_right(): assert i == j -@pytest.mark.skip("Need to fix sharing labels for odd layouts") -def test_sharing_labels_top_right_odd_layout(): +@pytest.mark.parametrize( + "layout, share, tick_loc, y_visible_indices, x_visible_indices", + [ + # Test case 1: Irregular layout with share=3 (default) + ( + [ + [1, 2, 0], + [1, 2, 5], + [3, 4, 5], + [3, 4, 0], + ], + 3, # default sharing level + {"xticklabelloc": "t", "yticklabelloc": "r"}, + [1, 3, 4], # y-axis labels visible indices + [0, 1, 4], # x-axis labels visible indices + ), + # Test case 2: Irregular layout with share=1 + ( + [ + [1, 0, 2], + [0, 3, 0], + [4, 0, 5], + ], + 1, # share only labels, not tick labels + {"xticklabelloc": "t", "yticklabelloc": "r"}, + [0, 1, 2, 3, 4], # all y-axis labels visible + [0, 1, 2, 3, 4], # all x-axis labels visible + ), + ], +) +def test_sharing_labels_with_layout( + layout, share, tick_loc, y_visible_indices, x_visible_indices +): + """ + Test if tick labels are correctly visible or hidden based on layout and sharing. + + Parameters + ---------- + layout : list of list of int + The layout configuration for the subplots + share : int + The sharing level (0-4) + tick_loc : dict + Tick label location settings + y_visible_indices : list + Indices in the axes array where y-tick labels should be visible + x_visible_indices : list + Indices in the axes array where x-tick labels should be visible + """ - # Helper function to check if the labels - # on an axis direction is visible - def check_state(numbers: list, state: bool, which: str): + # Helper function to check if the labels on an axis direction are visible + def check_state(ax, numbers, state, which): for number in numbers: for label in getattr(ax[number], f"get_{which}ticklabels")(): - assert label.get_visible() == state - - layout = [ - [1, 2, 0], - [1, 2, 5], - [3, 4, 5], - [3, 4, 0], - ] - fig, ax = uplt.subplots(layout) - ax.format( - xticklabelloc="t", - yticklabelloc="r", - ) + assert label.get_visible() == state, ( + f"Expected {which}-tick label visibility to be {state} " + f"for axis {number}, but got {not state}" + ) - # these correspond to the indices of the axis - # in the axes array (so the grid number minus 1) - check_state([0, 2], False, which="y") - check_state([1, 3, 4], True, which="y") - check_state([2, 3], False, which="x") - check_state([0, 1, 4], True, which="x") - uplt.close(fig) + # Create figure with the specified layout and sharing level + fig, ax = uplt.subplots(layout, share=share) - layout = [ - [1, 0, 2], - [0, 3, 0], - [4, 0, 5], - ] + # Format axes with the specified tick label locations + ax.format(**tick_loc) + + # Calculate the indices where labels should be hidden + all_indices = list(range(len(ax))) + y_hidden_indices = [i for i in all_indices if i not in y_visible_indices] + x_hidden_indices = [i for i in all_indices if i not in x_visible_indices] + + # Check that labels are visible or hidden as expected + check_state(ax, y_visible_indices, True, which="y") + check_state(ax, y_hidden_indices, False, which="y") + check_state(ax, x_visible_indices, True, which="x") + check_state(ax, x_hidden_indices, False, which="x") - fig, ax = uplt.subplots(layout, hspace=0.2, wspace=0.2, share=1) - ax.format( - xticklabelloc="t", - yticklabelloc="r", - ) - # these correspond to the indices of the axis - # in the axes array (so the grid number minus 1) - check_state([0, 3], True, which="y") - check_state([1, 2, 4], True, which="y") - check_state([0, 1, 2], True, which="x") - check_state([3, 4], True, which="x") uplt.close(fig) + + +@pytest.mark.mpl_image_compare +def test_alt_axes_y_shared(): + layout = [[1, 2], [3, 4]] + fig, ax = uplt.subplots(ncols=2, nrows=2) + + for axi in ax: + alt = axi.alty() + alt.set_ylabel("Alt Y") + assert alt.get_ylabel() == "Alt Y" + assert alt.get_xlabel() == "" + axi.set_ylabel("Y") + return fig + + +@pytest.mark.mpl_image_compare +def test_alt_axes_x_shared(): + layout = [[1, 2], [3, 4]] + fig, ax = uplt.subplots(ncols=2, nrows=2) + + for axi in ax: + alt = axi.altx() + alt.set_xlabel("Alt X") + assert alt.get_xlabel() == "Alt X" + assert alt.get_ylabel() == "" + axi.set_xlabel("X") + return fig diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 66d0fbad..de62695d 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -241,8 +241,49 @@ def test_lon0_shifts(): uplt.close(fig) -def test_sharing_cartopy(): - +@pytest.mark.parametrize( + "layout, expectations", + [ + ( + # layout 1: 3x3 grid with unique IDs + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ], + # expectations: per element ID (1-9), four booleans: [top, right, bottom, left] + [ + [True, False, False, True], # 1 + [True, False, False, False], # 2 + [True, False, True, False], # 3 + [False, False, False, True], # 4 + [False, False, False, False], # 5 + [False, False, True, False], # 6 + [False, True, False, True], # 7 + [False, True, False, False], # 8 + [False, True, True, False], # 9 + ], + ), + ( + # layout 2: shared IDs (merged subplots?) + [ + [1, 2, 0], + [1, 2, 5], + [3, 4, 5], + [3, 4, 0], + ], + # expectations for IDs 1–5: [top, right, bottom, left] + [ + [True, False, False, True], # 1 + [True, False, True, False], # 2 + [False, True, False, True], # 3 + [False, True, True, False], # 4 + [True, True, True, True], # 5 + ], + ), + ], +) +def test_sharing_cartopy(layout, expectations): def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]: gl = ax.gridlines_major @@ -252,50 +293,14 @@ def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]: on[idx] = True return on - n = 3 settings = dict(land=True, ocean=True, labels="both") - fig, ax = uplt.subplots(ncols=n, nrows=n, share="all", proj="cyl") + fig, ax = uplt.subplots(layout, share="all", proj="cyl") ax.format(**settings) - - expectations = ( - [True, False, False, True], - [True, False, False, False], - [True, False, True, False], - [False, False, False, True], - [False, False, False, False], - [False, False, True, False], - [False, True, False, True], - [False, True, False, False], - [False, True, True, False], - ) for axi in ax: state = are_labels_on(axi) expectation = expectations[axi.number - 1] for i, j in zip(state, expectation): assert i == j - - layout = [ - [1, 2, 0], - [1, 2, 5], - [3, 4, 5], - [3, 4, 0], - ] - - fig, ax = uplt.subplots(layout, share="all", proj="cyl") - ax.format(**settings) - fig.canvas.draw() # need a draw to trigger ax.draw for sharing - - expectations = ( - [True, False, False, True], # top left - [True, False, True, False], # top right - [False, True, False, True], # bottom left - [False, True, True, False], # bottom right - [True, True, True, False], # right plot (5) - ) - for axi in ax: - state = are_labels_on(axi) - expectation = expectations[axi.number - 1] - assert all([i == j for i, j in zip(state, expectation)]) uplt.close(fig) diff --git a/ultraplot/tests/test_imshow.py b/ultraplot/tests/test_imshow.py index 08c9ade2..b7f77619 100644 --- a/ultraplot/tests/test_imshow.py +++ b/ultraplot/tests/test_imshow.py @@ -81,7 +81,6 @@ def test_inbounds_data(rng): ylabel="ylabel", suptitle="Default vmin/vmax restricted to in-bounds data", ) - fig.show() return fig diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 05e609f4..207ca0d6 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -141,34 +141,60 @@ def test_aligned_outer_guides(): return fig +@pytest.mark.parametrize( + "test_case,refwidth,kwargs,setup_func,ref", + [ + ( + "simple", + 1.5, + {"ncols": 2}, + None, + None, + ), + ( + "funky_layout", + 1.5, + {"array": [[1, 1, 2, 2], [0, 3, 3, 0]]}, + lambda fig, axs: ( + axs[1].panel_axes("left"), + axs.format(xlocator=0.2, ylocator=0.2), + ), + 3, + ), + ( + "with_panels", + 2.0, + {"array": [[1, 1, 2], [3, 4, 5], [3, 4, 6]], "hratios": (2, 1, 1)}, + lambda fig, axs: ( + axs[2].panel_axes("right", width=0.5), + axs[0].panel_axes("bottom", width=0.5), + axs[3].panel_axes("left", width=0.5), + ), + None, + ), + ], +) @pytest.mark.mpl_image_compare -def test_reference_aspect(): +def test_reference_aspect(test_case, refwidth, kwargs, setup_func, ref): """ Rigorous test of reference aspect ratio accuracy. """ - # A simple test - refwidth = 1.5 - fig, axs = uplt.subplots(ncols=2, refwidth=refwidth) - fig.auto_layout() - assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) + # Add ref and refwidth to kwargs + subplot_kwargs = kwargs.copy() + subplot_kwargs["refwidth"] = refwidth + if ref is not None: + subplot_kwargs["ref"] = ref - # A test with funky layout - refwidth = 1.5 - fig, axs = uplt.subplots([[1, 1, 2, 2], [0, 3, 3, 0]], ref=3, refwidth=refwidth) - axs[1].panel_axes("left") - axs.format(xlocator=0.2, ylocator=0.2) - fig.auto_layout() - assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) + # Create subplots + fig, axs = uplt.subplots(**subplot_kwargs) - # A test with panels - refwidth = 2.0 - fig, axs = uplt.subplots( - [[1, 1, 2], [3, 4, 5], [3, 4, 6]], hratios=(2, 1, 1), refwidth=refwidth - ) - axs[2].panel_axes("right", width=0.5) - axs[0].panel_axes("bottom", width=0.5) - axs[3].panel_axes("left", width=0.5) + # Run setup function if provided + if setup_func is not None: + setup_func(fig, axs) + + # Apply auto layout fig.auto_layout() + # Assert reference width accuracy assert np.isclose(refwidth, axs[fig._refnum - 1]._get_size_inches()[0]) return fig @@ -215,3 +241,69 @@ def test_axis_sharing(share): assert ax[2].get_ylabel() == "D" return fig + + +@pytest.mark.parametrize( + "layout", + [ + [[1, 2], [3, 4]], # simple 2x2 + [[1, 0, 2], [0, 3, 0], [4, 0, 5]], # complex 3x3 with independent plots + [[0, 0, 1, 1, 0, 0], [0, 2, 2, 3, 3, 0]], # 1 spanning 2 different plot + ], +) +@pytest.mark.mpl_image_compare +def test_label_sharing_top_right(layout): + fig, ax = uplt.subplots(layout) + ax.format( + xticklabelloc="t", + yticklabelloc="r", + xlabel="xlabel", + ylabel="ylabel", + title="Test Title", + ) + fig.canvas.draw() # force redraw tick labels + for axi in ax: + assert axi._is_ticklabel_on("labelleft") == False + assert axi._is_ticklabel_on("labelbottom") == False + + for side, axs in fig._get_border_axes().items(): + for axi in axs: + if side == "top": + assert axi._is_ticklabel_on("labeltop") == True + if side == "right": + assert axi._is_ticklabel_on("labelright") == True + + return fig + + +@pytest.mark.parametrize("layout", [[[1, 2], [3, 4]]]) +@pytest.mark.mpl_image_compare +def test_panel_sharing_top_right(layout): + fig, ax = uplt.subplots(layout) + for dir in "left right top bottom".split(): + pax = ax[0].panel(dir) + fig.canvas.draw() # force redraw tick labels + for dir, paxs in ax[0]._panel_dict.items(): + # Since we are sharing some of the ticks + # should be hidden depending on where the panel is + # in the grid + for pax in paxs: + match dir: + case "left": + assert pax._is_ticklabel_on("labelleft") + assert pax._is_ticklabel_on("labelbottom") + case "top": + assert pax._is_ticklabel_on("labeltop") == False + assert pax._is_ticklabel_on("labelbottom") == False + assert pax._is_ticklabel_on("labelleft") + case "right": + print(pax._is_ticklabel_on("labelright")) + assert pax._is_ticklabel_on("labelright") == False + assert pax._is_ticklabel_on("labelbottom") + case "bottom": + assert pax._is_ticklabel_on("labelleft") + assert pax._is_ticklabel_on("labelbottom") == False + + # The sharing axis is not showing any ticks + assert ax[0]._is_ticklabel_on(dir) == False + return fig diff --git a/ultraplot/utils.py b/ultraplot/utils.py index 5e3a7b0f..1b1b97a9 100644 --- a/ultraplot/utils.py +++ b/ultraplot/utils.py @@ -7,9 +7,12 @@ import functools import re from numbers import Integral, Real +from dataclasses import dataclass +from typing import Generator import matplotlib.colors as mcolors import matplotlib.font_manager as mfonts +from matplotlib.gridspec import GridSpec import numpy as np from matplotlib import rcParams as rc_matplotlib @@ -904,6 +907,192 @@ def units( return result[0] if singleton else result +def _get_subplot_layout( + gs: "GridSpec", + all_axes: list["paxes.Axes"], + same_type=True, +) -> tuple[np.ndarray[int, int], np.ndarray[int, int], dict[type, int]]: + """ + Helper function to determine the grid layout of axes in a + GridSpec. It returns a grid of axis numbers and a grid of + axis types. This function is used internally to determine + the layout of axes in a GridSpec. + """ + grid = np.zeros((gs.nrows, gs.ncols)) + grid_axis_type = np.zeros((gs.nrows, gs.ncols)) + # Collect grouper based on kinds of axes. This + # would allow us to share labels across types + seen_axis_types = {type(axi) for axi in all_axes} + seen_axis_types = {type: idx for idx, type in enumerate(seen_axis_types)} + + for axi in all_axes: + # Infer coordinate from grdispec + spec = axi.get_subplotspec() + spans = spec._get_grid_span() + rowspan = spans[:2] + colspan = spans[-2:] + + x, y, xspan, yspan = spans + grid[ + slice(*rowspan), + slice(*colspan), + ] = axi.number + + # Allow grouping of mixed types + axis_type = 1 + if not same_type: + axis_type = seen_axis_types.get(type(axi), 1) + + grid_axis_type[ + slice(*rowspan), + slice(*colspan), + ] = axis_type + return grid, grid_axis_type, seen_axis_types + + +@dataclass +class _Crawler: + """ + A crawler is used to find edges of axes in a grid layout. + This is useful for determining whether to turn shared labels + on or depending on the position of an axis in the gridspec. + It crawls over the grid in all four cardinal directions and + checks whether it reaches a border of the grid or an axis of + a different type. It was created as adding colorbars will + change the underlying gridspec and therefore we cannot rely + on the original gridspec to determine whether an axis is a + border or not. + """ + + ax: object + grid: np.ndarray[int, int] + grid_axis_type: np.ndarray[int, int] + # The axis number + target: int + # The kind of axis, e.g. 1 for CartesianAxes, 2 for + # PolarAxes, etc. + axis_type: int + directions = { + "left": (0, -1), + "right": (0, 1), + "top": (-1, 0), + "bottom": (1, 0), + } + + def find_edges(self) -> Generator[tuple[str, bool], None, None]: + """ + Check all cardinal directions. When we find a + border for any starting conditions we break and + consider it a border. This could mean that for some + partial overlaps we consider borders that should + not be borders -- we are conservative in this + regard. + """ + for direction, d in self.directions.items(): + yield self.find_edge_for(direction, d) + + def find_edge_for( + self, + direction: str, + d: tuple[int, int], + ) -> tuple[str, bool]: + from itertools import product + + """ + Setup search for a specific direction. + """ + + # Retrieve where the axis is in the grid + spec = self.ax.get_subplotspec() + spans = spec._get_grid_span() + rowspan = spans[:2] + colspan = spans[-2:] + xs = range(*rowspan) + ys = range(*colspan) + is_border = False + for x, y in product(xs, ys): + pos = (x, y) + if self.is_border(pos, d): + is_border = True + break + return direction, is_border + + def is_border( + self, + pos: tuple[int, int], + direction: tuple[int, int], + ) -> bool: + """ + Recursively move over the grid by following the direction. + """ + x, y = pos + # Check if we are at an edge of the grid (out-of-bounds). + if x < 0: + return True + elif x > self.grid.shape[0] - 1: + return True + + if y < 0: + return True + elif y > self.grid.shape[1] - 1: + return True + + if self.grid[x, y] == 0 or self.grid_axis_type[x, y] != self.axis_type: + return True + + # Check if we reached a plot or an internal edge + if self.grid[x, y] != self.target and self.grid[x, y] > 0: + return self._check_ranges(direction, other=self.grid[x, y]) + + dx, dy = direction + pos = (x + dx, y + dy) + return self.is_border(pos, direction) + + def _check_ranges( + self, + direction: tuple[int, int], + other: int, + ) -> bool: + """ + Helper function to determined whether a subplot + is enclosed or enclosed another subplot. This is + key to know where a border is, e.g. + + 1 2 + 1 3 + + Implies that 1 cannot share y with 2 and 3, but 2, and 3 + can share x. + """ + this_spec = self.ax.get_subplotspec() + other_spec = self.ax.figure._subplot_dict[other].get_subplotspec() + + # Get the row and column spans of both axes + this_span = this_spec._get_grid_span() + this_rowspan = this_span[:2] + this_colspan = this_span[-2:] + + other_span = other_spec._get_grid_span() + other_rowspan = other_span[:2] + other_colspan = other_span[-2:] + + # We can share labels if the ranges are the same + # in the direction we are moving + dy, dx = direction # note columns are x and rows are y + if dx == 0: + # Check the y range + this_start, this_stop = this_colspan + other_start, other_stop = other_colspan + if dy == 0: + # Check the x range + this_start, this_stop = this_rowspan + other_start, other_stop = other_rowspan + + if this_start == other_start and this_stop == other_stop: + return False # not a border + return True + + # Deprecations shade, saturate = warnings._rename_objs( "0.6.0",