@@ -1040,18 +1040,16 @@ def make(self, key):
1040
1040
si_waveform_dir = output_dir / sorter_name / "waveform"
1041
1041
si_sorting_dir = output_dir / sorter_name / "spike_sorting"
1042
1042
1043
- if si_waveform_dir .exists ():
1044
-
1045
- # Read from spikeinterface outputs
1043
+ if si_waveform_dir .exists (): # Read from spikeinterface outputs
1046
1044
we : si .WaveformExtractor = si .load_waveforms (
1047
1045
si_waveform_dir , with_recording = False
1048
1046
)
1049
1047
si_sorting : si .sorters .BaseSorter = si .load_extractor (
1050
- si_sorting_dir / "si_sorting.pkl"
1048
+ si_sorting_dir / "si_sorting.pkl" , base_folder = output_dir
1051
1049
)
1052
1050
1053
1051
unit_peak_channel : dict [int , int ] = si .get_template_extremum_channel (
1054
- we , outputs = "id "
1052
+ we , outputs = "index "
1055
1053
) # {unit: peak_channel_id}
1056
1054
1057
1055
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
@@ -1061,7 +1059,8 @@ def make(self, key):
1061
1059
1062
1060
# reorder channel2electrode_map according to recording channel ids
1063
1061
channel2electrode_map = {
1064
- chn_id : channel2electrode_map [int (chn_id )] for chn_id in we .channel_ids
1062
+ chn_idx : channel2electrode_map [chn_idx ]
1063
+ for chn_idx in we .channel_ids_to_indices (we .channel_ids )
1065
1064
}
1066
1065
1067
1066
# Get unit id to quality label mapping
@@ -1090,7 +1089,7 @@ def make(self, key):
1090
1089
# Get channel depth
1091
1090
channel_depth_ind = np .array (
1092
1091
[
1093
- channel2electrode_map [unit_peak_channel [unit_id ]]["y_coord" ]
1092
+ we . get_probe (). contact_positions [unit_peak_channel [unit_id ]][1 ]
1094
1093
for unit_id in si_sorting .unit_ids
1095
1094
]
1096
1095
)
@@ -1132,7 +1131,6 @@ def make(self, key):
1132
1131
],
1133
1132
}
1134
1133
)
1135
-
1136
1134
else : # read from kilosort outputs
1137
1135
kilosort_dataset = kilosort .Kilosort (output_dir )
1138
1136
acq_software , sample_rate = (EphysRecording & key ).fetch1 (
@@ -1286,46 +1284,38 @@ def make(self, key):
1286
1284
) # {unit: peak_channel_index}
1287
1285
1288
1286
# reorder channel2electrode_map according to recording channel ids
1287
+ channel_indices = we .channel_ids_to_indices (we .channel_ids ).tolist ()
1289
1288
channel2electrode_map = {
1290
- chn_id : channel2electrode_map [int ( chn_id ) ] for chn_id in we . channel_ids
1289
+ chn_idx : channel2electrode_map [chn_idx ] for chn_idx in channel_indices
1291
1290
}
1292
1291
1293
- # Get mean waveform for each unit from all channels
1294
- mean_waveforms = we .get_all_templates (
1295
- mode = "average"
1296
- ) # (unit x sample x channel)
1297
-
1298
- unit_peak_waveform = []
1299
- unit_electrode_waveforms = []
1300
-
1301
- for unit in (CuratedClustering .Unit & key ).fetch ("KEY" , order_by = "unit" ):
1302
- unit_waveforms = we .get_template (
1303
- unit_id = unit ["unit" ], mode = "average" , force_dense = True
1304
- ) # (sample x channel)
1305
- peak_chn_idx = list (we .channel_ids ).index (
1306
- unit_id_to_peak_channel_map [unit ["unit" ]][0 ]
1307
- )
1308
- unit_peak_waveform .append (
1309
- {
1292
+ def yield_unit_waveforms ():
1293
+ for unit in (CuratedClustering .Unit & key ).fetch (
1294
+ "KEY" , order_by = "unit"
1295
+ ):
1296
+ # Get mean waveform for this unit from all channels - (sample x channel)
1297
+ unit_waveforms = we .get_template (
1298
+ unit_id = unit ["unit" ], mode = "average" , force_dense = True
1299
+ )
1300
+ peak_chn_idx = channel_indices .index (
1301
+ unit_id_to_peak_channel_map [unit ["unit" ]][0 ]
1302
+ )
1303
+ unit_peak_waveform = {
1310
1304
** unit ,
1311
1305
"peak_electrode_waveform" : unit_waveforms [:, peak_chn_idx ],
1312
1306
}
1313
- )
1314
- unit_electrode_waveforms .extend (
1315
- [
1307
+
1308
+ unit_electrode_waveforms = [
1316
1309
{
1317
1310
** unit ,
1318
- ** channel2electrode_map [c ],
1319
- "waveform_mean" : mean_waveforms [ unit [ "unit" ] - 1 , :, c_idx ],
1311
+ ** channel2electrode_map [chn_idx ],
1312
+ "waveform_mean" : unit_waveforms [ :, chn_idx ],
1320
1313
}
1321
- for c_idx , c in enumerate ( channel2electrode_map )
1314
+ for chn_idx in channel_indices
1322
1315
]
1323
- )
1324
1316
1325
- self .insert1 (key )
1326
- self .PeakWaveform .insert (unit_peak_waveform )
1327
- self .Waveform .insert (unit_electrode_waveforms )
1328
- else :
1317
+ yield unit_peak_waveform , unit_electrode_waveforms
1318
+ else : # read from kilosort outputs
1329
1319
kilosort_dataset = kilosort .Kilosort (output_dir )
1330
1320
1331
1321
acq_software , probe_serial_number = (
@@ -1340,10 +1330,6 @@ def make(self, key):
1340
1330
)
1341
1331
}
1342
1332
1343
- waveforms_folder = [
1344
- f for f in output_dir .parent .rglob (r"*/waveforms*" ) if f .is_dir ()
1345
- ]
1346
-
1347
1333
if (output_dir / "mean_waveforms.npy" ).exists ():
1348
1334
unit_waveforms = np .load (
1349
1335
output_dir / "mean_waveforms.npy"
@@ -1376,75 +1362,6 @@ def yield_unit_waveforms():
1376
1362
}
1377
1363
yield unit_peak_waveform , unit_electrode_waveforms
1378
1364
1379
- # Spike interface mean and peak waveform extraction from we object
1380
-
1381
- elif len (waveforms_folder ) > 0 & (waveforms_folder [0 ]).exists ():
1382
- we_kilosort = si .load_waveforms (waveforms_folder [0 ].parent )
1383
- unit_templates = we_kilosort .get_all_templates ()
1384
- unit_waveforms = np .reshape (
1385
- unit_templates ,
1386
- (
1387
- unit_templates .shape [1 ],
1388
- unit_templates .shape [3 ],
1389
- unit_templates .shape [2 ],
1390
- ),
1391
- )
1392
-
1393
- # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms)
1394
- def yield_unit_waveforms ():
1395
- for unit_no , unit_waveform in zip (
1396
- kilosort_dataset .data ["cluster_ids" ], unit_waveforms
1397
- ):
1398
- unit_peak_waveform = {}
1399
- unit_electrode_waveforms = []
1400
- if unit_no in units :
1401
- for channel , channel_waveform in zip (
1402
- kilosort_dataset .data ["channel_map" ], unit_waveform
1403
- ):
1404
- unit_electrode_waveforms .append (
1405
- {
1406
- ** units [unit_no ],
1407
- ** channel2electrode_map [channel ],
1408
- "waveform_mean" : channel_waveform ,
1409
- }
1410
- )
1411
- if (
1412
- channel2electrode_map [channel ]["electrode" ]
1413
- == units [unit_no ]["electrode" ]
1414
- ):
1415
- unit_peak_waveform = {
1416
- ** units [unit_no ],
1417
- "peak_electrode_waveform" : channel_waveform ,
1418
- }
1419
- yield unit_peak_waveform , unit_electrode_waveforms
1420
-
1421
- # Approach not using spike interface templates (ie. taking mean of each unit waveform)
1422
- # def yield_unit_waveforms():
1423
- # for unit_id in we_kilosort.unit_ids:
1424
- # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0)
1425
- # unit_peak_waveform = {}
1426
- # unit_electrode_waveforms = []
1427
- # if unit_id in units:
1428
- # for channel, channel_waveform in zip(
1429
- # kilosort_dataset.data["channel_map"], unit_waveform
1430
- # ):
1431
- # unit_electrode_waveforms.append(
1432
- # {
1433
- # **units[unit_id],
1434
- # **channel2electrodes[channel],
1435
- # "waveform_mean": channel_waveform,
1436
- # }
1437
- # )
1438
- # if (
1439
- # channel2electrodes[channel]["electrode"]
1440
- # == units[unit_id]["electrode"]
1441
- # ):
1442
- # unit_peak_waveform = {
1443
- # **units[unit_id],
1444
- # "peak_electrode_waveform": channel_waveform,
1445
- # }
1446
- # yield unit_peak_waveform, unit_electrode_waveforms
1447
-
1448
1365
else :
1449
1366
if acq_software == "SpikeGLX" :
1450
1367
spikeglx_meta_filepath = get_spikeglx_meta_filepath (key )
0 commit comments