Skip to content

Commit 51e2ced

Browse files
committed
feat: add memoized_result on spike sorting
1 parent 9094754 commit 51e2ced

File tree

3 files changed

+60
-46
lines changed

3 files changed

+60
-46
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
66
## [0.4.0] - 2024-05-28
77

88
+ Add - support for SpikeInterface version >= 0.101.0 (updated API)
9+
+ Add - feature for memoization of spike sorting results (prevent duplicated runs)
910

1011

1112
## [0.3.4] - 2024-03-22

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import spikeinterface as si
1010
from element_array_ephys import probe, readers
11-
from element_interface.utils import find_full_path
11+
from element_interface.utils import find_full_path, memoized_result
1212
from spikeinterface import exporters, postprocessing, qualitymetrics, sorters
1313

1414
from . import si_preprocessing
@@ -192,23 +192,29 @@ def make(self, key):
192192
recording_file, base_folder=output_dir
193193
)
194194

195+
sorting_params = params["SI_SORTING_PARAMS"]
196+
sorting_output_dir = output_dir / sorter_name / "spike_sorting"
197+
195198
# Run sorting
196-
# Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
197-
si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
198-
sorter_name=sorter_name,
199-
recording=si_recording,
200-
output_folder=output_dir / sorter_name / "spike_sorting",
201-
remove_existing_folder=True,
202-
verbose=True,
203-
docker_image=sorter_name not in si.sorters.installed_sorters(),
204-
**params.get("SI_SORTING_PARAMS", {}),
199+
@memoized_result(
200+
uniqueness_dict=sorting_params,
201+
output_directory=sorting_output_dir,
205202
)
203+
def _run_sorter():
204+
# Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
205+
si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
206+
sorter_name=sorter_name,
207+
recording=si_recording,
208+
output_folder=sorting_output_dir,
209+
remove_existing_folder=True,
210+
verbose=True,
211+
docker_image=sorter_name not in si.sorters.installed_sorters(),
212+
**sorting_params,
213+
)
206214

207-
# Save sorting object
208-
sorting_save_path = (
209-
output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
210-
)
211-
si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)
215+
# Save sorting object
216+
sorting_save_path = sorting_output_dir / "si_sorting.pkl"
217+
si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)
212218

213219
self.insert1(
214220
{
@@ -254,15 +260,20 @@ def make(self, key):
254260
sorting_file, base_folder=output_dir
255261
)
256262

257-
job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get(
263+
postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
264+
265+
job_kwargs = postprocessing_params.get(
258266
"job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
259267
)
260268

261-
# Sorting Analyzer
262269
analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
263-
if (analyzer_output_dir / "extensions").exists():
264-
sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir)
265-
else:
270+
271+
@memoized_result(
272+
uniqueness_dict=postprocessing_params,
273+
output_directory=analyzer_output_dir,
274+
)
275+
def _sorting_analyzer_compute():
276+
# Sorting Analyzer
266277
sorting_analyzer = si.create_sorting_analyzer(
267278
sorting=si_sorting,
268279
recording=si_recording,
@@ -273,31 +284,33 @@ def make(self, key):
273284
**job_kwargs
274285
)
275286

276-
# The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
277-
# each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
278-
extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {})
279-
extensions_to_compute = {
280-
ext_name: extensions_params[ext_name]
281-
for ext_name in sorting_analyzer.get_computable_extensions()
282-
if ext_name in extensions_params
283-
}
284-
285-
sorting_analyzer.compute(extensions_to_compute, **job_kwargs)
286-
287-
# Save to phy format
288-
if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False):
289-
si.exporters.export_to_phy(
290-
sorting_analyzer=sorting_analyzer,
291-
output_folder=output_dir / sorter_name / "phy",
292-
**job_kwargs,
293-
)
294-
# Generate spike interface report
295-
if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True):
296-
si.exporters.export_report(
297-
sorting_analyzer=sorting_analyzer,
298-
output_folder=output_dir / sorter_name / "spikeinterface_report",
299-
**job_kwargs,
300-
)
287+
# The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
288+
# each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
289+
extensions_params = postprocessing_params.get("extensions", {})
290+
extensions_to_compute = {
291+
ext_name: extensions_params[ext_name]
292+
for ext_name in sorting_analyzer.get_computable_extensions()
293+
if ext_name in extensions_params
294+
}
295+
296+
sorting_analyzer.compute(extensions_to_compute, **job_kwargs)
297+
298+
# Save to phy format
299+
if postprocessing_params.get("export_to_phy", False):
300+
si.exporters.export_to_phy(
301+
sorting_analyzer=sorting_analyzer,
302+
output_folder=analyzer_output_dir / "phy",
303+
**job_kwargs,
304+
)
305+
# Generate spike interface report
306+
if postprocessing_params.get("export_report", True):
307+
si.exporters.export_report(
308+
sorting_analyzer=sorting_analyzer,
309+
output_folder=analyzer_output_dir / "spikeinterface_report",
310+
**job_kwargs,
311+
)
312+
313+
_sorting_analyzer_compute()
301314

302315
self.insert1(
303316
{

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"scikit-image",
4040
"nbformat>=4.2.0",
4141
"pyopenephys>=1.1.6",
42-
"element-interface @ git+https://github.com/datajoint/element-interface.git",
42+
"element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results",
4343
"numba",
4444
],
4545
extras_require={

0 commit comments

Comments
 (0)