Skip to content

Commit 6b375f9

Browse files
author
Thinh Nguyen
committed
BLACK formatting
1 parent 72e784b commit 6b375f9

File tree

3 files changed

+254
-189
lines changed

3 files changed

+254
-189
lines changed

element_array_ephys/readers/kilosort.py

Lines changed: 99 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,29 @@
1313
class Kilosort:
1414

1515
_kilosort_core_files = [
16-
'params.py',
17-
'amplitudes.npy',
18-
'channel_map.npy',
19-
'channel_positions.npy',
20-
'pc_features.npy',
21-
'pc_feature_ind.npy',
22-
'similar_templates.npy',
23-
'spike_templates.npy',
24-
'spike_times.npy',
25-
'template_features.npy',
26-
'template_feature_ind.npy',
27-
'templates.npy',
28-
'templates_ind.npy',
29-
'whitening_mat.npy',
30-
'whitening_mat_inv.npy',
31-
'spike_clusters.npy'
16+
"params.py",
17+
"amplitudes.npy",
18+
"channel_map.npy",
19+
"channel_positions.npy",
20+
"pc_features.npy",
21+
"pc_feature_ind.npy",
22+
"similar_templates.npy",
23+
"spike_templates.npy",
24+
"spike_times.npy",
25+
"template_features.npy",
26+
"template_feature_ind.npy",
27+
"templates.npy",
28+
"templates_ind.npy",
29+
"whitening_mat.npy",
30+
"whitening_mat_inv.npy",
31+
"spike_clusters.npy",
3232
]
3333

3434
_kilosort_additional_files = [
35-
'spike_times_sec.npy',
36-
'spike_times_sec_adj.npy',
37-
'cluster_groups.csv',
38-
'cluster_KSLabel.tsv'
35+
"spike_times_sec.npy",
36+
"spike_times_sec_adj.npy",
37+
"cluster_groups.csv",
38+
"cluster_KSLabel.tsv",
3939
]
4040

4141
kilosort_files = _kilosort_core_files + _kilosort_additional_files
@@ -48,9 +48,11 @@ def __init__(self, kilosort_dir):
4848

4949
self.validate()
5050

51-
params_filepath = kilosort_dir / 'params.py'
52-
self._info = {'time_created': datetime.fromtimestamp(params_filepath.stat().st_ctime),
53-
'time_modified': datetime.fromtimestamp(params_filepath.stat().st_mtime)}
51+
params_filepath = kilosort_dir / "params.py"
52+
self._info = {
53+
"time_created": datetime.fromtimestamp(params_filepath.stat().st_ctime),
54+
"time_modified": datetime.fromtimestamp(params_filepath.stat().st_mtime),
55+
}
5456

5557
@property
5658
def data(self):
@@ -72,136 +74,157 @@ def validate(self):
7274
if not full_path.exists():
7375
missing_files.append(f)
7476
if missing_files:
75-
raise FileNotFoundError(f'Kilosort files missing in ({self._kilosort_dir}):'
76-
f' {missing_files}')
77+
raise FileNotFoundError(
78+
f"Kilosort files missing in ({self._kilosort_dir}):" f" {missing_files}"
79+
)
7780

7881
def _load(self):
7982
self._data = {}
8083
for kilosort_filename in Kilosort.kilosort_files:
8184
kilosort_filepath = self._kilosort_dir / kilosort_filename
8285

8386
if not kilosort_filepath.exists():
84-
log.debug('skipping {} - does not exist'.format(kilosort_filepath))
87+
log.debug("skipping {} - does not exist".format(kilosort_filepath))
8588
continue
8689

8790
base, ext = path.splitext(kilosort_filename)
8891
self._files[base] = kilosort_filepath
8992

90-
if kilosort_filename == 'params.py':
91-
log.debug('loading params.py {}'.format(kilosort_filepath))
93+
if kilosort_filename == "params.py":
94+
log.debug("loading params.py {}".format(kilosort_filepath))
9295
# params.py is a 'key = val' file
9396
params = {}
94-
for line in open(kilosort_filepath, 'r').readlines():
95-
k, v = line.strip('\n').split('=')
97+
for line in open(kilosort_filepath, "r").readlines():
98+
k, v = line.strip("\n").split("=")
9699
params[k.strip()] = convert_to_number(v.strip())
97-
log.debug('params: {}'.format(params))
100+
log.debug("params: {}".format(params))
98101
self._data[base] = params
99102

