Skip to content

Commit 1d30cb8

Browse files
authored
Merge pull request #142 from ttngu207/main
Update kilosort_triggering.py
2 parents 47dea95 + 5e1f055 commit 1d30cb8

File tree

2 files changed

+108
-69
lines changed

2 files changed

+108
-69
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [Unreleased] - 2023-06-23
7+
8+
+ Update - Improve kilosort triggering routine - better logging, remove temporary files, robust resumable processing
9+
610
## [0.2.10] - 2023-05-26
711

812
+ Add - Kilosort, NWB, and DANDI citations

element_array_ephys/readers/kilosort_triggering.py

Lines changed: 104 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
from ecephys_spike_sorting.scripts.create_input_json import createInputJson
2121
from ecephys_spike_sorting.scripts.helpers import SpikeGLX_utils
2222
except Exception as e:
23-
print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}')
23+
print(f'Warning: Failed loading "ecephys_spike_sorting" package - {str(e)}')
2424

2525
# import pykilosort package
2626
try:
2727
import pykilosort
2828
except Exception as e:
29-
print(f'Error in loading "pykilosort" package - {str(e)}')
29+
print(f'Warning: Failed loading "pykilosort" package - {str(e)}')
3030

3131

3232
class SGLXKilosortPipeline:
@@ -67,7 +67,6 @@ def __init__(
6767
ni_present=False,
6868
ni_extract_string=None,
6969
):
70-
7170
self._npx_input_dir = pathlib.Path(npx_input_dir)
7271

7372
self._ks_output_dir = pathlib.Path(ks_output_dir)
@@ -85,6 +84,13 @@ def __init__(
8584
self._json_directory = self._ks_output_dir / "json_configs"
8685
self._json_directory.mkdir(parents=True, exist_ok=True)
8786

87+
self._module_input_json = (
88+
self._json_directory / f"{self._npx_input_dir.name}-input.json"
89+
)
90+
self._module_logfile = (
91+
self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt"
92+
)
93+
8894
self._CatGT_finished = False
8995
self.ks_input_params = None
9096
self._modules_input_hash = None
@@ -223,20 +229,20 @@ def generate_modules_input_json(self):
223229
**params,
224230
)
225231

226-
self._modules_input_hash = dict_to_uuid(self.ks_input_params)
232+
self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver))
227233

228-
def run_modules(self):
234+
def run_modules(self, modules_to_run=None):
229235
if self._run_CatGT and not self._CatGT_finished:
230236
self.run_CatGT()
231237

232238
print("---- Running Modules ----")
233239
self.generate_modules_input_json()
234240
module_input_json = self._module_input_json.as_posix()
235-
module_logfile = module_input_json.replace(
236-
"-input.json", "-run_modules-log.txt"
237-
)
241+
module_logfile = self._module_logfile.as_posix()
242+
243+
modules = modules_to_run or self._modules
238244

