@@ -1037,98 +1037,70 @@ def make(self, key):
1037
1037
1038
1038
# Get sorter method and create output directory.
1039
1039
sorter_name = clustering_method .replace ("." , "_" )
1040
- si_waveform_dir = output_dir / sorter_name / "waveform"
1041
- si_sorting_dir = output_dir / sorter_name / "spike_sorting"
1040
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1042
1041
1043
- if si_waveform_dir .exists (): # Read from spikeinterface outputs
1044
- we : si .WaveformExtractor = si .load_waveforms (
1045
- si_waveform_dir , with_recording = False
1046
- )
1047
- si_sorting : si .sorters .BaseSorter = si .load_extractor (
1048
- si_sorting_dir / "si_sorting.pkl" , base_folder = output_dir
1049
- )
1042
+ if si_sorting_analyzer_dir .exists (): # Read from spikeinterface outputs
1043
+ sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1044
+ si_sorting = sorting_analyzer .sorting
1050
1045
1051
- unit_peak_channel : dict [int , int ] = si .get_template_extremum_channel (
1052
- we , outputs = "index"
1053
- ) # {unit: peak_channel_id}
1046
+ # Find representative channel for each unit
1047
+ unit_peak_channel : dict [int , np .ndarray ] = (
1048
+ si .ChannelSparsity .from_best_channels (
1049
+ sorting_analyzer , 1 , peak_sign = "neg"
1050
+ ).unit_id_to_channel_indices
1051
+ )
1052
+ unit_peak_channel : dict [int , int ] = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1054
1053
1055
1054
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
1056
1055
# {unit: spike_count}
1057
1056
1058
- spikes = si_sorting .to_spike_vector ()
1059
-
1060
1057
# reorder channel2electrode_map according to recording channel ids
1061
1058
channel2electrode_map = {
1062
1059
chn_idx : channel2electrode_map [chn_idx ]
1063
- for chn_idx in we .channel_ids_to_indices (we .channel_ids )
1060
+ for chn_idx in sorting_analyzer .channel_ids_to_indices (
1061
+ sorting_analyzer .channel_ids
1062
+ )
1064
1063
}
1065
1064
1066
1065
# Get unit id to quality label mapping
1067
- try :
1068
- cluster_quality_label_map = pd .read_csv (
1069
- si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv" ,
1070
- delimiter = "\t " ,
1066
+ cluster_quality_label_map = {
1067
+ int (unit_id ): (
1068
+ si_sorting .get_unit_property (unit_id , "KSLabel" )
1069
+ if "KSLabel" in si_sorting .get_property_keys ()
1070
+ else "n.a."
1071
1071
)
1072
- except FileNotFoundError :
1073
- cluster_quality_label_map = {}
1074
- else :
1075
- cluster_quality_label_map : dict [
1076
- int , str
1077
- ] = cluster_quality_label_map .set_index ("cluster_id" )[
1078
- "KSLabel"
1079
- ].to_dict () # {unit: quality_label}
1080
-
1081
- # Get electrode where peak unit activity is recorded
1082
- peak_electrode_ind = np .array (
1083
- [
1084
- channel2electrode_map [unit_peak_channel [unit_id ]]["electrode" ]
1085
- for unit_id in si_sorting .unit_ids
1086
- ]
1087
- )
1088
-
1089
- # Get channel depth
1090
- channel_depth_ind = np .array (
1091
- [
1092
- we .get_probe ().contact_positions [unit_peak_channel [unit_id ]][1 ]
1093
- for unit_id in si_sorting .unit_ids
1094
- ]
1095
- )
1096
-
1097
- # Assign electrode and depth for each spike
1098
- new_spikes = np .empty (
1099
- spikes .shape ,
1100
- spikes .dtype .descr + [("electrode" , "<i8" ), ("depth" , "<i8" )],
1101
- )
1102
-
1103
- for field in spikes .dtype .names :
1104
- new_spikes [field ] = spikes [field ]
1105
- del spikes
1072
+ for unit_id in si_sorting .unit_ids
1073
+ }
1106
1074
1107
- new_spikes [ "electrode" ] = peak_electrode_ind [ new_spikes [ "unit_index" ]]
1108
- new_spikes [ "depth" ] = channel_depth_ind [ new_spikes [ "unit_index" ]]
1075
+ spike_locations = sorting_analyzer . get_extension ( "spike_locations" )
1076
+ spikes_df = pd . DataFrame ( spike_locations . spikes )
1109
1077
1110
1078
units = []
1111
-
1112
- for unit_id in si_sorting .unit_ids :
1079
+ for unit_idx , unit_id in enumerate (si_sorting .unit_ids ):
1113
1080
unit_id = int (unit_id )
1081
+ unit_spikes_df = spikes_df [spikes_df .unit_index == unit_idx ]
1082
+ spike_sites = np .array (
1083
+ [
1084
+ channel2electrode_map [chn_idx ]["electrode" ]
1085
+ for chn_idx in unit_spikes_df .channel_index
1086
+ ]
1087
+ )
1088
+ unit_spikes_loc = spike_locations .get_data ()[unit_spikes_df .index ]
1089
+ _ , spike_depths = zip (* unit_spikes_loc ) # x-coordinates, y-coordinates
1090
+ spike_times = si_sorting .get_unit_spike_train (unit_id , return_times = True )
1091
+
1092
+ assert len (spike_times ) == len (spike_sites ) == len (spike_depths )
1093
+
1114
1094
units .append (
1115
1095
{
1116
1096
** key ,
1117
1097
** channel2electrode_map [unit_peak_channel [unit_id ]],
1118
1098
"unit" : unit_id ,
1119
- "cluster_quality_label" : cluster_quality_label_map .get (
1120
- unit_id , "n.a."
1121
- ),
1122
- "spike_times" : si_sorting .get_unit_spike_train (
1123
- unit_id , return_times = True
1124
- ),
1099
+ "cluster_quality_label" : cluster_quality_label_map [unit_id ],
1100
+ "spike_times" : spike_times ,
1125
1101
"spike_count" : spike_count_dict [unit_id ],
1126
- "spike_sites" : new_spikes ["electrode" ][
1127
- new_spikes ["unit_index" ] == unit_id
1128
- ],
1129
- "spike_depths" : new_spikes ["depth" ][
1130
- new_spikes ["unit_index" ] == unit_id
1131
- ],
1102
+ "spike_sites" : spike_sites ,
1103
+ "spike_depths" : spike_depths ,
1132
1104
}
1133
1105
)
1134
1106
else : # read from kilosort outputs
@@ -1272,33 +1244,38 @@ def make(self, key):
1272
1244
chn .pop ("channel_idx" ): chn for chn in channel2electrode_map
1273
1245
}
1274
1246
1275
- si_waveform_dir = output_dir / sorter_name / "waveform"
1276
- if si_waveform_dir .exists (): # read from spikeinterface outputs
1277
- we : si .WaveformExtractor = si .load_waveforms (
1278
- si_waveform_dir , with_recording = False
1279
- )
1280
- unit_id_to_peak_channel_map : dict [
1281
- int , np .ndarray
1282
- ] = si .ChannelSparsity .from_best_channels (
1283
- we , 1 , peak_sign = "neg"
1284
- ).unit_id_to_channel_indices # {unit: peak_channel_index}
1247
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1248
+ if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1249
+ sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1250
+
1251
+ # Find representative channel for each unit
1252
+ unit_peak_channel : dict [int , np .ndarray ] = (
1253
+ si .ChannelSparsity .from_best_channels (
1254
+ sorting_analyzer , 1 , peak_sign = "neg"
1255
+ ).unit_id_to_channel_indices
1256
+ ) # {unit: peak_channel_index}
1257
+ unit_peak_channel = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1285
1258
1286
1259
# reorder channel2electrode_map according to recording channel ids
1287
- channel_indices = we .channel_ids_to_indices (we .channel_ids ).tolist ()
1260
+ channel_indices = sorting_analyzer .channel_ids_to_indices (
1261
+ sorting_analyzer .channel_ids
1262
+ ).tolist ()
1288
1263
channel2electrode_map = {
1289
1264
chn_idx : channel2electrode_map [chn_idx ] for chn_idx in channel_indices
1290
1265
}
1291
1266
1267
+ templates = sorting_analyzer .get_extension ("templates" )
1268
+
1292
1269
def yield_unit_waveforms ():
1293
1270
for unit in (CuratedClustering .Unit & key ).fetch (
1294
1271
"KEY" , order_by = "unit"
1295
1272
):
1296
1273
# Get mean waveform for this unit from all channels - (sample x channel)
1297
- unit_waveforms = we . get_template (
1298
- unit_id = unit ["unit" ], mode = "average" , force_dense = True
1274
+ unit_waveforms = templates . get_unit_template (
1275
+ unit_id = unit ["unit" ], operator = "average"
1299
1276
)
1300
1277
peak_chn_idx = channel_indices .index (
1301
- unit_id_to_peak_channel_map [unit ["unit" ]][ 0 ]
1278
+ unit_peak_channel [unit ["unit" ]]
1302
1279
)
1303
1280
unit_peak_waveform = {
1304
1281
** unit ,
@@ -1316,7 +1293,7 @@ def yield_unit_waveforms():
1316
1293
1317
1294
yield unit_peak_waveform , unit_electrode_waveforms
1318
1295
1319
- else : # read from kilosort outputs
1296
+ else : # read from kilosort outputs (ecephys pipeline)
1320
1297
kilosort_dataset = kilosort .Kilosort (output_dir )
1321
1298
1322
1299
acq_software , probe_serial_number = (
@@ -1522,43 +1499,54 @@ def make(self, key):
1522
1499
output_dir = find_full_path (get_ephys_root_data_dir (), output_dir )
1523
1500
sorter_name = clustering_method .replace ("." , "_" )
1524
1501
1525
- # find metric_fp
1526
- for metric_fp in [
1527
- output_dir / "metrics.csv" ,
1528
- output_dir / sorter_name / "metrics" / "metrics.csv" ,
1529
- ]:
1530
- if metric_fp .exists ():
1531
- break
1532
- else :
1533
- raise FileNotFoundError (f"QC metrics file not found in: { output_dir } " )
1502
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1503
+ if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1504
+ sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1505
+ qc_metrics = sorting_analyzer .get_extension ("quality_metrics" ).get_data ()
1506
+ template_metrics = sorting_analyzer .get_extension (
1507
+ "template_metrics"
1508
+ ).get_data ()
1509
+ metrics_df = pd .concat ([qc_metrics , template_metrics ], axis = 1 )
1534
1510
1535
- metrics_df = pd .read_csv (metric_fp )
1536
-
1537
- # Conform the dataframe to match the table definition
1538
- if "cluster_id" in metrics_df .columns :
1539
- metrics_df .set_index ("cluster_id" , inplace = True )
1540
- else :
1541
1511
metrics_df .rename (
1542
- columns = {metrics_df .columns [0 ]: "cluster_id" }, inplace = True
1512
+ columns = {
1513
+ "amplitude_median" : "amplitude" ,
1514
+ "isi_violations_ratio" : "isi_violation" ,
1515
+ "isi_violations_count" : "number_violation" ,
1516
+ "silhouette" : "silhouette_score" ,
1517
+ "rp_contamination" : "contamination_rate" ,
1518
+ "drift_ptp" : "max_drift" ,
1519
+ "drift_mad" : "cumulative_drift" ,
1520
+ "half_width" : "halfwidth" ,
1521
+ "peak_trough_ratio" : "pt_ratio" ,
1522
+ "peak_to_valley" : "duration" ,
1523
+ },
1524
+ inplace = True ,
1543
1525
)
1544
- metrics_df .set_index ("cluster_id" , inplace = True )
1545
- metrics_df .replace ([np .inf , - np .inf ], np .nan , inplace = True )
1546
- metrics_df .columns = metrics_df .columns .str .lower ()
1547
-
1548
- metrics_df .rename (
1549
- columns = {
1550
- "isi_violations_ratio" : "isi_violation" ,
1551
- "isi_violations_count" : "number_violation" ,
1552
- "silhouette" : "silhouette_score" ,
1553
- "rp_contamination" : "contamination_rate" ,
1554
- "drift_ptp" : "max_drift" ,
1555
- "drift_mad" : "cumulative_drift" ,
1556
- "half_width" : "halfwidth" ,
1557
- "peak_trough_ratio" : "pt_ratio" ,
1558
- },
1559
- inplace = True ,
1560
- )
1526
+ else : # read from kilosort outputs (ecephys pipeline)
1527
+ # find metric_fp
1528
+ for metric_fp in [
1529
+ output_dir / "metrics.csv" ,
1530
+ ]:
1531
+ if metric_fp .exists ():
1532
+ break
1533
+ else :
1534
+ raise FileNotFoundError (f"QC metrics file not found in: { output_dir } " )
1535
+
1536
+ metrics_df = pd .read_csv (metric_fp )
1561
1537
1538
+ # Conform the dataframe to match the table definition
1539
+ if "cluster_id" in metrics_df .columns :
1540
+ metrics_df .set_index ("cluster_id" , inplace = True )
1541
+ else :
1542
+ metrics_df .rename (
1543
+ columns = {metrics_df .columns [0 ]: "cluster_id" }, inplace = True
1544
+ )
1545
+ metrics_df .set_index ("cluster_id" , inplace = True )
1546
+
1547
+ metrics_df .columns = metrics_df .columns .str .lower ()
1548
+
1549
+ metrics_df .replace ([np .inf , - np .inf ], np .nan , inplace = True )
1562
1550
metrics_list = [
1563
1551
dict (metrics_df .loc [unit_key ["unit" ]], ** unit_key )
1564
1552
for unit_key in (CuratedClustering .Unit & key ).fetch ("KEY" )
0 commit comments