Skip to content

Commit 1df41ea

Browse files
committed
get channel to electrode mapping in CuratedClustering
1 parent 727af24 commit 1df41ea

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import numpy as np
1010
import pandas as pd
1111
import spikeinterface as si
12-
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
12+
from element_interface.utils import (dict_to_uuid, find_full_path,
13+
find_root_directory)
1314
from spikeinterface import exporters, postprocessing, qualitymetrics, sorters
1415

1516
from . import ephys_report, probe
@@ -1022,18 +1023,19 @@ def make(self, key):
10221023
extremum_channel_inds=unit_peak_channel_map
10231024
)
10241025

1025-
# Get electrode info !#TODO: need to be modified
1026+
# Get electrode & channel info
10261027
electrode_config_key = (
10271028
EphysRecording * probe.ElectrodeConfig & key
10281029
).fetch1("KEY")
10291030

10301031
electrode_query = (
10311032
probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
10321033
& electrode_config_key
1033-
)
1034+
) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel)
1035+
10341036
channel2electrode_map = dict(
1035-
zip(*electrode_query.fetch("channel", "electrode"))
1036-
)
1037+
zip(*electrode_query.fetch("channel_idx", "electrode"))
1038+
) # {channel: electrode}
10371039

10381040
# Get unit id to quality label mapping
10391041
cluster_quality_label_map = {}
@@ -1051,24 +1053,24 @@ def make(self, key):
10511053
pass
10521054

10531055
# Get channel to electrode mapping
1054-
channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord")))
1056+
channel2depth_map = dict(zip(*electrode_query.fetch("channel_idx", "y_coord"))) # {channel: depth}
10551057

10561058
peak_electrode_ind = np.array(
10571059
[
10581060
channel2electrode_map[unit_peak_channel_map[unit_id]]
10591061
for unit_id in si_sorting.unit_ids
10601062
]
1061-
)
1063+
) # get the electrode where peak unit activity is recorded
10621064

10631065
# Get channel to depth mapping
1064-
electrode_depth_ind = np.array(
1066+
channel_depth_ind = np.array(
10651067
[
10661068
channel2depth_map[unit_peak_channel_map[unit_id]]
10671069
for unit_id in si_sorting.unit_ids
10681070
]
10691071
)
10701072
spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]]
1071-
spikes["depth"] = electrode_depth_ind[spikes["unit_index"]]
1073+
spikes["depth"] = channel_depth_ind[spikes["unit_index"]]
10721074

10731075
units = []
10741076

@@ -1233,11 +1235,11 @@ def make(self, key):
12331235
we: si.WaveformExtractor = si.load_waveforms(
12341236
output_dir / "waveform", with_recording=False
12351237
)
1236-
unit_id_to_peak_channel_indices: dict[
1237-
int, np.ndarray
1238-
] = si.ChannelSparsity.from_best_channels(
1239-
we, 1, peak_sign="neg"
1240-
).unit_id_to_channel_indices # {unit: peak_channel_index}
1238+
unit_id_to_peak_channel_indices: dict[int, np.ndarray] = (
1239+
si.ChannelSparsity.from_best_channels(
1240+
we, 1, peak_sign="neg"
1241+
).unit_id_to_channel_indices
1242+
) # {unit: peak_channel_index}
12411243

12421244
units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit")
12431245

0 commit comments

Comments
 (0)