Skip to content

Commit a5eb548

Browse files
authored
BUG: Fix constrained layout in psd.plot (#12103)
1 parent c7c8a29 commit a5eb548

File tree

8 files changed

+29
-48
lines changed

8 files changed

+29
-48
lines changed

doc/changes/devel.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Enhancements
3838
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_)
3939
- Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_)
4040
- Add support for drawing MEG sensors in :ref:`mne coreg` (:gh:`12098` by `Eric Larson`_)
41-
- By default MNE-Python creates matplotlib figures with ``layout='constrained'`` rather than the default ``layout='tight'`` (:gh:`12050` by `Mathieu Scheltienne`_ and `Eric Larson`_)
41+
- By default MNE-Python creates matplotlib figures with ``layout='constrained'`` rather than the default ``layout='tight'`` (:gh:`12050`, :gh:`12103` by `Mathieu Scheltienne`_ and `Eric Larson`_)
4242
- Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_)
4343
- Add :class:`mne.viz.ui_events.UIEvent` linking for interactive colorbars, allowing users to link figures and change the colormap and limits interactively. This supports :func:`~mne.viz.plot_evoked_topomap`, :func:`~mne.viz.plot_ica_components`, :func:`~mne.viz.plot_tfr_topomap`, :func:`~mne.viz.plot_projs_topomap`, :meth:`~mne.Evoked.plot_image`, and :meth:`~mne.Epochs.plot_image` (:gh:`12057` by `Santeri Ruuskanen`_)
4444

mne/viz/_mpl_figure.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,38 +2331,16 @@ def _get_scale_bar_texts(self):
23312331
class MNELineFigure(MNEFigure):
23322332
"""Interactive figure for non-scrolling line plots."""
23332333

2334-
def __init__(self, inst, n_axes, figsize, *, layout=None, **kwargs):
2335-
super().__init__(figsize=figsize, inst=inst, layout=layout, **kwargs)
2336-
2337-
# AXES: default margins (inches)
2338-
l_margin = 0.8
2339-
r_margin = 0.2
2340-
b_margin = 0.65
2341-
t_margin = 0.35
2342-
# AXES: default margins (figure-relative coordinates)
2343-
left = self._inch_to_rel(l_margin)
2344-
right = 1 - self._inch_to_rel(r_margin)
2345-
bottom = self._inch_to_rel(b_margin, horiz=False)
2346-
top = 1 - self._inch_to_rel(t_margin, horiz=False)
2347-
# AXES: make subplots
2348-
axes = [self.add_subplot(n_axes, 1, 1)]
2349-
for ix in range(1, n_axes):
2350-
axes.append(self.add_subplot(n_axes, 1, ix + 1, sharex=axes[0]))
2351-
self.subplotpars.update(
2352-
left=left, bottom=bottom, top=top, right=right, hspace=0.4
2353-
)
2354-
# save useful things
2355-
self.mne.ax_list = axes
2356-
2357-
def _resize(self, event):
2358-
"""Handle resize event."""
2359-
old_width, old_height = self.mne.fig_size_px
2360-
new_width, new_height = self._get_size_px()
2361-
new_margins = _calc_new_margins(
2362-
self, old_width, old_height, new_width, new_height
2334+
def __init__(self, inst, n_axes, figsize, *, layout="constrained", **kwargs):
2335+
super().__init__(
2336+
figsize=figsize,
2337+
inst=inst,
2338+
layout=layout,
2339+
sharex=True,
2340+
**kwargs,
23632341
)
2364-
self.subplots_adjust(**new_margins)
2365-
self.mne.fig_size_px = (new_width, new_height)
2342+
for ix in range(n_axes):
2343+
self.add_subplot(n_axes, 1, ix + 1)
23662344

23672345

23682346
def _close_all():
@@ -2426,11 +2404,10 @@ def _line_figure(inst, axes=None, picks=None, **kwargs):
24262404
FigureClass=MNELineFigure,
24272405
figsize=figsize,
24282406
n_axes=n_axes,
2429-
layout=None,
24302407
**kwargs,
24312408
)
24322409
fig.mne.fig_size_px = fig._get_size_px() # can't do in __init__
2433-
axes = fig.mne.ax_list
2410+
axes = fig.axes
24342411
return fig, axes
24352412

24362413

