5
5
import inspect
6
6
import uuid
7
7
import hashlib
8
+ import importlib
8
9
from collections .abc import Mapping
9
10
10
11
from .readers import neuropixels , kilosort
11
12
from . import probe
12
13
13
14
schema = dj .schema ()
14
15
15
- context = locals ()
16
16
17
- table_classes = (dj .Manual , dj .Lookup , dj .Imported , dj .Computed )
18
17
18
+ required_upstream_tables = ("Session" , "SkullReference" )
19
+ required_functions = ("get_neuropixels_data_directory" , "get_paramset_idx" , "get_kilosort_output_directory" )
19
20
20
- def activate ( ephys_schema_name , probe_schema_name = None , create_schema = True , create_tables = True , add_objects = None ):
21
- assert isinstance ( add_objects , Mapping )
21
+ _table_classes = ( dj . Manual , dj . Lookup , dj . Imported , dj . Computed )
22
+ _required_objects = {}
22
23
23
- upstream_tables = ("Session" , "SkullReference" )
24
- for name in upstream_tables :
25
- assert name in add_objects , "Upstream table %s is required in ephys.activate(add_objects=...)" % name
26
- table = add_objects [name ]
24
+
25
+ def activate (ephys_schema_name , probe_schema_name = None , create_schema = True , create_tables = True , ephys_requirement = None ):
26
+ global _required_objects
27
+
28
+ if not isinstance (ephys_requirement , Mapping ):
29
+ if isinstance (ephys_requirement , str ):
30
+ ephys_requirement = importlib .import_module (ephys_requirement )
31
+
32
+ if inspect .ismodule (ephys_requirement ):
33
+ ephys_requirement = {key : getattr (ephys_requirement , key ) for key in dir (ephys_requirement )}
34
+ else :
35
+ raise ValueError ("Argument 'ephys_requirement' must be a dictionary, a module's name or a module" )
36
+
37
+ for name in required_upstream_tables :
38
+ assert name in ephys_requirement , "Upstream table %s is required in ephys.activate(ephys_requirement=...)" % name
39
+ table = ephys_requirement [name ]
27
40
if inspect .isclass (table ):
28
41
table = table ()
29
- assert isinstance (table , table_classes ), "Upstream table %s must be a DataJoint table " \
30
- "object in ephys.activate(add_objects=...)" % name
42
+ assert isinstance (table , _table_classes ), "Upstream table %s must be a DataJoint table " \
43
+ "object in ephys.activate(ephys_requirement=...)" % name
44
+ _required_objects [name ] = ephys_requirement [name ]
31
45
32
- required_functions = ("get_neuropixels_data_directory" , "get_paramset_idx" , "get_kilosort_output_directory" )
33
46
for name in required_functions :
34
- assert name in add_objects , "Functions %s is required in ephys.activate(add_objects =...)" % name
35
- assert inspect .isfunction (add_objects [name ]), "%s must be a function in ephys.activate(add_objects =...)" % name
36
- context . update ( ** { name : add_objects [name ]})
47
+ assert name in ephys_requirement , "Functions %s is required in ephys.activate(ephys_requirement =...)" % name
48
+ assert inspect .isfunction (ephys_requirement [name ]), "%s must be a function in ephys.activate(ephys_requirement =...)" % name
49
+ _required_objects [ name ] = ephys_requirement [name ]
37
50
38
51
# activate
39
52
if probe .schema .database is not None :
40
53
probe .schema .activate (probe_schema_name or ephys_schema_name ,
41
54
create_schema = create_schema , create_tables = create_tables )
42
55
43
56
schema .activate (ephys_schema_name , create_schema = create_schema ,
44
- create_tables = create_tables , add_objects = add_objects )
57
+ create_tables = create_tables , add_objects = _required_objects )
45
58
46
59
47
60
# -------------- Functions required by the elements-ephys ---------------
@@ -53,8 +66,7 @@ def get_neuropixels_data_directory(probe_insertion_key: dict) -> str:
53
66
:param probe_insertion_key: a dictionary of one ProbeInsertion `key`
54
67
:return: a string for full path to the resulting Neuropixels data directory
55
68
"""
56
- assert set (ProbeInsertion ().primary_key ) <= set (probe_insertion_key )
57
- raise NotImplementedError ('Workflow module should define function: "get_neuropixels_data_directory"' )
69
+ return _required_objects ['get_neuropixels_data_directory' ](probe_insertion_key )
58
70
59
71
60
72
def get_kilosort_output_directory (clustering_task_key : dict ) -> str :
@@ -63,8 +75,7 @@ def get_kilosort_output_directory(clustering_task_key: dict) -> str:
63
75
:param clustering_task_key: a dictionary of one ClusteringTask `key`
64
76
:return: a string for full path to the resulting Kilosort output directory
65
77
"""
66
- assert set (EphysRecording ().primary_key ) <= set (clustering_task_key )
67
- raise NotImplementedError ('Workflow module should define function: "get_kilosort_output_directory"' )
78
+ return _required_objects ['get_kilosort_output_directory' ](clustering_task_key )
68
79
69
80
70
81
def get_paramset_idx (ephys_rec_key : dict ) -> int :
@@ -73,8 +84,7 @@ def get_paramset_idx(ephys_rec_key: dict) -> int:
73
84
:param ephys_rec_key: a dictionary of one EphysRecording `key`
74
85
:return: int specifying the `paramset_idx`
75
86
"""
76
- assert set (EphysRecording ().primary_key ) <= set (ephys_rec_key )
77
- raise NotImplementedError ('Workflow module should define function: get_paramset_idx' )
87
+ return _required_objects ['get_paramset_idx' ](ephys_rec_key )
78
88
79
89
80
90
# ----------------------------- Table declarations ----------------------
0 commit comments