Skip to content

Commit 70a813b

Browse files
Merge pull request #9 from ttngu207/main
keep `_timeseries` data as memmap int16 type, apply bitvolt conversion at LFP/Waveform extraction step & Bugfix in channel matching for SpikeGLX
2 parents f76086c + 93ea01a commit 70a813b

File tree

8 files changed

+286
-481
lines changed

8 files changed

+286
-481
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ephys pipeline.
1212

1313
## The Pipeline Architecture
1414

15-
![ephys pipeline diagram](images/attached_ephys_element.png)
15+
![ephys pipeline diagram](images/attached_ephys_element.svg)
1616

1717
As the diagram depicts, the ephys element starts immediately downstream from ***Session***,
1818
and also requires some notion of ***Location*** as a dependency for ***InsertionLocation***.

elements_ephys/ephys.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,11 @@ def make(self, key):
228228
spikeglx_rec_dir = (root_dir / spikeglx_meta_fp).parent
229229
spikeglx_recording = spikeglx.SpikeGLX(spikeglx_rec_dir)
230230

231-
lfp = spikeglx_recording.lf_timeseries[:, :-1].T # exclude the sync channel
231+
lfp_chn_ind = spikeglx_recording.lfmeta.recording_channels[-1::-self._skip_chn_counts]
232+
233+
# Extract LFP data at specified channels and convert to uV
234+
lfp = spikeglx_recording.lf_timeseries[:, lfp_chn_ind] # (sample x channel)
235+
lfp = (lfp * spikeglx_recording.get_channel_bit_volts('lf')[lfp_chn_ind]).T # (channel x sample)
232236

233237
self.insert1(dict(key,
234238
lfp_sampling_rate=spikeglx_recording.lfmeta.meta['imSampRate'],
@@ -237,7 +241,7 @@ def make(self, key):
237241

238242
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording & key
239243
electrodes = []
240-
for recorded_site in np.arange(lfp.shape[0]):
244+
for recorded_site in lfp_chn_ind:
241245
shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site]
242246
electrodes.append((q_electrodes
243247
& {'shank': shank,
@@ -247,12 +251,17 @@ def make(self, key):
247251
chn_lfp = list(zip(electrodes, lfp))
248252
self.Electrode().insert((
249253
{**key, **electrode, 'lfp': d}
250-
for electrode, d in chn_lfp[-1::-self._skip_chn_counts]), ignore_extra_fields=True)
254+
for electrode, d in chn_lfp), ignore_extra_fields=True)
255+
251256
elif acq_software == 'OpenEphys':
252257
sess_dir = pathlib.Path(get_session_directory(key))
253258
loaded_oe = openephys.OpenEphys(sess_dir)
254259
oe_probe = loaded_oe.probes[probe_sn]
255-
lfp = oe_probe.lfp_timeseries
260+
261+
lfp_chn_ind = np.arange(len(oe_probe.lfp_meta['channels_ids']))[-1::-self._skip_chn_counts]
262+
263+
lfp = oe_probe.lfp_timeseries[:, lfp_chn_ind] # (sample x channel)
264+
lfp = (lfp * np.array(oe_probe.lfp_meta['channels_gains'])[lfp_chn_ind]).T # (channel x sample)
256265
lfp_timestamps = oe_probe.lfp_timestamps
257266

258267
self.insert1(dict(key,
@@ -262,13 +271,13 @@ def make(self, key):
262271

263272
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording & key
264273
electrodes = []
265-
for chn_idx in oe_probe.lfp_meta['channels_ids']:
274+
for chn_idx in np.array(oe_probe.lfp_meta['channels_ids'])[lfp_chn_ind]:
266275
electrodes.append((q_electrodes & {'electrode': chn_idx}).fetch1('KEY'))
267276

268277
chn_lfp = list(zip(electrodes, lfp))
269278
self.Electrode().insert((
270279
{**key, **electrode, 'lfp': d}
271-
for electrode, d in chn_lfp[-1::-self._skip_chn_counts]), ignore_extra_fields=True)
280+
for electrode, d in chn_lfp), ignore_extra_fields=True)
272281

273282
else:
274283
raise NotImplementedError(f'LFP extraction from acquisition software of type {acq_software} is not yet implemented')

elements_ephys/readers/openephys.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,10 @@ def __init__(self, processor):
120120
def ap_timeseries(self):
121121
"""
122122
AP data concatenated across recordings. Shape: (sample x channel)
123-
Channels' gains (bit_volts) applied - unit: uV
123+
Data are stored as int16 - to convert to microvolts, multiply with self.ap_meta['channels_gains']
124124
"""
125125
if self._ap_timeseries is None:
126126
self._ap_timeseries = np.hstack([s.signal for s in self.ap_analog_signals]).T
127-
self._ap_timeseries *= self.ap_meta['channels_gains']
128127
return self._ap_timeseries
129128

130129
@property
@@ -137,11 +136,10 @@ def ap_timestamps(self):
137136
def lfp_timeseries(self):
138137
"""
139138
LFP data concatenated across recordings. Shape: (sample x channel)
140-
Channels' gains (bit_volts) applied - unit: uV
139+
Data are stored as int16 - to convert to microvolts, multiply with self.lfp_meta['channels_gains']
141140
"""
142141
if self._lfp_timeseries is None:
143142
self._lfp_timeseries = np.hstack([s.signal for s in self.lfp_analog_signals]).T
144-
self._lfp_timeseries *= self.lfp_meta['channels_gains']
145143
return self._lfp_timeseries
146144

147145
@property
@@ -150,15 +148,15 @@ def lfp_timestamps(self):
150148
self._lfp_timestamps = np.hstack([s.times for s in self.lfp_analog_signals])
151149
return self._lfp_timestamps
152150

153-
def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
151+
def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32)):
154152
"""
155153
:param spikes: spike times (in second) to extract waveforms
156-
:param channel: channel (name, not indices) to extract waveforms
154+
:param channel_ind: channel indices (of meta['channels_ids']) to extract waveforms
157155
:param n_wf: number of spikes per unit to extract the waveforms
158156
:param wf_win: number of sample pre and post a spike
159157
:return: waveforms (sample x channel x spike)
160158
"""
161-
channel_ind = [np.where(self.ap_meta['channels_ids'] == chn)[0][0] for chn in channel]
159+
channel_bit_volts = np.array(self.ap_meta['channels_gains'])[channel_ind]
162160

