Skip to content

Commit a925450

Browse files
update to comply with datajoint 0.13 deferred schema use
1 parent d1decf2 commit a925450

File tree

9 files changed

+143
-110
lines changed

9 files changed

+143
-110
lines changed

djephys/utils.py

Lines changed: 0 additions & 12 deletions
This file was deleted.
File renamed without changes.

djephys/ephys.py renamed to elements_ephys/ephys.py

Lines changed: 104 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,90 @@
1+
import datajoint as dj
12
import pathlib
23
import re
34
import numpy as np
4-
import datajoint as dj
5+
import inspect
56
import uuid
7+
import hashlib
8+
from collections.abc import Mapping
9+
10+
from .readers import neuropixels, kilosort
11+
from .probe import schema as probe_schema, ProbeType, ElectrodeConfig
12+
13+
14+
schema = dj.schema()
15+
16+
17+
def activate(ephys_schema_name, probe_schema_name=None, create_schema=True, create_tables=True, add_objects=None):
18+
upstream_tables = ("Session", "SkullReference")
19+
required_functions = ("get_neuropixels_data_directory", "get_paramset_idx", "get_kilosort_output_directory")
20+
assert isinstance(add_objects, Mapping)
21+
try:
22+
raise next(RuntimeError("Table %s is required for module ephys" % name)
23+
for name in upstream_tables
24+
if not isinstance(add_objects.get(name, None), (dj.Manual, dj.Lookup, dj.Imported, dj.Computed)))
25+
except StopIteration:
26+
pass # all ok
27+
28+
assert isinstance(add_objects, Mapping)
29+
try:
30+
raise next(RuntimeError("Function %s is required for module ephys" % name)
31+
for name in required_functions
32+
if not inspect.isfunction(add_objects.get(name, None)))
33+
except StopIteration:
34+
pass # all ok
35+
36+
if not probe_schema.is_activated:
37+
probe_schema.activate(probe_schema_name or ephys_schema_name,
38+
create_schema=create_schema, create_tables=create_tables)
39+
schema.activate(ephys_schema_name, create_schema=create_schema,
40+
create_tables=create_tables, add_objects=add_objects)
641

7-
from . import utils
8-
from .probe import schema, Probe, ProbeType, ElectrodeConfig
9-
from ephys_loaders import neuropixels, kilosort
1042

11-
from djutils.templates import required
43+
# REQUIREMENTS: The workflow module must define these functions ---------------
44+
45+
46+
def get_neuropixels_data_directory():
47+
return None
48+
49+
50+
def get_kilosort_output_directory(clustering_task_key: dict) -> str:
51+
"""
52+
Retrieve the Kilosort output directory for a given ClusteringTask
53+
:param clustering_task_key: a dictionary of one EphysRecording
54+
:return: a string for full path to the resulting Kilosort output directory
55+
"""
56+
assert set(EphysRecording().primary_key) <= set(clustering_task_key)
57+
raise NotImplementedError('Workflow module should define')
58+
59+
60+
def get_paramset_idx(ephys_rec_key: dict) -> int:
61+
"""
62+
Retrieve attribute `paramset_idx` from the ClusteringParamSet record for the given EphysRecording key.
63+
:param ephys_rec_key: a dictionary of one EphysRecording
64+
:return: int specifying the `paramset_idx`
65+
"""
66+
assert set(EphysRecording().primary_key) <= set(ephys_rec_key)
67+
raise NotImplementedError('Workflow module should define')
68+
69+
70+
71+
def dict_to_uuid(key):
72+
"""
73+
Given a dictionary `key`, returns a hash string
74+
"""
75+
hashed = hashlib.md5()
76+
for k, v in sorted(key.items()):
77+
hashed.update(str(k).encode())
78+
hashed.update(str(v).encode())
79+
return uuid.UUID(hex=hashed.hexdigest())
1280

1381
# ===================================== Probe Insertion =====================================
1482

1583

