Skip to content

Commit 42cf687

Browse files
massichcbrnr
authored andcommitted
[MRG] EDF stim_channels (#5841)
* TST: Add stim_channel behaviour * TST: update deprecation warning * Fix: stim_channle (single stim) * fix warnings and errors * fix: raise the proper error. ValueError - The impact of a user parameter is responsible to trigger it. * clean up * read multiple stim_channels * prefer list() to None to simplify the code * wip multiple stim channels * fix multi stim channels * clean up + add warning when 'auto' is in ch_names * docstring * check TAL ch names with their full name * maybe this TODO is more clear. * TST / ENH: Parametrize * fix: add support for int indexing * ups * fix lower-upper case comparison * update docstring * pep8:
1 parent faefe1e commit 42cf687

File tree

3 files changed

+129
-74
lines changed

3 files changed

+129
-74
lines changed

mne/io/edf/edf.py

Lines changed: 102 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,12 @@ class RawEDF(BaseRaw):
8787
Names of channels or list of indices that should be designated
8888
MISC channels. Values should correspond to the electrodes in the
8989
edf file. Default is None.
90-
stim_channel : False
91-
If False, there will be no stim channel added from a TAL channel.
92-
None is accepted as an alias for False.
90+
stim_channel : 'auto' | str | list of str | int | list of int
91+
It defaults to 'auto' where channels named 'status' or 'trigger'
92+
are set as 'stim'. When str (or list of str) channels matching this
93+
string are set as 'sitm'. The matching is not case sensitive, same for
94+
the default behavior. When int (or list of ints) the channel
95+
corresponding to this position is set as channels of type 'stim'.
9396
9497
.. warning:: 0.18 does not allow for stim channel synthesis from
9598
the TAL channels called 'EDF Annotations' or
@@ -155,7 +158,7 @@ class RawEDF(BaseRaw):
155158

156159
@verbose
157160
def __init__(self, input_fname, montage, eog=None, misc=None,
158-
stim_channel=None, exclude=(), preload=False,
161+
stim_channel='auto', exclude=(), preload=False,
159162
verbose=None): # noqa: D102
160163
logger.info('Extracting EDF parameters from %s...' % input_fname)
161164
input_fname = os.path.abspath(input_fname)
@@ -278,7 +281,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
278281
d_sidx = d_lims[ai][0]
279282
d_eidx = d_lims[ai + n_read - 1][1]
280283
if n_samps[ci] != buf_len:
281-
if ci == stim_channel:
284+
if ci in stim_channel:
282285
# Stim channel will be interpolated
283286
old = np.linspace(0, 1, n_samps[ci] + 1, True)
284287
new = np.linspace(0, 1, buf_len, False)
@@ -298,22 +301,21 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
298301

299302
# only try to read the stim channel if it's not None and it's
300303
# actually one of the requested channels
301-
_idx = np.arange(self.info['nchan'])[idx] # slice -> ints
302304
if stim_channel is None: # avoid NumPy comparison to None
303305
stim_channel_idx = np.array([], int)
304306
else:
305-
stim_channel_idx = np.where(_idx == stim_channel)[0]
307+
_idx = np.arange(self.info['nchan'])[idx] # slice -> ints
308+
stim_channel_idx = list()
309+
for stim_ch in stim_channel:
310+
stim_ch_idx = np.where(_idx == stim_ch)[0].tolist()
311+
if len(stim_ch_idx):
312+
stim_channel_idx.append(stim_ch_idx)
313+
stim_channel_idx = np.array(stim_channel_idx).ravel()
306314

307315
if subtype == 'bdf':
308-
# do not scale stim channel (see gh-5160)
309-
if stim_channel is None:
310-
stim_idx = [[]]
311-
else:
312-
stim_idx = np.where(np.arange(self.info['nchan']) ==
313-
stim_channel)
314-
cal[0, stim_idx[0]] = 1
315-
offsets[stim_idx[0], 0] = 0
316-
gains[0, stim_idx[0]] = 1
316+
cal[0, stim_channel_idx] = 1
317+
offsets[stim_channel_idx, 0] = 0
318+
gains[0, stim_channel_idx] = 1
317319
data *= cal.T[idx]
318320
data += offsets[idx]
319321
data *= gains.T[idx]
@@ -349,39 +351,45 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None):
349351
return ch_data
350352

351353

