Skip to content

Commit 0e305ee

Browse files
The order of raw1.info["bads"] should not matter when concatenating with raw0.info["bads"] #11501 (#11502)
1 parent e6b0253 commit 0e305ee

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

doc/changes/latest.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ Bugs
5757
- Fix bug in :func:`mne.preprocessing.compute_maxwell_basis` where using ``int_order=0`` would raise an error (:gh:`11562` by `Eric Larson`_)
5858
- Fix :func:`mne.io.read_raw` for file names containing multiple dots (:gh:`11521` by `Clemens Brunner`_)
5959
- Fix bug in :func:`mne.export.export_raw` when exporting to EDF with a physical range set smaller than the data range (:gh:`11569` by `Mathieu Scheltienne`_)
60+
- Fix bug in :func:`mne.concatenate_raws` where two raws could not be merged if the order of the bad channel lists did not match (:gh:`11502` by `Moritz Gerster`_)
61+
6062
6163
API changes
6264
~~~~~~~~~~~

mne/io/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,22 +2503,27 @@ def _check_raw_compatibility(raw):
25032503
for ri in range(1, len(raw)):
25042504
if not isinstance(raw[ri], type(raw[0])):
25052505
raise ValueError(f'raw[{ri}] type must match')
2506-
for key in ('nchan', 'bads', 'sfreq'):
2506+
for key in ('nchan', 'sfreq'):
25072507
a, b = raw[ri].info[key], raw[0].info[key]
25082508
if a != b:
25092509
raise ValueError(
25102510
f'raw[{ri}].info[{key}] must match:\n'
25112511
f'{repr(a)} != {repr(b)}')
2512-
if not set(raw[ri].info['ch_names']) == set(raw[0].info['ch_names']):
2513-
raise ValueError('raw[%d][\'info\'][\'ch_names\'] must match' % ri)
2514-
if not all(raw[ri]._cals == raw[0]._cals):
2512+
for kind in ('bads', 'ch_names'):
2513+
set1 = set(raw[0].info[kind])
2514+
set2 = set(raw[ri].info[kind])
2515+
mismatch = set1.symmetric_difference(set2)
2516+
if mismatch:
2517+
raise ValueError(f'raw[{ri}][\'info\'][{kind}] do not match: '
2518+
f'{sorted(mismatch)}')
2519+
if any(raw[ri]._cals != raw[0]._cals):
25152520
raise ValueError('raw[%d]._cals must match' % ri)
25162521
if len(raw[0].info['projs']) != len(raw[ri].info['projs']):
25172522
raise ValueError('SSP projectors in raw files must be the same')
25182523
if not all(_proj_equal(p1, p2) for p1, p2 in
25192524
zip(raw[0].info['projs'], raw[ri].info['projs'])):
25202525
raise ValueError('SSP projectors in raw files must be the same')
2521-
if not all(r.orig_format == raw[0].orig_format for r in raw):
2526+
if any(r.orig_format != raw[0].orig_format for r in raw):
25222527
warn('raw files do not all have the same data format, could result in '
25232528
'precision mismatch. Setting raw.orig_format="unknown"')
25242529
raw[0].orig_format = 'unknown'

mne/io/fiff/tests/test_raw_fiff.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from mne.io.tests.test_raw import _test_concat, _test_raw_reader
2929
from mne import (concatenate_events, find_events, equalize_channels,
3030
compute_proj_raw, pick_types, pick_channels, create_info,
31-
pick_info)
31+
pick_info, make_fixed_length_epochs)
3232
from mne.utils import (requires_pandas, assert_object_equal, _dt_to_stamp,
3333
requires_mne, run_subprocess, _record_warnings,
3434
assert_and_remove_boundary_annot)
@@ -382,6 +382,54 @@ def test_concatenate_raws(on_mismatch):
382382
concatenate_raws(**kws)
383383

384384

385+
def _create_toy_data(n_channels=3, sfreq=250, seed=None):
386+
rng = np.random.default_rng(seed)
387+
data = rng.standard_normal(size=(n_channels, 50 * sfreq)) * 5e-6
388+
info = create_info(n_channels, sfreq, "eeg")
389+
return RawArray(data, info)
390+
391+
392+
def test_concatenate_raws_bads_order():
393+
"""Test concatenation of raw instances."""
394+
raw0 = _create_toy_data()
395+
raw1 = _create_toy_data()
396+
397+
# Test bad channel order
398+
raw0.info["bads"] = ["0", "1"]
399+
raw1.info["bads"] = ["1", "0"]
400+
401+
# raw0 is modified in-place and therefore copied
402+
raw_concat = concatenate_raws([raw0.copy(), raw1])
403+
404+
# Check data are equal
405+
data_concat = np.concatenate([raw0.get_data(), raw1.get_data()], 1)
406+
assert np.all(raw_concat.get_data() == data_concat)
407+
408+
# Check bad channels
409+
assert set(raw_concat.info["bads"]) == {"0", "1"}
410+
411+
# Bad channel mismatch raises
412+
raw2 = raw1.copy()
413+
raw2.info["bads"] = ["0", "2"]
414+
with pytest.raises(ValueError):
415+
concatenate_raws([raw0, raw2])
416+
417+
# Type mismatch raises
418+
epochs1 = make_fixed_length_epochs(raw1)
419+
with pytest.raises(ValueError):
420+
concatenate_raws([raw0, epochs1])
421+
422+
# Sample rate mismatch
423+
raw3 = _create_toy_data(sfreq=500)
424+
with pytest.raises(ValueError):
425+
concatenate_raws([raw0, raw3])
426+
427+
# Number of channels mismatch
428+
raw4 = _create_toy_data(n_channels=4)
429+
with pytest.raises(ValueError):
430+
concatenate_raws([raw0, raw4])
431+
432+
385433
@testing.requires_testing_data
386434
@pytest.mark.parametrize('mod', (
387435
'meg',

0 commit comments

Comments
 (0)