Skip to content

Commit 05ccfdb

Browse files
committed
fix: update ingestion from spikeinterface results
1 parent 015341c commit 05ccfdb

File tree

1 file changed

+27
-110
lines changed

1 file changed

+27
-110
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 27 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,18 +1040,16 @@ def make(self, key):
10401040
si_waveform_dir = output_dir / sorter_name / "waveform"
10411041
si_sorting_dir = output_dir / sorter_name / "spike_sorting"
10421042

1043-
if si_waveform_dir.exists():
1044-
1045-
# Read from spikeinterface outputs
1043+
if si_waveform_dir.exists(): # Read from spikeinterface outputs
10461044
we: si.WaveformExtractor = si.load_waveforms(
10471045
si_waveform_dir, with_recording=False
10481046
)
10491047
si_sorting: si.sorters.BaseSorter = si.load_extractor(
1050-
si_sorting_dir / "si_sorting.pkl"
1048+
si_sorting_dir / "si_sorting.pkl", base_folder=output_dir
10511049
)
10521050

10531051
unit_peak_channel: dict[int, int] = si.get_template_extremum_channel(
1054-
we, outputs="id"
1052+
we, outputs="index"
10551053
) # {unit: peak_channel_id}
10561054

10571055
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
@@ -1061,7 +1059,8 @@ def make(self, key):
10611059

10621060
# reorder channel2electrode_map according to recording channel ids
10631061
channel2electrode_map = {
1064-
chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids
1062+
chn_idx: channel2electrode_map[chn_idx]
1063+
for chn_idx in we.channel_ids_to_indices(we.channel_ids)
10651064
}
10661065

10671066
# Get unit id to quality label mapping
@@ -1090,7 +1089,7 @@ def make(self, key):
10901089
# Get channel depth
10911090
channel_depth_ind = np.array(
10921091
[
1093-
channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"]
1092+
we.get_probe().contact_positions[unit_peak_channel[unit_id]][1]
10941093
for unit_id in si_sorting.unit_ids
10951094
]
10961095
)
@@ -1132,7 +1131,6 @@ def make(self, key):
11321131
],
11331132
}
11341133
)
1135-
11361134
else: # read from kilosort outputs
11371135
kilosort_dataset = kilosort.Kilosort(output_dir)
11381136
acq_software, sample_rate = (EphysRecording & key).fetch1(
@@ -1286,46 +1284,38 @@ def make(self, key):
12861284
) # {unit: peak_channel_index}
12871285

12881286
# reorder channel2electrode_map according to recording channel ids
1287+
channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist()
12891288
channel2electrode_map = {
1290-
chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids
1289+
chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices
12911290
}
12921291