1684
@schema
1785
class ProbeInsertion(dj.Manual): # (acute)
18-
19-
_Session = ...
20-
2186
definition = """
22-
-> self._Session
87+
-> Session
2388
insertion_number: tinyint unsigned
2489
---
2590
-> Probe
@@ -32,12 +97,10 @@ class ProbeInsertion(dj.Manual): # (acute)
3297
@schema
3398
class InsertionLocation(dj.Manual):
3499

35-
_SkullReference = ...
36-
37100
definition = """
38101
-> ProbeInsertion
39102
---
40-
-> self._SkullReference
103+
-> SkullReference
41104
ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
42105
ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
43106
depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
@@ -48,7 +111,7 @@ class InsertionLocation(dj.Manual):
48111

49112

50113
# ===================================== Ephys Recording =====================================
51-
# The abstract function _get_npx_data_dir() should expect one argument in the form of a
114+
# The abstract function _get_neuropixels_data_directory() should expect one argument in the form of a
52115
# dictionary with the keys from user-defined Subject and Session, as well as
53116
# "insertion_number" (as int) based on the "ProbeInsertion" table definition in this djephys
54117

@@ -62,33 +125,27 @@ class EphysRecording(dj.Imported):
62125
sampling_rate: float # (Hz)
63126
"""
64127

65-
@staticmethod
66-
@required
67-
def _get_npx_data_dir():
68-
return None
69-
70128
def make(self, key):
71-
npx_dir = EphysRecording._get_npx_data_dir(key)
72-
73-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
129+
neuropixels_dir = get_neuropixels_data_directory(key)
130+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
74131

75-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
132+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
76133

77-
if re.search('(1.0|2.0)', npx_meta.probe_model):
134+
if re.search('(1.0|2.0)', neuropixels_meta.probe_model):
78135
eg_members = []
79-
probe_type = {'probe_type': npx_meta.probe_model}
136+
probe_type = {'probe_type': neuropixels_meta.probe_model}
80137
q_electrodes = ProbeType.Electrode & probe_type
81-
for shank, shank_col, shank_row, is_used in npx_meta.shankmap['data']:
138+
for shank, shank_col, shank_row, is_used in neuropixels_meta.shankmap['data']:
82139
electrode = (q_electrodes & {'shank': shank,
83140
'shank_col': shank_col,
84141
'shank_row': shank_row}).fetch1('KEY')
85142
eg_members.append({**electrode, 'used_in_reference': is_used})
86143
else:
87144
raise NotImplementedError('Processing for neuropixels probe model {} not yet implemented'.format(
88-
npx_meta.probe_model))
145+
neuropixels_meta.probe_model))
89146

90147
# ---- compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) ----
91-
ec_hash = uuid.UUID(utils.dict_to_hash({k['electrode']: k for k in eg_members}))
148+
ec_hash = uuid.UUID(dict_to_uuid({k['electrode']: k for k in eg_members}))
92149

93150
el_list = sorted([k['electrode'] for k in eg_members])
94151
el_jumps = [-1] + np.where(np.diff(el_list) > 1)[0].tolist() + [len(el_list) - 1]
@@ -101,7 +158,7 @@ def make(self, key):
101158
ElectrodeConfig.insert1({**e_config, **probe_type, 'electrode_config_name': ec_name})
102159
ElectrodeConfig.Electrode.insert({**e_config, **m} for m in eg_members)
103160

104-
self.insert1({**key, **e_config, 'sampling_rate': npx_meta.meta['imSampRate']})
161+
self.insert1({**key, **e_config, 'sampling_rate': neuropixels_meta.meta['imSampRate']})
105162

106163

107164
# ===========================================================================================
@@ -130,14 +187,14 @@ class Electrode(dj.Part):
130187
"""
131188

132189
def make(self, key):
133-
npx_dir = EphysRecording._get_npx_data_dir(key)
134-
npx_recording = neuropixels.Neuropixels(npx_dir)
190+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(key)
191+
neuropixels_recording = neuropixels.Neuropixels(neuropixels_dir)
135192

