Skip to content

Commit f9e5fc2

Browse files
author
Thinh Nguyen
committed
keep _timeseries data as memmap int16 type, apply bitvolt conversion only when needed (at LFP or waveform extraction)
1 parent 625c630 commit f9e5fc2

File tree

3 files changed

+73
-44
lines changed

3 files changed

+73
-44
lines changed

elements_ephys/ephys.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,11 @@ def make(self, key):
228228
spikeglx_rec_dir = (root_dir / spikeglx_meta_fp).parent
229229
spikeglx_recording = spikeglx.SpikeGLX(spikeglx_rec_dir)
230230

231-
lfp = spikeglx_recording.lf_timeseries[:, :-1].T # exclude the sync channel
231+
lfp_chn_ind = spikeglx_recording.lfmeta.recording_channels[-1::-self.skip_chn_counts]
232+
233+
# Extract LFP data at specified channels and convert to uV
234+
lfp = spikeglx_recording.lf_timeseries[:, lfp_chn_ind] # (sample x channel)
235+
lfp = (lfp * spikeglx_recording.get_channel_bit_volts('lf')[lfp_chn_ind]).T # (channel x sample)
232236

233237
self.insert1(dict(key,
234238
lfp_sampling_rate=spikeglx_recording.lfmeta.meta['imSampRate'],
@@ -237,7 +241,7 @@ def make(self, key):
237241

238242
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording & key
239243
electrodes = []
240-
for recorded_site in np.arange(lfp.shape[0]):
244+
for recorded_site in lfp_chn_ind:
241245
shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site]
242246
electrodes.append((q_electrodes
243247
& {'shank': shank,
@@ -247,12 +251,17 @@ def make(self, key):
247251
chn_lfp = list(zip(electrodes, lfp))
248252
self.Electrode().insert((
249253
{**key, **electrode, 'lfp': d}
250-
for electrode, d in chn_lfp[-1::-self._skip_chn_counts]), ignore_extra_fields=True)
254+
for electrode, d in chn_lfp), ignore_extra_fields=True)
255+
251256
elif acq_software == 'OpenEphys':
252257
sess_dir = pathlib.Path(get_session_directory(key))
253258
loaded_oe = openephys.OpenEphys(sess_dir)
254259
oe_probe = loaded_oe.probes[probe_sn]
255-
lfp = oe_probe.lfp_timeseries
260+
261+
lfp_chn_ind = np.arange(len(oe_probe.lfp_meta['channels_ids']))[-1::-self.skip_chn_counts]
262+
263+
lfp = oe_probe.lfp_timeseries[:, lfp_chn_ind] # (sample x channel)
264+
lfp = (lfp * oe_probe.lfp_meta['channels_gains'][lfp_chn_ind]).T # (channel x sample)
256265
lfp_timestamps = oe_probe.lfp_timestamps
257266

258267
self.insert1(dict(key,
@@ -262,13 +271,13 @@ def make(self, key):
262271

263272
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording & key
264273
electrodes = []
265-
for chn_idx in oe_probe.lfp_meta['channels_ids']:
274+
for chn_idx in oe_probe.lfp_meta['channels_ids'][lfp_chn_ind]:
266275
electrodes.append((q_electrodes & {'electrode': chn_idx}).fetch1('KEY'))
267276

268277
chn_lfp = list(zip(electrodes, lfp))
269278
self.Electrode().insert((
270279
{**key, **electrode, 'lfp': d}
271-
for electrode, d in chn_lfp[-1::-self._skip_chn_counts]), ignore_extra_fields=True)
280+
for electrode, d in chn_lfp), ignore_extra_fields=True)
272281

273282
else:
274283
raise NotImplementedError(f'LFP extraction from acquisition software of type {acq_software} is not yet implemented')

elements_ephys/readers/openephys.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,10 @@ def __init__(self, processor):
120120
def ap_timeseries(self):
121121
"""
122122
AP data concatenated across recordings. Shape: (sample x channel)
123-
Channels' gains (bit_volts) applied - unit: uV
123+
Data are stored as int16 - to convert to microvolts, multiply with self.ap_meta['channels_gains']
124124
"""
125125
if self._ap_timeseries is None:
126126
self._ap_timeseries = np.hstack([s.signal for s in self.ap_analog_signals]).T
127-
self._ap_timeseries *= self.ap_meta['channels_gains']
128127
return self._ap_timeseries
129128

130129
@property
@@ -137,11 +136,10 @@ def ap_timestamps(self):
137136
def lfp_timeseries(self):
138137
"""
139138
LFP data concatenated across recordings. Shape: (sample x channel)
140-
Channels' gains (bit_volts) applied - unit: uV
139+
Data are stored as int16 - to convert to microvolts, multiply with self.lfp_meta['channels_gains']
141140
"""
142141
if self._lfp_timeseries is None:
143142
self._lfp_timeseries = np.hstack([s.signal for s in self.lfp_analog_signals]).T
144-
self._lfp_timeseries *= self.lfp_meta['channels_gains']
145143
return self._lfp_timeseries
146144

147145
@property
@@ -159,6 +157,7 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
159157
:return: waveforms (sample x channel x spike)
160158
"""
161159
channel_ind = [np.where(self.ap_meta['channels_ids'] == chn)[0][0] for chn in channel]
160+
channel_bit_volts = self.ap_meta['channels_gains'][channel_ind]
162161

163162
# ignore spikes at the beginning or end of raw data
164163
spikes = spikes[np.logical_and(spikes > (-wf_win[0] / self.ap_meta['sample_rate']),
@@ -171,6 +170,7 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
171170
spike_indices = np.searchsorted(self.ap_timestamps, spikes, side="left")
172171
# waveform at each spike: (sample x channel x spike)
173172
spike_wfs = np.dstack([self.ap_timeseries[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_ind]
173+
* channel_bit_volts
174174
for spk in spike_indices])
175175
return spike_wfs
176176
else: # if no spike found, return NaN of size (sample x channel x 1)

elements_ephys/readers/spikeglx.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def apmeta(self):
5050
def ap_timeseries(self):
5151
"""
5252
AP data: (sample x channel)
53-
Channels' gains (bit_volts) applied - unit: uV
53+
Data are stored as np.memmap with dtype: int16
54+
- to convert to microvolts, multiply with self.get_channel_bit_volts('ap')
5455
"""
5556
if self._ap_timeseries is None:
5657
self._ap_timeseries = self._read_bin(self.root_dir / (self.root_name + '.ap.bin'))
57-
self._ap_timeseries *= self.get_channel_bit_volts('ap')
5858
return self._ap_timeseries
5959

6060
@property
@@ -67,16 +67,16 @@ def lfmeta(self):
6767
def lf_timeseries(self):
6868
"""
6969
LFP data: (sample x channel)
70-
Channels' gains (bit_volts) applied - unit: uV
70+
Data are stored as np.memmap with dtype: int16
71+
- to convert to microvolts, multiply with self.get_channel_bit_volts('lf')
7172
"""
7273
if self._lf_timeseries is None:
7374
self._lf_timeseries = self._read_bin(self.root_dir / (self.root_name + '.lf.bin'))
74-
self._lf_timeseries *= self.get_channel_bit_volts('lf')
7575
return self._lf_timeseries
7676

7777
def get_channel_bit_volts(self, band='ap'):
7878
"""
79-
Extract the AP and LF channels' int16 to microvolts
79+
Extract the recorded AP and LF channels' int16 to microvolts - no Sync (SY) channels
8080
Following the steps specified in: https://billkarsh.github.io/SpikeGLX/Support/SpikeGLX_Datafile_Tools.zip
8181
dataVolts = dataInt * Vmax / Imax / gain
8282
"""
@@ -86,11 +86,13 @@ def get_channel_bit_volts(self, band='ap'):
8686
imax = IMAX[self.apmeta.probe_model]
8787
imroTbl_data = self.apmeta.imroTbl['data']
8888
imroTbl_idx = 3
89+
chn_ind = self.apmeta.get_recording_channels_indices(exclude_sync=True)
8990

9091
elif band == 'lf':
9192
imax = IMAX[self.lfmeta.probe_model]
9293
imroTbl_data = self.lfmeta.imroTbl['data']
9394
imroTbl_idx = 4
95+
chn_ind = self.lfmeta.get_recording_channels_indices(exclude_sync=True)
9496
else:
9597
raise ValueError(f'Unsupported band: {band} - Must be "ap" or "lf"')
9698

@@ -102,25 +104,26 @@ def get_channel_bit_volts(self, band='ap'):
102104
# 3A, 3B1, 3B2 (NP 1.0)
103105
chn_gains = [c[imroTbl_idx] for c in imroTbl_data]
104106

105-
return vmax / imax / np.array(chn_gains) * 1e6 # convert to uV as well
107+
chn_gains = np.array(chn_gains)[chn_ind]
108+
109+
return vmax / imax / chn_gains * 1e6 # convert to uV as well
106110

107111
def _read_bin(self, fname):
108112
nchan = self.apmeta.meta['nSavedChans']
109113
dtype = np.dtype((np.int16, nchan))
110114
return np.memmap(fname, dtype, 'r')
111115

112-
def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32), bit_volts=1):
116+
def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32)):
113117
"""
114118
:param spikes: spike times (in second) to extract waveforms
115-
:param channel: channel (name, not indices) to extract waveforms
119+
:param channel_ind: channel indices (of shankmap) to extract the waveforms from
116120
:param n_wf: number of spikes per unit to extract the waveforms
117121
:param wf_win: number of sample pre and post a spike
118-
:param bit_volts: scalar required to convert int16 values into microvolts (default of 1)
119-
:return: waveforms (sample x channel x spike)
122+
:return: waveforms (in uV) - shape: (sample x channel x spike)
120123
"""
124+
channel_bit_volts = self.get_channel_bit_volts('ap')[channel_ind]
121125

122126
data = self.ap_timeseries
123-
channel_idx = [np.where(self.apmeta.recording_channels == chn)[0][0] for chn in channel]
124127

125128
spikes = np.round(spikes * self.apmeta.meta['imSampRate']).astype(int) # convert to sample
126129
# ignore spikes at the beginning or end of raw data
@@ -130,10 +133,12 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32), b
130133
spikes = spikes[:n_wf]
131134
if len(spikes) > 0:
132135
# waveform at each spike: (sample x channel x spike)
133-
spike_wfs = np.dstack([data[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_idx] for spk in spikes])
134-
return spike_wfs * bit_volts
136+
spike_wfs = np.dstack([data[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_ind]
137+
* channel_bit_volts
138+
for spk in spikes])
139+
return spike_wfs
135140
else: # if no spike found, return NaN of size (sample x channel x 1)
136-
return np.full((len(range(*wf_win)), len(channel), 1), np.nan)
141+
return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan)
137142

138143

139144
class SpikeGLXMeta:
@@ -177,7 +182,9 @@ def __init__(self, meta_filepath):
177182
self.shankmap = self._parse_shankmap(self.meta['~snsShankMap']) if '~snsShankMap' in self.meta else None
178183
self.imroTbl = self._parse_imrotbl(self.meta['~imroTbl']) if '~imroTbl' in self.meta else None
179184

180-
self._recording_channels = None
185+
# Channels being recorded, exclude Sync channels - basically a 1-1 mapping to shankmap
186+
self.recording_channels = [int(v[0]) for k, v in self.chanmap.items()
187+
if k != 'shape' and not k.startswith('SY')]
181188

182189
@staticmethod
183190
def _parse_chanmap(raw):
@@ -208,6 +215,9 @@ def _parse_chanmap(raw):
208215
@staticmethod
209216
def _parse_shankmap(raw):
210217
"""
218+
The shankmap contains details on the shank info
219+
for each electrode sites of the sites being recorded only
220+
211221
https://github.com/billkarsh/SpikeGLX/blob/master/Markdown/UserManual.md#shank-map
212222
Parse shank map header structure. Converts:
213223
@@ -234,6 +244,10 @@ def _parse_shankmap(raw):
234244
@staticmethod
235245
def _parse_imrotbl(raw):
236246
"""
247+
The imro table contains info for all electrode sites (no sync)
248+
for a particular electrode configuration (all 384 sites)
249+
Note: not all of these 384 sites are necessarily recorded
250+
237251
https://github.com/billkarsh/SpikeGLX/blob/master/Markdown/UserManual.md#imro-per-channel-settings
238252
Parse imro tbl structure. Converts:
239253
@@ -257,8 +271,17 @@ def _parse_imrotbl(raw):
257271

258272
return res
259273

260-
@property
261-
def recording_channels(self):
274+
def get_recording_channels_indices(self, exclude_sync=False):
275+
"""
276+
The indices of recorded channels (in chanmap) with respect to the channels listed in the imro table
277+
"""
278+
recorded_chns_ind = [int(v[0]) for k, v in self.chanmap.items()
279+
if k != 'shape' and (not k.startswith('SY') if exclude_sync else True)]
280+
orig_chns_ind = self.get_original_chans()
281+
_, _, chns_ind = np.intersect1d(orig_chns_ind, recorded_chns_ind, return_indices=True)
282+
return chns_ind
283+
284+
def get_original_chans(self):
262285
"""
263286
Because you can selectively save channels, the
264287
ith channel in the file isn't necessarily the ith acquired channel.
@@ -267,23 +290,20 @@ def recording_channels(self):
267290
Credit to https://billkarsh.github.io/SpikeGLX/Support/SpikeGLX_Datafile_Tools.zip
268291
OriginalChans() function
269292
"""
270-
if self._recording_channels is None:
271-
if self.meta['snsSaveChanSubset'] == 'all':
272-
# output = int32, 0 to nSavedChans - 1
273-
self._recording_channels = np.arange(0, int(self.meta['nSavedChans']))
274-
else:
275-
# parse the snsSaveChanSubset string
276-
# split at commas
277-
chStrList = self.meta['snsSaveChanSubset'].split(sep=',')
278-
self._recording_channels = np.arange(0, 0) # creates an empty array of int32
279-
for sL in chStrList:
280-
currList = sL.split(sep=':')
281-
# each set of continuous channels specified by chan1:chan2 inclusive
282-
newChans = np.arange(int(currList[0]), int(currList[min(1, len(currList))]) + 1)
283-
284-
self._recording_channels = np.append(self._recording_channels, newChans)
285-
return self._recording_channels
286-
293+
if self.meta['snsSaveChanSubset'] == 'all':
294+
# output = int32, 0 to nSavedChans - 1
295+
chans = np.arange(0, int(self.meta['nSavedChans']))
296+
else:
297+
# parse the snsSaveChanSubset string
298+
# split at commas
299+
chStrList = self.meta['snsSaveChanSubset'].split(sep = ',')
300+
chans = np.arange(0, 0) # creates an empty array of int32
301+
for sL in chStrList:
302+
currList = sL.split(sep = ':')
303+
# each set of continuous channels specified by chan1:chan2 inclusive
304+
newChans = np.arange(int(currList[0]), int(currList[min(1, len(currList) - 1)]) + 1)
305+
chans = np.append(chans, newChans)
306+
return chans
287307

288308
# ============= HELPER FUNCTIONS =============
289309

0 commit comments

Comments
 (0)