163161
# ignore spikes at the beginning or end of raw data
164162
spikes = spikes[np.logical_and(spikes > (-wf_win[0] / self.ap_meta['sample_rate']),
@@ -171,7 +169,8 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
171169
spike_indices = np.searchsorted(self.ap_timestamps, spikes, side="left")
172170
# waveform at each spike: (sample x channel x spike)
173171
spike_wfs = np.dstack([self.ap_timeseries[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_ind]
172+
* channel_bit_volts
174173
for spk in spike_indices])
175174
return spike_wfs
176175
else: # if no spike found, return NaN of size (sample x channel x 1)
177-
return np.full((len(range(*wf_win)), len(channel), 1), np.nan)
176+
return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan)

elements_ephys/readers/spikeglx.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def __init__(self, root_dir):
3737

3838
self.root_dir = pathlib.Path(root_dir)
3939

40-
meta_filepath = next(pathlib.Path(root_dir).glob('*.ap.meta'))
40+
try:
41+
meta_filepath = next(pathlib.Path(root_dir).glob('*.ap.meta'))
42+
except StopIteration:
43+
raise FileNotFoundError(f'No SpikeGLX file (.ap.meta) found at: {root_dir}')
44+
4145
self.root_name = meta_filepath.name.replace('.ap.meta', '')
4246

4347
@property
@@ -50,11 +54,11 @@ def apmeta(self):
5054
def ap_timeseries(self):
5155
"""
5256
AP data: (sample x channel)
53-
Channels' gains (bit_volts) applied - unit: uV
57+
Data are stored as np.memmap with dtype: int16
58+
- to convert to microvolts, multiply with self.get_channel_bit_volts('ap')
5459
"""
5560
if self._ap_timeseries is None:
5661
self._ap_timeseries = self._read_bin(self.root_dir / (self.root_name + '.ap.bin'))
57-
self._ap_timeseries *= self.get_channel_bit_volts('ap')
5862
return self._ap_timeseries
5963

6064
@property
@@ -67,16 +71,16 @@ def lfmeta(self):
6771
def lf_timeseries(self):
6872
"""
6973
LFP data: (sample x channel)
70-
Channels' gains (bit_volts) applied - unit: uV
74+
Data are stored as np.memmap with dtype: int16
75+
- to convert to microvolts, multiply with self.get_channel_bit_volts('lf')
7176
"""
7277
if self._lf_timeseries is None:
7378
self._lf_timeseries = self._read_bin(self.root_dir / (self.root_name + '.lf.bin'))
74-
self._lf_timeseries *= self.get_channel_bit_volts('lf')
7579
return self._lf_timeseries
7680

7781
def get_channel_bit_volts(self, band='ap'):
7882
"""
79-
Extract the AP and LF channels' int16 to microvolts
83+
Extract the recorded AP and LF channels' int16 to microvolts - no Sync (SY) channels
8084
Following the steps specified in: https://billkarsh.github.io/SpikeGLX/Support/SpikeGLX_Datafile_Tools.zip
8185
dataVolts = dataInt * Vmax / Imax / gain
8286
"""
@@ -86,11 +90,13 @@ def get_channel_bit_volts(self, band='ap'):
8690
imax = IMAX[self.apmeta.probe_model]
8791
imroTbl_data = self.apmeta.imroTbl['data']
8892
imroTbl_idx = 3
93+
chn_ind = self.apmeta.get_recording_channels_indices(exclude_sync=True)
8994

9095
elif band == 'lf':
9196
imax = IMAX[self.lfmeta.probe_model]
9297
imroTbl_data = self.lfmeta.imroTbl['data']
9398
imroTbl_idx = 4
99+
chn_ind = self.lfmeta.get_recording_channels_indices(exclude_sync=True)
94100
else:
95101
raise ValueError(f'Unsupported band: {band} - Must be "ap" or "lf"')
96102

@@ -102,25 +108,26 @@ def get_channel_bit_volts(self, band='ap'):
102108
# 3A, 3B1, 3B2 (NP 1.0)
103109
chn_gains = [c[imroTbl_idx] for c in imroTbl_data]
104110

105-
return vmax / imax / np.array(chn_gains) * 1e6 # convert to uV as well
111+
chn_gains = np.array(chn_gains)[chn_ind]
112+
113+
return vmax / imax / chn_gains * 1e6 # convert to uV as well
106114

107115
def _read_bin(self, fname):
108116
nchan = self.apmeta.meta['nSavedChans']
109117
dtype = np.dtype((np.int16, nchan))
110118
return np.memmap(fname, dtype, 'r')
111119

112-
def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32), bit_volts=1):
120+
def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32)):
113121
"""
114122
:param spikes: spike times (in second) to extract waveforms
115-
:param channel: channel (name, not indices) to extract waveforms
123+
:param channel_ind: channel indices (of shankmap) to extract the waveforms from
116124
:param n_wf: number of spikes per unit to extract the waveforms
117125
:param wf_win: number of sample pre and post a spike
118-
:param bit_volts: scalar required to convert int16 values into microvolts (default of 1)
119-
:return: waveforms (sample x channel x spike)
126+
:return: waveforms (in uV) - shape: (sample x channel x spike)
120127
"""
128+
channel_bit_volts = self.get_channel_bit_volts('ap')[channel_ind]
121129

