8
8
import pandas as pd
9
9
import spikeinterface as si
10
10
from element_array_ephys import probe , readers
11
- from element_interface .utils import find_full_path
11
+ from element_interface .utils import find_full_path , memoized_result
12
12
from spikeinterface import exporters , postprocessing , qualitymetrics , sorters
13
13
14
14
from . import si_preprocessing
@@ -192,23 +192,29 @@ def make(self, key):
192
192
recording_file , base_folder = output_dir
193
193
)
194
194
195
+ sorting_params = params ["SI_SORTING_PARAMS" ]
196
+ sorting_output_dir = output_dir / sorter_name / "spike_sorting"
197
+
195
198
# Run sorting
196
- # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
197
- si_sorting : si .sorters .BaseSorter = si .sorters .run_sorter (
198
- sorter_name = sorter_name ,
199
- recording = si_recording ,
200
- output_folder = output_dir / sorter_name / "spike_sorting" ,
201
- remove_existing_folder = True ,
202
- verbose = True ,
203
- docker_image = sorter_name not in si .sorters .installed_sorters (),
204
- ** params .get ("SI_SORTING_PARAMS" , {}),
199
+ @memoized_result (
200
+ uniqueness_dict = sorting_params ,
201
+ output_directory = sorting_output_dir ,
205
202
)
203
+ def _run_sorter ():
204
+ # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
205
+ si_sorting : si .sorters .BaseSorter = si .sorters .run_sorter (
206
+ sorter_name = sorter_name ,
207
+ recording = si_recording ,
208
+ output_folder = sorting_output_dir ,
209
+ remove_existing_folder = True ,
210
+ verbose = True ,
211
+ docker_image = sorter_name not in si .sorters .installed_sorters (),
212
+ ** sorting_params ,
213
+ )
206
214
207
- # Save sorting object
208
- sorting_save_path = (
209
- output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
210
- )
211
- si_sorting .dump_to_pickle (sorting_save_path , relative_to = output_dir )
215
+ # Save sorting object
216
+ sorting_save_path = sorting_output_dir / "si_sorting.pkl"
217
+ si_sorting .dump_to_pickle (sorting_save_path , relative_to = output_dir )
212
218
213
219
self .insert1 (
214
220
{
@@ -254,15 +260,20 @@ def make(self, key):
254
260
sorting_file , base_folder = output_dir
255
261
)
256
262
257
- job_kwargs = params ["SI_POSTPROCESSING_PARAMS" ].get (
263
+ postprocessing_params = params ["SI_POSTPROCESSING_PARAMS" ]
264
+
265
+ job_kwargs = postprocessing_params .get (
258
266
"job_kwargs" , {"n_jobs" : - 1 , "chunk_duration" : "1s" }
259
267
)
260
268
261
- # Sorting Analyzer
262
269
analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
263
- if (analyzer_output_dir / "extensions" ).exists ():
264
- sorting_analyzer = si .load_sorting_analyzer (folder = analyzer_output_dir )
265
- else :
270
+
271
+ @memoized_result (
272
+ uniqueness_dict = postprocessing_params ,
273
+ output_directory = analyzer_output_dir ,
274
+ )
275
+ def _sorting_analyzer_compute ():
276
+ # Sorting Analyzer
266
277
sorting_analyzer = si .create_sorting_analyzer (
267
278
sorting = si_sorting ,
268
279
recording = si_recording ,
@@ -273,31 +284,33 @@ def make(self, key):
273
284
** job_kwargs
274
285
)
275
286
276
- # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
277
- # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
278
- extensions_params = params ["SI_POSTPROCESSING_PARAMS" ].get ("extensions" , {})
279
- extensions_to_compute = {
280
- ext_name : extensions_params [ext_name ]
281
- for ext_name in sorting_analyzer .get_computable_extensions ()
282
- if ext_name in extensions_params
283
- }
284
-
285
- sorting_analyzer .compute (extensions_to_compute , ** job_kwargs )
286
-
287
- # Save to phy format
288
- if params ["SI_POSTPROCESSING_PARAMS" ].get ("export_to_phy" , False ):
289
- si .exporters .export_to_phy (
290
- sorting_analyzer = sorting_analyzer ,
291
- output_folder = output_dir / sorter_name / "phy" ,
292
- ** job_kwargs ,
293
- )
294
- # Generate spike interface report
295
- if params ["SI_POSTPROCESSING_PARAMS" ].get ("export_report" , True ):
296
- si .exporters .export_report (
297
- sorting_analyzer = sorting_analyzer ,
298
- output_folder = output_dir / sorter_name / "spikeinterface_report" ,
299
- ** job_kwargs ,
300
- )
287
+ # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
288
+ # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
289
+ extensions_params = postprocessing_params .get ("extensions" , {})
290
+ extensions_to_compute = {
291
+ ext_name : extensions_params [ext_name ]
292
+ for ext_name in sorting_analyzer .get_computable_extensions ()
293
+ if ext_name in extensions_params
294
+ }
295
+
296
+ sorting_analyzer .compute (extensions_to_compute , ** job_kwargs )
297
+
298
+ # Save to phy format
299
+ if postprocessing_params .get ("export_to_phy" , False ):
300
+ si .exporters .export_to_phy (
301
+ sorting_analyzer = sorting_analyzer ,
302
+ output_folder = analyzer_output_dir / "phy" ,
303
+ ** job_kwargs ,
304
+ )
305
+ # Generate spike interface report
306
+ if postprocessing_params .get ("export_report" , True ):
307
+ si .exporters .export_report (
308
+ sorting_analyzer = sorting_analyzer ,
309
+ output_folder = analyzer_output_dir / "spikeinterface_report" ,
310
+ ** job_kwargs ,
311
+ )
312
+
313
+ _sorting_analyzer_compute ()
301
314
302
315
self .insert1 (
303
316
{
0 commit comments