Skip to content

Commit f70ae4e

Browse files
committed
feat: ✨ replace get_neuropixels_channel2electrode_map with channel_info
1 parent 417219f commit f70ae4e

File tree

1 file changed

+36
-18
lines changed

1 file changed

+36
-18
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def make(self, key):
10961096
}
10971097
)
10981098

1099-
else:
1099+
else: # read from kilosort outputs
11001100
kilosort_dataset = kilosort.Kilosort(output_dir)
11011101
acq_software, sample_rate = (EphysRecording & key).fetch1(
11021102
"acq_software", "sampling_rate"
@@ -1131,14 +1131,19 @@ def make(self, key):
11311131
kilosort_dataset.extract_spike_depths()
11321132

11331133
# Get channel and electrode-site mapping
1134-
channel2electrodes = get_neuropixels_channel2electrode_map(
1135-
key, acq_software
1134+
channel_info = (
1135+
(EphysRecording.Channel & key)
1136+
.proj(..., "-channel_name")
1137+
.fetch(as_dict=True, order_by="channel_idx")
11361138
)
1139+
channel_info: dict[int, dict] = {
1140+
ch.pop("channel_idx"): ch for ch in channel_info
1141+
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
11371142

11381143
# -- Spike-sites and Spike-depths --
11391144
spike_sites = np.array(
11401145
[
1141-
channel2electrodes[s]["electrode"]
1146+
channel_info[s]["electrode"]
11421147
for s in kilosort_dataset.data["spike_sites"]
11431148
]
11441149
)
@@ -1157,9 +1162,10 @@ def make(self, key):
11571162

11581163
units.append(
11591164
{
1165+
**key,
11601166
"unit": unit,
11611167
"cluster_quality_label": unit_lbl,
1162-
**channel2electrodes[unit_channel],
1168+
**channel_info[unit_channel],
11631169
"spike_times": unit_spike_times,
11641170
"spike_count": spike_count,
11651171
"spike_sites": spike_sites[
@@ -1228,13 +1234,21 @@ class Waveform(dj.Part):
12281234

12291235
def make(self, key):
12301236
"""Populates waveform tables."""
1231-
output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
1237+
clustering_method, output_dir = (
1238+
ClusteringTask * ClusteringParamSet & key
1239+
).fetch1("clustering_method", "clustering_output_dir")
12321240
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
1241+
sorter_name = (
1242+
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1243+
)
12331244

1234-
if (output_dir / "waveform").exists(): # read from spikeinterface outputs
1245+
if (
1246+
output_dir / sorter_name / "waveform"
1247+
).exists(): # read from spikeinterface outputs
12351248

1249+
waveform_dir = output_dir / sorter_name / "waveform"
12361250
we: si.WaveformExtractor = si.load_waveforms(
1237-
output_dir / "waveform", with_recording=False
1251+
waveform_dir, with_recording=False
12381252
)
12391253
unit_id_to_peak_channel_indices: dict[int, np.ndarray] = (
12401254
si.ChannelSparsity.from_best_channels(
@@ -1299,11 +1313,15 @@ def make(self, key):
12991313
EphysRecording * ProbeInsertion & key
13001314
).fetch1("acq_software", "probe")
13011315

1302-
# -- Get channel and electrode-site mapping
1303-
recording_key = (EphysRecording & key).fetch1("KEY")
1304-
channel2electrodes = get_neuropixels_channel2electrode_map(
1305-
recording_key, acq_software
1316+
# Get channel and electrode-site mapping
1317+
channel_info = (
1318+
(EphysRecording.Channel & key)
1319+
.proj(..., "-channel_name")
1320+
.fetch(as_dict=True, order_by="channel_idx")
13061321
)
1322+
channel_info: dict[int, dict] = {
1323+
ch.pop("channel_idx"): ch for ch in channel_info
1324+
} # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
13071325

13081326
# Get all units
13091327
units = {
@@ -1335,12 +1353,12 @@ def yield_unit_waveforms():
13351353
unit_electrode_waveforms.append(
13361354
{
13371355
**units[unit_no],
1338-
**channel2electrodes[channel],
1356+
**channel_info[channel],
13391357
"waveform_mean": channel_waveform,
13401358
}
13411359
)
13421360
if (
1343-
channel2electrodes[channel]["electrode"]
1361+
channel_info[channel]["electrode"]
13441362
== units[unit_no]["electrode"]
13451363
):
13461364
unit_peak_waveform = {
@@ -1377,12 +1395,12 @@ def yield_unit_waveforms():
13771395
unit_electrode_waveforms.append(
13781396
{
13791397
**units[unit_no],
1380-
**channel2electrodes[channel],
1398+
**channel_info[channel],
13811399
"waveform_mean": channel_waveform,
13821400
}
13831401
)
13841402
if (
1385-
channel2electrodes[channel]["electrode"]
1403+
channel_info[channel]["electrode"]
13861404
== units[unit_no]["electrode"]
13871405
):
13881406
unit_peak_waveform = {
@@ -1451,13 +1469,13 @@ def yield_unit_waveforms():
14511469
unit_electrode_waveforms.append(
14521470
{
14531471
**unit_dict,
1454-
**channel2electrodes[channel],
1472+
**channel_info[channel],
14551473
"waveform_mean": channel_waveform.mean(axis=0),
14561474
"waveforms": channel_waveform,
14571475
}
14581476
)
14591477
if (
1460-
channel2electrodes[channel]["electrode"]
1478+
channel_info[channel]["electrode"]
14611479
== unit_dict["electrode"]
14621480
):
14631481
unit_peak_waveform = {

0 commit comments

Comments
 (0)