122130
data = self.ap_timeseries
123-
channel_idx = [np.where(self.apmeta.recording_channels == chn)[0][0] for chn in channel]
124131

125132
spikes = np.round(spikes * self.apmeta.meta['imSampRate']).astype(int) # convert to sample
126133
# ignore spikes at the beginning or end of raw data
@@ -130,10 +137,12 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32), b
130137
spikes = spikes[:n_wf]
131138
if len(spikes) > 0:
132139
# waveform at each spike: (sample x channel x spike)
133-
spike_wfs = np.dstack([data[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_idx] for spk in spikes])
134-
return spike_wfs * bit_volts
140+
spike_wfs = np.dstack([data[int(spk + wf_win[0]):int(spk + wf_win[-1]), channel_ind]
141+
* channel_bit_volts
142+
for spk in spikes])
143+
return spike_wfs
135144
else: # if no spike found, return NaN of size (sample x channel x 1)
136-
return np.full((len(range(*wf_win)), len(channel), 1), np.nan)
145+
return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan)
137146

138147

139148
class SpikeGLXMeta:
@@ -177,7 +186,8 @@ def __init__(self, meta_filepath):
177186
self.shankmap = self._parse_shankmap(self.meta['~snsShankMap']) if '~snsShankMap' in self.meta else None
178187
self.imroTbl = self._parse_imrotbl(self.meta['~imroTbl']) if '~imroTbl' in self.meta else None
179188

