@@ -1048,8 +1048,8 @@ def make(self, key):
1048
1048
si .ChannelSparsity .from_best_channels (
1049
1049
sorting_analyzer , 1 , peak_sign = "neg"
1050
1050
).unit_id_to_channel_indices
1051
- ) # {unit: peak_channel_index}
1052
- unit_peak_channel = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1051
+ )
1052
+ unit_peak_channel : dict [ int , int ] = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1053
1053
1054
1054
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
1055
1055
# {unit: spike_count}
@@ -1076,9 +1076,9 @@ def make(self, key):
1076
1076
spikes_df = pd .DataFrame (spike_locations .spikes )
1077
1077
1078
1078
units = []
1079
- for unit_id in si_sorting .unit_ids :
1079
+ for unit_idx , unit_id in enumerate ( si_sorting .unit_ids ) :
1080
1080
unit_id = int (unit_id )
1081
- unit_spikes_df = spikes_df [spikes_df .unit_index == unit_id ]
1081
+ unit_spikes_df = spikes_df [spikes_df .unit_index == unit_idx ]
1082
1082
spike_sites = np .array (
1083
1083
[
1084
1084
channel2electrode_map [chn_idx ]["electrode" ]
@@ -1087,16 +1087,17 @@ def make(self, key):
1087
1087
)
1088
1088
unit_spikes_loc = spike_locations .get_data ()[unit_spikes_df .index ]
1089
1089
_ , spike_depths = zip (* unit_spikes_loc ) # x-coordinates, y-coordinates
1090
+ spike_times = si_sorting .get_unit_spike_train (unit_id , return_times = True )
1091
+
1092
+ assert len (spike_times ) == len (spike_sites ) == len (spike_depths )
1090
1093
1091
1094
units .append (
1092
1095
{
1093
1096
** key ,
1094
1097
** channel2electrode_map [unit_peak_channel [unit_id ]],
1095
1098
"unit" : unit_id ,
1096
1099
"cluster_quality_label" : cluster_quality_label_map [unit_id ],
1097
- "spike_times" : si_sorting .get_unit_spike_train (
1098
- unit_id , return_times = True
1099
- ),
1100
+ "spike_times" : spike_times ,
1100
1101
"spike_count" : spike_count_dict [unit_id ],
1101
1102
"spike_sites" : spike_sites ,
1102
1103
"spike_depths" : spike_depths ,
0 commit comments