Skip to content

Commit 983d61a

Browse files
authored
Merge pull request #10 from ttngu207/main
Ephys pipeline with support for multiple curations
2 parents 70a813b + e98b34f commit 983d61a

File tree

7 files changed

+305
-222
lines changed

7 files changed

+305
-222
lines changed

elements_ephys/ephys.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -328,16 +328,6 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, paramset_d
328328
cls.insert1(param_dict)
329329

330330

331-
@schema
332-
class ClusteringTask(dj.Manual):
333-
definition = """
334-
-> EphysRecording
335-
-> ClusteringParamSet
336-
---
337-
clustering_output_dir: varchar(255) # clustering output directory relative to root data directory
338-
"""
339-
340-
341331
@schema
342332
class ClusterQualityLabel(dj.Lookup):
343333
definition = """
@@ -354,15 +344,82 @@ class ClusterQualityLabel(dj.Lookup):
354344
]
355345

356346

347+
@schema
348+
class ClusteringTask(dj.Manual):
349+
definition = """
350+
-> EphysRecording
351+
-> ClusteringParamSet
352+
---
353+
clustering_output_dir: varchar(255) # clustering output directory relative to root data directory
354+
task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
355+
"""
356+
357+
357358
@schema
358359
class Clustering(dj.Imported):
360+
"""
361+
A processing table to handle each ClusteringTask:
362+
+ If `task_mode == "trigger"`: trigger clustering analysis according to the ClusteringParamSet (e.g. launch a kilosort job)
363+
+ If `task_mode == "load"`: verify output
364+
"""
359365
definition = """
360366
-> ClusteringTask
361367
---
362-
clustering_time: datetime # time of generation of this set of clustering results
363-
quality_control: bool # has this clustering result undergone quality control?
364-
manual_curation: bool # has manual curation been performed on this clustering result?
365-
clustering_note='': varchar(2000)
368+
clustering_time: datetime # time of generation of this set of clustering results
369+
"""
370+
371+
def make(self, key):
372+
root_dir = pathlib.Path(get_ephys_root_data_dir())
373+
task_mode, output_dir = (ClusteringTask & key).fetch1('task_mode', 'clustering_output_dir')
374+
ks_dir = root_dir / output_dir
375+
376+
if task_mode == 'load':
377+
ks = kilosort.Kilosort(ks_dir) # check if the directory is a valid Kilosort output
378+
creation_time, _, _ = kilosort.extract_clustering_info(ks_dir)
379+
elif task_mode == 'trigger':
380+
raise NotImplementedError('Automatic triggering of clustering analysis is not yet supported')
381+
else:
382+
raise ValueError(f'Unknown task mode: {task_mode}')
383+
384+
self.insert1({**key, 'clustering_time': creation_time})
385+
386+
387+
@schema
388+
class Curation(dj.Manual):
389+
definition = """
390+
-> Clustering
391+
curation_id: int
392+
---
393+
curation_time: datetime # time of generation of this set of curated clustering results
394+
curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
395+
quality_control: bool # has this clustering result undergone quality control?
396+
manual_curation: bool # has manual curation been performed on this clustering result?
397+
curation_note='': varchar(2000)
398+
"""
399+
400+
def create1_from_clustering_task(self, key, curation_note=''):
401+
"""
402+
A convenient function to create a new corresponding "Curation" for a particular "ClusteringTask"
403+
"""
404+
if key not in Clustering():
405+
raise ValueError(f'No corresponding entry in Clustering available for: {key}; do `Clustering.populate(key)`')
406+
407+
root_dir = pathlib.Path(get_ephys_root_data_dir())
408+
task_mode, output_dir = (ClusteringTask & key).fetch1('task_mode', 'clustering_output_dir')
409+
ks_dir = root_dir / output_dir
410+
creation_time, is_curated, is_qc = kilosort.extract_clustering_info(ks_dir)
411+
# Synthesize curation_id
412+
curation_id = dj.U().aggr(self & key, n='ifnull(max(curation_id)+1,1)').fetch1('n')
413+
self.insert1({**key, 'curation_id': curation_id,
414+
'curation_time': creation_time, 'curation_output_dir': output_dir,
415+
'quality_control': is_qc, 'manual_curation': is_curated,
416+
'curation_note': curation_note})
417+
418+
419+
@schema
420+
class CuratedClustering(dj.Imported):
421+
definition = """
422+
-> Curation
366423
"""
367424

368425
class Unit(dj.Part):
@@ -380,13 +437,10 @@ class Unit(dj.Part):
380437

381438
def make(self, key):
382439
root_dir = pathlib.Path(get_ephys_root_data_dir())
383-
ks_dir = root_dir / (ClusteringTask & key).fetch1('clustering_output_dir')
440+
ks_dir = root_dir / (Curation & key).fetch1('curation_output_dir')
384441
ks = kilosort.Kilosort(ks_dir)
385442
acq_software = (EphysRecording & key).fetch1('acq_software')
386443

387-
# ---------- Clustering ----------
388-
creation_time, is_curated, is_qc = kilosort.extract_clustering_info(ks_dir)
389-
390444
# ---------- Unit ----------
391445
# -- Remove 0-spike units
392446
withspike_idx = [i for i, u in enumerate(ks.data['cluster_ids']) if (ks.data['spike_clusters'] == u).any()]
@@ -422,15 +476,14 @@ def make(self, key):
422476
'spike_sites': spike_sites[ks.data['spike_clusters'] == unit],
423477
'spike_depths': spike_depths[ks.data['spike_clusters'] == unit]})
424478

425-
self.insert1({**key, 'clustering_time': creation_time,
426-
'quality_control': is_qc, 'manual_curation': is_curated})
479+
self.insert1(key)
427480
self.Unit.insert([{**key, **u} for u in units])
428481

429482

430483
@schema
431484
class Waveform(dj.Imported):
432485
definition = """
433-
-> Clustering.Unit
486+
-> CuratedClustering.Unit
434487
---
435488
peak_chn_waveform_mean: longblob # mean over all spikes at the peak channel for this unit
436489
"""
@@ -446,11 +499,11 @@ class Electrode(dj.Part):
446499

447500
@property
448501
def key_source(self):
449-
return Clustering()
502+
return Curation()
450503

451504
def make(self, key):
452505
root_dir = pathlib.Path(get_ephys_root_data_dir())
453-
ks_dir = root_dir / (ClusteringTask & key).fetch1('clustering_output_dir')
506+
ks_dir = root_dir / (Curation & key).fetch1('curation_output_dir')
454507
ks = kilosort.Kilosort(ks_dir)
455508

456509
acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1('acq_software', 'probe')
@@ -459,10 +512,10 @@ def make(self, key):
459512
rec_key = (EphysRecording & key).fetch1('KEY')
460513
chn2electrodes = get_neuropixels_chn2electrode_map(rec_key, acq_software)
461514

462-
is_qc = (Clustering & key).fetch1('quality_control')
515+
is_qc = (Curation & key).fetch1('quality_control')
463516

464517
# Get all units
465-
units = {u['unit']: u for u in (Clustering.Unit & key).fetch(as_dict=True, order_by='unit')}
518+
units = {u['unit']: u for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by='unit')}
466519

467520
unit_waveforms, unit_peak_waveforms = [], []
468521
if is_qc:
@@ -503,7 +556,7 @@ def make(self, key):
503556
@schema
504557
class ClusterQualityMetrics(dj.Imported):
505558
definition = """
506-
-> Clustering.Unit
559+
-> CuratedClustering.Unit
507560
---
508561
amp: float
509562
snr: float

