Skip to content

Commit 5834b4a

Browse files
authored
Merge pull request #2 from dimitri-yatsenko/main
Refactor to use schema.activate from DataJoint 0.13
2 parents d1decf2 + 5edc3ce commit 5834b4a

File tree

8 files changed

+149
-118
lines changed

8 files changed

+149
-118
lines changed

djephys/utils.py

Lines changed: 0 additions & 12 deletions
This file was deleted.
File renamed without changes.

djephys/ephys.py renamed to elements_ephys/ephys.py

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,91 @@
1+
import datajoint as dj
12
import pathlib
23
import re
34
import numpy as np
4-
import datajoint as dj
5+
import inspect
56
import uuid
7+
import hashlib
8+
from collections.abc import Mapping
9+
10+
from .readers import neuropixels, kilosort
11+
from . import probe
12+
13+
schema = dj.schema()
14+
15+
16+
def activate(ephys_schema_name, probe_schema_name=None, create_schema=True, create_tables=True, add_objects=None):
17+
upstream_tables = ("Session", "SkullReference")
18+
assert isinstance(add_objects, Mapping)
19+
try:
20+
raise RuntimeError("Table %s is required for module ephys" % next(
21+
name for name in upstream_tables
22+
if not isinstance(add_objects.get(name, None), (dj.Manual, dj.Lookup, dj.Imported, dj.Computed))))
23+
except StopIteration:
24+
pass # all ok
25+
26+
required_functions = ("get_neuropixels_data_directory", "get_paramset_idx", "get_kilosort_output_directory")
27+
assert isinstance(add_objects, Mapping)
28+
try:
29+
raise RuntimeError("Function %s is required for module ephys" % next(
30+
name for name in required_functions
31+
if not inspect.isfunction(add_objects.get(name, None))))
32+
except StopIteration:
33+
pass # all ok
34+
35+
if not probe.schema.is_activated:
36+
probe.schema.activate(probe_schema_name or ephys_schema_name,
37+
create_schema=create_schema, create_tables=create_tables)
38+
schema.activate(ephys_schema_name, create_schema=create_schema,
39+
create_tables=create_tables, add_objects=add_objects)
40+
641

7-
from . import utils
8-
from .probe import schema, Probe, ProbeType, ElectrodeConfig
9-
from ephys_loaders import neuropixels, kilosort
42+
# REQUIREMENTS: The workflow module must define these functions ---------------
1043

11-
from djutils.templates import required
44+
45+
def get_neuropixels_data_directory():
46+
return None
47+
48+
49+
def get_kilosort_output_directory(clustering_task_key: dict) -> str:
50+
"""
51+
Retrieve the Kilosort output directory for a given ClusteringTask
52+
:param clustering_task_key: a dictionary of one EphysRecording
53+
:return: a string for full path to the resulting Kilosort output directory
54+
"""
55+
assert set(EphysRecording().primary_key) <= set(clustering_task_key)
56+
raise NotImplementedError('Workflow module should define')
57+
58+
59+
def get_paramset_idx(ephys_rec_key: dict) -> int:
60+
"""
61+
Retrieve attribute `paramset_idx` from the ClusteringParamSet record for the given EphysRecording key.
62+
:param ephys_rec_key: a dictionary of one EphysRecording
63+
:return: int specifying the `paramset_idx`
64+
"""
65+
assert set(EphysRecording().primary_key) <= set(ephys_rec_key)
66+
raise NotImplementedError('Workflow module should define')
67+
68+
69+
def dict_to_uuid(key):
70+
"""
71+
Given a dictionary `key`, returns a hash string
72+
"""
73+
hashed = hashlib.md5()
74+
for k, v in sorted(key.items()):
75+
hashed.update(str(k).encode())
76+
hashed.update(str(v).encode())
77+
return uuid.UUID(hex=hashed.hexdigest())
1278

1379
# ===================================== Probe Insertion =====================================
1480

1581

1682
@schema
1783
class ProbeInsertion(dj.Manual): # (acute)
18-
19-
_Session = ...
20-
2184
definition = """
22-
-> self._Session
85+
-> Session
2386
insertion_number: tinyint unsigned
2487
---
25-
-> Probe
88+
-> probe.Probe
2689
"""
2790

2891

@@ -32,12 +95,10 @@ class ProbeInsertion(dj.Manual): # (acute)
3295
@schema
3396
class InsertionLocation(dj.Manual):
3497

