Skip to content

Commit 0ccbec9

Browse files
committed
fix CuratedClustering make function
1 parent f70ae4e commit 0ccbec9

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,12 @@ def make(self, key):
10321032
& electrode_config_key
10331033
) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel)
10341034

1035+
channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx")
1036+
1037+
channel_info: dict[int, dict] = {
1038+
ch.pop("channel_idx"): ch for ch in channel_info
1039+
}
1040+
10351041
channel2electrode_map = dict(
10361042
zip(*electrode_query.fetch("channel_idx", "electrode"))
10371043
) # {channel: electrode}
@@ -1058,27 +1064,37 @@ def make(self, key):
10581064

10591065
peak_electrode_ind = np.array(
10601066
[
1061-
channel2electrode_map[unit_peak_channel_map[unit_id]]
1067+
channel_info[unit_peak_channel_map[unit_id]]["electrode"]
10621068
for unit_id in si_sorting.unit_ids
10631069
]
10641070
) # get the electrode where peak unit activity is recorded
10651071

10661072
# Get channel to depth mapping
10671073
channel_depth_ind = np.array(
10681074
[
1069-
channel2depth_map[unit_peak_channel_map[unit_id]]
1075+
channel_info[unit_peak_channel_map[unit_id]]["y_coord"]
10701076
for unit_id in si_sorting.unit_ids
10711077
]
10721078
)
1073-
spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]]
1074-
spikes["depth"] = channel_depth_ind[spikes["unit_index"]]
1079+
1080+
# Assign electrode and depth for each spike
1081+
new_spikes = np.empty(spikes.shape, spikes.dtype.descr + [('electrode', '<i8'), ('depth', '<i8')])
1082+
1083+
for field in spikes.dtype.names:
1084+
new_spikes[field] = spikes[field]
1085+
del spikes
1086+
1087+
new_spikes["electrode"] = peak_electrode_ind[new_spikes["unit_index"]]
1088+
new_spikes["depth"] = channel_depth_ind[new_spikes["unit_index"]]
10751089

10761090
units = []
10771091

10781092
for unit_id in si_sorting.unit_ids:
10791093
unit_id = int(unit_id)
10801094
units.append(
1081-
{
1095+
{
1096+
**key,
1097+
**channel_info[unit_peak_channel_map[unit_id]],
10821098
"unit": unit_id,
10831099
"cluster_quality_label": cluster_quality_label_map.get(
10841100
unit_id, "n.a."
@@ -1087,11 +1103,11 @@ def make(self, key):
10871103
unit_id, return_times=True
10881104
),
10891105
"spike_count": spike_count_dict[unit_id],
1090-
"spike_sites": spikes["electrode"][
1091-
spikes["unit_index"] == unit_id
1106+
"spike_sites": new_spikes["electrode"][
1107+
new_spikes["unit_index"] == unit_id
10921108
],
1093-
"spike_depths": spikes["depth"][
1094-
spikes["unit_index"] == unit_id
1109+
"spike_depths": new_spikes["depth"][
1110+
new_spikes["unit_index"] == unit_id
10951111
],
10961112
}
10971113
)
@@ -1178,7 +1194,7 @@ def make(self, key):
11781194
)
11791195

11801196
self.insert1(key)
1181-
self.Unit.insert([{**key, **u} for u in units])
1197+
self.Unit.insert(units, ignore_extra_fields=True)
11821198

11831199

11841200
@schema

0 commit comments

Comments
 (0)