elements_ephys/readers/kilosort.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from os import path
22
from datetime import datetime
3+
import pathlib
34
import pandas as pd
45
import numpy as np
56
import re
@@ -37,14 +38,19 @@ class Kilosort:
3738
# keys to self.files, .data are file name e.g. self.data['params'], etc.
3839
ks_keys = [path.splitext(i)[0] for i in ks_files]
3940

40-
def __init__(self, dname):
41-
self._dname = dname
41+
def __init__(self, ks_dir):
42+
self._ks_dir = pathlib.Path(ks_dir)
4243
self._files = {}
4344
self._data = None
4445
self._clusters = None
4546

46-
self._info = {'time_created': datetime.fromtimestamp((dname / 'params.py').stat().st_ctime),
47-
'time_modified': datetime.fromtimestamp((dname / 'params.py').stat().st_mtime)}
47+
params_fp = ks_dir / 'params.py'
48+
49+
if not params_fp.exists():
50+
raise FileNotFoundError(f'No Kilosort output found in: {ks_dir}')
51+
52+
self._info = {'time_created': datetime.fromtimestamp(params_fp.stat().st_ctime),
53+
'time_modified': datetime.fromtimestamp(params_fp.stat().st_mtime)}
4854

4955
@property
5056
def data(self):
@@ -59,7 +65,7 @@ def info(self):
5965
def _stat(self):
6066
self._data = {}
6167
for i in Kilosort.ks_files:
62-
f = self._dname / i
68+
f = self._ks_dir / i
6369