100-
if ext == '.npy':
101-
log.debug('loading npy {}'.format(kilosort_filepath))
102-
d = np.load(kilosort_filepath, mmap_mode='r',
103-
allow_pickle=False, fix_imports=False)
104-
self._data[base] = (np.reshape(d, d.shape[0])
105-
if d.ndim == 2 and d.shape[1] == 1 else d)
103+
if ext == ".npy":
104+
log.debug("loading npy {}".format(kilosort_filepath))
105+
d = np.load(
106+
kilosort_filepath,
107+
mmap_mode="r",
108+
allow_pickle=False,
109+
fix_imports=False,
110+
)
111+
self._data[base] = (
112+
np.reshape(d, d.shape[0]) if d.ndim == 2 and d.shape[1] == 1 else d
113+
)
106114

107-
self._data['channel_map'] = self._data['channel_map'].flatten()
115+
self._data["channel_map"] = self._data["channel_map"].flatten()
108116

109117
# Read the Cluster Groups
110-
for cluster_pattern, cluster_col_name in zip(['cluster_group.*', 'cluster_KSLabel.*'],
111-
['group', 'KSLabel']):
118+
for cluster_pattern, cluster_col_name in zip(
119+
["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"]
120+
):
112121
try:
113122
cluster_file = next(self._kilosort_dir.glob(cluster_pattern))
114123
except StopIteration:
115124
pass
116125
else:
117126
cluster_file_suffix = cluster_file.suffix
118-
assert cluster_file_suffix in ('.tsv', '.xlsx')
127+
assert cluster_file_suffix in (".tsv", ".xlsx")
119128
break
120129
else:
121130
raise FileNotFoundError(
122-
'Neither "cluster_groups" nor "cluster_KSLabel" file found!')
131+
'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
132+
)
123133

124-
if cluster_file_suffix == '.tsv':
125-
df = pd.read_csv(cluster_file, sep='\t', header=0)
126-
elif cluster_file_suffix == '.xlsx':
127-
df = pd.read_excel(cluster_file, engine='openpyxl')
134+
if cluster_file_suffix == ".tsv":
135+
df = pd.read_csv(cluster_file, sep="\t", header=0)
136+
elif cluster_file_suffix == ".xlsx":
137+
df = pd.read_excel(cluster_file, engine="openpyxl")
128138
else:
129-
df = pd.read_csv(cluster_file, delimiter='\t')
139+
df = pd.read_csv(cluster_file, delimiter="\t")
130140

131-
self._data['cluster_groups'] = np.array(df[cluster_col_name].values)
132-
self._data['cluster_ids'] = np.array(df['cluster_id'].values)
141+
self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
142+
self._data["cluster_ids"] = np.array(df["cluster_id"].values)
133143

134144
def get_best_channel(self, unit):
135-
template_idx = self.data['spike_templates'][
136-
np.where(self.data['spike_clusters'] == unit)[0][0]]
137-
channel_templates = self.data['templates'][template_idx, :, :]
145+
template_idx = self.data["spike_templates"][
146+
np.where(self.data["spike_clusters"] == unit)[0][0]
147+
]
148+
channel_templates = self.data["templates"][template_idx, :, :]
138149
max_channel_idx = np.abs(channel_templates).max(axis=0).argmax()
139-
max_channel = self.data['channel_map'][max_channel_idx]
150+
max_channel = self.data["channel_map"][max_channel_idx]
140151

141152
return max_channel, max_channel_idx
142153

143154
def extract_spike_depths(self):
144-
""" Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m """
145-
146-
if 'pc_features' in self.data:
147-
ycoords = self.data['channel_positions'][:, 1]
148-
pc_features = self.data['pc_features'][:, 0, :] # 1st PC only
155+
"""Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m"""
156+
157+
if "pc_features" in self.data:
158+
ycoords = self.data["channel_positions"][:, 1]
159+
pc_features = self.data["pc_features"][:, 0, :] # 1st PC only
149160
pc_features = np.where(pc_features < 0, 0, pc_features)
150161

151162
# ---- compute center of mass of these features (spike depths) ----
152163