mne/viz/epochs.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,11 @@ def plot_epochs_image(
302302

303303
# check for compatible `fig` / `axes`; instantiate figs if needed; add
304304
# fig(s) and axes into group_by
305+
needs_colorbar = colorbar and (axes is not None or fig is not None)
305306
group_by = _validate_fig_and_axes(
306-
fig, axes, group_by, evoked, colorbar, clear=clear
307+
fig, axes, group_by, evoked, colorbar=needs_colorbar, clear=clear
307308
)
309+
del fig, axes, needs_colorbar, clear
308310

309311
# prepare images in advance to get consistent vmin/vmax.
310312
# At the same time, create a subsetted epochs object for each group
@@ -649,20 +651,26 @@ def _plot_epochs_image(
649651
ax["evoked"].xaxis.set_major_locator(loc)
650652
ax["evoked"].yaxis.set_major_locator(AutoLocator())
651653

654+
fig = ax_im.get_figure()
655+
652656
# draw the colorbar
653657
if colorbar:
654658
from matplotlib.pyplot import colorbar as cbar
655659

656-
this_colorbar = cbar(im, cax=ax["colorbar"])
657-
this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12)
660+
if "colorbar" in ax: # axes supplied by user
661+
this_colorbar = cbar(im, cax=ax["colorbar"])
662+
this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12)
663+
else: # we created them
664+
this_colorbar = fig.colorbar(im, ax=ax_im)
665+
this_colorbar.ax.set_title(unit)
658666
if cmap[1]:
659667
ax_im.CB = DraggableColorbar(
660668
this_colorbar, im, kind="epochs_image", ch_type=unit
661669
)
662670

663671
# finish
664-
plt_show(show)
665-
return ax_im.get_figure()
672+
plt_show(show, fig=fig)
673+
return fig
666674

667675

668676
def plot_drop_log(

mne/viz/evoked.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,7 @@ def plot_evoked_joint(
20272027
if topomap_args.get("colorbar", True):
20282028
from matplotlib import ticker
20292029

2030-
cbar = fig.colorbar(map_ax[0].images[0], ax=map_ax, cax=cbar_ax)
2030+
cbar = fig.colorbar(map_ax[0].images[0], ax=map_ax, cax=cbar_ax, shrink=0.8)
20312031
cbar.ax.grid(False) # auto-removal deprecated as of 2021/10/05
20322032
if isinstance(contours, (list, np.ndarray)):
20332033
cbar.set_ticks(contours)

mne/viz/topomap.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,13 @@ def _add_colorbar(
304304
im,
305305
cmap,
306306
*,
307-
side="right",
308307
title=None,
309308
format=None,
310-
size="5%",
311309
kind=None,
312310
ch_type=None,
313311
):
314312
"""Add a colorbar to an axis."""
315-
cbar = ax.figure.colorbar(im, format=format)
313+
cbar = ax.figure.colorbar(im, format=format, shrink=0.6)
316314
if cmap is not None and cmap[1]:
317315
ax.CB = DraggableColorbar(cbar, im, kind, ch_type)
318316
cax = cbar.ax
@@ -1712,7 +1710,6 @@ def plot_ica_components(
17121710
im,
17131711
cmap,
17141712
title="AU",
1715-
side="right",
17161713
format=cbar_fmt,
17171714
kind="ica_comp_topomap",
17181715
ch_type=ch_type,
@@ -2564,7 +2561,7 @@ def _plot_topomap_multi_cbar(
25642561
)
25652562

25662563
if colorbar:
2567-
cbar, cax = _add_colorbar(ax, im, cmap, title=None, size="10%", format=cbar_fmt)
2564+
cbar, cax = _add_colorbar(ax, im, cmap, title=None, format=cbar_fmt)
25682565
cbar.set_ticks(_vlim)
25692566
if unit is not None:
25702567
cbar.ax.set_ylabel(unit, fontsize=8)
@@ -3744,7 +3741,7 @@ def plot_bridged_electrodes(
37443741
if title is not None:
37453742
im.axes.set_title(title)
37463743
if colorbar:
3747-
cax = fig.colorbar(im)
3744+
cax = fig.colorbar(im, shrink=0.6)
37483745
cax.set_label(r"Electrical Distance ($\mu$$V^2$)")
37493746
return fig
37503747

tutorials/epochs/20_visualize_epochs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
We'll start by importing the modules we need, loading the continuous (raw)
1313
sample data, and cropping it to save memory:
1414
"""
15-
1615
# %%
1716

1817
import mne

tutorials/evoked/20_visualize_evoked.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
1111
As usual we'll start by importing the modules we need:
1212
"""
13+
1314
# %%
1415

1516
import numpy as np

tutorials/raw/40_visualize_raw.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
:ref:`example data <sample-dataset>`, and cropping the `~mne.io.Raw`
1414
object to just 60 seconds before loading it into RAM to save memory:
1515
"""
16-
1716
# %%
1817

1918
import os

0 commit comments

Comments
 (0)