Skip to content

Commit dba0a48

Browse files
author
Thinh Nguyen
committed
tweaks to LFP and waveform ingestion - do in small batches to mitigate memory issue
1 parent dcf8906 commit dba0a48

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

element_array_ephys/ephys.py

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def make(self, key):
267267
acq_software, probe_sn = (EphysRecording
268268
* ProbeInsertion & key).fetch1('acq_software', 'probe')
269269

270+
electrode_keys, lfp = [], []
271+
270272
if acq_software == 'SpikeGLX':
271273
spikeglx_meta_fp = (EphysRecording.EphysFile
272274
& key & 'file_path LIKE "%.ap.meta"').fetch1('file_path')
@@ -289,19 +291,13 @@ def make(self, key):
289291
q_electrodes = (probe.ProbeType.Electrode
290292
* probe.ElectrodeConfig.Electrode
291293
* EphysRecording & key)
292-
electrodes = []
293294
for recorded_site in lfp_channel_ind:
294295
shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site]
295-
electrodes.append((q_electrodes
296+
electrode_keys.append((q_electrodes
296297
& {'shank': shank,
297298
'shank_col': shank_col,
298299
'shank_row': shank_row}).fetch1('KEY'))
299300

300-
channel_lfp = list(zip(electrodes, lfp))
301-
self.Electrode().insert((
302-
{**key, **electrode, 'lfp': d}
303-
for electrode, d in channel_lfp), ignore_extra_fields=True)
304-
305301
elif acq_software == 'OpenEphys':
306302
sess_dir = pathlib.Path(get_session_directory(key))
307303
loaded_oe = openephys.OpenEphys(sess_dir)
@@ -322,19 +318,17 @@ def make(self, key):
322318
q_electrodes = (probe.ProbeType.Electrode
323319
* probe.ElectrodeConfig.Electrode
324320
* EphysRecording & key)
325-
electrodes = []
326321
for channel_idx in np.array(oe_probe.lfp_meta['channels_ids'])[lfp_channel_ind]:
327-
electrodes.append((q_electrodes & {'electrode': channel_idx}).fetch1('KEY'))
328-
329-
channel_lfp = list(zip(electrodes, lfp))
330-
self.Electrode().insert((
331-
{**key, **electrode, 'lfp': d}
332-
for electrode, d in channel_lfp), ignore_extra_fields=True)
322+
electrode_keys.append((q_electrodes & {'electrode': channel_idx}).fetch1('KEY'))
333323

334324
else:
335325
raise NotImplementedError(f'LFP extraction from acquisition software'
336326
f' of type {acq_software} is not yet implemented')
337327

328+
# single insert in loop to mitigate potential memory issue
329+
for electrode_key, lfp_trace in zip(electrode_keys, lfp):
330+
self.Electrode.insert1({**key, **electrode_key, 'lfp': lfp_trace})
331+
338332

339333
# ------------ Clustering --------------
340334

@@ -585,20 +579,25 @@ def make(self, key):
585579
units = {u['unit']: u for u in (CuratedClustering.Unit & key).fetch(
586580
as_dict=True, order_by='unit')}
587581

588-
unit_waveforms, unit_peak_waveforms = [], []
589582
if is_qc:
590583
unit_waveforms = np.load(ks_dir / 'mean_waveforms.npy') # unit x channel x sample
591-
for unit_no, unit_waveform in zip(ks.data['cluster_ids'], unit_waveforms):
592-
if unit_no in units:
593-
for channel, channel_waveform in zip(ks.data['channel_map'],
594-
unit_waveform):
595-
unit_waveforms.append({
596-
**units[unit_no], **channel2electrodes[channel],
597-
'waveform_mean': channel_waveform})
598-
if channel2electrodes[channel]['electrode'] == units[unit_no]['electrode']:
599-
unit_peak_waveforms.append({
600-
**units[unit_no],
601-
'peak_chn_waveform_mean': channel_waveform})
584+
585+
def yield_unit_waveforms():
586+
for unit_no, unit_waveform in zip(ks.data['cluster_ids'], unit_waveforms):
587+
unit_peak_waveform = {}
588+
unit_electrode_waveforms = []
589+
if unit_no in units:
590+
for channel, channel_waveform in zip(ks.data['channel_map'],
591+
unit_waveform):
592+
unit_electrode_waveforms.append({
593+
**units[unit_no], **channel2electrodes[channel],
594+
'waveform_mean': channel_waveform})
595+
if channel2electrodes[channel]['electrode'] == units[unit_no]['electrode']:
596+
unit_peak_waveform = {
597+
**units[unit_no],
598+
'peak_chn_waveform_mean': channel_waveform}
599+
yield unit_peak_waveform, unit_electrode_waveforms
600+
602601
else:
603602
if acq_software == 'SpikeGLX':
604603
ephys_root_dir = get_ephys_root_data_dir()
@@ -610,22 +609,31 @@ def make(self, key):
610609
loaded_oe = openephys.OpenEphys(sess_dir)
611610
npx_recording = loaded_oe.probes[probe_sn]
612611

613-
for unit_dict in units.values():
614-
spikes = unit_dict['spike_times']
615-
waveforms = npx_recording.extract_spike_waveforms(
616-
spikes, ks.data['channel_map']) # (sample x channel x spike)
617-
waveforms = waveforms.transpose((1, 2, 0)) # (channel x spike x sample)
618-
for channel, channel_waveform in zip(ks.data['channel_map'], waveforms):
619-
unit_waveforms.append({**unit_dict, **channel2electrodes[channel],
620-
'waveform_mean': channel_waveform.mean(axis=0),
621-
'waveforms': channel_waveform})
622-
if channel2electrodes[channel]['electrode'] == unit_dict['electrode']:
623-
unit_peak_waveforms.append({
624-
**unit_dict,
625-
'peak_chn_waveform_mean': channel_waveform.mean(axis=0)})
626-
627-
self.insert(unit_peak_waveforms, ignore_extra_fields=True)
628-
self.Electrode.insert(unit_waveforms, ignore_extra_fields=True)
612+
def yield_unit_waveforms():
613+
for unit_dict in units.values():
614+
unit_peak_waveform = {}
615+
unit_electrode_waveforms = []
616+
617+
spikes = unit_dict['spike_times']
618+
waveforms = npx_recording.extract_spike_waveforms(
619+
spikes, ks.data['channel_map']) # (sample x channel x spike)
620+
waveforms = waveforms.transpose((1, 2, 0)) # (channel x spike x sample)
621+
for channel, channel_waveform in zip(ks.data['channel_map'], waveforms):
622+
unit_electrode_waveforms.append({
623+
**unit_dict, **channel2electrodes[channel],
624+
'waveform_mean': channel_waveform.mean(axis=0),
625+
'waveforms': channel_waveform})
626+
if channel2electrodes[channel]['electrode'] == unit_dict['electrode']:
627+
unit_peak_waveform = {
628+
**unit_dict,
629+
'peak_chn_waveform_mean': channel_waveform.mean(axis=0)}
630+
631+
yield unit_peak_waveform, unit_electrode_waveforms
632+
633+
# insert waveform on a per-unit basis to mitigate potential memory issue
634+
for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
635+
self.insert1(unit_peak_waveform, ignore_extra_fields=True)
636+
self.Electrode.insert(unit_electrode_waveforms, ignore_extra_fields=True)
629637

630638

631639
# ----------- Quality Control ----------

0 commit comments

Comments
 (0)