35-
_SkullReference = ...
36-
3798
definition = """
3899
-> ProbeInsertion
39100
---
40-
-> self._SkullReference
101+
-> SkullReference
41102
ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
42103
ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
43104
depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
@@ -48,7 +109,7 @@ class InsertionLocation(dj.Manual):
48109

49110

50111
# ===================================== Ephys Recording =====================================
51-
# The abstract function _get_npx_data_dir() should expect one argument in the form of a
112+
# The abstract function _get_neuropixels_data_directory() should expect one argument in the form of a
52113
# dictionary with the keys from user-defined Subject and Session, as well as
53114
# "insertion_number" (as int) based on the "ProbeInsertion" table definition in this djephys
54115

@@ -62,33 +123,27 @@ class EphysRecording(dj.Imported):
62123
sampling_rate: float # (Hz)
63124
"""
64125

65-
@staticmethod
66-
@required
67-
def _get_npx_data_dir():
68-
return None
69-
70126
def make(self, key):
71-
npx_dir = EphysRecording._get_npx_data_dir(key)
72-
73-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
127+
neuropixels_dir = get_neuropixels_data_directory(key)
128+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
74129

75-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
130+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
76131

77-
if re.search('(1.0|2.0)', npx_meta.probe_model):
132+
if re.search('(1.0|2.0)', neuropixels_meta.probe_model):
78133
eg_members = []
79-
probe_type = {'probe_type': npx_meta.probe_model}
80-
q_electrodes = ProbeType.Electrode & probe_type
81-
for shank, shank_col, shank_row, is_used in npx_meta.shankmap['data']:
134+
probe_type = {'probe_type': neuropixels_meta.probe_model}
135+
q_electrodes = probe.ProbeType.Electrode & probe_type
136+
for shank, shank_col, shank_row, is_used in neuropixels_meta.shankmap['data']:
82137
electrode = (q_electrodes & {'shank': shank,
83138
'shank_col': shank_col,
84139
'shank_row': shank_row}).fetch1('KEY')
85140
eg_members.append({**electrode, 'used_in_reference': is_used})
86141
else:
87142
raise NotImplementedError('Processing for neuropixels probe model {} not yet implemented'.format(
88-
npx_meta.probe_model))
143+
neuropixels_meta.probe_model))
89144

90145
# ---- compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) ----
91-
ec_hash = uuid.UUID(utils.dict_to_hash({k['electrode']: k for k in eg_members}))
146+
ec_hash = uuid.UUID(dict_to_uuid({k['electrode']: k for k in eg_members}))
92147

93148
el_list = sorted([k['electrode'] for k in eg_members])
94149
el_jumps = [-1] + np.where(np.diff(el_list) > 1)[0].tolist() + [len(el_list) - 1]
@@ -101,7 +156,7 @@ def make(self, key):
101156
ElectrodeConfig.insert1({**e_config, **probe_type, 'electrode_config_name': ec_name})
102157
ElectrodeConfig.Electrode.insert({**e_config, **m} for m in eg_members)
103158

104-
self.insert1({**key, **e_config, 'sampling_rate': npx_meta.meta['imSampRate']})
159+
self.insert1({**key, **e_config, 'sampling_rate': neuropixels_meta.meta['imSampRate']})
105160

106161

107162
# ===========================================================================================
@@ -124,29 +179,29 @@ class LFP(dj.Imported):
124179
class Electrode(dj.Part):
125180
definition = """
126181
-> master
127-
-> ElectrodeConfig.Electrode
182+
-> probe.ElectrodeConfig.Electrode
128183
---
129184
lfp: longblob # (mV) recorded lfp at this electrode
130185
"""
131186

132187
def make(self, key):
133-
npx_dir = EphysRecording._get_npx_data_dir(key)
134-
npx_recording = neuropixels.Neuropixels(npx_dir)
188+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(key)
189+
neuropixels_recording = neuropixels.Neuropixels(neuropixels_dir)
135190

136-
lfp = npx_recording.lfdata[:, :-1].T # exclude the sync channel
191+
lfp = neuropixels_recording.lfdata[:, :-1].T # exclude the sync channel
137192

