Skip to content

Commit 94686f5

Browse files
author
Thinh Nguyen
committed
prototype design for multiple curations
1 parent f76086c commit 94686f5

File tree

2 files changed

+92
-48
lines changed

2 files changed

+92
-48
lines changed

elements_ephys/ephys.py

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,6 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, paramset_d
319319
cls.insert1(param_dict)
320320

321321

322-
@schema
323-
class ClusteringTask(dj.Manual):
324-
definition = """
325-
-> EphysRecording
326-
-> ClusteringParamSet
327-
---
328-
clustering_output_dir: varchar(255) # clustering output directory relative to root data directory
329-
"""
330-
331-
332322
@schema
333323
class ClusterQualityLabel(dj.Lookup):
334324
definition = """
@@ -345,39 +335,89 @@ class ClusterQualityLabel(dj.Lookup):
345335
]
346336

347337

338+
@schema
339+
class ClusteringTask(dj.Manual):
340+
definition = """
341+
-> EphysRecording
342+
-> ClusteringParamSet
343+
---
344+
clustering_output_dir: varchar(255) # clustering output directory relative to root data directory
345+
task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
346+
"""
347+
348+
348349
@schema
349350
class Clustering(dj.Imported):
351+
"""
352+
A processing table to handle each ClusteringTask:
353+
+ If `task_mode == "trigger"`: trigger clustering analysis according to the ClusteringParamSet (e.g. launch a kilosort job)
354+
+ If `task_mode == "load"`: verify output and create a corresponding entry in the Curation table
355+
"""
350356
definition = """
351357
-> ClusteringTask
352358
---
353-
clustering_time: datetime # time of generation of this set of clustering results
354-
quality_control: bool # has this clustering result undergone quality control?
355-
manual_curation: bool # has manual curation been performed on this clustering result?
356-
clustering_note='': varchar(2000)
359+
clustering_time: datetime # time of generation of this set of clustering results
360+
"""
361+
362+
def make(self, key):
363+
root_dir = pathlib.Path(get_ephys_root_data_dir())
364+
task_mode, output_dir = (ClusteringTask & key).fetch1('task_mode', 'clustering_output_dir')
365+
ks_dir = root_dir / output_dir
366+
367+
if task_mode == 'load':
368+
ks = kilosort.Kilosort(ks_dir) # check if the directory is a valid Kilosort output
369+
creation_time, is_curated, is_qc = kilosort.extract_clustering_info(ks_dir)
370+
# Synthesize curation_id
371+
curation_id = (dj.U().aggr(Curation & key, n='max(curation_id)').fetch1('n') or 0) + 1
372+
373+
self.insert1({**key, 'clustering_time': creation_time})
374+
Curation.insert1({**key, 'curation_id': curation_id,
375+
'curation_time': creation_time, 'curation_output_dir': output_dir,
376+
'quality_control': is_qc, 'manual_curation': is_curated})
377+
elif task_mode == 'trigger':
378+
raise NotImplementedError('Automatic triggering of clustering analysis is not yet supported')
379+
else:
380+
raise ValueError(f'Unknown task mode: {task_mode}')
381+
382+
383+
@schema
384+
class Curation(dj.Manual):
385+
definition = """
386+
-> ClusteringTask
387+
curation_id: int
388+
---
389+
curation_time: datetime # time of generation of this set of curated clustering results
390+
curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
391+
quality_control: bool # has this clustering result undergone quality control?
392+
manual_curation: bool # has manual curation been performed on this clustering result?
393+
curation_note='': varchar(2000)
357394
"""
358395

359-
class Unit(dj.Part):
360-
definition = """
361-
-> master
362-
unit: int
363-
---
364-
-> probe.ElectrodeConfig.Electrode # electrode on the probe that this unit has highest response amplitude
365-
-> ClusterQualityLabel
366-
spike_count: int # how many spikes in this recording of this unit
367-
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
368-
spike_sites : longblob # array of electrode associated with each spike
369-
spike_depths : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
370-
"""
396+
397+
@schema
398+
class Unit(dj.Imported):
399+
definition = """
400+
-> Curation
401+
unit: int
402+
---
403+
-> probe.ElectrodeConfig.Electrode # electrode on the probe that this unit has highest response amplitude
404+
-> ClusterQualityLabel
405+
spike_count: int # how many spikes in this recording of this unit
406+
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
407+
spike_sites : longblob # array of electrode associated with each spike
408+
spike_depths : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
409+
"""
410+
411+
@property
412+
def key_source(self):
413+
return Curation()
371414

372415
def make(self, key):
373416
root_dir = pathlib.Path(get_ephys_root_data_dir())
374-
ks_dir = root_dir / (ClusteringTask & key).fetch1('clustering_output_dir')
417+
ks_dir = root_dir / (Curation & key).fetch1('curation_output_dir')
375418
ks = kilosort.Kilosort(ks_dir)
376419
acq_software = (EphysRecording & key).fetch1('acq_software')
377420

378-
# ---------- Clustering ----------
379-
creation_time, is_curated, is_qc = kilosort.extract_clustering_info(ks_dir)
380-
381421
# ---------- Unit ----------
382422
# -- Remove 0-spike units
383423
withspike_idx = [i for i, u in enumerate(ks.data['cluster_ids']) if (ks.data['spike_clusters'] == u).any()]
@@ -413,15 +453,13 @@ def make(self, key):
413453
'spike_sites': spike_sites[ks.data['spike_clusters'] == unit],
414454
'spike_depths': spike_depths[ks.data['spike_clusters'] == unit]})
415455

416-
self.insert1({**key, 'clustering_time': creation_time,
417-
'quality_control': is_qc, 'manual_curation': is_curated})
418-
self.Unit.insert([{**key, **u} for u in units])
456+
self.insert([{**key, **u} for u in units])
419457

420458

421459
@schema
422460
class Waveform(dj.Imported):
423461
definition = """
424-
-> Clustering.Unit
462+
-> Unit
425463
---
426464
peak_chn_waveform_mean: longblob # mean over all spikes at the peak channel for this unit
427465
"""
@@ -437,11 +475,11 @@ class Electrode(dj.Part):
437475

438476
@property
439477
def key_source(self):
440-
return Clustering()
478+
return Curation()
441479

442480
def make(self, key):
443481
root_dir = pathlib.Path(get_ephys_root_data_dir())
444-
ks_dir = root_dir / (ClusteringTask & key).fetch1('clustering_output_dir')
482+
ks_dir = root_dir / (Curation & key).fetch1('curation_output_dir')
445483
ks = kilosort.Kilosort(ks_dir)
446484

447485
acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1('acq_software', 'probe')
@@ -450,10 +488,10 @@ def make(self, key):
450488
rec_key = (EphysRecording & key).fetch1('KEY')
451489
chn2electrodes = get_neuropixels_chn2electrode_map(rec_key, acq_software)
452490

453-
is_qc = (Clustering & key).fetch1('quality_control')
491+
is_qc = (Curation & key).fetch1('quality_control')
454492

455493
# Get all units
456-
units = {u['unit']: u for u in (Clustering.Unit & key).fetch(as_dict=True, order_by='unit')}
494+
units = {u['unit']: u for u in (Unit & key).fetch(as_dict=True, order_by='unit')}
457495

458496
unit_waveforms, unit_peak_waveforms = [], []
459497
if is_qc:
@@ -494,7 +532,7 @@ def make(self, key):
494532
@schema
495533
class ClusterQualityMetrics(dj.Imported):
496534
definition = """
497-
-> Clustering.Unit
535+
-> Unit
498536
---
499537
amp: float
500538
snr: float

