@@ -134,15 +134,19 @@ def make(self, key):
134
134
# Add probe information to recording object
135
135
electrodes_df = (
136
136
(
137
- ephys .EphysRecording .Channel * probe .ElectrodeConfig .Electrode * probe .ProbeType .Electrode
137
+ ephys .EphysRecording .Channel
138
+ * probe .ElectrodeConfig .Electrode
139
+ * probe .ProbeType .Electrode
138
140
& key
139
141
)
140
142
.fetch (format = "frame" )
141
143
.reset_index ()
142
144
)
143
145
144
146
# Create SI probe object
145
- si_probe = readers .probe_geometry .to_probeinterface (electrodes_df [["electrode" , "x_coord" , "y_coord" , "shank" ]])
147
+ si_probe = readers .probe_geometry .to_probeinterface (
148
+ electrodes_df [["electrode" , "x_coord" , "y_coord" , "shank" ]]
149
+ )
146
150
si_probe .set_device_channel_indices (electrodes_df ["channel_idx" ].values )
147
151
si_recording .set_probe (probe = si_probe , in_place = True )
148
152
@@ -184,7 +188,9 @@ def make(self, key):
184
188
output_dir = find_full_path (ephys .get_ephys_root_data_dir (), output_dir )
185
189
sorter_name = clustering_method .replace ("." , "_" )
186
190
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
187
- si_recording : si .BaseRecording = si .load_extractor (recording_file , base_folder = output_dir )
191
+ si_recording : si .BaseRecording = si .load_extractor (
192
+ recording_file , base_folder = output_dir
193
+ )
188
194
189
195
# Run sorting
190
196
# Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
@@ -241,8 +247,12 @@ def make(self, key):
241
247
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
242
248
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
243
249
244
- si_recording : si .BaseRecording = si .load_extractor (recording_file , base_folder = output_dir )
245
- si_sorting : si .sorters .BaseSorter = si .load_extractor (sorting_file , base_folder = output_dir )
250
+ si_recording : si .BaseRecording = si .load_extractor (
251
+ recording_file , base_folder = output_dir
252
+ )
253
+ si_sorting : si .sorters .BaseSorter = si .load_extractor (
254
+ sorting_file , base_folder = output_dir
255
+ )
246
256
247
257
# Extract waveforms
248
258
we : si .WaveformExtractor = si .extract_waveforms (
@@ -257,37 +267,61 @@ def make(self, key):
257
267
** params .get ("SI_JOB_KWARGS" , {"n_jobs" : - 1 , "chunk_size" : 30000 }),
258
268
)
259
269
260
- # Calculate QC Metrics
261
- metrics : pd .DataFrame = si .qualitymetrics .compute_quality_metrics (
262
- we ,
263
- metric_names = [
264
- "firing_rate" ,
265
- "snr" ,
266
- "presence_ratio" ,
267
- "isi_violation" ,
268
- "num_spikes" ,
269
- "amplitude_cutoff" ,
270
- "amplitude_median" ,
271
- "sliding_rp_violation" ,
272
- "rp_violation" ,
273
- "drift" ,
274
- ],
275
- )
276
- # Add PCA based metrics. These will be added to the metrics dataframe above.
270
+ # Calculate Cluster and Waveform Metrics
271
+
272
+ # To provide waveform_principal_component
277
273
_ = si .postprocessing .compute_principal_components (
278
274
waveform_extractor = we , ** params .get ("SI_QUALITY_METRICS_PARAMS" , None )
279
275
)
280
- metrics = si .qualitymetrics .compute_quality_metrics (waveform_extractor = we )
276
+
277
+ # To estimate the location of each spike in the sorting output.
278
+ # The drift metrics require the `spike_locations` waveform extension.
279
+ _ = si .postprocessing .compute_spike_locations (waveform_extractor = we )
280
+
281
+ # The `sd_ratio` metric requires the `spike_amplitudes` waveform extension.
282
+ # It is highly recommended before calculating amplitude-based quality metrics.
283
+ _ = si .postprocessing .compute_spike_amplitudes (waveform_extractor = we )
284
+
285
+ # To compute correlograms for spike trains.
286
+ _ = si .postprocessing .compute_correlograms (we )
287
+
288
+ metric_names = si .qualitymetrics .get_quality_metric_list ()
289
+ metric_names .extend (si .qualitymetrics .get_quality_pca_metric_list ())
290
+
291
+ # To compute commonly used cluster quality metrics.
292
+ qc_metrics = si .qualitymetrics .compute_quality_metrics (
293
+ waveform_extractor = we ,
294
+ metric_names = metric_names ,
295
+ )
296
+
297
+ # To compute commonly used waveform/template metrics.
298
+ template_metric_names = si .postprocessing .get_template_metric_names ()
299
+ template_metric_names .extend (["amplitude" , "duration" ])
300
+
301
+ template_metrics = si .postprocessing .compute_template_metrics (
302
+ waveform_extractor = we ,
303
+ include_multi_channel_metrics = True ,
304
+ metric_names = template_metric_names ,
305
+ )
306
+
307
+ # Save the output (metrics.csv to the output dir)
308
+ metrics = pd .DataFrame ()
309
+ metrics = pd .concat ([qc_metrics , template_metrics ], axis = 1 )
281
310
282
311
# Save the output (metrics.csv to the output dir)
283
312
metrics_output_dir = output_dir / sorter_name / "metrics"
284
313
metrics_output_dir .mkdir (parents = True , exist_ok = True )
285
314
metrics .to_csv (metrics_output_dir / "metrics.csv" )
286
315
287
316
# Save to phy format
288
- si .exporters .export_to_phy (waveform_extractor = we , output_folder = output_dir / sorter_name / "phy" )
317
+ si .exporters .export_to_phy (
318
+ waveform_extractor = we , output_folder = output_dir / sorter_name / "phy"
319
+ )
289
320
# Generate spike interface report
290
- si .exporters .export_report (waveform_extractor = we , output_folder = output_dir / sorter_name / "spikeinterface_report" )
321
+ si .exporters .export_report (
322
+ waveform_extractor = we ,
323
+ output_folder = output_dir / sorter_name / "spikeinterface_report" ,
324
+ )
291
325
292
326
self .insert1 (
293
327
{
0 commit comments