@@ -1033,8 +1033,8 @@ def get_data(self, picks=None, start=0, stop=None,
1033
1033
return data
1034
1034
1035
1035
@verbose
1036
- def apply_function (self , fun , picks = None , dtype = None ,
1037
- n_jobs = 1 , * args , ** kwargs ):
1036
+ def apply_function (self , fun , picks = None , dtype = None , n_jobs = 1 ,
1037
+ channel_wise = True , * args , ** kwargs ):
1038
1038
"""Apply a function to a subset of channels.
1039
1039
1040
1040
The function "fun" is applied to the channels defined in "picks". The
@@ -1059,15 +1059,23 @@ def apply_function(self, fun, picks=None, dtype=None,
1059
1059
fun : function
1060
1060
A function to be applied to the channels. The first argument of
1061
1061
fun has to be a timeseries (numpy.ndarray). The function must
1062
- return an numpy.ndarray with the same size as the input.
1062
+ operate on an array of shape ``(n_times,)`` if
1063
+ ``channel_wise=True`` and ``(len(picks), n_times)`` otherwise.
1064
+ The function must return an ndarray shaped like its input.
1063
1065
picks : array-like of int (default: None)
1064
1066
Indices of channels to apply the function to. If None, all data
1065
1067
channels are used.
1066
1068
dtype : numpy.dtype (default: None)
1067
1069
Data type to use for raw data after applying the function. If None
1068
1070
the data type is not modified.
1069
1071
n_jobs: int (default: 1)
1070
- Number of jobs to run in parallel.
1072
+ Number of jobs to run in parallel. Ignored if `channel_wise` is
1073
+ False.
1074
+ channel_wise: bool (default: True)
1075
+ Whether to apply the function to each channel individually. If
1076
+ False, the function will be applied to all channels at once.
1077
+
1078
+ .. versionadded:: 0.18
1071
1079
*args :
1072
1080
Additional positional arguments to pass to fun (first pos. argument
1073
1081
of fun is the timeseries of a channel).
@@ -1094,18 +1102,23 @@ def apply_function(self, fun, picks=None, dtype=None,
1094
1102
if dtype is not None and dtype != self ._data .dtype :
1095
1103
self ._data = self ._data .astype (dtype )
1096
1104
1097
- if n_jobs == 1 :
1098
- # modify data inplace to save memory
1099
- for idx in picks :
1100
- self ._data [idx , :] = _check_fun (fun , data_in [idx , :],
1101
- * args , ** kwargs )
1105
+ if channel_wise :
1106
+ if n_jobs == 1 :
1107
+ # modify data inplace to save memory
1108
+ for idx in picks :
1109
+ self ._data [idx , :] = _check_fun (fun , data_in [idx , :],
1110
+ * args , ** kwargs )
1111
+ else :
1112
+ # use parallel function
1113
+ parallel , p_fun , _ = parallel_func (_check_fun , n_jobs )
1114
+ data_picks_new = parallel (
1115
+ p_fun (fun , data_in [p ], * args , ** kwargs ) for p in picks )
1116
+ for pp , p in enumerate (picks ):
1117
+ self ._data [p , :] = data_picks_new [pp ]
1102
1118
else :
1103
- # use parallel function
1104
- parallel , p_fun , _ = parallel_func (_check_fun , n_jobs )
1105
- data_picks_new = parallel (p_fun (fun , data_in [p ], * args , ** kwargs )
1106
- for p in picks )
1107
- for pp , p in enumerate (picks ):
1108
- self ._data [p , :] = data_picks_new [pp ]
1119
+ self ._data [picks , :] = _check_fun (
1120
+ fun , data_in [picks , :], * args , ** kwargs )
1121
+
1109
1122
return self
1110
1123
1111
1124
@verbose
0 commit comments