180-
self._recording_channels = None
189+
# Channels being recorded, exclude Sync channels - basically a 1-1 mapping to shankmap
190+
self.recording_channels = np.arange(len(self.imroTbl['data']))[self.get_recording_channels_indices(exclude_sync=True)]
181191

182192
@staticmethod
183193
def _parse_chanmap(raw):
@@ -208,6 +218,9 @@ def _parse_chanmap(raw):
208218
@staticmethod
209219
def _parse_shankmap(raw):
210220
"""
221+
The shankmap contains details on the shank info
222+
for each electrode sites of the sites being recorded only
223+
211224
https://github.com/billkarsh/SpikeGLX/blob/master/Markdown/UserManual.md#shank-map
212225
Parse shank map header structure. Converts:
213226
@@ -234,6 +247,10 @@ def _parse_shankmap(raw):
234247
@staticmethod
235248
def _parse_imrotbl(raw):
236249
"""
250+
The imro table contains info for all electrode sites (no sync)
251+
for a particular electrode configuration (all 384 sites)
252+
Note: not all of these 384 sites are necessarily recorded
253+
237254
https://github.com/billkarsh/SpikeGLX/blob/master/Markdown/UserManual.md#imro-per-channel-settings
238255
Parse imro tbl structure. Converts:
239256
@@ -257,8 +274,17 @@ def _parse_imrotbl(raw):
257274

258275
return res
259276

260-
@property
261-
def recording_channels(self):
277+
def get_recording_channels_indices(self, exclude_sync=False):
278+
"""
279+
The indices of recorded channels (in chanmap) with respect to the channels listed in the imro table
280+
"""
281+
recorded_chns_ind = [int(v[0]) for k, v in self.chanmap.items()
282+
if k != 'shape' and (not k.startswith('SY') if exclude_sync else True)]
283+
orig_chns_ind = self.get_original_chans()
284+
_, _, chns_ind = np.intersect1d(orig_chns_ind, recorded_chns_ind, return_indices=True)
285+
return chns_ind
286+
287+
def get_original_chans(self):
262288
"""
263289
Because you can selectively save channels, the
264290
ith channel in the file isn't necessarily the ith acquired channel.
@@ -267,22 +293,18 @@ def recording_channels(self):
267293
Credit to https://billkarsh.github.io/SpikeGLX/Support/SpikeGLX_Datafile_Tools.zip
268294
OriginalChans() function
269295
"""
270-
if self._recording_channels is None:
271-
if self.meta['snsSaveChanSubset'] == 'all':
272-
# output = int32, 0 to nSavedChans - 1
273-
self._recording_channels = np.arange(0, int(self.meta['nSavedChans']))
274-
else:
275-
# parse the snsSaveChanSubset string
276-
# split at commas
277-
chStrList = self.meta['snsSaveChanSubset'].split(sep=',')
278-
self._recording_channels = np.arange(0, 0) # creates an empty array of int32
279-
for sL in chStrList:
280-
currList = sL.split(sep=':')
281-
# each set of continuous channels specified by chan1:chan2 inclusive
282-
newChans = np.arange(int(currList[0]), int(currList[min(1, len(currList))]) + 1)
283-
284-
self._recording_channels = np.append(self._recording_channels, newChans)
285-
return self._recording_channels
296+
if self.meta['snsSaveChanSubset'] == 'all':
297+
# output = int32, 0 to nSavedChans - 1
298+
channels = np.arange(0, int(self.meta['nSavedChans']))
299+
else:
300+
# parse the channel list self.meta['snsSaveChanSubset']
301+
channels = np.arange(0) # empty array
302+
for channel_range in self.meta['snsSaveChanSubset'].split(','):
303+
# a block of contiguous channels specified as chan or chan1:chan2 inclusive
304+
ix = [int(r) for r in channel_range.split(':')]
305+
assert len(ix) in (1, 2), f"Invalid channel range spec '{channel_range}'"
306+
channels = np.append(np.r_[ix[0]:ix[-1] + 1])
307+
return channels
286308

287309

288310
# ============= HELPER FUNCTIONS =============

images/attached_ephys_element.svg

Lines changed: 1 addition & 1 deletion
Loading

0 commit comments

Comments
 (0)