20
20
from ecephys_spike_sorting .scripts .create_input_json import createInputJson
21
21
from ecephys_spike_sorting .scripts .helpers import SpikeGLX_utils
22
22
except Exception as e :
23
- print (f'Error in loading "ecephys_spike_sorting" package - { str (e )} ' )
23
+ print (f'Warning: Failed loading "ecephys_spike_sorting" package - { str (e )} ' )
24
24
25
25
# import pykilosort package
26
26
try :
27
27
import pykilosort
28
28
except Exception as e :
29
- print (f'Error in loading "pykilosort" package - { str (e )} ' )
29
+ print (f'Warning: Failed loading "pykilosort" package - { str (e )} ' )
30
30
31
31
32
32
class SGLXKilosortPipeline :
@@ -67,7 +67,6 @@ def __init__(
67
67
ni_present = False ,
68
68
ni_extract_string = None ,
69
69
):
70
-
71
70
self ._npx_input_dir = pathlib .Path (npx_input_dir )
72
71
73
72
self ._ks_output_dir = pathlib .Path (ks_output_dir )
@@ -85,6 +84,13 @@ def __init__(
85
84
self ._json_directory = self ._ks_output_dir / "json_configs"
86
85
self ._json_directory .mkdir (parents = True , exist_ok = True )
87
86
87
+ self ._module_input_json = (
88
+ self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
89
+ )
90
+ self ._module_logfile = (
91
+ self ._json_directory / f"{ self ._npx_input_dir .name } -run_modules-log.txt"
92
+ )
93
+
88
94
self ._CatGT_finished = False
89
95
self .ks_input_params = None
90
96
self ._modules_input_hash = None
@@ -223,20 +229,20 @@ def generate_modules_input_json(self):
223
229
** params ,
224
230
)
225
231
226
- self ._modules_input_hash = dict_to_uuid (self .ks_input_params )
232
+ self ._modules_input_hash = dict_to_uuid (dict ( self ._params , KS2ver = self . _KS2ver ) )
227
233
228
- def run_modules (self ):
234
+ def run_modules (self , modules_to_run = None ):
229
235
if self ._run_CatGT and not self ._CatGT_finished :
230
236
self .run_CatGT ()
231
237
232
238
print ("---- Running Modules ----" )
233
239
self .generate_modules_input_json ()
234
240
module_input_json = self ._module_input_json .as_posix ()
235
- module_logfile = module_input_json . replace (
236
- "-input.json" , "-run_modules-log.txt"
237
- )
241
+ module_logfile = self . _module_logfile . as_posix ()
242
+
243
+ modules = modules_to_run or self . _modules
238
244
239
- for module in self . _modules :
245
+ for module in modules :
240
246
module_status = self ._get_module_status (module )
241
247
if module_status ["completion_time" ] is not None :
242
248
continue
@@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}):
312
318
else :
313
319
# handle cases of processing rerun on different parameters (the hash changes)
314
320
# delete outdated files
315
- outdated_files = [
316
- f
321
+ [
322
+ f . unlink ()
317
323
for f in self ._json_directory .glob ("*" )
318
324
if f .is_file () and f .name != self ._module_input_json .name
319
325
]
320
- for f in outdated_files :
321
- f .unlink ()
322
326
323
327
modules_status = {
324
328
module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -371,14 +375,26 @@ def _update_total_duration(self):
371
375
for k , v in modules_status .items ()
372
376
if k not in ("cumulative_execution_duration" , "total_duration" )
373
377
)
378
+
379
+ for m in self ._modules :
380
+ first_start_time = modules_status [m ]["start_time" ]
381
+ if first_start_time is not None :
382
+ break
383
+
384
+ for m in self ._modules [::- 1 ]:
385
+ last_completion_time = modules_status [m ]["completion_time" ]
386
+ if last_completion_time is not None :
387
+ break
388
+
389
+ if first_start_time is None or last_completion_time is None :
390
+ return
391
+
374
392
total_duration = (
375
393
datetime .strptime (
376
- modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
394
+ last_completion_time ,
377
395
"%Y-%m-%d %H:%M:%S.%f" ,
378
396
)
379
- - datetime .strptime (
380
- modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
381
- )
397
+ - datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
382
398
).total_seconds ()
383
399
self ._update_module_status (
384
400
{
@@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline:
414
430
def __init__ (
415
431
self , npx_input_dir : str , ks_output_dir : str , params : dict , KS2ver : str
416
432
):
417
-
418
433
self ._npx_input_dir = pathlib .Path (npx_input_dir )
419
434
420
435
self ._ks_output_dir = pathlib .Path (ks_output_dir )
@@ -426,7 +441,13 @@ def __init__(
426
441
self ._json_directory = self ._ks_output_dir / "json_configs"
427
442
self ._json_directory .mkdir (parents = True , exist_ok = True )
428
443
429
- self ._median_subtraction_status = {}
444
+ self ._module_input_json = (
445
+ self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
446
+ )
447
+ self ._module_logfile = (
448
+ self ._json_directory / f"{ self ._npx_input_dir .name } -run_modules-log.txt"
449
+ )
450
+
430
451
self .ks_input_params = None
431
452
self ._modules_input_hash = None
432
453
self ._modules_input_hash_fp = None
@@ -451,9 +472,6 @@ def make_chanmap_file(self):
451
472
452
473
def generate_modules_input_json (self ):
453
474
self .make_chanmap_file ()
454
- self ._module_input_json = (
455
- self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
456
- )
457
475
458
476
continuous_file = self ._get_raw_data_filepaths ()
459
477
@@ -497,35 +515,37 @@ def generate_modules_input_json(self):
497
515
** params ,
498
516
)
499
517
500
- self ._modules_input_hash = dict_to_uuid (self .ks_input_params )
518
+ self ._modules_input_hash = dict_to_uuid (dict ( self ._params , KS2ver = self . _KS2ver ) )
501
519
502
- def run_modules (self ):
520
+ def run_modules (self , modules_to_run = None ):
503
521
print ("---- Running Modules ----" )
504
522
self .generate_modules_input_json ()
505
523
module_input_json = self ._module_input_json .as_posix ()
506
- module_logfile = module_input_json .replace (
507
- "-input.json" , "-run_modules-log.txt"
508
- )
524
+ module_logfile = self ._module_logfile .as_posix ()
509
525
510
- for module in self ._modules :
526
+ modules = modules_to_run or self ._modules
527
+
528
+ for module in modules :
511
529
module_status = self ._get_module_status (module )
512
530
if module_status ["completion_time" ] is not None :
513
531
continue
514
532
515
- if module == "median_subtraction" and self ._median_subtraction_status :
516
- median_subtraction_status = self ._get_module_status (
517
- "median_subtraction"
518
- )
519
- median_subtraction_status ["duration" ] = self ._median_subtraction_status [
520
- "duration"
521
- ]
522
- median_subtraction_status ["completion_time" ] = datetime .strptime (
523
- median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
524
- ) + timedelta (seconds = median_subtraction_status ["duration" ])
525
- self ._update_module_status (
526
- {"median_subtraction" : median_subtraction_status }
533
+ if module == "median_subtraction" :
534
+ median_subtraction_duration = (
535
+ self ._get_median_subtraction_duration_from_log ()
527
536
)
528
- continue
537
+ if median_subtraction_duration is not None :
538
+ median_subtraction_status = self ._get_module_status (
539
+ "median_subtraction"
540
+ )
541
+ median_subtraction_status ["duration" ] = median_subtraction_duration
542
+ median_subtraction_status ["completion_time" ] = datetime .strptime (
543
+ median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
544
+ ) + timedelta (seconds = median_subtraction_status ["duration" ])
545
+ self ._update_module_status (
546
+ {"median_subtraction" : median_subtraction_status }
547
+ )
548
+ continue
529
549
530
550
module_output_json = self ._get_module_output_json_filename (module )
531
551
command = [
@@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self):
576
596
assert "depth_estimation" in self ._modules
577
597
continuous_file = self ._ks_output_dir / "continuous.dat"
578
598
if continuous_file .exists ():
579
- if raw_ap_fp .stat ().st_mtime < continuous_file .stat ().st_mtime :
580
- # if the copied continuous.dat was actually modified,
581
- # median_subtraction may have been completed - let's check
582
- module_input_json = self ._module_input_json .as_posix ()
583
- module_logfile = module_input_json .replace (
584
- "-input.json" , "-run_modules-log.txt"
585
- )
586
- with open (module_logfile , "r" ) as f :
587
- previous_line = ""
588
- for line in f .readlines ():
589
- if line .startswith (
590
- "ecephys spike sorting: median subtraction module"
591
- ) and previous_line .startswith ("Total processing time:" ):
592
- # regex to search for the processing duration - a float value
593
- duration = int (
594
- re .search ("\d+\.?\d+" , previous_line ).group ()
595
- )
596
- self ._median_subtraction_status ["duration" ] = duration
597
- return continuous_file
598
- previous_line = line
599
+ if raw_ap_fp .stat ().st_mtime == continuous_file .stat ().st_mtime :
600
+ return continuous_file
601
+ else :
602
+ if self ._module_logfile .exists ():
603
+ return continuous_file
599
604
600
605
shutil .copy2 (raw_ap_fp , continuous_file )
601
606
return continuous_file
@@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}):
614
619
else :
615
620
# handle cases of processing rerun on different parameters (the hash changes)
616
621
# delete outdated files
617
- outdated_files = [
618
- f
622
+ [
623
+ f . unlink ()
619
624
for f in self ._json_directory .glob ("*" )
620
625
if f .is_file () and f .name != self ._module_input_json .name
621
626
]
622
- for f in outdated_files :
623
- f .unlink ()
624
627
625
628
modules_status = {
626
629
module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -673,14 +676,26 @@ def _update_total_duration(self):
673
676
for k , v in modules_status .items ()
674
677
if k not in ("cumulative_execution_duration" , "total_duration" )
675
678
)
679
+
680
+ for m in self ._modules :
681
+ first_start_time = modules_status [m ]["start_time" ]
682
+ if first_start_time is not None :
683
+ break
684
+
685
+ for m in self ._modules [::- 1 ]:
686
+ last_completion_time = modules_status [m ]["completion_time" ]
687
+ if last_completion_time is not None :
688
+ break
689
+
690
+ if first_start_time is None or last_completion_time is None :
691
+ return
692
+
676
693
total_duration = (
677
694
datetime .strptime (
678
- modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
695
+ last_completion_time ,
679
696
"%Y-%m-%d %H:%M:%S.%f" ,
680
697
)
681
- - datetime .strptime (
682
- modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
683
- )
698
+ - datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
684
699
).total_seconds ()
685
700
self ._update_module_status (
686
701
{
@@ -689,6 +704,26 @@ def _update_total_duration(self):
689
704
}
690
705
)
691
706
707
+ def _get_median_subtraction_duration_from_log (self ):
708
+ raw_ap_fp = self ._npx_input_dir / "continuous.dat"
709
+ continuous_file = self ._ks_output_dir / "continuous.dat"
710
+ if raw_ap_fp .stat ().st_mtime < continuous_file .stat ().st_mtime :
711
+ # if the copied continuous.dat was actually modified,
712
+ # median_subtraction may have been completed - let's check
713
+ if self ._module_logfile .exists ():
714
+ with open (self ._module_logfile , "r" ) as f :
715
+ previous_line = ""
716
+ for line in f .readlines ():
717
+ if line .startswith (
718
+ "ecephys spike sorting: median subtraction module"
719
+ ) and previous_line .startswith ("Total processing time:" ):
720
+ # regex to search for the processing duration - a float value
721
+ duration = int (
722
+ re .search ("\d+\.?\d+" , previous_line ).group ()
723
+ )
724
+ return duration
725
+ previous_line = line
726
+
692
727
693
728
def run_pykilosort (
694
729
continuous_file ,
0 commit comments