@@ -440,8 +440,6 @@ def key_source(self):
440
440
return Clustering ()
441
441
442
442
def make (self , key ):
443
- units = {u ['unit' ]: u for u in (Clustering .Unit & key ).fetch (as_dict = True , order_by = 'unit' )}
444
-
445
443
root_dir = pathlib .Path (get_ephys_root_data_dir ())
446
444
ks_dir = root_dir / (ClusteringTask & key ).fetch1 ('clustering_output_dir' )
447
445
ks = kilosort .Kilosort (ks_dir )
@@ -454,6 +452,9 @@ def make(self, key):
454
452
455
453
is_qc = (Clustering & key ).fetch1 ('quality_control' )
456
454
455
+ # Get all units
456
+ units = {u ['unit' ]: u for u in (Clustering .Unit & key ).fetch (as_dict = True , order_by = 'unit' )}
457
+
457
458
unit_waveforms , unit_peak_waveforms = [], []
458
459
if is_qc :
459
460
unit_wfs = np .load (ks_dir / 'mean_waveforms.npy' ) # unit x channel x sample
@@ -473,8 +474,8 @@ def make(self, key):
473
474
loaded_oe = openephys .OpenEphys (sess_dir )
474
475
npx_recording = loaded_oe .probes [probe_sn ]
475
476
476
- for unit_no , unit_dict in units .items ():
477
- spks = ( Clustering . Unit & unit_dict ). fetch1 ( 'unit_spike_times' )
477
+ for unit_dict in units .values ():
478
+ spks = unit_dict [ 'spike_times' ]
478
479
wfs = npx_recording .extract_spike_waveforms (spks , ks .data ['channel_map' ]) # (sample x channel x spike)
479
480
wfs = wfs .transpose ((1 , 2 , 0 )) # (channel x spike x sample)
480
481
for chn , chn_wf in zip (ks .data ['channel_map' ], wfs ):
0 commit comments