13
13
class Kilosort :
14
14
15
15
_kilosort_core_files = [
16
- ' params.py' ,
17
- ' amplitudes.npy' ,
18
- ' channel_map.npy' ,
19
- ' channel_positions.npy' ,
20
- ' pc_features.npy' ,
21
- ' pc_feature_ind.npy' ,
22
- ' similar_templates.npy' ,
23
- ' spike_templates.npy' ,
24
- ' spike_times.npy' ,
25
- ' template_features.npy' ,
26
- ' template_feature_ind.npy' ,
27
- ' templates.npy' ,
28
- ' templates_ind.npy' ,
29
- ' whitening_mat.npy' ,
30
- ' whitening_mat_inv.npy' ,
31
- ' spike_clusters.npy'
16
+ " params.py" ,
17
+ " amplitudes.npy" ,
18
+ " channel_map.npy" ,
19
+ " channel_positions.npy" ,
20
+ " pc_features.npy" ,
21
+ " pc_feature_ind.npy" ,
22
+ " similar_templates.npy" ,
23
+ " spike_templates.npy" ,
24
+ " spike_times.npy" ,
25
+ " template_features.npy" ,
26
+ " template_feature_ind.npy" ,
27
+ " templates.npy" ,
28
+ " templates_ind.npy" ,
29
+ " whitening_mat.npy" ,
30
+ " whitening_mat_inv.npy" ,
31
+ " spike_clusters.npy" ,
32
32
]
33
33
34
34
_kilosort_additional_files = [
35
- ' spike_times_sec.npy' ,
36
- ' spike_times_sec_adj.npy' ,
37
- ' cluster_groups.csv' ,
38
- ' cluster_KSLabel.tsv'
35
+ " spike_times_sec.npy" ,
36
+ " spike_times_sec_adj.npy" ,
37
+ " cluster_groups.csv" ,
38
+ " cluster_KSLabel.tsv" ,
39
39
]
40
40
41
41
kilosort_files = _kilosort_core_files + _kilosort_additional_files
@@ -48,9 +48,11 @@ def __init__(self, kilosort_dir):
48
48
49
49
self .validate ()
50
50
51
- params_filepath = kilosort_dir / 'params.py'
52
- self ._info = {'time_created' : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
53
- 'time_modified' : datetime .fromtimestamp (params_filepath .stat ().st_mtime )}
51
+ params_filepath = kilosort_dir / "params.py"
52
+ self ._info = {
53
+ "time_created" : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
54
+ "time_modified" : datetime .fromtimestamp (params_filepath .stat ().st_mtime ),
55
+ }
54
56
55
57
@property
56
58
def data (self ):
@@ -72,136 +74,157 @@ def validate(self):
72
74
if not full_path .exists ():
73
75
missing_files .append (f )
74
76
if missing_files :
75
- raise FileNotFoundError (f'Kilosort files missing in ({ self ._kilosort_dir } ):'
76
- f' { missing_files } ' )
77
+ raise FileNotFoundError (
78
+ f"Kilosort files missing in ({ self ._kilosort_dir } ):" f" { missing_files } "
79
+ )
77
80
78
81
def _load (self ):
79
82
self ._data = {}
80
83
for kilosort_filename in Kilosort .kilosort_files :
81
84
kilosort_filepath = self ._kilosort_dir / kilosort_filename
82
85
83
86
if not kilosort_filepath .exists ():
84
- log .debug (' skipping {} - does not exist' .format (kilosort_filepath ))
87
+ log .debug (" skipping {} - does not exist" .format (kilosort_filepath ))
85
88
continue
86
89
87
90
base , ext = path .splitext (kilosort_filename )
88
91
self ._files [base ] = kilosort_filepath
89
92
90
- if kilosort_filename == ' params.py' :
91
- log .debug (' loading params.py {}' .format (kilosort_filepath ))
93
+ if kilosort_filename == " params.py" :
94
+ log .debug (" loading params.py {}" .format (kilosort_filepath ))
92
95
# params.py is a 'key = val' file
93
96
params = {}
94
- for line in open (kilosort_filepath , 'r' ).readlines ():
95
- k , v = line .strip (' \n ' ).split ('=' )
97
+ for line in open (kilosort_filepath , "r" ).readlines ():
98
+ k , v = line .strip (" \n " ).split ("=" )
96
99
params [k .strip ()] = convert_to_number (v .strip ())
97
- log .debug (' params: {}' .format (params ))
100
+ log .debug (" params: {}" .format (params ))
98
101
self ._data [base ] = params
99
102
100
- if ext == '.npy' :
101
- log .debug ('loading npy {}' .format (kilosort_filepath ))
102
- d = np .load (kilosort_filepath , mmap_mode = 'r' ,
103
- allow_pickle = False , fix_imports = False )
104
- self ._data [base ] = (np .reshape (d , d .shape [0 ])
105
- if d .ndim == 2 and d .shape [1 ] == 1 else d )
103
+ if ext == ".npy" :
104
+ log .debug ("loading npy {}" .format (kilosort_filepath ))
105
+ d = np .load (
106
+ kilosort_filepath ,
107
+ mmap_mode = "r" ,
108
+ allow_pickle = False ,
109
+ fix_imports = False ,
110
+ )
111
+ self ._data [base ] = (
112
+ np .reshape (d , d .shape [0 ]) if d .ndim == 2 and d .shape [1 ] == 1 else d
113
+ )
106
114
107
- self ._data [' channel_map' ] = self ._data [' channel_map' ].flatten ()
115
+ self ._data [" channel_map" ] = self ._data [" channel_map" ].flatten ()
108
116
109
117
# Read the Cluster Groups
110
- for cluster_pattern , cluster_col_name in zip (['cluster_group.*' , 'cluster_KSLabel.*' ],
111
- ['group' , 'KSLabel' ]):
118
+ for cluster_pattern , cluster_col_name in zip (
119
+ ["cluster_group.*" , "cluster_KSLabel.*" ], ["group" , "KSLabel" ]
120
+ ):
112
121
try :
113
122
cluster_file = next (self ._kilosort_dir .glob (cluster_pattern ))
114
123
except StopIteration :
115
124
pass
116
125
else :
117
126
cluster_file_suffix = cluster_file .suffix
118
- assert cluster_file_suffix in (' .tsv' , ' .xlsx' )
127
+ assert cluster_file_suffix in (" .tsv" , " .xlsx" )
119
128
break
120
129
else :
121
130
raise FileNotFoundError (
122
- 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' )
131
+ 'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
132
+ )
123
133
124
- if cluster_file_suffix == ' .tsv' :
125
- df = pd .read_csv (cluster_file , sep = ' \t ' , header = 0 )
126
- elif cluster_file_suffix == ' .xlsx' :
127
- df = pd .read_excel (cluster_file , engine = ' openpyxl' )
134
+ if cluster_file_suffix == " .tsv" :
135
+ df = pd .read_csv (cluster_file , sep = " \t " , header = 0 )
136
+ elif cluster_file_suffix == " .xlsx" :
137
+ df = pd .read_excel (cluster_file , engine = " openpyxl" )
128
138
else :
129
- df = pd .read_csv (cluster_file , delimiter = ' \t ' )
139
+ df = pd .read_csv (cluster_file , delimiter = " \t " )
130
140
131
- self ._data [' cluster_groups' ] = np .array (df [cluster_col_name ].values )
132
- self ._data [' cluster_ids' ] = np .array (df [' cluster_id' ].values )
141
+ self ._data [" cluster_groups" ] = np .array (df [cluster_col_name ].values )
142
+ self ._data [" cluster_ids" ] = np .array (df [" cluster_id" ].values )
133
143
134
144
def get_best_channel (self , unit ):
135
- template_idx = self .data ['spike_templates' ][
136
- np .where (self .data ['spike_clusters' ] == unit )[0 ][0 ]]
137
- channel_templates = self .data ['templates' ][template_idx , :, :]
145
+ template_idx = self .data ["spike_templates" ][
146
+ np .where (self .data ["spike_clusters" ] == unit )[0 ][0 ]
147
+ ]
148
+ channel_templates = self .data ["templates" ][template_idx , :, :]
138
149
max_channel_idx = np .abs (channel_templates ).max (axis = 0 ).argmax ()
139
- max_channel = self .data [' channel_map' ][max_channel_idx ]
150
+ max_channel = self .data [" channel_map" ][max_channel_idx ]
140
151
141
152
return max_channel , max_channel_idx
142
153
143
154
def extract_spike_depths (self ):
144
- """ Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m """
145
-
146
- if ' pc_features' in self .data :
147
- ycoords = self .data [' channel_positions' ][:, 1 ]
148
- pc_features = self .data [' pc_features' ][:, 0 , :] # 1st PC only
155
+ """Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m"""
156
+
157
+ if " pc_features" in self .data :
158
+ ycoords = self .data [" channel_positions" ][:, 1 ]
159
+ pc_features = self .data [" pc_features" ][:, 0 , :] # 1st PC only
149
160
pc_features = np .where (pc_features < 0 , 0 , pc_features )
150
161
151
162
# ---- compute center of mass of these features (spike depths) ----
152
163
153
164
# which channels for each spike?
154
- spk_feature_ind = self .data ['pc_feature_ind' ][self .data ['spike_templates' ], :]
165
+ spk_feature_ind = self .data ["pc_feature_ind" ][
166
+ self .data ["spike_templates" ], :
167
+ ]
155
168
# ycoords of those channels?
156
169
spk_feature_ycoord = ycoords [spk_feature_ind ]
157
170
# center of mass is sum(coords.*features)/sum(features)
158
- self ._data ['spike_depths' ] = (np .sum (spk_feature_ycoord * pc_features ** 2 , axis = 1 )
159
- / np .sum (pc_features ** 2 , axis = 1 ))
171
+ self ._data ["spike_depths" ] = np .sum (
172
+ spk_feature_ycoord * pc_features ** 2 , axis = 1
173
+ ) / np .sum (pc_features ** 2 , axis = 1 )
160
174
else :
161
- self ._data [' spike_depths' ] = None
175
+ self ._data [" spike_depths" ] = None
162
176
163
177
# ---- extract spike sites ----
164
- max_site_ind = np .argmax (np .abs (self .data [' templates' ]).max (axis = 1 ), axis = 1 )
165
- spike_site_ind = max_site_ind [self .data [' spike_templates' ]]
166
- self ._data [' spike_sites' ] = self .data [' channel_map' ][spike_site_ind ]
178
+ max_site_ind = np .argmax (np .abs (self .data [" templates" ]).max (axis = 1 ), axis = 1 )
179
+ spike_site_ind = max_site_ind [self .data [" spike_templates" ]]
180
+ self ._data [" spike_sites" ] = self .data [" channel_map" ][spike_site_ind ]
167
181
168
182
169
183
def extract_clustering_info (cluster_output_dir ):
170
184
creation_time = None
171
185
172
- phy_curation_indicators = ['Merge clusters' , 'Split cluster' , 'Change metadata_group' ]
186
+ phy_curation_indicators = [
187
+ "Merge clusters" ,
188
+ "Split cluster" ,
189
+ "Change metadata_group" ,
190
+ ]
173
191
# ---- Manual curation? ----
174
- phylog_filepath = cluster_output_dir / ' phy.log'
192
+ phylog_filepath = cluster_output_dir / " phy.log"
175
193
if phylog_filepath .exists ():
176
194
phylog = pd .read_fwf (phylog_filepath , colspecs = [(6 , 40 ), (41 , 250 )])
177
- phylog .columns = ['meta' , 'detail' ]
178
- curation_row = [bool (re .match ('|' .join (phy_curation_indicators ), str (s )))
179
- for s in phylog .detail ]
195
+ phylog .columns = ["meta" , "detail" ]
196
+ curation_row = [
197
+ bool (re .match ("|" .join (phy_curation_indicators ), str (s )))
198
+ for s in phylog .detail
199
+ ]
180
200
is_curated = bool (np .any (curation_row ))
181
201
if creation_time is None and is_curated :
182
202
row_meta = phylog .meta [np .where (curation_row )[0 ].max ()]
183
- datetime_str = re .search (' \d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}' , row_meta )
203
+ datetime_str = re .search (" \d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}" , row_meta )
184
204
if datetime_str :
185
- creation_time = datetime .strptime (datetime_str .group (), '%Y-%m-%d %H:%M:%S' )
205
+ creation_time = datetime .strptime (
206
+ datetime_str .group (), "%Y-%m-%d %H:%M:%S"
207
+ )
186
208
else :
187
209
creation_time = datetime .fromtimestamp (phylog_filepath .stat ().st_ctime )
188
- time_str = re .search (' \d{2}:\d{2}:\d{2}' , row_meta )
210
+ time_str = re .search (" \d{2}:\d{2}:\d{2}" , row_meta )
189
211
if time_str :
190
212
creation_time = datetime .combine (
191
213
creation_time .date (),
192
- datetime .strptime (time_str .group (), '%H:%M:%S' ).time ())
214
+ datetime .strptime (time_str .group (), "%H:%M:%S" ).time (),
215
+ )
193
216
else :
194
217
is_curated = False
195
218
196
219
# ---- Quality control? ----
197
- metric_filepath = cluster_output_dir / ' metrics.csv'
220
+ metric_filepath = cluster_output_dir / " metrics.csv"
198
221
is_qc = metric_filepath .exists ()
199
222
if is_qc :
200
223
if creation_time is None :
201
224
creation_time = datetime .fromtimestamp (metric_filepath .stat ().st_ctime )
202
225
203
226
if creation_time is None :
204
- spiketimes_filepath = next (cluster_output_dir .glob (' spike_times.npy' ))
227
+ spiketimes_filepath = next (cluster_output_dir .glob (" spike_times.npy" ))
205
228
creation_time = datetime .fromtimestamp (spiketimes_filepath .stat ().st_ctime )
206
229
207
230
return creation_time , is_curated , is_qc
0 commit comments