352-
def _get_info(fname, stim_channel, eog, misc, exclude, preload):
353-
"""Extract all the information from the EDF+, BDF or GDF file."""
354-
if stim_channel is not None:
355-
if isinstance(stim_channel, bool) and not stim_channel:
356-
warn('stim_channel parameter is deprecated and will be removed in'
357-
' 0.19.', DeprecationWarning)
358-
stim_channel = None
359-
else:
360-
_msg = ('The synthesis of the stim channel is not supported since'
361-
' 0.18. Please set `stim_channel` to False and use'
362-
' `mne.events_from_annotations` instead')
363-
raise RuntimeError(_msg)
354+
def _read_header(fname, exclude):
355+
"""Unify edf, bdf and gdf _read_header call.
364356
365-
if eog is None:
366-
eog = []
367-
if misc is None:
368-
misc = []
357+
Parameters
358+
----------
359+
fname : str
360+
Path to the EDF+, BDF, or GDF file.
361+
exclude : list of str
362+
Channel names to exclude. This can help when reading data with
363+
different sampling rates to avoid unnecessary resampling.
369364
370-
# Read header from file
365+
Returns
366+
-------
367+
(edf_info, orig_units) : tuple
368+
"""
371369
ext = os.path.splitext(fname)[1][1:].lower()
372370
logger.info('%s file detected' % ext.upper())
373371
if ext in ('bdf', 'edf'):
374-
edf_info, orig_units = _read_edf_header(fname, exclude)
372+
return _read_edf_header(fname, exclude)
375373
elif ext in ('gdf'):
376-
edf_info = _read_gdf_header(fname, stim_channel, exclude)
377-
378-
# orig_units not yet implemented for gdf
379-
orig_units = None
380-
374+
return _read_gdf_header(fname, exclude), None
381375
else:
382376
raise NotImplementedError(
383377
'Only GDF, EDF, and BDF files are supported, got %s.' % ext)
384378

379+
380+
def _get_info(fname, stim_channel, eog, misc, exclude, preload):
381+
"""Extract all the information from the EDF+, BDF or GDF file."""
382+
eog = eog if eog is not None else []
383+
misc = misc if misc is not None else []
384+
385+
edf_info, orig_units = _read_header(fname, exclude)
386+
387+
# XXX: `tal_ch_names` to pass to `_check_stim_channel` should be computed
388+
# from `edf_info['ch_names']` and `edf_info['tal_idx']` but 'tal_idx'
389+
# contains stim channels that are not TAL.
390+
stim_ch_idxs, stim_ch_names = _check_stim_channel(stim_channel,
391+
edf_info['ch_names'])
392+
385393
sel = edf_info['sel'] # selection of channels not excluded
386394
ch_names = edf_info['ch_names'] # of length len(sel)
387395
n_samps = edf_info['n_samps'][sel]
@@ -398,8 +406,6 @@ def _get_info(fname, stim_channel, eog, misc, exclude, preload):
398406
warn('Physical range is not defined in following channels:\n' +
399407
', '.join(ch_names[i] for i in bad_idx))
400408
physical_ranges[bad_idx] = 1
401-
stim_channel, stim_ch_name = \
402-
_check_stim_channel(stim_channel, ch_names, sel)
403409

404410
# Creates a list of dicts of eeg channels for raw.info
405411
logger.info('Setting channel info structure...')
@@ -429,7 +435,7 @@ def _get_info(fname, stim_channel, eog, misc, exclude, preload):
429435
chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE
430436
chan_info['kind'] = FIFF.FIFFV_MISC_CH
431437
pick_mask[idx] = False
432-
elif stim_channel == idx:
438+
elif idx in stim_ch_idxs:
433439
chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE
434440
chan_info['unit'] = FIFF.FIFF_UNIT_NONE
435441
chan_info['kind'] = FIFF.FIFFV_STIM_CH
@@ -438,7 +444,8 @@ def _get_info(fname, stim_channel, eog, misc, exclude, preload):
438444
ch_names[idx] = chan_info['ch_name']
439445
edf_info['units'][idx] = 1
440446
chs.append(chan_info)
441-
edf_info['stim_channel'] = stim_channel
447+
448+
edf_info['stim_channel'] = stim_ch_idxs if len(stim_ch_idxs) else None
442449

443450
if any(pick_mask):
444451
picks = [item for item, mask in zip(range(nchan), pick_mask) if mask]
@@ -449,12 +456,9 @@ def _get_info(fname, stim_channel, eog, misc, exclude, preload):
449456
# Info structure
450457
# -------------------------------------------------------------------------
451458

