@@ -1045,10 +1045,13 @@ def make(self, key):
1045
1045
# Find representative channel for each unit
1046
1046
unit_peak_channel : dict [int , np .ndarray ] = (
1047
1047
si .ChannelSparsity .from_best_channels (
1048
- sorting_analyzer , 1 , peak_sign = "both"
1048
+ sorting_analyzer ,
1049
+ 1 ,
1049
1050
).unit_id_to_channel_indices
1050
1051
)
1051
- unit_peak_channel : dict [int , int ] = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1052
+ unit_peak_channel : dict [int , int ] = {
1053
+ u : chn [0 ] for u , chn in unit_peak_channel .items ()
1054
+ }
1052
1055
1053
1056
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
1054
1057
# {unit: spike_count}
@@ -1084,7 +1087,9 @@ def make(self, key):
1084
1087
)
1085
1088
unit_spikes_loc = spike_locations .get_data ()[unit_spikes_df .index ]
1086
1089
_ , spike_depths = zip (* unit_spikes_loc ) # x-coordinates, y-coordinates
1087
- spike_times = si_sorting .get_unit_spike_train (unit_id , return_times = True )
1090
+ spike_times = si_sorting .get_unit_spike_train (
1091
+ unit_id , return_times = True
1092
+ )
1088
1093
1089
1094
assert len (spike_times ) == len (spike_sites ) == len (spike_depths )
1090
1095
@@ -1243,13 +1248,13 @@ def make(self, key):
1243
1248
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1244
1249
if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1245
1250
import spikeinterface as si
1246
-
1251
+
1247
1252
sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1248
1253
1249
1254
# Find representative channel for each unit
1250
1255
unit_peak_channel : dict [int , np .ndarray ] = (
1251
1256
si .ChannelSparsity .from_best_channels (
1252
- sorting_analyzer , 1 , peak_sign = "both"
1257
+ sorting_analyzer , 1
1253
1258
).unit_id_to_channel_indices
1254
1259
) # {unit: peak_channel_index}
1255
1260
unit_peak_channel = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
@@ -1272,7 +1277,9 @@ def yield_unit_waveforms():
1272
1277
)
1273
1278
unit_peak_waveform = {
1274
1279
** unit ,
1275
- "peak_electrode_waveform" : unit_waveforms [:, unit_peak_channel [unit ["unit" ]]],
1280
+ "peak_electrode_waveform" : unit_waveforms [
1281
+ :, unit_peak_channel [unit ["unit" ]]
1282
+ ],
1276
1283
}
1277
1284
1278
1285
unit_electrode_waveforms = [
@@ -1495,7 +1502,7 @@ def make(self, key):
1495
1502
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1496
1503
if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1497
1504
import spikeinterface as si
1498
-
1505
+
1499
1506
sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1500
1507
qc_metrics = sorting_analyzer .get_extension ("quality_metrics" ).get_data ()
1501
1508
template_metrics = sorting_analyzer .get_extension (
0 commit comments