153164
# which channels for each spike?
154-
spk_feature_ind = self.data['pc_feature_ind'][self.data['spike_templates'], :]
165+
spk_feature_ind = self.data["pc_feature_ind"][
166+
self.data["spike_templates"], :
167+
]
155168
# ycoords of those channels?
156169
spk_feature_ycoord = ycoords[spk_feature_ind]
157170
# center of mass is sum(coords.*features)/sum(features)
158-
self._data['spike_depths'] = (np.sum(spk_feature_ycoord * pc_features**2, axis=1)
159-
/ np.sum(pc_features**2, axis=1))
171+
self._data["spike_depths"] = np.sum(
172+
spk_feature_ycoord * pc_features ** 2, axis=1
173+
) / np.sum(pc_features ** 2, axis=1)
160174
else:
161-
self._data['spike_depths'] = None
175+
self._data["spike_depths"] = None
162176

163177
# ---- extract spike sites ----
164-
max_site_ind = np.argmax(np.abs(self.data['templates']).max(axis=1), axis=1)
165-
spike_site_ind = max_site_ind[self.data['spike_templates']]
166-
self._data['spike_sites'] = self.data['channel_map'][spike_site_ind]
178+
max_site_ind = np.argmax(np.abs(self.data["templates"]).max(axis=1), axis=1)
179+
spike_site_ind = max_site_ind[self.data["spike_templates"]]
180+
self._data["spike_sites"] = self.data["channel_map"][spike_site_ind]
167181

168182

169183
def extract_clustering_info(cluster_output_dir):
170184
creation_time = None
171185

172-
phy_curation_indicators = ['Merge clusters', 'Split cluster', 'Change metadata_group']
186+
phy_curation_indicators = [
187+
"Merge clusters",
188+
"Split cluster",
189+
"Change metadata_group",
190+
]
173191
# ---- Manual curation? ----
174-
phylog_filepath = cluster_output_dir / 'phy.log'
192+
phylog_filepath = cluster_output_dir / "phy.log"
175193
if phylog_filepath.exists():
176194
phylog = pd.read_fwf(phylog_filepath, colspecs=[(6, 40), (41, 250)])
177-
phylog.columns = ['meta', 'detail']
178-
curation_row = [bool(re.match('|'.join(phy_curation_indicators), str(s)))
179-
for s in phylog.detail]
195+
phylog.columns = ["meta", "detail"]
196+
curation_row = [
197+
bool(re.match("|".join(phy_curation_indicators), str(s)))
198+
for s in phylog.detail
199+
]
180200
is_curated = bool(np.any(curation_row))
181201
if creation_time is None and is_curated:
182202
row_meta = phylog.meta[np.where(curation_row)[0].max()]
183-
datetime_str = re.search('\d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}', row_meta)
203+
datetime_str = re.search("\d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}", row_meta)
184204
if datetime_str:
185-
creation_time = datetime.strptime(datetime_str.group(), '%Y-%m-%d %H:%M:%S')
205+
creation_time = datetime.strptime(
206+
datetime_str.group(), "%Y-%m-%d %H:%M:%S"
207+
)
186208
else:
187209
creation_time = datetime.fromtimestamp(phylog_filepath.stat().st_ctime)
188-
time_str = re.search('\d{2}:\d{2}:\d{2}', row_meta)
210+
time_str = re.search("\d{2}:\d{2}:\d{2}", row_meta)
189211
if time_str:
190212
creation_time = datetime.combine(
191213
creation_time.date(),
192-
datetime.strptime(time_str.group(), '%H:%M:%S').time())
214+
datetime.strptime(time_str.group(), "%H:%M:%S").time(),
215+
)
193216
else:
194217
is_curated = False
195218

196219
# ---- Quality control? ----
197-
metric_filepath = cluster_output_dir / 'metrics.csv'
220+
metric_filepath = cluster_output_dir / "metrics.csv"
198221
is_qc = metric_filepath.exists()
199222
if is_qc:
200223
if creation_time is None:
201224
creation_time = datetime.fromtimestamp(metric_filepath.stat().st_ctime)
202225

203226
if creation_time is None:
204-
spiketimes_filepath = next(cluster_output_dir.glob('spike_times.npy'))
227+
spiketimes_filepath = next(cluster_output_dir.glob("spike_times.npy"))
205228
creation_time = datetime.fromtimestamp(spiketimes_filepath.stat().st_ctime)
206229

207230
return creation_time, is_curated, is_qc

0 commit comments

Comments
 (0)