452-
# sfreq defined as the max sampling rate of eeg (stim_ch not included)
453-
if stim_channel is None:
454-
data_samps = n_samps
455-
else:
456-
data_samps = np.delete(n_samps, slice(stim_channel, stim_channel + 1))
457-
sfreq = data_samps.max() * \
459+
not_stim_ch = [x for x in range(n_samps.shape[0])
460+
if x not in stim_ch_idxs]
461+
sfreq = np.take(n_samps, not_stim_ch).max() * \
458462
edf_info['record_length'][1] / edf_info['record_length'][0]
459463
info = _empty_info(sfreq)
460464
info['meas_date'] = edf_info['meas_date']
@@ -633,7 +637,7 @@ def _read_edf_header(fname, exclude):
633637
return edf_info, orig_units
634638

635639

636-
def _read_gdf_header(fname, stim_channel, exclude):
640+
def _read_gdf_header(fname, exclude):
637641
"""Read GDF 1.x and GDF 2.x header info."""
638642
edf_info = dict()
639643
events = None
@@ -1007,27 +1011,57 @@ def _read_gdf_header(fname, stim_channel, exclude):
10071011
return edf_info
10081012

10091013

1010-
def _check_stim_channel(stim_channel, ch_names, sel):
1014+
def _check_stim_channel(stim_channel, ch_names,
1015+
tal_ch_names=['EDF Annotations', 'BDF Annotations']):
10111016
"""Check that the stimulus channel exists in the current datafile."""
1012-
if stim_channel is False:
1013-
return None, None
1017+
DEFAULT_STIM_CH_NAMES = ['status', 'trigger']
1018+
10141019
if stim_channel is None:
1015-
stim_channel = 'auto'
1020+
return [], []
10161021

1017-
if isinstance(stim_channel, str):
1022+
elif isinstance(stim_channel, str):
10181023
if stim_channel == 'auto':
1019-
if 'STATUS' in ch_names:
1020-
stim_channel_idx = ch_names.index('STATUS')
1021-
elif 'Status' in ch_names:
1022-
stim_channel_idx = ch_names.index('Status')
1024+
if 'auto' in ch_names:
1025+
warn(RuntimeWarning, "Using `stim_channel='auto'` when auto"
1026+
" also corresponds to a channel name is ambiguous."
1027+
" Please use `stim_channel=['auto']`.")
10231028
else:
1024-
stim_channel_idx = None
1029+
valid_stim_ch_names = DEFAULT_STIM_CH_NAMES
1030+
else:
1031+
valid_stim_ch_names = [stim_channel.lower()]
1032+
1033+
elif isinstance(stim_channel, int):
1034+
valid_stim_ch_names = [ch_names[stim_channel].lower()]
1035+
1036+
elif isinstance(stim_channel, list):
1037+
if all([isinstance(s, str) for s in stim_channel]):
1038+
valid_stim_ch_names = [s.lower() for s in stim_channel]
1039+
elif all([isinstance(s, int) for s in stim_channel]):
1040+
valid_stim_ch_names = [ch_names[s].lower() for s in stim_channel]
1041+
else:
1042+
raise ValueError('Invalid stim_channel')
10251043
else:
10261044
raise ValueError('Invalid stim_channel')
10271045

1028-
name = None if stim_channel_idx is None else ch_names[stim_channel_idx]
1029-
1030-
return stim_channel_idx, name
1046+
# Forbid the synthesis of stim channels from TAL Annotations
1047+
tal_ch_names_found = [ch for ch in valid_stim_ch_names
1048+
if ch in [t.lower() for t in tal_ch_names]]
1049+
if len(tal_ch_names_found):
1050+
_msg = ('The synthesis of the stim channel is not supported'
1051+
' since 0.18. Please remove {} from `stim_channel`'
1052+
' and use `mne.events_from_annotations` instead'
1053+
).format(tal_ch_names_found)
1054+
raise ValueError(_msg)
1055+
1056+
ch_names_low = [ch.lower() for ch in ch_names]
1057+
found = list(set(valid_stim_ch_names) & set(ch_names_low))
1058+
1059+
if not found:
1060+
return [], []
1061+
else:
1062+
stim_channel_idxs = [ch_names_low.index(f) for f in found]
1063+
names = [ch_names[idx] for idx in stim_channel_idxs]
1064+
return stim_channel_idxs, names
10311065

10321066

10331067
def _find_exclude_idx(ch_names, exclude):
@@ -1047,7 +1081,7 @@ def _find_tal_idx(ch_names):
10471081

10481082

