Skip to content

Commit d778b1e

Browse files
committed
fix: update channel-electrode mapping
1 parent 7309082 commit d778b1e

File tree

2 files changed

+81
-96
lines changed

2 files changed

+81
-96
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 76 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,51 +1040,47 @@ def make(self, key):
10401040
).fetch1("clustering_method", "clustering_output_dir")
10411041
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
10421042

1043-
# Get sorter method and create output directory.
1044-
sorter_name = (
1045-
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1043+
# Get channel and electrode-site mapping
1044+
electrode_query = (
1045+
(EphysRecording.Channel & key)
1046+
.proj(..., "-channel_name")
10461047
)
1047-
waveform_dir = output_dir / sorter_name / "waveform"
1048-
sorting_dir = output_dir / sorter_name / "spike_sorting"
1048+
channel2electrode_map = electrode_query.fetch(as_dict=True)
1049+
channel2electrode_map: dict[int, dict] = {
1050+
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1051+
}
10491052

1050-
if waveform_dir.exists(): # read from spikeinterface outputs
1051-
we: si.WaveformExtractor = si.load_waveforms(
1052-
waveform_dir, with_recording=False
1053-
)
1053+
# Get sorter method and create output directory.
1054+
sorter_name = clustering_method.replace(".", "_")
1055+
si_waveform_dir = output_dir / sorter_name / "waveform"
1056+
si_sorting_dir = output_dir / sorter_name / "spike_sorting"
1057+
1058+
if si_waveform_dir.exists():
1059+
1060+
# Read from spikeinterface outputs
1061+
we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False)
10541062
si_sorting: si.sorters.BaseSorter = si.load_extractor(
1055-
sorting_dir / "si_sorting.pkl"
1063+
si_sorting_dir / "si_sorting.pkl"
10561064
)
10571065

1058-
unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel(
1059-
we, outputs="index"
1060-
) # {unit: peak_channel_index}
1066+
unit_peak_channel: dict[int, int] = si.get_template_extremum_channel(
1067+
we, outputs="id"
1068+
) # {unit: peak_channel_id}
10611069

10621070
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
10631071
# {unit: spike_count}
10641072

1065-
spikes = si_sorting.to_spike_vector(
1066-
extremum_channel_inds=unit_peak_channel_map
1067-
)
1068-
1069-
# Get electrode & channel info
1070-
electrode_config_key = (
1071-
EphysRecording * probe.ElectrodeConfig & key
1072-
).fetch1("KEY")
1073-
1074-
electrode_query = (
1075-
probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
1076-
& electrode_config_key
1077-
) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel)
1073+
spikes = si_sorting.to_spike_vector()
10781074

1079-
channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx")
1080-
channel_info: dict[int, dict] = {
1081-
ch.pop("channel_idx"): ch for ch in channel_info
1075+
# reorder channel2electrode_map according to recording channel ids
1076+
channel2electrode_map = {
1077+
chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids
10821078
}
10831079

10841080
# Get unit id to quality label mapping
10851081
try:
10861082
cluster_quality_label_map = pd.read_csv(
1087-
sorting_dir / "sorter_output" / "cluster_KSLabel.tsv",
1083+
si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv",
10881084
delimiter="\t",
10891085
)
10901086
except FileNotFoundError:
@@ -1099,15 +1095,15 @@ def make(self, key):
10991095
# Get electrode where peak unit activity is recorded
11001096
peak_electrode_ind = np.array(
11011097
[
1102-
channel_info[unit_peak_channel_map[unit_id]]["electrode"]
1098+
channel2electrode_map[unit_peak_channel[unit_id]]["electrode"]
11031099
for unit_id in si_sorting.unit_ids
11041100
]
11051101
)
11061102

11071103
# Get channel depth
11081104
channel_depth_ind = np.array(
11091105
[
1110-
channel_info[unit_peak_channel_map[unit_id]]["y_coord"]
1106+
channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"]
11111107
for unit_id in si_sorting.unit_ids
11121108
]
11131109
)
@@ -1132,7 +1128,7 @@ def make(self, key):
11321128
units.append(
11331129
{
11341130
**key,
1135-
**channel_info[unit_peak_channel_map[unit_id]],
1131+
**channel2electrode_map[unit_peak_channel[unit_id]],
11361132
"unit": unit_id,
11371133
"cluster_quality_label": cluster_quality_label_map.get(
11381134
unit_id, "n.a."
@@ -1143,10 +1139,10 @@ def make(self, key):
11431139
"spike_count": spike_count_dict[unit_id],
11441140
"spike_sites": new_spikes["electrode"][
11451141
new_spikes["unit_index"] == unit_id
1146-
],
1142+
],
11471143
"spike_depths": new_spikes["depth"][
11481144
new_spikes["unit_index"] == unit_id
1149-
],
1145+
],
11501146
}
11511147
)
11521148

