@@ -558,15 +558,18 @@ def make(self, key):
558
558
559
559
unit_waveforms , unit_peak_waveforms = [], []
560
560
if is_qc :
561
- unit_wfs = np .load (ks_dir / 'mean_waveforms.npy' ) # unit x channel x sample
562
- for unit_no , unit_wf in zip (ks .data ['cluster_ids' ], unit_wfs ):
561
+ unit_waveforms = np .load (ks_dir / 'mean_waveforms.npy' ) # unit x channel x sample
562
+ for unit_no , unit_waveform in zip (ks .data ['cluster_ids' ], unit_waveforms ):
563
563
if unit_no in units :
564
- for chn , chn_wf in zip (ks .data ['channel_map' ], unit_wf ):
565
- unit_waveforms .append ({** units [unit_no ], ** channel2electrodes [chn ],
566
- 'waveform_mean' : chn_wf })
567
- if channel2electrodes [chn ]['electrode' ] == units [unit_no ]['electrode' ]:
568
- unit_peak_waveforms .append ({** units [unit_no ],
569
- 'peak_chn_waveform_mean' : chn_wf })
564
+ for channel , channel_waveform in zip (ks .data ['channel_map' ],
565
+ unit_waveform ):
566
+ unit_waveforms .append ({
567
+ ** units [unit_no ], ** channel2electrodes [channel ],
568
+ 'waveform_mean' : channel_waveform })
569
+ if channel2electrodes [channel ]['electrode' ] == units [unit_no ]['electrode' ]:
570
+ unit_peak_waveforms .append ({
571
+ ** units [unit_no ],
572
+ 'peak_chn_waveform_mean' : channel_waveform })
570
573
else :
571
574
if acq_software == 'SpikeGLX' :
572
575
npx_meta_fp = root_dir / (EphysRecording .EphysFile & key
@@ -578,16 +581,18 @@ def make(self, key):
578
581
npx_recording = loaded_oe .probes [probe_sn ]
579
582
580
583
for unit_dict in units .values ():
581
- spks = unit_dict ['spike_times' ]
582
- wfs = npx_recording .extract_spike_waveforms (spks , ks .data ['channel_map' ]) # (sample x channel x spike)
583
- wfs = wfs .transpose ((1 , 2 , 0 )) # (channel x spike x sample)
584
- for chn , chn_wf in zip (ks .data ['channel_map' ], wfs ):
585
- unit_waveforms .append ({** unit_dict , ** channel2electrodes [chn ],
586
- 'waveform_mean' : chn_wf .mean (axis = 0 ),
587
- 'waveforms' : chn_wf })
588
- if channel2electrodes [chn ]['electrode' ] == unit_dict ['electrode' ]:
584
+ spikes = unit_dict ['spike_times' ]
585
+ waveforms = npx_recording .extract_spike_waveforms (
586
+ spikes , ks .data ['channel_map' ]) # (sample x channel x spike)
587
+ waveforms = waveforms .transpose ((1 , 2 , 0 )) # (channel x spike x sample)
588
+ for channel , channel_waveform in zip (ks .data ['channel_map' ], waveforms ):
589
+ unit_waveforms .append ({** unit_dict , ** channel2electrodes [channel ],
590
+ 'waveform_mean' : channel_waveform .mean (axis = 0 ),
591
+ 'waveforms' : channel_waveform })
592
+ if channel2electrodes [channel ]['electrode' ] == unit_dict ['electrode' ]:
589
593
unit_peak_waveforms .append ({
590
- ** unit_dict , 'peak_chn_waveform_mean' : chn_wf .mean (axis = 0 )})
594
+ ** unit_dict ,
595
+ 'peak_chn_waveform_mean' : channel_waveform .mean (axis = 0 )})
591
596
592
597
self .insert (unit_peak_waveforms , ignore_extra_fields = True )
593
598
self .Electrode .insert (unit_waveforms , ignore_extra_fields = True )
0 commit comments