10491083
def read_raw_edf(input_fname, montage=None, eog=None, misc=None,
1050-
stim_channel=None, exclude=(), preload=False, verbose=None):
1084+
stim_channel='auto', exclude=(), preload=False, verbose=None):
10511085
"""Reader function for EDF+, BDF, GDF conversion to FIF.
10521086
10531087
Parameters
@@ -1066,9 +1100,10 @@ def read_raw_edf(input_fname, montage=None, eog=None, misc=None,
10661100
Names of channels or list of indices that should be designated
10671101
MISC channels. Values should correspond to the electrodes in the
10681102
edf file. Default is None.
1069-
stim_channel : False
1070-
If False, there will be no stim channel added from a TAL channel.
1071-
None is accepted as an alias for False.
1103+
stim_channel : 'auto' | list of str
1104+
Channels appearing in this list are set as channels of type 'stim'.
1105+
It defaults to 'auto' where channels matching 'status' or 'trigger'
1106+
are set as 'stim'. The matching is not case sensitive.
10721107
10731108
.. warning:: 0.18 does not allow for stim channel synthesis from
10741109
the TAL channels called 'EDF Annotations' or
771 KB
Binary file not shown.

mne/io/edf/tests/test_edf.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from mne.io.edf.edf import _read_ch
2929
from mne.io.pick import channel_indices_by_type
3030
from mne.annotations import events_from_annotations, read_annotations
31+
from mne.io.meas_info import _kind_dict as _KIND_DICT
3132

3233
FILE = inspect.getfile(inspect.currentframe())
3334
data_dir = op.join(op.dirname(op.abspath(FILE)), 'data')
@@ -196,19 +197,15 @@ def test_to_data_frame():
196197
assert_array_equal(df.values[:, 0], raw._data[0] * 1e13)
197198

198199

199-
def test_read_raw_edf_deprecation():
200+
def test_read_raw_edf_stim_channel_input_parameters():
200201
"""Test edf raw reader deprecation."""
201202
_MSG = "`read_raw_edf` is not supposed to trigger a deprecation warning"
202203
with pytest.warns(None) as recwarn:
203204
read_raw_edf(edf_path)
204205
assert all([w.category != DeprecationWarning for w in recwarn.list]), _MSG
205206

206-
with pytest.deprecated_call(match="stim_channel .* removed in 0.19"):
207-
read_raw_edf(edf_path, stim_channel=False)
208-
209-
for invalid_stim_parameter in ['what ever', 'STATUS', 'EDF Annotations',
210-
'BDF Annotations', 0, -1]:
211-
with pytest.raises(RuntimeError,
207+
for invalid_stim_parameter in ['EDF Annotations', 'BDF Annotations']:
208+
with pytest.raises(ValueError,
212209
match="stim channel is not supported"):
213210
read_raw_edf(edf_path, stim_channel=invalid_stim_parameter)
214211

@@ -286,4 +283,27 @@ def test_load_generator(fname, recwarn):
286283
assert_array_equal(events, [[0, 0, 1], [120000, 0, 2]])
287284

288285

286+
@pytest.mark.parametrize('EXPECTED, test_input', [
287+
pytest.param({'stAtUs': 'stim', 'tRigGer': 'stim', 'sine 1 Hz': 'eeg'},
288+
'auto', id='auto'),
289+
pytest.param({'stAtUs': 'eeg', 'tRigGer': 'eeg', 'sine 1 Hz': 'eeg'},
290+
None, id='None'),
291+
pytest.param({'stAtUs': 'eeg', 'tRigGer': 'eeg', 'sine 1 Hz': 'stim'},
292+
'sine 1 Hz', id='single string'),
293+
pytest.param({'stAtUs': 'eeg', 'tRigGer': 'eeg', 'sine 1 Hz': 'stim'},
294+
2, id='single int'),
295+
pytest.param({'stAtUs': 'eeg', 'tRigGer': 'eeg', 'sine 1 Hz': 'stim'},
296+
-1, id='single int (revers indexing)'),
297+
pytest.param({'stAtUs': 'stim', 'tRigGer': 'stim', 'sine 1 Hz': 'eeg'},
298+
[0, 1], id='int list')])
299+
def test_edf_stim_ch_pick_up(test_input, EXPECTED):
300+
"""Test stim_channel."""
301+
TYPE_LUT = {v[0]: k for k, v in _KIND_DICT.items()}
302+
fname = op.join(data_dir, 'test_stim_channel.edf')
303+
304+
raw = read_raw_edf(fname, stim_channel=test_input)
305+
ch_types = {ch['ch_name']: TYPE_LUT[ch['kind']] for ch in raw.info['chs']}
306+
assert ch_types == EXPECTED
307+
308+
289309
run_tests_if_main()

0 commit comments

Comments
 (0)