Skip to content

Commit cd53a27

Browse files
hubertjbagramfort
authored andcommitted
[MRG] Adding channel_wise argument to Raw.apply_function (#5875)
* Adding `channel_wise` argument to `Raw.apply_function` * precising docstring and modifying test syntax * DOC: rst and order fixes * DOC: rst fixes * updating whats_new.rst * FIX: Line lengths
1 parent c6956e7 commit cd53a27

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ Changelog
5353

5454
- Add ``extrapolate`` argument to :func:`mne.viz.plot_topomap` for better control of extrapolation points placement by `Mikołaj Magnuski`_
5555

56+
- Add ``channel_wise`` argument to :func:`mne.io.Raw.apply_function` to allow applying a function on multiple channels at once by `Hubert Banville`_
57+
5658
Bug
5759
~~~
5860

mne/io/base.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,8 +1033,8 @@ def get_data(self, picks=None, start=0, stop=None,
10331033
return data
10341034

10351035
@verbose
1036-
def apply_function(self, fun, picks=None, dtype=None,
1037-
n_jobs=1, *args, **kwargs):
1036+
def apply_function(self, fun, picks=None, dtype=None, n_jobs=1,
1037+
channel_wise=True, *args, **kwargs):
10381038
"""Apply a function to a subset of channels.
10391039
10401040
The function "fun" is applied to the channels defined in "picks". The
@@ -1059,15 +1059,23 @@ def apply_function(self, fun, picks=None, dtype=None,
10591059
fun : function
10601060
A function to be applied to the channels. The first argument of
10611061
fun has to be a timeseries (numpy.ndarray). The function must
1062-
return an numpy.ndarray with the same size as the input.
1062+
operate on an array of shape ``(n_times,)`` if
1063+
``channel_wise=True`` and ``(len(picks), n_times)`` otherwise.
1064+
The function must return an ndarray shaped like its input.
10631065
picks : array-like of int (default: None)
10641066
Indices of channels to apply the function to. If None, all data
10651067
channels are used.
10661068
dtype : numpy.dtype (default: None)
10671069
Data type to use for raw data after applying the function. If None
10681070
the data type is not modified.
10691071
n_jobs: int (default: 1)
1070-
Number of jobs to run in parallel.
1072+
Number of jobs to run in parallel. Ignored if `channel_wise` is
1073+
False.
1074+
channel_wise: bool (default: True)
1075+
Whether to apply the function to each channel individually. If
1076+
False, the function will be applied to all channels at once.
1077+
1078+
.. versionadded:: 0.18
10711079
*args :
10721080
Additional positional arguments to pass to fun (first pos. argument
10731081
of fun is the timeseries of a channel).
@@ -1094,18 +1102,23 @@ def apply_function(self, fun, picks=None, dtype=None,
10941102
if dtype is not None and dtype != self._data.dtype:
10951103
self._data = self._data.astype(dtype)
10961104

1097-
if n_jobs == 1:
1098-
# modify data inplace to save memory
1099-
for idx in picks:
1100-
self._data[idx, :] = _check_fun(fun, data_in[idx, :],
1101-
*args, **kwargs)
1105+
if channel_wise:
1106+
if n_jobs == 1:
1107+
# modify data inplace to save memory
1108+
for idx in picks:
1109+
self._data[idx, :] = _check_fun(fun, data_in[idx, :],
1110+
*args, **kwargs)
1111+
else:
1112+
# use parallel function
1113+
parallel, p_fun, _ = parallel_func(_check_fun, n_jobs)
1114+
data_picks_new = parallel(
1115+
p_fun(fun, data_in[p], *args, **kwargs) for p in picks)
1116+
for pp, p in enumerate(picks):
1117+
self._data[p, :] = data_picks_new[pp]
11021118
else:
1103-
# use parallel function
1104-
parallel, p_fun, _ = parallel_func(_check_fun, n_jobs)
1105-
data_picks_new = parallel(p_fun(fun, data_in[p], *args, **kwargs)
1106-
for p in picks)
1107-
for pp, p in enumerate(picks):
1108-
self._data[p, :] = data_picks_new[pp]
1119+
self._data[picks, :] = _check_fun(
1120+
fun, data_in[picks, :], *args, **kwargs)
1121+
11091122
return self
11101123

11111124
@verbose

mne/io/tests/test_apply_function.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def bad_2(x):
2020
return x[:-1] # bad shape
2121

2222

23+
def bad_3(x):
24+
"""Fail."""
25+
return x[0, :]
26+
27+
2328
def printer(x):
2429
"""Print."""
2530
logger.info('exec')
@@ -35,10 +40,21 @@ def test_apply_function_verbose():
3540
raw = RawArray(np.zeros((n_chan, n_times)),
3641
create_info(ch_names, 1., 'mag'))
3742
# test return types in both code paths (parallel / 1 job)
38-
pytest.raises(TypeError, raw.apply_function, bad_1)
39-
pytest.raises(ValueError, raw.apply_function, bad_2)
40-
pytest.raises(TypeError, raw.apply_function, bad_1, n_jobs=2)
41-
pytest.raises(ValueError, raw.apply_function, bad_2, n_jobs=2)
43+
with pytest.raises(TypeError, match='Return value must be an ndarray'):
44+
raw.apply_function(bad_1)
45+
with pytest.raises(ValueError, match='Return data must have shape'):
46+
raw.apply_function(bad_2)
47+
with pytest.raises(TypeError, match='Return value must be an ndarray'):
48+
raw.apply_function(bad_1, n_jobs=2)
49+
with pytest.raises(ValueError, match='Return data must have shape'):
50+
raw.apply_function(bad_2, n_jobs=2)
51+
52+
# test return type when `channel_wise=False`
53+
raw.apply_function(printer, channel_wise=False)
54+
with pytest.raises(TypeError, match='Return value must be an ndarray'):
55+
raw.apply_function(bad_1, channel_wise=False)
56+
with pytest.raises(ValueError, match='Return data must have shape'):
57+
raw.apply_function(bad_3, channel_wise=False)
4258

4359
# check our arguments
4460
with catch_logging() as sio:

0 commit comments

Comments
 (0)