Skip to content

Commit c65345a

Browse files
authored
BUG: Fix bug with pickling MNEBadsList (#12063)
1 parent 9d20815 commit c65345a

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mne/_fiff/meas_info.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,13 @@ def __init__(self, *, bads, info):
963963
def extend(self, iterable):
964964
if not isinstance(iterable, list):
965965
iterable = list(iterable)
966-
_check_bads_info_compat(iterable, self._mne_info)
966+
# can happen during pickling
967+
try:
968+
info = self._mne_info
969+
except AttributeError:
970+
pass # can happen during pickling
971+
else:
972+
_check_bads_info_compat(iterable, info)
967973
return super().extend(iterable)
968974

969975
def append(self, x):
@@ -1551,6 +1557,7 @@ def __getstate__(self):
15511557
def __setstate__(self, state):
15521558
"""Set state (for pickling)."""
15531559
self._unlocked = state["_unlocked"]
1560+
self["bads"] = MNEBadsList(bads=self["bads"], info=self)
15541561

15551562
def __setitem__(self, key, val):
15561563
"""Attribute setter."""

mne/_fiff/tests/test_meas_info.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,21 +1070,27 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname):
10701070
apply_inverse(evoked, inv) # smoke test
10711071

10721072

1073+
@pytest.mark.parametrize("protocol", ("highest", "default"))
10731074
@pytest.mark.parametrize("fname_info", (raw_fname, "create_info"))
10741075
@pytest.mark.parametrize("unlocked", (True, False))
1075-
def test_pickle(fname_info, unlocked):
1076+
def test_pickle(fname_info, unlocked, protocol):
10761077
"""Test that Info can be (un)pickled."""
10771078
if fname_info == "create_info":
10781079
info = create_info(3, 1000.0, "eeg")
10791080
else:
10801081
info = read_info(fname_info)
1082+
protocol = getattr(pickle, f"{protocol.upper()}_PROTOCOL")
1083+
assert isinstance(info["bads"], MNEBadsList)
1084+
info["bads"] = info["ch_names"][:1]
10811085
assert not info._unlocked
10821086
info._unlocked = unlocked
1083-
data = pickle.dumps(info)
1087+
data = pickle.dumps(info, protocol=protocol)
10841088
info_un = pickle.loads(data) # nosec B301
10851089
assert isinstance(info_un, Info)
10861090
assert_object_equal(info, info_un)
10871091
assert info_un._unlocked == unlocked
1092+
assert isinstance(info_un["bads"], MNEBadsList)
1093+
assert info_un["bads"]._mne_info is info_un
10881094

10891095

10901096
def test_info_bad():

0 commit comments

Comments
 (0)