Skip to content

Commit f6a52d9

Browse files
committed
chore: more robust channel mapping
1 parent 6155f13 commit f6a52d9

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,9 +1028,8 @@ def make(self, key):
10281028

10291029
# Get channel and electrode-site mapping
10301030
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
1031-
channel2electrode_map = electrode_query.fetch(as_dict=True)
10321031
channel2electrode_map: dict[int, dict] = {
1033-
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1032+
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
10341033
}
10351034

10361035
# Get sorter method and create output directory.
@@ -1054,12 +1053,10 @@ def make(self, key):
10541053
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
10551054
# {unit: spike_count}
10561055

1057-
# reorder channel2electrode_map according to recording channel ids
1056+
# update channel2electrode_map to match with probe's channel index
10581057
channel2electrode_map = {
1059-
chn_idx: channel2electrode_map[chn_idx]
1060-
for chn_idx in sorting_analyzer.channel_ids_to_indices(
1061-
sorting_analyzer.channel_ids
1062-
)
1058+
idx: channel2electrode_map[int(chn_idx)]
1059+
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
10631060
}
10641061

10651062
# Get unit id to quality label mapping
@@ -1239,9 +1236,8 @@ def make(self, key):
12391236

12401237
# Get channel and electrode-site mapping
12411238
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
1242-
channel2electrode_map = electrode_query.fetch(as_dict=True)
12431239
channel2electrode_map: dict[int, dict] = {
1244-
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1240+
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
12451241
}
12461242

12471243
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
@@ -1258,12 +1254,10 @@ def make(self, key):
12581254
) # {unit: peak_channel_index}
12591255
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
12601256

1261-
# reorder channel2electrode_map according to recording channel ids
1262-
channel_indices = sorting_analyzer.channel_ids_to_indices(
1263-
sorting_analyzer.channel_ids
1264-
).tolist()
1257+
# update channel2electrode_map to match with probe's channel index
12651258
channel2electrode_map = {
1266-
chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices
1259+
idx: channel2electrode_map[int(chn_idx)]
1260+
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
12671261
}
12681262

12691263
templates = sorting_analyzer.get_extension("templates")
@@ -1276,12 +1270,9 @@ def yield_unit_waveforms():
12761270
unit_waveforms = templates.get_unit_template(
12771271
unit_id=unit["unit"], operator="average"
12781272
)
1279-
peak_chn_idx = channel_indices.index(
1280-
unit_peak_channel[unit["unit"]]
1281-
)
12821273
unit_peak_waveform = {
12831274
**unit,
1284-
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
1275+
"peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]],
12851276
}
12861277

12871278
unit_electrode_waveforms = [
@@ -1290,7 +1281,7 @@ def yield_unit_waveforms():
12901281
**channel2electrode_map[chn_idx],
12911282
"waveform_mean": unit_waveforms[:, chn_idx],
12921283
}
1293-
for chn_idx in channel_indices
1284+
for chn_idx in channel2electrode_map
12941285
]
12951286

12961287
yield unit_peak_waveform, unit_electrode_waveforms

0 commit comments

Comments
 (0)