@@ -1184,20 +1180,10 @@ def make(self, key):
11841180
spike_times = kilosort_dataset.data[spike_time_key]
11851181
kilosort_dataset.extract_spike_depths()
11861182

1187-
# Get channel and electrode-site mapping
1188-
channel_info = (
1189-
(EphysRecording.Channel & key)
1190-
.proj(..., "-channel_name")
1191-
.fetch(as_dict=True, order_by="channel_idx")
1192-
)
1193-
channel_info: dict[int, dict] = {
1194-
ch.pop("channel_idx"): ch for ch in channel_info
1195-
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1196-
11971183
# -- Spike-sites and Spike-depths --
11981184
spike_sites = np.array(
11991185
[
1200-
channel_info[s]["electrode"]
1186+
channel2electrode_map[s]["electrode"]
12011187
for s in kilosort_dataset.data["spike_sites"]
12021188
]
12031189
)
@@ -1219,7 +1205,7 @@ def make(self, key):
12191205
**key,
12201206
"unit": unit,
12211207
"cluster_quality_label": unit_lbl,
1222-
**channel_info[unit_channel],
1208+
**channel2electrode_map[unit_channel],
12231209
"spike_times": unit_spike_times,
12241210
"spike_count": spike_count,
12251211
"spike_sites": spike_sites[
@@ -1292,33 +1278,31 @@ def make(self, key):
12921278
ClusteringTask * ClusteringParamSet & key
12931279
).fetch1("clustering_method", "clustering_output_dir")
12941280
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
1295-
sorter_name = (
1296-
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1297-
)
1281+
sorter_name = clustering_method.replace(".", "_")
12981282

12991283
# Get channel and electrode-site mapping
1300-
channel_info = (
1284+
electrode_query = (
13011285
(EphysRecording.Channel & key)
13021286
.proj(..., "-channel_name")
1303-
.fetch(as_dict=True, order_by="channel_idx")
13041287
)
1305-
channel_info: dict[int, dict] = {
1306-
ch.pop("channel_idx"): ch for ch in channel_info
1307-
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1288+
channel2electrode_map = electrode_query.fetch(as_dict=True)
1289+
channel2electrode_map: dict[int, dict] = {
1290+
chn.pop("channel_idx"): chn for chn in channel2electrode_map
1291+
}
13081292

1309-
if (
1310-
output_dir / sorter_name / "waveform"
1311-
).exists(): # read from spikeinterface outputs
1293+
si_waveform_dir = output_dir / sorter_name / "waveform"
1294+
if si_waveform_dir.exists(): # read from spikeinterface outputs
1295+
we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False)
1296+
unit_id_to_peak_channel_map: dict[
1297+
int, np.ndarray
1298+
] = si.ChannelSparsity.from_best_channels(
1299+
we, 1, peak_sign="neg"
1300+
).unit_id_to_channel_indices # {unit: peak_channel_index}
13121301

1313-
waveform_dir = output_dir / sorter_name / "waveform"
1314-
we: si.WaveformExtractor = si.load_waveforms(
1315-
waveform_dir, with_recording=False
1316-
)
1317-
unit_id_to_peak_channel_map: dict[int, np.ndarray] = (
1318-
si.ChannelSparsity.from_best_channels(
1319-
we, 1, peak_sign="neg"
1320-
).unit_id_to_channel_indices
1321-
) # {unit: peak_channel_index}
1302+
# reorder channel2electrode_map according to recording channel ids
1303+
channel2electrode_map = {
1304+
chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids
1305+
}
13221306

13231307
# Get mean waveform for each unit from all channels
13241308
mean_waveforms = we.get_all_templates(
@@ -1329,30 +1313,32 @@ def make(self, key):
13291313
unit_electrode_waveforms = []
13301314

13311315
for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"):
1316+
unit_waveforms = we.get_template(
1317+
unit_id=unit["unit"], mode="average", force_dense=True
1318+
) # (sample x channel)
1319+
peak_chn_idx = list(we.channel_ids).index(
1320+
unit_id_to_peak_channel_map[unit["unit"]][0]
1321+
)
13321322
unit_peak_waveform.append(
13331323
{
13341324
**unit,
1335-
"peak_electrode_waveform": we.get_template(
1336-
unit_id=unit["unit"], mode="average", force_dense=True
1337-
)[:, unit_id_to_peak_channel_map[unit["unit"]][0]],
1325+
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
13381326
}
13391327
)
1340-
13411328
unit_electrode_waveforms.extend(
13421329
[
13431330
{
13441331
**unit,
1345-
**channel_info[c],
1346-
"waveform_mean": mean_waveforms[unit["unit"] - 1, :, c],
1332+
**channel2electrode_map[c],
1333+
"waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx],
13471334
}
1348-
for c in channel_info
1335+
for c_idx, c in enumerate(channel2electrode_map)
13491336
]
13501337
)
13511338

13521339
self.insert1(key)
13531340
self.PeakWaveform.insert(unit_peak_waveform)
13541341
self.Waveform.insert(unit_electrode_waveforms)
1355-
13561342
else:
13571343
kilosort_dataset = kilosort.Kilosort(output_dir)
13581344

@@ -1390,12 +1376,12 @@ def yield_unit_waveforms():
13901376
unit_electrode_waveforms.append(
13911377
{
13921378
**units[unit_no],
1393-
**channel_info[channel],
1379+
**channel2electrode_map[channel],
13941380
"waveform_mean": channel_waveform,
13951381
}
13961382
)
13971383
if (
1398-
channel_info[channel]["electrode"]
1384+
channel2electrode_map[channel]["electrode"]
13991385
== units[unit_no]["electrode"]
14001386
):
14011387
unit_peak_waveform = {
@@ -1405,7 +1391,6 @@ def yield_unit_waveforms():
14051391
yield unit_peak_waveform, unit_electrode_waveforms
14061392

14071393
# Spike interface mean and peak waveform extraction from we object
1408-
14091394
elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists():
14101395
we_kilosort = si.load_waveforms(waveforms_folder[0].parent)
14111396
unit_templates = we_kilosort.get_all_templates()
@@ -1432,12 +1417,12 @@ def yield_unit_waveforms():
14321417
unit_electrode_waveforms.append(
14331418
{
14341419
**units[unit_no],
1435-
**channel_info[channel],
1420+
**channel2electrode_map[channel],
14361421
"waveform_mean": channel_waveform,
14371422
}
14381423
)
14391424
if (
1440-
channel_info[channel]["electrode"]
1425+
channel2electrode_map[channel]["electrode"]
14411426
== units[unit_no]["electrode"]
14421427
):
14431428
unit_peak_waveform = {
@@ -1506,13 +1491,13 @@ def yield_unit_waveforms():
15061491
unit_electrode_waveforms.append(
15071492
{
15081493
**unit_dict,
1509-
**channel_info[channel],
1494+
**channel2electrode_map[channel],
15101495
"waveform_mean": channel_waveform.mean(axis=0),
15111496
"waveforms": channel_waveform,
15121497
}
15131498
)
15141499
if (
1515-
channel_info[channel]["electrode"]
1500+
channel2electrode_map[channel]["electrode"]
15161501
== unit_dict["electrode"]
15171502
):
15181503
unit_peak_waveform = {
@@ -1630,12 +1615,15 @@ def make(self, key):
16301615
ClusteringTask * ClusteringParamSet & key
16311616
).fetch1("clustering_method", "clustering_output_dir")
16321617
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
1633-
sorter_name = (
1634-
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1635-
)
1636-
metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv"
1637-
if not metric_fp.exists():
1638-
raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
1618+
sorter_name = clustering_method.replace(".", "_")
1619+
1620+
# find metric_fp
1621+
for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]:
1622+
if metric_fp.exists():
1623+
break
1624+
else:
1625+
raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
1626+
16391627
metrics_df = pd.read_csv(metric_fp)
16401628

16411629
# Conform the dataframe to match the table definition

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,18 @@ def make(self, key):
132132
)
133133

134134
# Add probe information to recording object
135-
electrode_config_key = (
136-
probe.ElectrodeConfig * ephys.EphysRecording & key
137-
).fetch1("KEY")
138135
electrodes_df = (
139136
(
140-
probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
141-
& electrode_config_key
137+
ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
138+
& key
142139
)
143140
.fetch(format="frame")
144-
.reset_index()[["electrode", "x_coord", "y_coord", "shank"]]
141+
.reset_index()
145142
)
146143

147144
# Create SI probe object
148-
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df)
149-
si_probe.set_device_channel_indices(range(len(electrodes_df)))
145+
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]])
146+
si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
150147
si_recording.set_probe(probe=si_probe, in_place=True)
151148

152149
# Run preprocessing and save results to output folder

0 commit comments

Comments
 (0)