elements_ephys/readers/kilosort.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from os import path
22
from datetime import datetime
3+
import pathlib
34
import pandas as pd
45
import numpy as np
56
import re
@@ -37,14 +38,19 @@ class Kilosort:
3738
# keys to self.files, .data are file name e.g. self.data['params'], etc.
3839
ks_keys = [path.splitext(i)[0] for i in ks_files]
3940

40-
def __init__(self, dname):
41-
self._dname = dname
41+
def __init__(self, ks_dir):
42+
self._ks_dir = pathlib.Path(ks_dir)
4243
self._files = {}
4344
self._data = None
4445
self._clusters = None
4546

46-
self._info = {'time_created': datetime.fromtimestamp((dname / 'params.py').stat().st_ctime),
47-
'time_modified': datetime.fromtimestamp((dname / 'params.py').stat().st_mtime)}
47+
params_fp = ks_dir / 'params.py'
48+
49+
if not params_fp.exists():
50+
raise FileNotFoundError(f'No Kilosort output found in: {ks_dir}')
51+
52+
self._info = {'time_created': datetime.fromtimestamp(params_fp.stat().st_ctime),
53+
'time_modified': datetime.fromtimestamp(params_fp.stat().st_mtime)}
4854

4955
@property
5056
def data(self):
@@ -59,7 +65,7 @@ def info(self):
5965
def _stat(self):
6066
self._data = {}
6167
for i in Kilosort.ks_files:
62-
f = self._dname / i
68+
f = self._ks_dir / i
6369

6470
if not f.exists():
6571
log.debug('skipping {} - doesnt exist'.format(f))
@@ -84,12 +90,12 @@ def _stat(self):
8490
self._data[base] = np.reshape(d, d.shape[0]) if d.ndim == 2 and d.shape[1] == 1 else d
8591

8692
# Read the Cluster Groups
87-
if (self._dname / 'cluster_groups.csv').exists():
88-
df = pd.read_csv(self._dname / 'cluster_groups.csv', delimiter='\t')
93+
if (self._ks_dir / 'cluster_groups.csv').exists():
94+
df = pd.read_csv(self._ks_dir / 'cluster_groups.csv', delimiter= '\t')
8995
self._data['cluster_groups'] = np.array(df['group'].values)
9096
self._data['cluster_ids'] = np.array(df['cluster_id'].values)
91-
elif (self._dname / 'cluster_KSLabel.tsv').exists():
92-
df = pd.read_csv(self._dname / 'cluster_KSLabel.tsv', sep = "\t", header = 0)
97+
elif (self._ks_dir / 'cluster_KSLabel.tsv').exists():
98+
df = pd.read_csv(self._ks_dir / 'cluster_KSLabel.tsv', sep = "\t", header = 0)
9399
self._data['cluster_groups'] = np.array(df['KSLabel'].values)
94100
self._data['cluster_ids'] = np.array(df['cluster_id'].values)
95101
else:

0 commit comments

Comments
 (0)