Skip to content

Commit 5a31632

Browse files
authored
Merge pull request #195 from ttngu207/dev_spikeinterface_v101
feat: separate `export` (phy and report) into a separate table
2 parents 82d73c9 + a4a8380 commit 5a31632

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class PostProcessing(dj.Imported):
239239
---
240240
execution_time: datetime # datetime of the start of this step
241241
execution_duration: float # execution duration in hours
242+
do_si_export=0: bool # whether to export to phy
242243
"""
243244

244245
def make(self, key):
@@ -295,22 +296,6 @@ def _sorting_analyzer_compute():
295296

296297
sorting_analyzer.compute(extensions_to_compute, **job_kwargs)
297298

298-
# Save to phy format
299-
if postprocessing_params.get("export_to_phy", False):
300-
si.exporters.export_to_phy(
301-
sorting_analyzer=sorting_analyzer,
302-
output_folder=analyzer_output_dir / "phy",
303-
use_relative_path=True,
304-
**job_kwargs,
305-
)
306-
# Generate spike interface report
307-
if postprocessing_params.get("export_report", True):
308-
si.exporters.export_report(
309-
sorting_analyzer=sorting_analyzer,
310-
output_folder=analyzer_output_dir / "spikeinterface_report",
311-
**job_kwargs,
312-
)
313-
314299
_sorting_analyzer_compute()
315300

316301
self.insert1(
@@ -321,10 +306,87 @@ def _sorting_analyzer_compute():
321306
datetime.utcnow() - execution_time
322307
).total_seconds()
323308
/ 3600,
309+
"do_si_export": postprocessing_params.get("export_to_phy", False)
310+
or postprocessing_params.get("export_report", False),
324311
}
325312
)
326313

327314
# Once finished, insert this `key` into ephys.Clustering
328315
ephys.Clustering.insert1(
329316
{**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True
330317
)
318+
319+
320+
@schema
321+
class SIExport(dj.Computed):
322+
"""A SpikeInterface export report and to Phy"""
323+
324+
definition = """
325+
-> PostProcessing
326+
---
327+
execution_time: datetime
328+
execution_duration: float
329+
"""
330+
331+
@property
332+
def key_source(self):
333+
return PostProcessing & "do_si_export = 1"
334+
335+
def make(self, key):
336+
execution_time = datetime.utcnow()
337+
338+
clustering_method, output_dir, params = (
339+
ephys.ClusteringTask * ephys.ClusteringParamSet & key
340+
).fetch1("clustering_method", "clustering_output_dir", "params")
341+
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
342+
sorter_name = clustering_method.replace(".", "_")
343+
344+
postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
345+
346+
job_kwargs = postprocessing_params.get(
347+
"job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
348+
)
349+
350+
analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
351+
sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir)
352+
353+
@memoized_result(
354+
uniqueness_dict=postprocessing_params,
355+
output_directory=analyzer_output_dir / "phy",
356+
)
357+
def _export_to_phy():
358+
# Save to phy format
359+
si.exporters.export_to_phy(
360+
sorting_analyzer=sorting_analyzer,
361+
output_folder=analyzer_output_dir / "phy",
362+
use_relative_path=True,
363+
**job_kwargs,
364+
)
365+
366+
@memoized_result(
367+
uniqueness_dict=postprocessing_params,
368+
output_directory=analyzer_output_dir / "spikeinterface_report",
369+
)
370+
def _export_report():
371+
# Generate spike interface report
372+
si.exporters.export_report(
373+
sorting_analyzer=sorting_analyzer,
374+
output_folder=analyzer_output_dir / "spikeinterface_report",
375+
**job_kwargs,
376+
)
377+
378+
if postprocessing_params.get("export_report", False):
379+
_export_report()
380+
if postprocessing_params.get("export_to_phy", False):
381+
_export_to_phy()
382+
383+
self.insert1(
384+
{
385+
**key,
386+
"execution_time": execution_time,
387+
"execution_duration": (
388+
datetime.utcnow() - execution_time
389+
).total_seconds()
390+
/ 3600,
391+
}
392+
)

0 commit comments

Comments
 (0)