Skip to content

Commit 258839b

Browse files
author
Thinh Nguyen
committed
split find_valid_full_path to find_full_path and find_root_directory
1 parent 4e824cf commit 258839b

File tree

3 files changed

+79
-57
lines changed

3 files changed

+79
-57
lines changed

element_array_ephys/__init__.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,51 @@
55
dj.config['enable_python_native_blobs'] = True
66

77

8-
def find_valid_full_path(potential_root_directories, path):
8+
def find_full_path(root_directories, relative_path):
99
"""
10-
Given multiple potential root directories and a single path
11-
Search and return one directory that is the parent of the given path
12-
:param potential_root_directories: potential root directories
13-
:param path: the path to search the root directory
14-
:return: (fullpath, root_directory)
10+
Given a relative path, search and return the full-path
11+
from provided potential root directories (in the given order)
12+
:param root_directories: potential root directories
13+
:param relative_path: the relative path to find the valid root directory
14+
:return: root_directory
1515
"""
16-
path = pathlib.Path(path)
16+
relative_path = pathlib.Path(relative_path)
17+
18+
if relative_path.exists():
19+
return relative_path
1720

1821
# turn to list if only a single root directory is provided
19-
if isinstance(potential_root_directories, (str, pathlib.Path)):
20-
potential_root_directories = [potential_root_directories]
21-
22-
# search routine
23-
for root_dir in potential_root_directories:
24-
root_dir = pathlib.Path(root_dir)
25-
if path.exists():
26-
if root_dir in list(path.parents):
27-
return path, root_dir
28-
else:
29-
if (root_dir / path).exists():
30-
return root_dir / path, root_dir
31-
32-
raise FileNotFoundError('Unable to identify root-directory (from {})'
33-
' associated with {}'.format(potential_root_directories, path))
22+
if isinstance(root_directories, (str, pathlib.Path)):
23+
root_directories = [root_directories]
24+
25+
for root_dir in root_directories:
26+
if (pathlib.Path(root_dir) / relative_path).exists():
27+
return pathlib.Path(root_dir) / relative_path
28+
29+
raise FileNotFoundError('No valid full-path found (from {})'
30+
' for {}'.format(root_directories, relative_path))
31+
32+
33+
def find_root_directory(root_directories, full_path):
34+
"""
35+
Given multiple potential root directories and a full-path,
36+
search and return one directory that is the parent of the given path
37+
:param root_directories: potential root directories
38+
:param full_path: the relative path to search the root directory
39+
:return: full-path
40+
"""
41+
full_path = pathlib.Path(full_path)
42+
43+
if not full_path.exists():
44+
raise FileNotFoundError(f'{full_path} does not exist!')
45+
46+
# turn to list if only a single root directory is provided
47+
if isinstance(root_directories, (str, pathlib.Path)):
48+
root_directories = [root_directories]
49+
50+
try:
51+
return next(root_dir for root_dir in root_directories
52+
if full_path.is_relative_to(root_dir))
53+
except StopIteration:
54+
raise FileNotFoundError('No valid root directory found (from {})'
55+
' for {}'.format(root_directories, full_path))

element_array_ephys/ephys.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import importlib
99

1010
from .readers import spikeglx, kilosort, openephys
11-
from . import probe, find_valid_full_path
11+
from . import probe, find_full_path, find_root_directory
1212

1313
schema = dj.schema()
1414

@@ -186,7 +186,7 @@ def make(self, key):
186186
'acq_software': acq_software,
187187
'sampling_rate': spikeglx_meta.meta['imSampRate']})
188188

189-
_, root_dir = find_valid_full_path(get_ephys_root_data_dir(), meta_filepath)
189+
root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
190190
self.EphysFile.insert1({
191191
**key,
192192
'file_path': meta_filepath.relative_to(root_dir).as_posix()})
@@ -221,7 +221,7 @@ def make(self, key):
221221
'acq_software': acq_software,
222222
'sampling_rate': probe_data.ap_meta['sample_rate']})
223223

