@@ -148,16 +148,15 @@ def lfp_timestamps(self):
148
148
self ._lfp_timestamps = np .hstack ([s .times for s in self .lfp_analog_signals ])
149
149
return self ._lfp_timestamps
150
150
151
- def extract_spike_waveforms (self , spikes , channel , n_wf = 500 , wf_win = (- 32 , 32 )):
151
+ def extract_spike_waveforms (self , spikes , channel_ind , n_wf = 500 , wf_win = (- 32 , 32 )):
152
152
"""
153
153
:param spikes: spike times (in second) to extract waveforms
154
- :param channel : channel (name, not indices ) to extract waveforms
154
+ :param channel_ind : channel indices (of meta['channels_ids'] ) to extract waveforms
155
155
:param n_wf: number of spikes per unit to extract the waveforms
156
156
:param wf_win: number of sample pre and post a spike
157
157
:return: waveforms (sample x channel x spike)
158
158
"""
159
- channel_ind = [np .where (self .ap_meta ['channels_ids' ] == chn )[0 ][0 ] for chn in channel ]
160
- channel_bit_volts = self .ap_meta ['channels_gains' ][channel_ind ]
159
+ channel_bit_volts = np .array (self .ap_meta ['channels_gains' ])[channel_ind ]
161
160
162
161
# ignore spikes at the beginning or end of raw data
163
162
spikes = spikes [np .logical_and (spikes > (- wf_win [0 ] / self .ap_meta ['sample_rate' ]),
@@ -174,4 +173,4 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
174
173
for spk in spike_indices ])
175
174
return spike_wfs
176
175
else : # if no spike found, return NaN of size (sample x channel x 1)
177
- return np .full ((len (range (* wf_win )), len (channel ), 1 ), np .nan )
176
+ return np .full ((len (range (* wf_win )), len (channel_ind ), 1 ), np .nan )
0 commit comments