138193
self.insert1(dict(key,
139-
lfp_sampling_rate=npx_recording.lfmeta['imSampRate'],
140-
lfp_time_stamps=np.arange(lfp.shape[1]) / npx_recording.lfmeta['imSampRate'],
194+
lfp_sampling_rate=neuropixels_recording.lfmeta['imSampRate'],
195+
lfp_time_stamps=np.arange(lfp.shape[1]) / neuropixels_recording.lfmeta['imSampRate'],
141196
lfp_mean=lfp.mean(axis=0)))
142197
'''
143198
Only store LFP for every 9th channel (defined in skip_chn_counts), counting in reverse
144199
Due to high channel density, close-by channels exhibit highly similar lfp
145200
'''
146-
q_electrodes = ProbeType.Electrode * ElectrodeConfig.Electrode & key
201+
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & key
147202
electrodes = []
148203
for recorded_site in np.arange(lfp.shape[0]):
149-
shank, shank_col, shank_row, _ = npx_recording.npx_meta.shankmap['data'][recorded_site]
204+
shank, shank_col, shank_row, _ = neuropixels_recording.neuropixels_meta.shankmap['data'][recorded_site]
150205
electrodes.append((q_electrodes
151206
& {'shank': shank,
152207
'shank_col': shank_col,
@@ -191,7 +246,7 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, paramset_d
191246
'paramset_idx': paramset_idx,
192247
'paramset_desc': paramset_desc,
193248
'params': params,
194-
'param_set_hash': uuid.UUID(utils.dict_to_hash(params))}
249+
'param_set_hash': dict_to_uuid(params)}
195250
q_param = cls & {'param_set_hash': param_dict['param_set_hash']}
196251

197252
if q_param: # If the specified param-set already exists
@@ -212,26 +267,6 @@ class ClusteringTask(dj.Imported):
212267
-> ClusteringParamSet
213268
"""
214269

215-
@staticmethod
216-
@required
217-
def _get_paramset_idx(ephys_rec_key: dict) -> int:
218-
"""
219-
Retrieve the 'paramset_idx' (for ClusteringParamSet) to be used for this EphysRecording
220-
:param ephys_rec_key: a dictionary of one EphysRecording
221-
:return: int specifying the 'paramset_idx'
222-
"""
223-
return None
224-
225-
@staticmethod
226-
@required
227-
def _get_ks_data_dir(clustering_task_key: dict) -> str:
228-
"""
229-
Retrieve the Kilosort output directory for a given ClusteringTask
230-
:param clustering_task_key: a dictionary of one EphysRecording
231-
:return: a string for full path to the resulting Kilosort output directory
232-
"""
233-
return None
234-
235270
def make(self, key):
236271
key['paramset_idx'] = ClusteringTask._get_paramset_idx(key)
237272
self.insert1(key)
@@ -274,7 +309,7 @@ class Unit(dj.Part):
274309
-> master
275310
unit: int
276311
---
277-
-> ElectrodeConfig.Electrode # electrode on the probe that this unit has highest response amplitude
312+
-> probe.ElectrodeConfig.Electrode # electrode on the probe that this unit has highest response amplitude
278313
-> ClusterQualityLabel
279314
spike_count: int # how many spikes in this recording of this unit
280315
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
@@ -294,7 +329,7 @@ def make(self, key):
294329
valid_units = ks.data['cluster_ids'][withspike_idx]
295330
valid_unit_labels = ks.data['cluster_groups'][withspike_idx]
296331
# -- Get channel and electrode-site mapping
297-
chn2electrodes = get_npx_chn2electrode_map(key)
332+
chn2electrodes = get_neuropixels_chn2electrode_map(key)
298333

299334
# -- Spike-times --
300335
# spike_times_sec_adj > spike_times_sec > spike_times
@@ -339,7 +374,7 @@ class Waveform(dj.Imported):
339374
class Electrode(dj.Part):
340375
definition = """
341376
-> master
342-
-> ElectrodeConfig.Electrode
377+
-> probe.ElectrodeConfig.Electrode
343378
---
344379
waveform_mean: longblob # mean over all spikes
345380
waveforms=null: longblob # (spike x sample) waveform of each spike at each electrode
@@ -352,16 +387,16 @@ def key_source(self):
352387
def make(self, key):
353388
units = {u['unit']: u for u in (Clustering.Unit & key).fetch(as_dict=True, order_by='unit')}
354389

355-
npx_dir = EphysRecording._get_npx_data_dir(key)
356-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
357-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
390+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(key)
391+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
392+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
358393

359394
ks_dir = ClusteringTask._get_ks_data_dir(key)
360395
ks = kilosort.Kilosort(ks_dir)
361396

362397
# -- Get channel and electrode-site mapping
363398
rec_key = (EphysRecording & key).fetch1('KEY')
364-
chn2electrodes = get_npx_chn2electrode_map(rec_key)
399+
chn2electrodes = get_neuropixels_chn2electrode_map(rec_key)
365400

366401
is_qc = (Clustering & key).fetch1('quality_control')
367402

@@ -375,10 +410,10 @@ def make(self, key):
375410
if chn2electrodes[chn]['electrode'] == units[unit_no]['electrode']:
376411
unit_peak_waveforms.append({**units[unit_no], 'peak_chn_waveform_mean': chn_wf})
377412
else:
378-
npx_recording = neuropixels.Neuropixels(npx_dir)
413+
neuropixels_recording = neuropixels.Neuropixels(neuropixels_dir)
379414
for unit_no, unit_dict in units.items():
380415
spks = (Clustering.Unit & unit_dict).fetch1('unit_spike_times')
381-
wfs = npx_recording.extract_spike_waveforms(spks, ks.data['channel_map']) # (sample x channel x spike)
416+
wfs = neuropixels_recording.extract_spike_waveforms(spks, ks.data['channel_map']) # (sample x channel x spike)
382417
wfs = wfs.transpose((1, 2, 0)) # (channel x spike x sample)
383418
for chn, chn_wf in zip(ks.data['channel_map'], wfs):
384419
unit_waveforms.append({**unit_dict, **chn2electrodes[chn],
@@ -425,15 +460,15 @@ def make(self, key):
425460
# ========================== HELPER FUNCTIONS =======================
426461

427462

428-
def get_npx_chn2electrode_map(ephys_recording_key):
429-
npx_dir = EphysRecording._get_npx_data_dir(ephys_recording_key)
430-
meta_filepath = next(pathlib.Path(npx_dir).glob('*.ap.meta'))
431-
npx_meta = neuropixels.NeuropixelsMeta(meta_filepath)
432-
e_config_key = (EphysRecording * ElectrodeConfig & ephys_recording_key).fetch1('KEY')
463+
def get_neuropixels_chn2electrode_map(ephys_recording_key):
464+
neuropixels_dir = EphysRecording._get_neuropixels_data_directory(ephys_recording_key)
465+
meta_filepath = next(pathlib.Path(neuropixels_dir).glob('*.ap.meta'))
466+
neuropixels_meta = neuropixels.NeuropixelsMeta(meta_filepath)
467+
e_config_key = (EphysRecording * probe.ElectrodeConfig & ephys_recording_key).fetch1('KEY')
433468

434-
q_electrodes = ProbeType.Electrode * ElectrodeConfig.Electrode & e_config_key
469+
q_electrodes = probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & e_config_key
435470
chn2electrode_map = {}
436-
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(npx_meta.shankmap['data']):
471+
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(neuropixels_meta.shankmap['data']):
437472
chn2electrode_map[recorded_site] = (q_electrodes
438473
& {'shank': shank,
439474
'shank_col': shank_col,

djephys/probe.py renamed to elements_ephys/probe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
"""
2+
Neuropixels Probes
3+
"""
4+
15
import datajoint as dj
26
import numpy as np
7+
schema = dj.schema()
38

4-
from djutils.templates import SchemaTemplate
5-
6-
7-
schema = SchemaTemplate()
8-
9-
# ===================================== Neuropixels Probes =====================================
109

10+
def activate(schema_name, create_schema=True, create_tables=True):
11+
schema.activate(schema_name, create_schema=create_schema, create_tables=create_tables)
1112

1213
@schema
1314
class ProbeType(dj.Lookup):
File renamed without changes.

ephys_loaders/kilosort.py renamed to elements_ephys/readers/kilosort.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@
55
import re
66
import logging
77

8-
from .utils import handle_string
9-
108
log = logging.getLogger(__name__)
119

1210

11+
def convert_to_number(value: str):
12+
if isinstance(value, str):
13+
try:
14+
value = int(value)
15+
except ValueError:
16+
try:
17+
value = float(value)
18+
except ValueError:
19+
pass
20+
return value
21+
22+
1323
class Kilosort:
1424

1525
ks_files = [
@@ -163,4 +173,4 @@ def extract_clustering_info(cluster_output_dir):
163173
spk_fp = next(cluster_output_dir.glob('spike_times.npy'))
164174
creation_time = datetime.fromtimestamp(spk_fp.stat().st_ctime)
165175

166-
return creation_time, is_curated, is_qc
176+
return creation_time, is_curated, is_qc

0 commit comments

Comments
 (0)