12
12
13
13
class Kilosort :
14
14
15
- ks_files = [
15
+ kilosort_files = [
16
16
'params.py' ,
17
17
'amplitudes.npy' ,
18
18
'channel_map.npy' ,
@@ -36,18 +36,18 @@ class Kilosort:
36
36
]
37
37
38
38
# keys to self.files, .data are file name e.g. self.data['params'], etc.
39
- ks_keys = [path .splitext (ks_file )[0 ] for ks_file in ks_files ]
39
+ kilosort_keys = [path .splitext (kilosort_file )[0 ] for kilosort_file in kilosort_files ]
40
40
41
- def __init__ (self , ks_dir ):
42
- self ._ks_dir = pathlib .Path (ks_dir )
41
+ def __init__ (self , kilosort_dir ):
42
+ self ._kilosort_dir = pathlib .Path (kilosort_dir )
43
43
self ._files = {}
44
44
self ._data = None
45
45
self ._clusters = None
46
46
47
- params_filepath = ks_dir / 'params.py'
47
+ params_filepath = kilosort_dir / 'params.py'
48
48
49
49
if not params_filepath .exists ():
50
- raise FileNotFoundError (f'No Kilosort output found in: { ks_dir } ' )
50
+ raise FileNotFoundError (f'No Kilosort output found in: { kilosort_dir } ' )
51
51
52
52
self ._info = {'time_created' : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
53
53
'time_modified' : datetime .fromtimestamp (params_filepath .stat ().st_mtime )}
@@ -64,42 +64,44 @@ def info(self):
64
64
65
65
def _stat (self ):
66
66
self ._data = {}
67
- for ks_filename in Kilosort .ks_files :
68
- ks_filepath = self ._ks_dir / ks_filename
67
+ for kilosort_filename in Kilosort .kilosort_files :
68
+ kilosort_filepath = self ._kilosort_dir / kilosort_filename
69
69
70
- if not ks_filepath .exists ():
71
- log .debug ('skipping {} - does not exist' .format (ks_filepath ))
70
+ if not kilosort_filepath .exists ():
71
+ log .debug ('skipping {} - does not exist' .format (kilosort_filepath ))
72
72
continue
73
73
74
- base , ext = path .splitext (ks_filename )
75
- self ._files [base ] = ks_filepath
74
+ base , ext = path .splitext (kilosort_filename )
75
+ self ._files [base ] = kilosort_filepath
76
76
77
- if ks_filename == 'params.py' :
78
- log .debug ('loading params.py {}' .format (ks_filepath ))
77
+ if kilosort_filename == 'params.py' :
78
+ log .debug ('loading params.py {}' .format (kilosort_filepath ))
79
79
# params.py is a 'key = val' file
80
80
params = {}
81
- for line in open (ks_filepath , 'r' ).readlines ():
81
+ for line in open (kilosort_filepath , 'r' ).readlines ():
82
82
k , v = line .strip ('\n ' ).split ('=' )
83
83
params [k .strip ()] = convert_to_number (v .strip ())
84
84
log .debug ('params: {}' .format (params ))
85
85
self ._data [base ] = params
86
86
87
87
if ext == '.npy' :
88
- log .debug ('loading npy {}' .format (ks_filepath ))
89
- d = np .load (ks_filepath , mmap_mode = 'r' , allow_pickle = False , fix_imports = False )
88
+ log .debug ('loading npy {}' .format (kilosort_filepath ))
89
+ d = np .load (kilosort_filepath , mmap_mode = 'r' ,
90
+ allow_pickle = False , fix_imports = False )
90
91
self ._data [base ] = (np .reshape (d , d .shape [0 ])
91
92
if d .ndim == 2 and d .shape [1 ] == 1 else d )
92
93
93
94
# Read the Cluster Groups
94
95
for cluster_pattern , cluster_col_name in zip (['cluster_groups.*' , 'cluster_KSLabel.*' ],
95
96
['group' , 'KSLabel' ]):
96
97
try :
97
- cluster_file = next (self ._ks_dir .glob (cluster_pattern ))
98
- cluster_file_suffix = cluster_file .suffix
99
- assert cluster_file_suffix in ('.csv' , '.tsv' , '.xlsx' )
100
- break
98
+ cluster_file = next (self ._kilosort_dir .glob (cluster_pattern ))
101
99
except StopIteration :
102
100
pass
101
+
102
+ cluster_file_suffix = cluster_file .suffix
103
+ assert cluster_file_suffix in ('.csv' , '.tsv' , '.xlsx' )
104
+ break
103
105
else :
104
106
raise FileNotFoundError (
105
107
'Neither "cluster_groups" nor "cluster_KSLabel" file found!' )
@@ -118,7 +120,7 @@ def get_best_channel(self, unit):
118
120
template_idx = self .data ['spike_templates' ][
119
121
np .where (self .data ['spike_clusters' ] == unit )[0 ][0 ]]
120
122
channel_templates = self .data ['templates' ][template_idx , :, :]
121
- max_channel_idx = np .abs (np . abs ( channel_templates ).max (axis = 0 ) ).argmax ()
123
+ max_channel_idx = np .abs (channel_templates ).max (axis = 0 ).argmax ()
122
124
max_channel = self .data ['channel_map' ][max_channel_idx ]
123
125
124
126
return max_channel , max_channel_idx
@@ -174,12 +176,10 @@ def extract_clustering_info(cluster_output_dir):
174
176
175
177
# ---- Quality control? ----
176
178
metric_filepath = cluster_output_dir / 'metrics.csv'
177
- if metric_filepath .exists ():
178
- is_qc = True
179
+ is_qc = metric_filepath .exists ()
180
+ if is_qc :
179
181
if creation_time is None :
180
182
creation_time = datetime .fromtimestamp (metric_filepath .stat ().st_ctime )
181
- else :
182
- is_qc = False
183
183
184
184
if creation_time is None :
185
185
spiketimes_filepath = next (cluster_output_dir .glob ('spike_times.npy' ))
0 commit comments