Skip to content

Commit 92547d7

Browse files
authored
Merge pull request #191 from ttngu207/dev_spikeinterface_v101
code cleanup + bugfix
2 parents 2a4b166 + f6a52d9 commit 92547d7

File tree

2 files changed

+21
-25
lines changed

2 files changed

+21
-25
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
import datajoint as dj
99
import numpy as np
1010
import pandas as pd
11-
import spikeinterface as si
1211
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
13-
from spikeinterface import exporters, postprocessing, qualitymetrics, sorters
1412

1513
from . import ephys_report, probe
1614
from .readers import kilosort, openephys, spikeglx
1715

18-
log = dj.logger
16+
logger = dj.logger
1917

2018
schema = dj.schema()
2119

@@ -824,7 +822,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False):
824822

825823
if mkdir:
826824
output_dir.mkdir(parents=True, exist_ok=True)
827-
log.info(f"{output_dir} created!")
825+
logger.info(f"{output_dir} created!")
828826

829827
return output_dir.relative_to(processed_dir) if relative else output_dir
830828

@@ -1030,16 +1028,17 @@ def make(self, key):
10301028

10311029
# Get channel and electrode-site mapping
10321030
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
1033-
channel2electrode_map = electrode_query.fetch(as_dict=True)
10341031
channel2electrode_map: dict[int, dict] = {
1035-
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1032+
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
10361033
}
10371034

10381035
# Get sorter method and create output directory.
10391036
sorter_name = clustering_method.replace(".", "_")
10401037
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
10411038

10421039
if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
1040+
import spikeinterface as si
1041+
10431042
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
10441043
si_sorting = sorting_analyzer.sorting
10451044

@@ -1054,12 +1053,10 @@ def make(self, key):
10541053
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
10551054
# {unit: spike_count}
10561055

1057-
# reorder channel2electrode_map according to recording channel ids
1056+
# update channel2electrode_map to match with probe's channel index
10581057
channel2electrode_map = {
1059-
chn_idx: channel2electrode_map[chn_idx]
1060-
for chn_idx in sorting_analyzer.channel_ids_to_indices(
1061-
sorting_analyzer.channel_ids
1062-
)
1058+
idx: channel2electrode_map[int(chn_idx)]
1059+
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
10631060
}
10641061

10651062
# Get unit id to quality label mapping
@@ -1239,13 +1236,14 @@ def make(self, key):
12391236

12401237
# Get channel and electrode-site mapping
12411238
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
1242-
channel2electrode_map = electrode_query.fetch(as_dict=True)
12431239
channel2electrode_map: dict[int, dict] = {
1244-
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1240+
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
12451241
}
12461242

12471243
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
12481244
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
1245+
import spikeinterface as si
1246+
12491247
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
12501248

12511249
# Find representative channel for each unit
@@ -1256,12 +1254,10 @@ def make(self, key):
12561254
) # {unit: peak_channel_index}
12571255
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
12581256

1259-
# reorder channel2electrode_map according to recording channel ids
1260-
channel_indices = sorting_analyzer.channel_ids_to_indices(
1261-
sorting_analyzer.channel_ids
1262-
).tolist()
1257+
# update channel2electrode_map to match with probe's channel index
12631258
channel2electrode_map = {
1264-
chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices
1259+
idx: channel2electrode_map[int(chn_idx)]
1260+
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
12651261
}
12661262

12671263
templates = sorting_analyzer.get_extension("templates")
@@ -1274,12 +1270,9 @@ def yield_unit_waveforms():
12741270
unit_waveforms = templates.get_unit_template(
12751271
unit_id=unit["unit"], operator="average"
12761272
)
1277-
peak_chn_idx = channel_indices.index(
1278-
unit_peak_channel[unit["unit"]]
1279-
)
12801273
unit_peak_waveform = {
12811274
**unit,
1282-
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
1275+
"peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]],
12831276
}
12841277

12851278
unit_electrode_waveforms = [
@@ -1288,7 +1281,7 @@ def yield_unit_waveforms():
12881281
**channel2electrode_map[chn_idx],
12891282
"waveform_mean": unit_waveforms[:, chn_idx],
12901283
}
1291-
for chn_idx in channel_indices
1284+
for chn_idx in channel2electrode_map
12921285
]
12931286

12941287
yield unit_peak_waveform, unit_electrode_waveforms
@@ -1501,6 +1494,8 @@ def make(self, key):
15011494

15021495
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
15031496
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
1497+
import spikeinterface as si
1498+
15041499
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
15051500
qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
15061501
template_metrics = sorting_analyzer.get_extension(

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def _run_sorter():
216216
sorting_save_path = sorting_output_dir / "si_sorting.pkl"
217217
si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)
218218

219+
_run_sorter()
220+
219221
self.insert1(
220222
{
221223
**key,
@@ -248,7 +250,6 @@ def make(self, key):
248250
).fetch1("clustering_method", "clustering_output_dir", "params")
249251
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
250252
sorter_name = clustering_method.replace(".", "_")
251-
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
252253

253254
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
254255
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
@@ -281,7 +282,7 @@ def _sorting_analyzer_compute():
281282
folder=analyzer_output_dir,
282283
sparse=True,
283284
overwrite=True,
284-
**job_kwargs
285+
**job_kwargs,
285286
)
286287

287288
# The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()

0 commit comments

Comments
 (0)