6470
if not f.exists():
6571
log.debug('skipping {} - doesnt exist'.format(f))
@@ -84,12 +90,12 @@ def _stat(self):
8490
self._data[base] = np.reshape(d, d.shape[0]) if d.ndim == 2 and d.shape[1] == 1 else d
8591

8692
# Read the Cluster Groups
87-
if (self._dname / 'cluster_groups.csv').exists():
88-
df = pd.read_csv(self._dname / 'cluster_groups.csv', delimiter='\t')
93+
if (self._ks_dir / 'cluster_groups.csv').exists():
94+
df = pd.read_csv(self._ks_dir / 'cluster_groups.csv', delimiter= '\t')
8995
self._data['cluster_groups'] = np.array(df['group'].values)
9096
self._data['cluster_ids'] = np.array(df['cluster_id'].values)
91-
elif (self._dname / 'cluster_KSLabel.tsv').exists():
92-
df = pd.read_csv(self._dname / 'cluster_KSLabel.tsv', sep = "\t", header = 0)
97+
elif (self._ks_dir / 'cluster_KSLabel.tsv').exists():
98+
df = pd.read_csv(self._ks_dir / 'cluster_KSLabel.tsv', sep = "\t", header = 0)
9399
self._data['cluster_groups'] = np.array(df['KSLabel'].values)
94100
self._data['cluster_ids'] = np.array(df['cluster_id'].values)
95101
else:

elements_ephys/readers/openephys.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_probe_data(self):
6969
oe_probe.recording_info['recording_datetimes'].append(rec.datetime)
7070
oe_probe.recording_info['recording_durations'].append(float(rec.duration))
7171
oe_probe.recording_info['recording_files'].append(
72-
rec.absolute_foldername / cont_info['folder_name'])
72+
rec.absolute_foldername / 'continuous' / cont_info['folder_name'])
7373

7474
elif cont_info['source_processor_sub_idx'] == 1: # lfp data
7575
assert cont_info['sample_rate'] == analog_signal.sample_rate == 2500

elements_ephys/readers/spikeglx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def get_original_chans(self):
303303
# a block of contiguous channels specified as chan or chan1:chan2 inclusive
304304
ix = [int(r) for r in channel_range.split(':')]
305305
assert len(ix) in (1, 2), f"Invalid channel range spec '{channel_range}'"
306-
channels = np.append(np.r_[ix[0]:ix[-1] + 1])
306+
channels = np.append(channels, np.r_[ix[0]:ix[-1] + 1])
307307
return channels
308308

309309

images/attached_ephys_element.svg

Lines changed: 1 addition & 1 deletion
Loading

0 commit comments

Comments
 (0)