Skip to content

Commit c9e7f1e

Browse files
author
Thinh Nguyen
committed
implement new "activation" mechanism -> using dict, module name or module as 'requirement'
1 parent 852f5a4 commit c9e7f1e

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

elements_ephys/ephys.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,56 @@
55
import inspect
66
import uuid
77
import hashlib
8+
import importlib
89
from collections.abc import Mapping
910

1011
from .readers import neuropixels, kilosort
1112
from . import probe
1213

1314
schema = dj.schema()
1415

15-
context = locals()
1616

17-
table_classes = (dj.Manual, dj.Lookup, dj.Imported, dj.Computed)
1817

18+
required_upstream_tables = ("Session", "SkullReference")
19+
required_functions = ("get_neuropixels_data_directory", "get_paramset_idx", "get_kilosort_output_directory")
1920

20-
def activate(ephys_schema_name, probe_schema_name=None, create_schema=True, create_tables=True, add_objects=None):
21-
assert isinstance(add_objects, Mapping)
21+
_table_classes = (dj.Manual, dj.Lookup, dj.Imported, dj.Computed)
22+
_required_objects = {}
2223

23-
upstream_tables = ("Session", "SkullReference")
24-
for name in upstream_tables:
25-
assert name in add_objects, "Upstream table %s is required in ephys.activate(add_objects=...)" % name
26-
table = add_objects[name]
24+
25+
def activate(ephys_schema_name, probe_schema_name=None, create_schema=True, create_tables=True, ephys_requirement=None):
26+
global _required_objects
27+
28+
if not isinstance(ephys_requirement, Mapping):
29+
if isinstance(ephys_requirement, str):
30+
ephys_requirement = importlib.import_module(ephys_requirement)
31+
32+
if inspect.ismodule(ephys_requirement):
33+
ephys_requirement = {key: getattr(ephys_requirement, key) for key in dir(ephys_requirement)}
34+
else:
35+
raise ValueError("Argument 'ephys_requirement' must be a dictionary, a module's name or a module")
36+
37+
for name in required_upstream_tables:
38+
assert name in ephys_requirement, "Upstream table %s is required in ephys.activate(ephys_requirement=...)" % name
39+
table = ephys_requirement[name]
2740
if inspect.isclass(table):
2841
table = table()
29-
assert isinstance(table, table_classes), "Upstream table %s must be a DataJoint table " \
30-
"object in ephys.activate(add_objects=...)" % name
42+
assert isinstance(table, _table_classes), "Upstream table %s must be a DataJoint table " \
43+
"object in ephys.activate(ephys_requirement=...)" % name
44+
_required_objects[name] = ephys_requirement[name]
3145

32-
required_functions = ("get_neuropixels_data_directory", "get_paramset_idx", "get_kilosort_output_directory")
3346
for name in required_functions:
34-
assert name in add_objects, "Functions %s is required in ephys.activate(add_objects=...)" % name
35-
assert inspect.isfunction(add_objects[name]), "%s must be a function in ephys.activate(add_objects=...)" % name
36-
context.update(**{name: add_objects[name]})
47+
assert name in ephys_requirement, "Functions %s is required in ephys.activate(ephys_requirement=...)" % name
48+
assert inspect.isfunction(ephys_requirement[name]), "%s must be a function in ephys.activate(ephys_requirement=...)" % name
49+
_required_objects[name] = ephys_requirement[name]
3750

3851
# activate
3952
if probe.schema.database is not None:
4053
probe.schema.activate(probe_schema_name or ephys_schema_name,
4154
create_schema=create_schema, create_tables=create_tables)
4255

4356
schema.activate(ephys_schema_name, create_schema=create_schema,
44-
create_tables=create_tables, add_objects=add_objects)
57+
create_tables=create_tables, add_objects=_required_objects)
4558

4659

4760
# -------------- Functions required by the elements-ephys ---------------
@@ -53,8 +66,7 @@ def get_neuropixels_data_directory(probe_insertion_key: dict) -> str:
5366
:param probe_insertion_key: a dictionary of one ProbeInsertion `key`
5467
:return: a string for full path to the resulting Neuropixels data directory
5568
"""
56-
assert set(ProbeInsertion().primary_key) <= set(probe_insertion_key)
57-
raise NotImplementedError('Workflow module should define function: "get_neuropixels_data_directory"')
69+
return _required_objects['get_neuropixels_data_directory'](probe_insertion_key)
5870

5971

6072
def get_kilosort_output_directory(clustering_task_key: dict) -> str:
@@ -63,8 +75,7 @@ def get_kilosort_output_directory(clustering_task_key: dict) -> str:
6375
:param clustering_task_key: a dictionary of one ClusteringTask `key`
6476
:return: a string for full path to the resulting Kilosort output directory
6577
"""
66-
assert set(EphysRecording().primary_key) <= set(clustering_task_key)
67-
raise NotImplementedError('Workflow module should define function: "get_kilosort_output_directory"')
78+
return _required_objects['get_kilosort_output_directory'](clustering_task_key)
6879

6980

7081
def get_paramset_idx(ephys_rec_key: dict) -> int:
@@ -73,8 +84,7 @@ def get_paramset_idx(ephys_rec_key: dict) -> int:
7384
:param ephys_rec_key: a dictionary of one EphysRecording `key`
7485
:return: int specifying the `paramset_idx`
7586
"""
76-
assert set(EphysRecording().primary_key) <= set(ephys_rec_key)
77-
raise NotImplementedError('Workflow module should define function: get_paramset_idx')
87+
return _required_objects['get_paramset_idx'](ephys_rec_key)
7888

7989

8090
# ----------------------------- Table declarations ----------------------

0 commit comments

Comments
 (0)