Skip to content

Commit 226142b

Browse files
committed
feat: ✨ Update WaveformSet make function
1 parent 8ab4c58 commit 226142b

File tree

1 file changed

+16
-31
lines changed

1 file changed

+16
-31
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,16 @@ def make(self, key):
12481248
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
12491249
)
12501250

1251+
# Get channel and electrode-site mapping
1252+
channel_info = (
1253+
(EphysRecording.Channel & key)
1254+
.proj(..., "-channel_name")
1255+
.fetch(as_dict=True, order_by="channel_idx")
1256+
)
1257+
channel_info: dict[int, dict] = {
1258+
ch.pop("channel_idx"): ch for ch in channel_info
1259+
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1260+
12511261
if (
12521262
output_dir / sorter_name / "waveform"
12531263
).exists(): # read from spikeinterface outputs
@@ -1256,27 +1266,12 @@ def make(self, key):
12561266
we: si.WaveformExtractor = si.load_waveforms(
12571267
waveform_dir, with_recording=False
12581268
)
1259-
unit_id_to_peak_channel_indices: dict[int, np.ndarray] = (
1269+
unit_id_to_peak_channel_map: dict[int, np.ndarray] = (
12601270
si.ChannelSparsity.from_best_channels(
12611271
we, 1, peak_sign="neg"
12621272
).unit_id_to_channel_indices
12631273
) # {unit: peak_channel_index}
12641274

1265-
units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit")
1266-
1267-
# Get electrode info
1268-
electrode_config_key = (
1269-
EphysRecording * probe.ElectrodeConfig & key
1270-
).fetch1("KEY")
1271-
1272-
electrode_query = (
1273-
probe.ProbeType.Electrode.proj() * probe.ElectrodeConfig.Electrode
1274-
& electrode_config_key
1275-
)
1276-
electrode_info = electrode_query.fetch(
1277-
"KEY", order_by="electrode", as_dict=True
1278-
)
1279-
12801275
# Get mean waveform for each unit from all channels
12811276
mean_waveforms = we.get_all_templates(
12821277
mode="average"
@@ -1285,26 +1280,26 @@ def make(self, key):
12851280
unit_peak_waveform = []
12861281
unit_electrode_waveforms = []
12871282

1288-
for unit in units:
1283+
for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"):
12891284
unit_peak_waveform.append(
12901285
{
12911286
**unit,
12921287
"peak_electrode_waveform": we.get_template(
12931288
unit_id=unit["unit"], mode="average", force_dense=True
1294-
)[:, unit_id_to_peak_channel_indices[unit["unit"]][0]],
1289+
)[:, unit_id_to_peak_channel_map[unit["unit"]][0]],
12951290
}
12961291
)
12971292

12981293
unit_electrode_waveforms.extend(
12991294
[
13001295
{
13011296
**unit,
1302-
**e,
1297+
**channel_info[c],
13031298
"waveform_mean": mean_waveforms[
1304-
unit["unit"], :, e["electrode"]
1299+
unit["unit"] - 1, :, c
13051300
],
13061301
}
1307-
for e in electrode_info
1302+
for c in channel_info
13081303
]
13091304
)
13101305

@@ -1319,16 +1314,6 @@ def make(self, key):
13191314
EphysRecording * ProbeInsertion & key
13201315
).fetch1("acq_software", "probe")
13211316

1322-
# Get channel and electrode-site mapping
1323-
channel_info = (
1324-
(EphysRecording.Channel & key)
1325-
.proj(..., "-channel_name")
1326-
.fetch(as_dict=True, order_by="channel_idx")
1327-
)
1328-
channel_info: dict[int, dict] = {
1329-
ch.pop("channel_idx"): ch for ch in channel_info
1330-
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1331-
13321317
# Get all units
13331318
units = {
13341319
u["unit"]: u

0 commit comments

Comments
 (0)