Skip to content

Commit 93895a9

Browse files
committed
Refactor Quality Metrics Logic + blackformatting
1 parent c819d19 commit 93895a9

File tree

2 files changed

+73
-33
lines changed

2 files changed

+73
-33
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,11 +1277,11 @@ def make(self, key):
12771277
we: si.WaveformExtractor = si.load_waveforms(
12781278
si_waveform_dir, with_recording=False
12791279
)
1280-
unit_id_to_peak_channel_map: dict[int, np.ndarray] = (
1281-
si.ChannelSparsity.from_best_channels(
1282-
we, 1, peak_sign="neg"
1283-
).unit_id_to_channel_indices
1284-
) # {unit: peak_channel_index}
1280+
unit_id_to_peak_channel_map: dict[
1281+
int, np.ndarray
1282+
] = si.ChannelSparsity.from_best_channels(
1283+
we, 1, peak_sign="neg"
1284+
).unit_id_to_channel_indices # {unit: peak_channel_index}
12851285

12861286
# reorder channel2electrode_map according to recording channel ids
12871287
channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist()
@@ -1315,6 +1315,7 @@ def yield_unit_waveforms():
13151315
]
13161316

13171317
yield unit_peak_waveform, unit_electrode_waveforms
1318+
13181319
else: # read from kilosort outputs
13191320
kilosort_dataset = kilosort.Kilosort(output_dir)
13201321

@@ -1546,9 +1547,14 @@ def make(self, key):
15461547

15471548
metrics_df.rename(
15481549
columns={
1549-
"isi_viol": "isi_violation",
1550-
"num_viol": "number_violation",
1551-
"contam_rate": "contamination_rate",
1550+
"isi_violations_ratio": "isi_violation",
1551+
"isi_violations_count": "number_violation",
1552+
"silhouette": "silhouette_score",
1553+
"rp_contamination": "contamination_rate",
1554+
"drift_ptp": "max_drift",
1555+
"drift_mad": "cumulative_drift",
1556+
"half_width": "halfwidth",
1557+
"peak_trough_ratio": "pt_ratio",
15521558
},
15531559
inplace=True,
15541560
)

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,19 @@ def make(self, key):
134134
# Add probe information to recording object
135135
electrodes_df = (
136136
(
137-
ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
137+
ephys.EphysRecording.Channel
138+
* probe.ElectrodeConfig.Electrode
139+
* probe.ProbeType.Electrode
138140
& key
139141
)
140142
.fetch(format="frame")
141143
.reset_index()
142144
)
143145

144146
# Create SI probe object
145-
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]])
147+
si_probe = readers.probe_geometry.to_probeinterface(
148+
electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]
149+
)
146150
si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
147151
si_recording.set_probe(probe=si_probe, in_place=True)
148152

@@ -184,7 +188,9 @@ def make(self, key):
184188
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
185189
sorter_name = clustering_method.replace(".", "_")
186190
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
187-
si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir)
191+
si_recording: si.BaseRecording = si.load_extractor(
192+
recording_file, base_folder=output_dir
193+
)
188194

189195
# Run sorting
190196
# Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
@@ -241,8 +247,12 @@ def make(self, key):
241247
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
242248
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
243249

244-
si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir)
245-
si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir)
250+
si_recording: si.BaseRecording = si.load_extractor(
251+
recording_file, base_folder=output_dir
252+
)
253+
si_sorting: si.sorters.BaseSorter = si.load_extractor(
254+
sorting_file, base_folder=output_dir
255+
)
246256

247257
# Extract waveforms
248258
we: si.WaveformExtractor = si.extract_waveforms(
@@ -257,37 +267,61 @@ def make(self, key):
257267
**params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}),
258268
)
259269

260-
# Calculate QC Metrics
261-
metrics: pd.DataFrame = si.qualitymetrics.compute_quality_metrics(
262-
we,
263-
metric_names=[
264-
"firing_rate",
265-
"snr",
266-
"presence_ratio",
267-
"isi_violation",
268-
"num_spikes",
269-
"amplitude_cutoff",
270-
"amplitude_median",
271-
"sliding_rp_violation",
272-
"rp_violation",
273-
"drift",
274-
],
275-
)
276-
# Add PCA based metrics. These will be added to the metrics dataframe above.
270+
# Calculate Cluster and Waveform Metrics
271+
272+
# To provide waveform_principal_component
277273
_ = si.postprocessing.compute_principal_components(
278274
waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None)
279275
)
280-
metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we)
276+
277+
# To estimate the location of each spike in the sorting output.
278+
# The drift metrics require the `spike_locations` waveform extension.
279+
_ = si.postprocessing.compute_spike_locations(waveform_extractor=we)
280+
281+
# The `sd_ratio` metric requires the `spike_amplitudes` waveform extension.
282+
# It is highly recommended before calculating amplitude-based quality metrics.
283+
_ = si.postprocessing.compute_spike_amplitudes(waveform_extractor=we)
284+
285+
# To compute correlograms for spike trains.
286+
_ = si.postprocessing.compute_correlograms(we)
287+
288+
metric_names = si.qualitymetrics.get_quality_metric_list()
289+
metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list())
290+
291+
# To compute commonly used cluster quality metrics.
292+
qc_metrics = si.qualitymetrics.compute_quality_metrics(
293+
waveform_extractor=we,
294+
metric_names=metric_names,
295+
)
296+
297+
# To compute commonly used waveform/template metrics.
298+
template_metric_names = si.postprocessing.get_template_metric_names()
299+
template_metric_names.extend(["amplitude", "duration"])
300+
301+
template_metrics = si.postprocessing.compute_template_metrics(
302+
waveform_extractor=we,
303+
include_multi_channel_metrics=True,
304+
metric_names=template_metric_names,
305+
)
306+
307+
# Save the output (metrics.csv to the output dir)
308+
metrics = pd.DataFrame()
309+
metrics = pd.concat([qc_metrics, template_metrics], axis=1)
281310

282311
# Save the output (metrics.csv to the output dir)
283312
metrics_output_dir = output_dir / sorter_name / "metrics"
284313
metrics_output_dir.mkdir(parents=True, exist_ok=True)
285314
metrics.to_csv(metrics_output_dir / "metrics.csv")
286315

287316
# Save to phy format
288-
si.exporters.export_to_phy(waveform_extractor=we, output_folder=output_dir / sorter_name / "phy")
317+
si.exporters.export_to_phy(
318+
waveform_extractor=we, output_folder=output_dir / sorter_name / "phy"
319+
)
289320
# Generate spike interface report
290-
si.exporters.export_report(waveform_extractor=we, output_folder=output_dir / sorter_name / "spikeinterface_report")
321+
si.exporters.export_report(
322+
waveform_extractor=we,
323+
output_folder=output_dir / sorter_name / "spikeinterface_report",
324+
)
291325

292326
self.insert1(
293327
{

0 commit comments

Comments
 (0)