224-
_, root_dir = find_valid_full_path(
224+
root_dir = find_root_directory(
225225
get_ephys_root_data_dir(),
226226
probe_data.recording_info['recording_files'][0])
227227
self.EphysFile.insert([{**key,
@@ -417,7 +417,7 @@ class Clustering(dj.Imported):
417417
def make(self, key):
418418
task_mode, output_dir = (ClusteringTask & key).fetch1(
419419
'task_mode', 'clustering_output_dir')
420-
kilosort_dir, _ = find_valid_full_path(get_ephys_root_data_dir(), output_dir)
420+
kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
421421

422422
if task_mode == 'load':
423423
kilosort_dataset = kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output
@@ -455,7 +455,7 @@ def create1_from_clustering_task(self, key, curation_note=''):
455455

456456
task_mode, output_dir = (ClusteringTask & key).fetch1(
457457
'task_mode', 'clustering_output_dir')
458-
kilosort_dir, _ = find_valid_full_path(get_ephys_root_data_dir(), output_dir)
458+
kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
459459

460460
creation_time, is_curated, is_qc = kilosort.extract_clustering_info(kilosort_dir)
461461
# Synthesize curation_id
@@ -487,7 +487,7 @@ class Unit(dj.Part):
487487

488488
def make(self, key):
489489
output_dir = (Curation & key).fetch1('curation_output_dir')
490-
kilosort_dir, _ = find_valid_full_path(get_ephys_root_data_dir(), output_dir)
490+
kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
491491

492492
kilosort_dataset = kilosort.Kilosort(kilosort_dir)
493493
acq_software = (EphysRecording & key).fetch1('acq_software')
@@ -561,7 +561,7 @@ class UnitElectrode(dj.Part):
561561

562562
def make(self, key):
563563
output_dir = (Curation & key).fetch1('curation_output_dir')
564-
kilosort_dir, _ = find_valid_full_path(get_ephys_root_data_dir(), output_dir)
564+
kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
565565

566566
kilosort_dataset = kilosort.Kilosort(kilosort_dir)
567567

@@ -645,8 +645,8 @@ def get_spikeglx_meta_filepath(ephys_recording_key):
645645
& 'file_path LIKE "%.ap.meta"').fetch1('file_path')
646646

647647
try:
648-
spikeglx_meta_filepath, _ = find_valid_full_path(get_ephys_root_data_dir(),
649-
spikeglx_meta_filepath)
648+
spikeglx_meta_filepath = find_full_path(get_ephys_root_data_dir(),
649+
spikeglx_meta_filepath)
650650
except FileNotFoundError:
651651
# if not found, search in session_dir again
652652
if not spikeglx_meta_filepath.exists():

element_array_ephys/readers/kilosort.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class Kilosort:
1414

15-
ks_files = [
15+
kilosort_files = [
1616
'params.py',
1717
'amplitudes.npy',
1818
'channel_map.npy',
@@ -36,18 +36,18 @@ class Kilosort:
3636
]
3737

3838
# keys to self.files, .data are file name e.g. self.data['params'], etc.
39-
ks_keys = [path.splitext(ks_file)[0] for ks_file in ks_files]
39+
kilosort_keys = [path.splitext(kilosort_file)[0] for kilosort_file in kilosort_files]
4040

41-
def __init__(self, ks_dir):
42-
self._ks_dir = pathlib.Path(ks_dir)
41+
def __init__(self, kilosort_dir):
42+
self._kilosort_dir = pathlib.Path(kilosort_dir)
4343
self._files = {}
4444
self._data = None
4545
self._clusters = None
4646

47-
params_filepath = ks_dir / 'params.py'
47+
params_filepath = kilosort_dir / 'params.py'
4848

4949
if not params_filepath.exists():
50-
raise FileNotFoundError(f'No Kilosort output found in: {ks_dir}')
50+
raise FileNotFoundError(f'No Kilosort output found in: {kilosort_dir}')
5151

5252
self._info = {'time_created': datetime.fromtimestamp(params_filepath.stat().st_ctime),
5353
'time_modified': datetime.fromtimestamp(params_filepath.stat().st_mtime)}
@@ -64,42 +64,44 @@ def info(self):
6464

6565
def _stat(self):
6666
self._data = {}
67-
for ks_filename in Kilosort.ks_files:
68-
ks_filepath = self._ks_dir / ks_filename
67+
for kilosort_filename in Kilosort.kilosort_files:
68+
kilosort_filepath = self._kilosort_dir / kilosort_filename
6969

70-
if not ks_filepath.exists():
71-
log.debug('skipping {} - does not exist'.format(ks_filepath))
70+
if not kilosort_filepath.exists():
71+
log.debug('skipping {} - does not exist'.format(kilosort_filepath))
7272
continue
7373

74-
base, ext = path.splitext(ks_filename)
75-
self._files[base] = ks_filepath
74+
base, ext = path.splitext(kilosort_filename)
75+
self._files[base] = kilosort_filepath
7676

77-
if ks_filename == 'params.py':
78-
log.debug('loading params.py {}'.format(ks_filepath))
77+
if kilosort_filename == 'params.py':
78+
log.debug('loading params.py {}'.format(kilosort_filepath))
7979
# params.py is a 'key = val' file
8080
params = {}
81-
for line in open(ks_filepath, 'r').readlines():
81+
for line in open(kilosort_filepath, 'r').readlines():
8282
k, v = line.strip('\n').split('=')
8383
params[k.strip()] = convert_to_number(v.strip())
8484
log.debug('params: {}'.format(params))
8585
self._data[base] = params
8686

8787
if ext == '.npy':
88-
log.debug('loading npy {}'.format(ks_filepath))
89-
d = np.load(ks_filepath, mmap_mode='r', allow_pickle=False, fix_imports=False)
88+
log.debug('loading npy {}'.format(kilosort_filepath))
89+
d = np.load(kilosort_filepath, mmap_mode='r',
90+
allow_pickle=False, fix_imports=False)
9091
self._data[base] = (np.reshape(d, d.shape[0])
9192
if d.ndim == 2 and d.shape[1] == 1 else d)
9293

9394
# Read the Cluster Groups
9495
for cluster_pattern, cluster_col_name in zip(['cluster_groups.*', 'cluster_KSLabel.*'],
9596
['group', 'KSLabel']):
9697
try:
97-
cluster_file = next(self._ks_dir.glob(cluster_pattern))
98-
cluster_file_suffix = cluster_file.suffix
99-
assert cluster_file_suffix in ('.csv', '.tsv', '.xlsx')
100-
break
98+
cluster_file = next(self._kilosort_dir.glob(cluster_pattern))
10199
except StopIteration:
102100
pass
101+
102+
cluster_file_suffix = cluster_file.suffix
103+
assert cluster_file_suffix in ('.csv', '.tsv', '.xlsx')
104+
break
103105
else:
104106
raise FileNotFoundError(
105107
'Neither "cluster_groups" nor "cluster_KSLabel" file found!')
@@ -118,7 +120,7 @@ def get_best_channel(self, unit):
118120
template_idx = self.data['spike_templates'][
119121
np.where(self.data['spike_clusters'] == unit)[0][0]]
120122
channel_templates = self.data['templates'][template_idx, :, :]
121-
max_channel_idx = np.abs(np.abs(channel_templates).max(axis=0)).argmax()
123+
max_channel_idx = np.abs(channel_templates).max(axis=0).argmax()
122124
max_channel = self.data['channel_map'][max_channel_idx]
123125

124126
return max_channel, max_channel_idx
@@ -174,12 +176,10 @@ def extract_clustering_info(cluster_output_dir):
174176

175177
# ---- Quality control? ----
176178
metric_filepath = cluster_output_dir / 'metrics.csv'
177-
if metric_filepath.exists():
178-
is_qc = True
179+
is_qc = metric_filepath.exists()
180+
if is_qc:
179181
if creation_time is None:
180182
creation_time = datetime.fromtimestamp(metric_filepath.stat().st_ctime)
181-
else:
182-
is_qc = False
183183

184184
if creation_time is None:
185185
spiketimes_filepath = next(cluster_output_dir.glob('spike_times.npy'))

0 commit comments

Comments
 (0)