@@ -1040,51 +1040,47 @@ def make(self, key):
1040
1040
).fetch1 ("clustering_method" , "clustering_output_dir" )
1041
1041
output_dir = find_full_path (get_ephys_root_data_dir (), output_dir )
1042
1042
1043
- # Get sorter method and create output directory.
1044
- sorter_name = (
1045
- "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1043
+ # Get channel and electrode-site mapping
1044
+ electrode_query = (
1045
+ (EphysRecording .Channel & key )
1046
+ .proj (..., "-channel_name" )
1046
1047
)
1047
- waveform_dir = output_dir / sorter_name / "waveform"
1048
- sorting_dir = output_dir / sorter_name / "spike_sorting"
1048
+ channel2electrode_map = electrode_query .fetch (as_dict = True )
1049
+ channel2electrode_map : dict [int , dict ] = {
1050
+ chn .pop ("channel_idx" ): chn for chn in channel2electrode_map
1051
+ }
1049
1052
1050
- if waveform_dir .exists (): # read from spikeinterface outputs
1051
- we : si .WaveformExtractor = si .load_waveforms (
1052
- waveform_dir , with_recording = False
1053
- )
1053
+ # Get sorter method and create output directory.
1054
+ sorter_name = clustering_method .replace ("." , "_" )
1055
+ si_waveform_dir = output_dir / sorter_name / "waveform"
1056
+ si_sorting_dir = output_dir / sorter_name / "spike_sorting"
1057
+
1058
+ if si_waveform_dir .exists ():
1059
+
1060
+ # Read from spikeinterface outputs
1061
+ we : si .WaveformExtractor = si .load_waveforms (si_waveform_dir , with_recording = False )
1054
1062
si_sorting : si .sorters .BaseSorter = si .load_extractor (
1055
- sorting_dir / "si_sorting.pkl"
1063
+ si_sorting_dir / "si_sorting.pkl"
1056
1064
)
1057
1065
1058
- unit_peak_channel_map : dict [int , int ] = si .get_template_extremum_channel (
1059
- we , outputs = "index "
1060
- ) # {unit: peak_channel_index }
1066
+ unit_peak_channel : dict [int , int ] = si .get_template_extremum_channel (
1067
+ we , outputs = "id "
1068
+ ) # {unit: peak_channel_id }
1061
1069
1062
1070
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
1063
1071
# {unit: spike_count}
1064
1072
1065
- spikes = si_sorting .to_spike_vector (
1066
- extremum_channel_inds = unit_peak_channel_map
1067
- )
1068
-
1069
- # Get electrode & channel info
1070
- electrode_config_key = (
1071
- EphysRecording * probe .ElectrodeConfig & key
1072
- ).fetch1 ("KEY" )
1073
-
1074
- electrode_query = (
1075
- probe .ProbeType .Electrode * probe .ElectrodeConfig .Electrode
1076
- & electrode_config_key
1077
- ) * (dj .U ("electrode" , "channel_idx" ) & EphysRecording .Channel )
1073
+ spikes = si_sorting .to_spike_vector ()
1078
1074
1079
- channel_info = electrode_query . fetch ( as_dict = True , order_by = "channel_idx" )
1080
- channel_info : dict [ int , dict ] = {
1081
- ch . pop ( "channel_idx" ): ch for ch in channel_info
1075
+ # reorder channel2electrode_map according to recording channel ids
1076
+ channel2electrode_map = {
1077
+ chn_id : channel2electrode_map [ int ( chn_id )] for chn_id in we . channel_ids
1082
1078
}
1083
1079
1084
1080
# Get unit id to quality label mapping
1085
1081
try :
1086
1082
cluster_quality_label_map = pd .read_csv (
1087
- sorting_dir / "sorter_output" / "cluster_KSLabel.tsv" ,
1083
+ si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv" ,
1088
1084
delimiter = "\t " ,
1089
1085
)
1090
1086
except FileNotFoundError :
@@ -1099,15 +1095,15 @@ def make(self, key):
1099
1095
# Get electrode where peak unit activity is recorded
1100
1096
peak_electrode_ind = np .array (
1101
1097
[
1102
- channel_info [ unit_peak_channel_map [unit_id ]]["electrode" ]
1098
+ channel2electrode_map [ unit_peak_channel [unit_id ]]["electrode" ]
1103
1099
for unit_id in si_sorting .unit_ids
1104
1100
]
1105
1101
)
1106
1102
1107
1103
# Get channel depth
1108
1104
channel_depth_ind = np .array (
1109
1105
[
1110
- channel_info [ unit_peak_channel_map [unit_id ]]["y_coord" ]
1106
+ channel2electrode_map [ unit_peak_channel [unit_id ]]["y_coord" ]
1111
1107
for unit_id in si_sorting .unit_ids
1112
1108
]
1113
1109
)
@@ -1132,7 +1128,7 @@ def make(self, key):
1132
1128
units .append (
1133
1129
{
1134
1130
** key ,
1135
- ** channel_info [ unit_peak_channel_map [unit_id ]],
1131
+ ** channel2electrode_map [ unit_peak_channel [unit_id ]],
1136
1132
"unit" : unit_id ,
1137
1133
"cluster_quality_label" : cluster_quality_label_map .get (
1138
1134
unit_id , "n.a."
@@ -1143,10 +1139,10 @@ def make(self, key):
1143
1139
"spike_count" : spike_count_dict [unit_id ],
1144
1140
"spike_sites" : new_spikes ["electrode" ][
1145
1141
new_spikes ["unit_index" ] == unit_id
1146
- ],
1142
+ ],
1147
1143
"spike_depths" : new_spikes ["depth" ][
1148
1144
new_spikes ["unit_index" ] == unit_id
1149
- ],
1145
+ ],
1150
1146
}
1151
1147
)
1152
1148
@@ -1184,20 +1180,10 @@ def make(self, key):
1184
1180
spike_times = kilosort_dataset .data [spike_time_key ]
1185
1181
kilosort_dataset .extract_spike_depths ()
1186
1182
1187
- # Get channel and electrode-site mapping
1188
- channel_info = (
1189
- (EphysRecording .Channel & key )
1190
- .proj (..., "-channel_name" )
1191
- .fetch (as_dict = True , order_by = "channel_idx" )
1192
- )
1193
- channel_info : dict [int , dict ] = {
1194
- ch .pop ("channel_idx" ): ch for ch in channel_info
1195
- } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1196
-
1197
1183
# -- Spike-sites and Spike-depths --
1198
1184
spike_sites = np .array (
1199
1185
[
1200
- channel_info [s ]["electrode" ]
1186
+ channel2electrode_map [s ]["electrode" ]
1201
1187
for s in kilosort_dataset .data ["spike_sites" ]
1202
1188
]
1203
1189
)
@@ -1219,7 +1205,7 @@ def make(self, key):
1219
1205
** key ,
1220
1206
"unit" : unit ,
1221
1207
"cluster_quality_label" : unit_lbl ,
1222
- ** channel_info [unit_channel ],
1208
+ ** channel2electrode_map [unit_channel ],
1223
1209
"spike_times" : unit_spike_times ,
1224
1210
"spike_count" : spike_count ,
1225
1211
"spike_sites" : spike_sites [
@@ -1292,33 +1278,31 @@ def make(self, key):
1292
1278
ClusteringTask * ClusteringParamSet & key
1293
1279
).fetch1 ("clustering_method" , "clustering_output_dir" )
1294
1280
output_dir = find_full_path (get_ephys_root_data_dir (), output_dir )
1295
- sorter_name = (
1296
- "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1297
- )
1281
+ sorter_name = clustering_method .replace ("." , "_" )
1298
1282
1299
1283
# Get channel and electrode-site mapping
1300
- channel_info = (
1284
+ electrode_query = (
1301
1285
(EphysRecording .Channel & key )
1302
1286
.proj (..., "-channel_name" )
1303
- .fetch (as_dict = True , order_by = "channel_idx" )
1304
1287
)
1305
- channel_info : dict [int , dict ] = {
1306
- ch .pop ("channel_idx" ): ch for ch in channel_info
1307
- } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1288
+ channel2electrode_map = electrode_query .fetch (as_dict = True )
1289
+ channel2electrode_map : dict [int , dict ] = {
1290
+ chn .pop ("channel_idx" ): chn for chn in channel2electrode_map
1291
+ }
1308
1292
1309
- if (
1310
- output_dir / sorter_name / "waveform"
1311
- ).exists (): # read from spikeinterface outputs
1293
+ si_waveform_dir = output_dir / sorter_name / "waveform"
1294
+ if si_waveform_dir .exists (): # read from spikeinterface outputs
1295
+ we : si .WaveformExtractor = si .load_waveforms (si_waveform_dir , with_recording = False )
1296
+ unit_id_to_peak_channel_map : dict [
1297
+ int , np .ndarray
1298
+ ] = si .ChannelSparsity .from_best_channels (
1299
+ we , 1 , peak_sign = "neg"
1300
+ ).unit_id_to_channel_indices # {unit: peak_channel_index}
1312
1301
1313
- waveform_dir = output_dir / sorter_name / "waveform"
1314
- we : si .WaveformExtractor = si .load_waveforms (
1315
- waveform_dir , with_recording = False
1316
- )
1317
- unit_id_to_peak_channel_map : dict [int , np .ndarray ] = (
1318
- si .ChannelSparsity .from_best_channels (
1319
- we , 1 , peak_sign = "neg"
1320
- ).unit_id_to_channel_indices
1321
- ) # {unit: peak_channel_index}
1302
+ # reorder channel2electrode_map according to recording channel ids
1303
+ channel2electrode_map = {
1304
+ chn_id : channel2electrode_map [int (chn_id )] for chn_id in we .channel_ids
1305
+ }
1322
1306
1323
1307
# Get mean waveform for each unit from all channels
1324
1308
mean_waveforms = we .get_all_templates (
@@ -1329,30 +1313,32 @@ def make(self, key):
1329
1313
unit_electrode_waveforms = []
1330
1314
1331
1315
for unit in (CuratedClustering .Unit & key ).fetch ("KEY" , order_by = "unit" ):
1316
+ unit_waveforms = we .get_template (
1317
+ unit_id = unit ["unit" ], mode = "average" , force_dense = True
1318
+ ) # (sample x channel)
1319
+ peak_chn_idx = list (we .channel_ids ).index (
1320
+ unit_id_to_peak_channel_map [unit ["unit" ]][0 ]
1321
+ )
1332
1322
unit_peak_waveform .append (
1333
1323
{
1334
1324
** unit ,
1335
- "peak_electrode_waveform" : we .get_template (
1336
- unit_id = unit ["unit" ], mode = "average" , force_dense = True
1337
- )[:, unit_id_to_peak_channel_map [unit ["unit" ]][0 ]],
1325
+ "peak_electrode_waveform" : unit_waveforms [:, peak_chn_idx ],
1338
1326
}
1339
1327
)
1340
-
1341
1328
unit_electrode_waveforms .extend (
1342
1329
[
1343
1330
{
1344
1331
** unit ,
1345
- ** channel_info [c ],
1346
- "waveform_mean" : mean_waveforms [unit ["unit" ] - 1 , :, c ],
1332
+ ** channel2electrode_map [c ],
1333
+ "waveform_mean" : mean_waveforms [unit ["unit" ] - 1 , :, c_idx ],
1347
1334
}
1348
- for c in channel_info
1335
+ for c_idx , c in enumerate ( channel2electrode_map )
1349
1336
]
1350
1337
)
1351
1338
1352
1339
self .insert1 (key )
1353
1340
self .PeakWaveform .insert (unit_peak_waveform )
1354
1341
self .Waveform .insert (unit_electrode_waveforms )
1355
-
1356
1342
else :
1357
1343
kilosort_dataset = kilosort .Kilosort (output_dir )
1358
1344
@@ -1390,12 +1376,12 @@ def yield_unit_waveforms():
1390
1376
unit_electrode_waveforms .append (
1391
1377
{
1392
1378
** units [unit_no ],
1393
- ** channel_info [channel ],
1379
+ ** channel2electrode_map [channel ],
1394
1380
"waveform_mean" : channel_waveform ,
1395
1381
}
1396
1382
)
1397
1383
if (
1398
- channel_info [channel ]["electrode" ]
1384
+ channel2electrode_map [channel ]["electrode" ]
1399
1385
== units [unit_no ]["electrode" ]
1400
1386
):
1401
1387
unit_peak_waveform = {
@@ -1405,7 +1391,6 @@ def yield_unit_waveforms():
1405
1391
yield unit_peak_waveform , unit_electrode_waveforms
1406
1392
1407
1393
# Spike interface mean and peak waveform extraction from we object
1408
-
1409
1394
elif len (waveforms_folder ) > 0 & (waveforms_folder [0 ]).exists ():
1410
1395
we_kilosort = si .load_waveforms (waveforms_folder [0 ].parent )
1411
1396
unit_templates = we_kilosort .get_all_templates ()
@@ -1432,12 +1417,12 @@ def yield_unit_waveforms():
1432
1417
unit_electrode_waveforms .append (
1433
1418
{
1434
1419
** units [unit_no ],
1435
- ** channel_info [channel ],
1420
+ ** channel2electrode_map [channel ],
1436
1421
"waveform_mean" : channel_waveform ,
1437
1422
}
1438
1423
)
1439
1424
if (
1440
- channel_info [channel ]["electrode" ]
1425
+ channel2electrode_map [channel ]["electrode" ]
1441
1426
== units [unit_no ]["electrode" ]
1442
1427
):
1443
1428
unit_peak_waveform = {
@@ -1506,13 +1491,13 @@ def yield_unit_waveforms():
1506
1491
unit_electrode_waveforms .append (
1507
1492
{
1508
1493
** unit_dict ,
1509
- ** channel_info [channel ],
1494
+ ** channel2electrode_map [channel ],
1510
1495
"waveform_mean" : channel_waveform .mean (axis = 0 ),
1511
1496
"waveforms" : channel_waveform ,
1512
1497
}
1513
1498
)
1514
1499
if (
1515
- channel_info [channel ]["electrode" ]
1500
+ channel2electrode_map [channel ]["electrode" ]
1516
1501
== unit_dict ["electrode" ]
1517
1502
):
1518
1503
unit_peak_waveform = {
@@ -1630,12 +1615,15 @@ def make(self, key):
1630
1615
ClusteringTask * ClusteringParamSet & key
1631
1616
).fetch1 ("clustering_method" , "clustering_output_dir" )
1632
1617
output_dir = find_full_path (get_ephys_root_data_dir (), output_dir )
1633
- sorter_name = (
1634
- "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1635
- )
1636
- metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv"
1637
- if not metric_fp .exists ():
1638
- raise FileNotFoundError (f"QC metrics file not found: { metric_fp } " )
1618
+ sorter_name = clustering_method .replace ("." , "_" )
1619
+
1620
+ # find metric_fp
1621
+ for metric_fp in [output_dir / "metrics.csv" , output_dir / sorter_name / "metrics" / "metrics.csv" ]:
1622
+ if metric_fp .exists ():
1623
+ break
1624
+ else :
1625
+ raise FileNotFoundError (f"QC metrics file not found in: { output_dir } " )
1626
+
1639
1627
metrics_df = pd .read_csv (metric_fp )
1640
1628
1641
1629
# Conform the dataframe to match the table definition
0 commit comments