136-
lfp = npx_recording.lfdata[:, :-1].T # exclude the sync channel
193+
lfp = neuropixels_recording.lfdata[:, :-1].T # exclude the sync channel
137194

138195
self.insert1(dict(key,
139-
lfp_sampling_rate=npx_recording.lfmeta['imSampRate'],
140-
lfp_time_stamps=np.arange(lfp.shape[1]) / npx_recording.lfmeta['imSampRate'],
196+
lfp_sampling_rate=neuropixels_recording.lfmeta['imSampRate'],
197+
lfp_time_stamps=np.arange(lfp.shape[1]) / neuropixels_recording.lfmeta['imSampRate'],
141198
lfp_mean=lfp.mean(axis=0)))
142199
'''
143200
Only store LFP for every 9th channel (defined in skip_chn_counts), counting in reverse
@@ -146,7 +203,7 @@ def make(self, key):
146203
q_electrodes = ProbeType.Electrode * ElectrodeConfig.Electrode & key
147204
electrodes = []
148205
for recorded_site in np.arange(lfp.shape[0]):
149-
shank, shank_col, shank_row, _ = npx_recording.npx_meta.shankmap['data'][recorded_site]
206+
shank, shank_col, shank_row, _ = neuropixels_recording.neuropixels_meta.shankmap['data'][recorded_site]
150207
electrodes.append((q_electrodes
151208
& {'shank': shank,
152209
'shank_col': shank_col,
@@ -191,7 +248,7 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, paramset_d
191248
'paramset_idx': paramset_idx,
192249
'paramset_desc': paramset_desc,
193250
'params': params,
194-
'param_set_hash': uuid.UUID(utils.dict_to_hash(params))}
251+
'param_set_hash': dict_to_uuid(params)}
195252
q_param = cls & {'param_set_hash': param_dict['param_set_hash']}
196253

197254
if q_param: # If the specified param-set already exists
@@ -212,26 +269,6 @@ class ClusteringTask(dj.Imported):
212269
-> ClusteringParamSet
213270
"""
214271

215-
@staticmethod
216-
@required
217-
def _get_paramset_idx(ephys_rec_key: dict) -> int:
218-
"""
219-
Retrieve the 'paramset_idx' (for ClusteringParamSet) to be used for this EphysRecording
220-
:param ephys_rec_key: a dictionary of one EphysRecording
221-
:return: int specifying the 'paramset_idx'
222-
"""
223-
return None
224-
225-
@staticmethod
226-
@required
227-
def _get_ks_data_dir(clustering_task_key: dict) -> str:
228-
"""
229-
Retrieve the Kilosort output directory for a given ClusteringTask
230-
:param clustering_task_key: a dictionary of one EphysRecording
231-
:return: a string for full path to the resulting Kilosort output directory
232-
"""
233-
return None
234-
235272
def make(self, key):
236273
key['paramset_idx'] = ClusteringTask._get_paramset_idx(key)
237274
self.insert1(key)
@@ -294,7 +331,7 @@ def make(self, key):
294331
valid_units = ks.data['cluster_ids'][withspike_idx]
295332
valid_unit_labels = ks.data['cluster_groups'][withspike_idx]
296333
# -- Get channel and electrode-site mapping
297-
chn2electrodes = get_npx_chn2electrode_map(key)
334+
chn2electrodes = get_neuropixels_chn2electrode_map(key)
298335

299336
# -- Spike-times --
300337
# spike_times_sec_adj > spike_times_sec > spike_times
@@ -352,16 +389,16 @@ def key_source(self):
352389
def make(self, key):
353390
units = {u['unit']: u for u in (Clustering.Unit & key).fetch(as_dict=True, order_by='unit')}
354391

355-
npx_dir = EphysRecording._get_npx_data_dir(key)
356-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
357-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
392+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(key)
393+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
394+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
358395

359396
ks_dir = ClusteringTask._get_ks_data_dir(key)
360397
ks = kilosort.Kilosort(ks_dir)
361398

362399
# -- Get channel and electrode-site mapping
363400
rec_key = (EphysRecording & key).fetch1('KEY')
364-
chn2electrodes = get_npx_chn2electrode_map(rec_key)
401+
chn2electrodes = get_neuropixels_chn2electrode_map(rec_key)
365402

366403
is_qc = (Clustering & key).fetch1('quality_control')
367404

@@ -375,10 +412,10 @@ def make(self, key):
375412
if chn2electrodes[chn]['electrode'] == units[unit_no]['electrode']:
376413
unit_peak_waveforms.append({**units[unit_no], 'peak_chn_waveform_mean': chn_wf})
377414
else:
378-
npx_recording = neuropixels.Neuropixels(npx_dir)
415+
neuropixels_recording = neuropixels.Neuropixels(neuropixels_dir)
379416
for unit_no, unit_dict in units.items():
380417
spks = (Clustering.Unit & unit_dict).fetch1('unit_spike_times')
381-
wfs = npx_recording.extract_spike_waveforms(spks, ks.data['channel_map']) # (sample x channel x spike)
418+
wfs = neuropixels_recording.extract_spike_waveforms(spks, ks.data['channel_map']) # (sample x channel x spike)
382419
wfs = wfs.transpose((1, 2, 0)) # (channel x spike x sample)
383420
for chn, chn_wf in zip(ks.data['channel_map'], wfs):
384421
unit_waveforms.append({**unit_dict, **chn2electrodes[chn],
@@ -425,15 +462,15 @@ def make(self, key):
425462
# ========================== HELPER FUNCTIONS =======================
426463

427464

428-
def get_npx_chn2electrode_map(ephys_recording_key):
429-
npx_dir = EphysRecording._get_npx_data_dir(ephys_recording_key)
430-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
431-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
465+
def get_neuropixels_chn2electrode_map(ephys_recording_key):
466+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(ephys_recording_key)
467+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
468+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
432469
e_config_key = (EphysRecording * ElectrodeConfig & ephys_recording_key).fetch1('KEY')
433470

434471
q_electrodes = ProbeType.Electrode * ElectrodeConfig.Electrode & e_config_key
435472
chn2electrode_map = {}
436-
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(npx_meta.shankmap['data']):
473+
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(neuropixels_meta.shankmap['data']):
437474
chn2electrode_map[recorded_site] = (q_electrodes
438475
& {'shank': shank,
439476
'shank_col': shank_col,

djephys/probe.py renamed to elements_ephys/probe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
"""
2+
Neuropixels Probes
3+
"""
4+
15
import datajoint as dj
26
import numpy as np
7+
schema = dj.schema()
38

4-
from djutils.templates import SchemaTemplate
5-
6-
7-
schema = SchemaTemplate()
8-
9-
# ===================================== Neuropixels Probes =====================================
109

10+
def activate(schema_name, create_schema=True, create_tables=True):
11+
schema.activate(schema_name, create_schema=create_schema, create_tables=create_tables)
1112

1213
@schema
1314
class ProbeType(dj.Lookup):
File renamed without changes.

ephys_loaders/kilosort.py renamed to elements_ephys/readers/kilosort.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@
55
import re
66
import logging
77

8-
from .utils import handle_string
9-
108
log = logging.getLogger(__name__)
119

1210

11+
def convert_to_number(value: str):
12+
if isinstance(value, str):
13+
try:
14+
value = int(value)
15+
except ValueError:
16+
try:
17+
value = float(value)
18+
except ValueError:
19+
pass
20+
return value
21+
22+
1323
class Kilosort:
1424

1525
ks_files = [
@@ -163,4 +173,4 @@ def extract_clustering_info(cluster_output_dir):
163173
spk_fp = next(cluster_output_dir.glob('spike_times.npy'))
164174
creation_time = datetime.fromtimestamp(spk_fp.stat().st_ctime)
165175

166-
return creation_time, is_curated, is_qc
176+
return creation_time, is_curated, is_qc

0 commit comments

Comments
 (0)