1293-
# Get mean waveform for each unit from all channels
1294-
mean_waveforms = we.get_all_templates(
1295-
mode="average"
1296-
) # (unit x sample x channel)
1297-
1298-
unit_peak_waveform = []
1299-
unit_electrode_waveforms = []
1300-
1301-
for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"):
1302-
unit_waveforms = we.get_template(
1303-
unit_id=unit["unit"], mode="average", force_dense=True
1304-
) # (sample x channel)
1305-
peak_chn_idx = list(we.channel_ids).index(
1306-
unit_id_to_peak_channel_map[unit["unit"]][0]
1307-
)
1308-
unit_peak_waveform.append(
1309-
{
1292+
def yield_unit_waveforms():
1293+
for unit in (CuratedClustering.Unit & key).fetch(
1294+
"KEY", order_by="unit"
1295+
):
1296+
# Get mean waveform for this unit from all channels - (sample x channel)
1297+
unit_waveforms = we.get_template(
1298+
unit_id=unit["unit"], mode="average", force_dense=True
1299+
)
1300+
peak_chn_idx = channel_indices.index(
1301+
unit_id_to_peak_channel_map[unit["unit"]][0]
1302+
)
1303+
unit_peak_waveform = {
13101304
**unit,
13111305
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
13121306
}
1313-
)
1314-
unit_electrode_waveforms.extend(
1315-
[
1307+
1308+
unit_electrode_waveforms = [
13161309
{
13171310
**unit,
1318-
**channel2electrode_map[c],
1319-
"waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx],
1311+
**channel2electrode_map[chn_idx],
1312+
"waveform_mean": unit_waveforms[:, chn_idx],
13201313
}
1321-
for c_idx, c in enumerate(channel2electrode_map)
1314+
for chn_idx in channel_indices
13221315
]
1323-
)
13241316

1325-
self.insert1(key)
1326-
self.PeakWaveform.insert(unit_peak_waveform)
1327-
self.Waveform.insert(unit_electrode_waveforms)
1328-
else:
1317+
yield unit_peak_waveform, unit_electrode_waveforms
1318+
else: # read from kilosort outputs
13291319
kilosort_dataset = kilosort.Kilosort(output_dir)
13301320

13311321
acq_software, probe_serial_number = (
@@ -1340,10 +1330,6 @@ def make(self, key):
13401330
)
13411331
}
13421332

1343-
waveforms_folder = [
1344-
f for f in output_dir.parent.rglob(r"*/waveforms*") if f.is_dir()
1345-
]
1346-
13471333
if (output_dir / "mean_waveforms.npy").exists():
13481334
unit_waveforms = np.load(
13491335
output_dir / "mean_waveforms.npy"
@@ -1376,75 +1362,6 @@ def yield_unit_waveforms():
13761362
}
13771363
yield unit_peak_waveform, unit_electrode_waveforms
13781364

1379-
# Spike interface mean and peak waveform extraction from we object
1380-
1381-
elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists():
1382-
we_kilosort = si.load_waveforms(waveforms_folder[0].parent)
1383-
unit_templates = we_kilosort.get_all_templates()
1384-
unit_waveforms = np.reshape(
1385-
unit_templates,
1386-
(
1387-
unit_templates.shape[1],
1388-
unit_templates.shape[3],
1389-
unit_templates.shape[2],
1390-
),
1391-
)
1392-
1393-
# Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms)
1394-
def yield_unit_waveforms():
1395-
for unit_no, unit_waveform in zip(
1396-
kilosort_dataset.data["cluster_ids"], unit_waveforms
1397-
):
1398-
unit_peak_waveform = {}
1399-
unit_electrode_waveforms = []
1400-
if unit_no in units:
1401-
for channel, channel_waveform in zip(
1402-
kilosort_dataset.data["channel_map"], unit_waveform
1403-
):
1404-
unit_electrode_waveforms.append(
1405-
{
1406-
**units[unit_no],
1407-
**channel2electrode_map[channel],
1408-
"waveform_mean": channel_waveform,
1409-
}
1410-
)
1411-
if (
1412-
channel2electrode_map[channel]["electrode"]
1413-
== units[unit_no]["electrode"]
1414-
):
1415-
unit_peak_waveform = {
1416-
**units[unit_no],
1417-
"peak_electrode_waveform": channel_waveform,
1418-
}
1419-
yield unit_peak_waveform, unit_electrode_waveforms
1420-
1421-
# Approach not using spike interface templates (ie. taking mean of each unit waveform)
1422-
# def yield_unit_waveforms():
1423-
# for unit_id in we_kilosort.unit_ids:
1424-
# unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0)
1425-
# unit_peak_waveform = {}
1426-
# unit_electrode_waveforms = []
1427-
# if unit_id in units:
1428-
# for channel, channel_waveform in zip(
1429-
# kilosort_dataset.data["channel_map"], unit_waveform
1430-
# ):
1431-
# unit_electrode_waveforms.append(
1432-
# {
1433-
# **units[unit_id],
1434-
# **channel2electrodes[channel],
1435-
# "waveform_mean": channel_waveform,
1436-
# }
1437-
# )
1438-
# if (
1439-
# channel2electrodes[channel]["electrode"]
1440-
# == units[unit_id]["electrode"]
1441-
# ):
1442-
# unit_peak_waveform = {
1443-
# **units[unit_id],
1444-
# "peak_electrode_waveform": channel_waveform,
1445-
# }
1446-
# yield unit_peak_waveform, unit_electrode_waveforms
1447-
14481365
else:
14491366
if acq_software == "SpikeGLX":
14501367
spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)

0 commit comments

Comments
 (0)