Skip to content

Commit 84a57e6

Browse files
cvanelterenCopilot
andauthored
Hotfix get_border_axes (#236)
* simplified logic and basing reference on original grid * rm print statements * spelling * Update ultraplot/figure.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * do more comprehensive checking * rm debug * add some reasoning * fixed * add a stale check * compound tests and check for the failure * update calls * update tests * formatting restored to defaults * don't adjust labels when not sharing * tests pass -- some expected failures * rm debug statements * fix test * Update ultraplot/tests/test_geographic.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/figure.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/geo.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Revert changes * rm duplicate * fixes --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b853a3f commit 84a57e6

File tree

5 files changed

+118
-73
lines changed

5 files changed

+118
-73
lines changed

ultraplot/axes/geo.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,9 @@ def __share_axis_setup(
585585
if level > 1 and limits:
586586
self._share_limits_with(other, which=which)
587587

588+
if level >= 1 and labels:
589+
self._share_labels_with_others()
590+
588591
@override
589592
def _sharey_setup(self, sharey, *, labels=True, limits=True):
590593
"""
@@ -666,6 +669,17 @@ def _apply_axis_sharing(self):
666669
target_axis=self._lataxis,
667670
)
668671

672+
# This block is apart of the draw sequence as the
673+
# gridliner object is created late in the
674+
# build chain.
675+
if not self.stale:
676+
return
677+
if self.figure._get_sharing_level() == 0:
678+
return
679+
# Share labels with all levels higher or equal
680+
# to 1.
681+
self._share_labels_with_others()
682+
669683
def _get_gridliner_labels(
670684
self,
671685
bottom=None,
@@ -719,9 +733,11 @@ def _handle_axis_sharing(
719733
target_axis.set_view_interval(*source_axis.get_view_interval())
720734
target_axis.set_minor_locator(source_axis.get_minor_locator())
721735

722-
if not self.stale:
723-
return
724-
736+
def _share_labels_with_others(self):
737+
"""
738+
Helpers function to ensure the labels
739+
are shared for rectilinear GeoAxes.
740+
"""
725741
# Turn all labels off
726742
# Note: this action performs it for all the axes in
727743
# the figure. We use the stale here to only perform
@@ -735,10 +751,7 @@ def _handle_axis_sharing(
735751

736752
# We turn off the tick labels when the scale and
737753
# ticks are shared (level >= 3)
738-
are_ticks_on = True
739-
if self.figure._get_sharing_level() >= 3:
740-
are_ticks_on = False
741-
754+
are_ticks_on = False
742755
default = dict(
743756
left=are_ticks_on,
744757
right=are_ticks_on,
@@ -751,11 +764,14 @@ def _handle_axis_sharing(
751764
# sharing that is specific for the GeoAxes.
752765
if not isinstance(axi, GeoAxes):
753766
continue
754-
767+
gridlabels = self._get_gridliner_labels(
768+
bottom=True, top=True, left=True, right=True
769+
)
755770
sides = recoded.get(axi, [])
756771
tmp = default.copy()
757772
for side in sides:
758-
tmp[side] = True
773+
if side in gridlabels and gridlabels[side]:
774+
tmp[side] = True
759775
axi._toggle_gridliner_labels(**tmp)
760776
self.stale = False
761777

@@ -1430,6 +1446,8 @@ def _get_gridliner_labels(
14301446
):
14311447
if side != True:
14321448
continue
1449+
if self.gridlines_major is None:
1450+
continue
14331451
sides[dir] = getattr(self.gridlines_major, f"{dir}_label_artists")
14341452
return sides
14351453

ultraplot/figure.py

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -914,66 +914,84 @@ def _get_border_axes(self) -> dict[str, list[paxes.Axes]]:
914914
"""
915915

916916
gs = self.gridspec
917-
all_axes = self.axes
917+
918+
# Skip colorbars or panels etc
919+
all_axes = [axi for axi in self.axes if axi.number is not None]
918920

919921
# Handle empty cases
920922
nrows, ncols = gs.nrows, gs.ncols
923+
border_axes = dict(top=[], bottom=[], left=[], right=[])
921924
if nrows == 0 or ncols == 0 or not all_axes:
922-
return dict(top=[], bottom=[], left=[], right=[])
925+
return border_axes
926+
# We cannot use the gridspec on the axes as it
927+
# is modified when a colorbar is added. Use self.gridspec
928+
# as a reference.
929+
# Reconstruct the grid based on axis locations. Note that
930+
# spanning axes will fit into one of the boxes. Check
931+
# this with unittest to see how empty axes are handles
932+
grid = np.zeros((gs.nrows, gs.ncols))
933+
for axi in all_axes:
934+
# Infer coordinate from grdispec
935+
spec = axi.get_subplotspec()
936+
spans = spec._get_rows_columns()
937+
rowspans = spans[:2]
938+
colspans = spans[-2:]
939+
940+
grid[
941+
rowspans[0] : rowspans[1] + 1,
942+
colspans[0] : colspans[1] + 1,
943+
] = axi.number
944+
directions = {
945+
"left": (0, -1),
946+
"right": (0, 1),
947+
"top": (-1, 0),
948+
"bottom": (1, 0),
949+
}
923950

924-
# Find occupied grid cells and valid axes
925-
occupied_cells = set()
926-
axes_with_spec = []
951+
def is_border(pos, grid, target, direction):
952+
x, y = pos
953+
# Check if we are at an edge of the grid (out-of-bounds).
954+
if x < 0:
955+
return True
956+
elif x > grid.shape[0] - 1:
957+
return True
958+
959+
if y < 0:
960+
return True
961+
elif y > grid.shape[1] - 1:
962+
return True
963+
964+
# Check if we reached a plot or an internal edge
965+
if grid[x, y] != target and grid[x, y] > 0:
966+
return False
967+
if grid[x, y] == 0:
968+
return True
969+
dx, dy = direction
970+
new_pos = (x + dx, y + dy)
971+
return is_border(new_pos, grid, target, direction)
972+
973+
from itertools import product
927974

928975
for axi in all_axes:
929976
spec = axi.get_subplotspec()
930-
if spec is not None:
931-
axes_with_spec.append((axi, spec))
932-
r0, r1 = spec.rowspan.start, spec.rowspan.stop
933-
c0, c1 = spec.colspan.start, spec.colspan.stop
934-
for r in range(r0, r1):
935-
for c in range(c0, c1):
936-
occupied_cells.add((r, c))
937-
938-
if not axes_with_spec:
939-
return dict(top=[], bottom=[], left=[], right=[])
940-
941-
# Initialize border axes sets
942-
border_axes_sets = dict(top=set(), bottom=set(), left=set(), right=set())
943-
944-
# Check each axis against border criteria
945-
for axi, spec in axes_with_spec:
946-
r0, r1 = spec.rowspan.start, spec.rowspan.stop
947-
c0, c1 = spec.colspan.start, spec.colspan.stop
948-
949-
# Check top border
950-
if r0 == 0 or (
951-
r0 == 1 and any((0, c) not in occupied_cells for c in range(c0, c1))
952-
):
953-
border_axes_sets["top"].add(axi)
954-
955-
# Check bottom border
956-
if r1 == nrows or (
957-
r1 == nrows - 1
958-
and any((nrows - 1, c) not in occupied_cells for c in range(c0, c1))
959-
):
960-
border_axes_sets["bottom"].add(axi)
961-
962-
# Check left border
963-
if c0 == 0 or (
964-
c0 == 1 and any((r, 0) not in occupied_cells for r in range(r0, r1))
965-
):
966-
border_axes_sets["left"].add(axi)
967-
968-
# Check right border
969-
if c1 == ncols or (
970-
c1 == ncols - 1
971-
and any((r, ncols - 1) not in occupied_cells for r in range(r0, r1))
972-
):
973-
border_axes_sets["right"].add(axi)
974-
975-
# Convert sets to lists
976-
return {key: list(val) for key, val in border_axes_sets.items()}
977+
spans = spec._get_rows_columns()
978+
rowspan = spans[:2]
979+
colspan = spans[-2:]
980+
# Check all cardinal directions. When we find a
981+
# border for any starting conditions we break and
982+
# consider it a border. This could mean that for some
983+
# partial overlaps we consider borders that should
984+
# not be borders -- we are conservative in this
985+
# regard
986+
for direction, d in directions.items():
987+
xs = range(rowspan[0], rowspan[1] + 1)
988+
ys = range(colspan[0], colspan[1] + 1)
989+
for x, y in product(xs, ys):
990+
pos = (x, y)
991+
if is_border(pos=pos, grid=grid, target=axi.number, direction=d):
992+
border_axes[direction].append(axi)
993+
break
994+
return border_axes
977995

978996
def _get_align_coord(self, side, axs, includepanels=False):
979997
"""

ultraplot/tests/test_axes.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,14 @@ def test_sharing_labels_top_right():
260260

261261

262262
def test_sharing_labels_top_right_odd_layout():
263+
264+
# Helper function to check if the labels
265+
# on an axis direction is visible
266+
def check_state(numbers: list, state: bool, which: str):
267+
for number in numbers:
268+
for label in getattr(ax[number], f"get_{which}ticklabels")():
269+
assert label.get_visible() == state
270+
263271
layout = [
264272
[1, 2, 0],
265273
[1, 2, 5],
@@ -272,11 +280,6 @@ def test_sharing_labels_top_right_odd_layout():
272280
yticklabelloc="r",
273281
)
274282

275-
def check_state(numbers: list, state: bool, which: str):
276-
for number in numbers:
277-
for label in getattr(ax[number], f"get_{which}ticklabels")():
278-
assert label.get_visible() == state
279-
280283
# these correspond to the indices of the axis
281284
# in the axes array (so the grid number minus 1)
282285
check_state([0, 2], False, which="y")
@@ -291,15 +294,15 @@ def check_state(numbers: list, state: bool, which: str):
291294
[4, 0, 5],
292295
]
293296

294-
fig, ax = uplt.subplots(layout)
297+
fig, ax = uplt.subplots(layout, hspace=0.2, wspace=0.2, share=1)
295298
ax.format(
296299
xticklabelloc="t",
297300
yticklabelloc="r",
298301
)
299302
# these correspond to the indices of the axis
300303
# in the axes array (so the grid number minus 1)
301-
check_state([0, 3], False, which="y")
304+
check_state([0, 3], True, which="y")
302305
check_state([1, 2, 4], True, which="y")
303306
check_state([0, 1, 2], True, which="x")
304-
check_state([3, 4], False, which="x")
307+
check_state([3, 4], True, which="x")
305308
uplt.close(fig)

ultraplot/tests/test_format.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_patch_format():
5858
"""
5959
Test application of patch args on initialization.
6060
"""
61-
fig = uplt.figure(suptitle="Super title")
61+
fig = uplt.figure(suptitle="Super title", share=0)
6262
fig.subplot(
6363
121, proj="cyl", labels=True, land=True, latlines=20, abcloc="l", abc="[A]"
6464
)
@@ -82,7 +82,7 @@ def test_multi_formatting():
8282
Support formatting in multiple projections.
8383
"""
8484
# Mix Cartesian with a projection
85-
fig, axs = uplt.subplots(ncols=2, proj=("cart", "cyl"))
85+
fig, axs = uplt.subplots(ncols=2, proj=("cart", "cyl"), share=0)
8686
axs[0].pcolormesh(np.random.rand(5, 5))
8787

8888
# Warning is raised based on projection. Cart does not have lonlim, latllim or labels

ultraplot/tests/test_geographic.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]:
253253
n = 3
254254
settings = dict(land=True, ocean=True, labels="both")
255255
fig, ax = uplt.subplots(ncols=n, nrows=n, share="all", proj="cyl")
256+
# Add data and ensure the tests still hold
257+
# Adding a colorbar will change the underlying gridspec, the
258+
# labels should still be correctly treated.
259+
data = np.random.rand(10, 10)
260+
h = ax.imshow(data)[0]
261+
fig.colorbar(h, loc="r")
256262
ax.format(**settings)
257263
fig.canvas.draw() # need a draw to trigger ax.draw for sharing
258264

@@ -574,7 +580,7 @@ def assert_views_are_sharing(ax):
574580
assert_views_are_sharing(axi)
575581
# When we share the labels but not the limits,
576582
# we expect all ticks to be on
577-
if level < 3:
583+
if level == 0:
578584
assert s == 4
579585
else:
580586
assert s == 2
@@ -605,7 +611,7 @@ def test_cartesian_and_geo():
605611
ax[0].pcolormesh(np.random.rand(10, 10))
606612
ax[1].scatter(*np.random.rand(2, 100))
607613
ax[0]._apply_axis_sharing()
608-
assert mocked.call_count == 1
614+
assert mocked.call_count == 2
609615
return fig
610616

611617

0 commit comments

Comments
 (0)