239-
for module in self._modules:
245+
for module in modules:
240246
module_status = self._get_module_status(module)
241247
if module_status["completion_time"] is not None:
242248
continue
@@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}):
312318
else:
313319
# handle cases of processing rerun on different parameters (the hash changes)
314320
# delete outdated files
315-
outdated_files = [
316-
f
321+
[
322+
f.unlink()
317323
for f in self._json_directory.glob("*")
318324
if f.is_file() and f.name != self._module_input_json.name
319325
]
320-
for f in outdated_files:
321-
f.unlink()
322326

323327
modules_status = {
324328
module: {"start_time": None, "completion_time": None, "duration": None}
@@ -371,14 +375,26 @@ def _update_total_duration(self):
371375
for k, v in modules_status.items()
372376
if k not in ("cumulative_execution_duration", "total_duration")
373377
)
378+
379+
for m in self._modules:
380+
first_start_time = modules_status[m]["start_time"]
381+
if first_start_time is not None:
382+
break
383+
384+
for m in self._modules[::-1]:
385+
last_completion_time = modules_status[m]["completion_time"]
386+
if last_completion_time is not None:
387+
break
388+
389+
if first_start_time is None or last_completion_time is None:
390+
return
391+
374392
total_duration = (
375393
datetime.strptime(
376-
modules_status[self._modules[-1]]["completion_time"],
394+
last_completion_time,
377395
"%Y-%m-%d %H:%M:%S.%f",
378396
)
379-
- datetime.strptime(
380-
modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f"
381-
)
397+
- datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f")
382398
).total_seconds()
383399
self._update_module_status(
384400
{
@@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline:
414430
def __init__(
415431
self, npx_input_dir: str, ks_output_dir: str, params: dict, KS2ver: str
416432
):
417-
418433
self._npx_input_dir = pathlib.Path(npx_input_dir)
419434

420435
self._ks_output_dir = pathlib.Path(ks_output_dir)
@@ -426,7 +441,13 @@ def __init__(
426441
self._json_directory = self._ks_output_dir / "json_configs"
427442
self._json_directory.mkdir(parents=True, exist_ok=True)
428443

429-
self._median_subtraction_status = {}
444+
self._module_input_json = (
445+
self._json_directory / f"{self._npx_input_dir.name}-input.json"
446+
)
447+
self._module_logfile = (
448+
self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt"
449+
)
450+
430451
self.ks_input_params = None
431452
self._modules_input_hash = None
432453
self._modules_input_hash_fp = None
@@ -451,9 +472,6 @@ def make_chanmap_file(self):
451472

452473
def generate_modules_input_json(self):
453474
self.make_chanmap_file()
454-
self._module_input_json = (
455-
self._json_directory / f"{self._npx_input_dir.name}-input.json"
456-
)
457475

458476
continuous_file = self._get_raw_data_filepaths()
459477

@@ -497,35 +515,37 @@ def generate_modules_input_json(self):
497515
**params,
498516
)
499517

500-
self._modules_input_hash = dict_to_uuid(self.ks_input_params)
518+
self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver))
501519

502-
def run_modules(self):
520+
def run_modules(self, modules_to_run=None):
503521
print("---- Running Modules ----")
504522
self.generate_modules_input_json()
505523
module_input_json = self._module_input_json.as_posix()
506-
module_logfile = module_input_json.replace(
507-
"-input.json", "-run_modules-log.txt"
508-
)
524+
module_logfile = self._module_logfile.as_posix()
509525

510-
for module in self._modules:
526+
modules = modules_to_run or self._modules
527+
528+
for module in modules:
511529
module_status = self._get_module_status(module)
512530
if module_status["completion_time"] is not None:
513531
continue
514532

515-
if module == "median_subtraction" and self._median_subtraction_status:
516-
median_subtraction_status = self._get_module_status(
517-
"median_subtraction"
518-
)
519-
median_subtraction_status["duration"] = self._median_subtraction_status[
520-
"duration"
521-
]
522-
median_subtraction_status["completion_time"] = datetime.strptime(
523-
median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f"
524-
) + timedelta(seconds=median_subtraction_status["duration"])
525-
self._update_module_status(
526-
{"median_subtraction": median_subtraction_status}
533+
if module == "median_subtraction":
534+
median_subtraction_duration = (
535+
self._get_median_subtraction_duration_from_log()
527536
)
528-
continue
537+
if median_subtraction_duration is not None:
538+
median_subtraction_status = self._get_module_status(
539+
"median_subtraction"
540+
)
541+
median_subtraction_status["duration"] = median_subtraction_duration
542+
median_subtraction_status["completion_time"] = datetime.strptime(
543+
median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f"
544+
) + timedelta(seconds=median_subtraction_status["duration"])
545+
self._update_module_status(
546+
{"median_subtraction": median_subtraction_status}
547+
)
548+
continue
529549

530550
module_output_json = self._get_module_output_json_filename(module)
531551
command = [
@@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self):
576596
assert "depth_estimation" in self._modules
577597
continuous_file = self._ks_output_dir / "continuous.dat"
578598
if continuous_file.exists():
579-
if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime:
580-
# if the copied continuous.dat was actually modified,
581-
# median_subtraction may have been completed - let's check
582-
module_input_json = self._module_input_json.as_posix()
583-
module_logfile = module_input_json.replace(
584-
"-input.json", "-run_modules-log.txt"
585-
)
586-
with open(module_logfile, "r") as f:
587-
previous_line = ""
588-
for line in f.readlines():
589-
if line.startswith(
590-
"ecephys spike sorting: median subtraction module"
591-
) and previous_line.startswith("Total processing time:"):
592-
# regex to search for the processing duration - a float value
593-
duration = int(
594-
re.search("\d+\.?\d+", previous_line).group()
595-
)
596-
self._median_subtraction_status["duration"] = duration
597-
return continuous_file
598-
previous_line = line
599+
if raw_ap_fp.stat().st_mtime == continuous_file.stat().st_mtime:
600+
return continuous_file
601+
else:
602+
if self._module_logfile.exists():
603+
return continuous_file
599604

600605
shutil.copy2(raw_ap_fp, continuous_file)
601606
return continuous_file
@@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}):
614619
else:
615620
# handle cases of processing rerun on different parameters (the hash changes)
616621
# delete outdated files
617-
outdated_files = [
618-
f
622+
[
623+
f.unlink()
619624
for f in self._json_directory.glob("*")
620625
if f.is_file() and f.name != self._module_input_json.name
621626
]
622-
for f in outdated_files:
623-
f.unlink()
624627

625628
modules_status = {
626629
module: {"start_time": None, "completion_time": None, "duration": None}
@@ -673,14 +676,26 @@ def _update_total_duration(self):
673676
for k, v in modules_status.items()
674677
if k not in ("cumulative_execution_duration", "total_duration")
675678
)
679+
680+
for m in self._modules:
681+
first_start_time = modules_status[m]["start_time"]
682+
if first_start_time is not None:
683+
break
684+
685+
for m in self._modules[::-1]:
686+
last_completion_time = modules_status[m]["completion_time"]
687+
if last_completion_time is not None:
688+
break
689+
690+
if first_start_time is None or last_completion_time is None:
691+
return
692+
676693
total_duration = (
677694
datetime.strptime(
678-
modules_status[self._modules[-1]]["completion_time"],
695+
last_completion_time,
679696
"%Y-%m-%d %H:%M:%S.%f",
680697
)
681-
- datetime.strptime(
682-
modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f"
683-
)
698+
- datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f")
684699
).total_seconds()
685700
self._update_module_status(
686701
{
@@ -689,6 +704,26 @@ def _update_total_duration(self):
689704
}
690705
)
691706

707+
def _get_median_subtraction_duration_from_log(self):
708+
raw_ap_fp = self._npx_input_dir / "continuous.dat"
709+
continuous_file = self._ks_output_dir / "continuous.dat"
710+
if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime:
711+
# if the copied continuous.dat was actually modified,
712+
# median_subtraction may have been completed - let's check
713+
if self._module_logfile.exists():
714+
with open(self._module_logfile, "r") as f:
715+
previous_line = ""
716+
for line in f.readlines():
717+
if line.startswith(
718+
"ecephys spike sorting: median subtraction module"
719+
) and previous_line.startswith("Total processing time:"):
720+
# regex to search for the processing duration - a float value
721+
duration = int(
722+
re.search("\d+\.?\d+", previous_line).group()
723+
)
724+
return duration
725+
previous_line = line
726+
692727

693728
def run_pykilosort(
694729
continuous_file,

0 commit comments

Comments
 (0)