Skip to content

Commit 015341c

Browse files
committed
feat: test spikeinterface for spikeglx data
1 parent d778b1e commit 015341c

File tree

3 files changed

+72
-82
lines changed

3 files changed

+72
-82
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 67 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -352,24 +352,24 @@ def make(self, key):
352352
raise NotImplementedError(
353353
f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented."
354354
)
355-
else:
356-
probe_type = spikeglx_meta.probe_model
357-
electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
358355

359-
probe_electrodes = {
360-
(shank, shank_col, shank_row): key
361-
for key, shank, shank_col, shank_row in zip(
362-
*electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
363-
)
364-
} # electrode configuration
365-
electrode_group_members = [
366-
probe_electrodes[(shank, shank_col, shank_row)]
367-
for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
368-
] # recording session-specific electrode configuration
369-
370-
econfig_entry, econfig_electrodes = generate_electrode_config_entry(
371-
probe_type, electrode_group_members
356+
probe_type = spikeglx_meta.probe_model
357+
electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
358+
359+
probe_electrodes = {
360+
(shank, shank_col, shank_row): key
361+
for key, shank, shank_col, shank_row in zip(
362+
*electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
372363
)
364+
} # electrode configuration
365+
electrode_group_members = [
366+
probe_electrodes[(shank, shank_col, shank_row)]
367+
for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
368+
] # recording session-specific electrode configuration
369+
370+
econfig_entry, econfig_electrodes = generate_electrode_config_entry(
371+
probe_type, electrode_group_members
372+
)
373373

374374
ephys_recording_entry = {
375375
**key,
@@ -398,18 +398,6 @@ def make(self, key):
398398

399399
# Insert channel information
400400
# Get channel and electrode-site mapping
401-
electrode_query = (
402-
probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
403-
& {"electrode_config_hash": econfig_entry["electrode_config_hash"]}
404-
)
405-
406-
probe_electrodes = {
407-
(shank, shank_col, shank_row): key
408-
for key, shank, shank_col, shank_row in zip(
409-
*electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
410-
)
411-
}
412-
413401
channel2electrode_map = {
414402
recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
415403
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
@@ -418,7 +406,12 @@ def make(self, key):
418406
}
419407

420408
ephys_channel_entries = [
421-
{**key, "channel_idx": channel_idx, **channel_info}
409+
{
410+
**key,
411+
"electrode_config_hash": econfig_entry["electrode_config_hash"],
412+
"channel_idx": channel_idx,
413+
**channel_info,
414+
}
422415
for channel_idx, channel_info in channel2electrode_map.items()
423416
]
424417
elif acq_software == "Open Ephys":
@@ -438,24 +431,24 @@ def make(self, key):
438431

439432
if probe_data.probe_model not in supported_probe_types:
440433
raise NotImplementedError(
441-
f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented."
434+
f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented."
442435
)
443-
else:
444-
probe_type = probe_data.probe_model
445-
electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
446436

447-
probe_electrodes = {
448-
key["electrode"]: key for key in electrode_query.fetch("KEY")
449-
} # electrode configuration
437+
probe_type = probe_data.probe_model
438+
electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
450439

451-
electrode_group_members = [
452-
probe_electrodes[channel_idx]
453-
for channel_idx in probe_data.ap_meta["channels_indices"]
454-
] # recording session-specific electrode configuration
440+
probe_electrodes = {
441+
key["electrode"]: key for key in electrode_query.fetch("KEY")
442+
} # electrode configuration
455443

456-
econfig_entry, econfig_electrodes = generate_electrode_config_entry(
457-
probe_type, electrode_group_members
458-
)
444+
electrode_group_members = [
445+
probe_electrodes[channel_idx]
446+
for channel_idx in probe_data.ap_meta["channels_indices"]
447+
] # recording session-specific electrode configuration
448+
449+
econfig_entry, econfig_electrodes = generate_electrode_config_entry(
450+
probe_type, electrode_group_members
451+
)
459452

460453
ephys_recording_entry = {
461454
**key,
@@ -480,29 +473,24 @@ def make(self, key):
480473
for fp in probe_data.recording_info["recording_files"]
481474
]
482475

483-
# Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough
484-
del probe_data, dataset
485-
gc.collect()
486-
487-
probe_dataset = get_openephys_probe_data(key)
488-
electrode_query = (
489-
probe.ProbeType.Electrode
490-
* probe.ElectrodeConfig.Electrode
491-
* EphysRecording
492-
& key
493-
)
494-
probe_electrodes = {
495-
key["electrode"]: key for key in electrode_query.fetch("KEY")
496-
}
497476
channel2electrode_map = {
498477
channel_idx: probe_electrodes[channel_idx]
499-
for channel_idx in probe_dataset.ap_meta["channels_indices"]
478+
for channel_idx in probe_data.ap_meta["channels_indices"]
500479
}
501480

502481
ephys_channel_entries = [
503-
{**key, "channel_idx": channel_idx, **channel_info}
482+
{
483+
**key,
484+
"electrode_config_hash": econfig_entry["electrode_config_hash"],
485+
"channel_idx": channel_idx,
486+
**channel_info,
487+
}
504488
for channel_idx, channel_info in channel2electrode_map.items()
505489
]
490+
491+
# Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough
492+
del probe_data, dataset
493+
gc.collect()
506494
else:
507495
raise NotImplementedError(
508496
f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented."
@@ -1041,10 +1029,7 @@ def make(self, key):
10411029
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
10421030

10431031
# Get channel and electrode-site mapping
1044-
electrode_query = (
1045-
(EphysRecording.Channel & key)
1046-
.proj(..., "-channel_name")
1047-
)
1032+
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
10481033
channel2electrode_map = electrode_query.fetch(as_dict=True)
10491034
channel2electrode_map: dict[int, dict] = {
10501035
chn.pop("channel_idx"): chn for chn in channel2electrode_map
@@ -1058,7 +1043,9 @@ def make(self, key):
10581043
if si_waveform_dir.exists():
10591044

10601045
# Read from spikeinterface outputs
1061-
we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False)
1046+
we: si.WaveformExtractor = si.load_waveforms(
1047+
si_waveform_dir, with_recording=False
1048+
)
10621049
si_sorting: si.sorters.BaseSorter = si.load_extractor(
10631050
si_sorting_dir / "si_sorting.pkl"
10641051
)
@@ -1139,10 +1126,10 @@ def make(self, key):
11391126
"spike_count": spike_count_dict[unit_id],
11401127
"spike_sites": new_spikes["electrode"][
11411128
new_spikes["unit_index"] == unit_id
1142-
],
1129+
],
11431130
"spike_depths": new_spikes["depth"][
11441131
new_spikes["unit_index"] == unit_id
1145-
],
1132+
],
11461133
}
11471134
)
11481135

@@ -1281,23 +1268,22 @@ def make(self, key):
12811268
sorter_name = clustering_method.replace(".", "_")
12821269

12831270
# Get channel and electrode-site mapping
1284-
electrode_query = (
1285-
(EphysRecording.Channel & key)
1286-
.proj(..., "-channel_name")
1287-
)
1271+
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
12881272
channel2electrode_map = electrode_query.fetch(as_dict=True)
12891273
channel2electrode_map: dict[int, dict] = {
12901274
chn.pop("channel_idx"): chn for chn in channel2electrode_map
12911275
}
12921276

12931277
si_waveform_dir = output_dir / sorter_name / "waveform"
12941278
if si_waveform_dir.exists(): # read from spikeinterface outputs
1295-
we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False)
1296-
unit_id_to_peak_channel_map: dict[
1297-
int, np.ndarray
1298-
] = si.ChannelSparsity.from_best_channels(
1299-
we, 1, peak_sign="neg"
1300-
).unit_id_to_channel_indices # {unit: peak_channel_index}
1279+
we: si.WaveformExtractor = si.load_waveforms(
1280+
si_waveform_dir, with_recording=False
1281+
)
1282+
unit_id_to_peak_channel_map: dict[int, np.ndarray] = (
1283+
si.ChannelSparsity.from_best_channels(
1284+
we, 1, peak_sign="neg"
1285+
).unit_id_to_channel_indices
1286+
) # {unit: peak_channel_index}
13011287

13021288
# reorder channel2electrode_map according to recording channel ids
13031289
channel2electrode_map = {
@@ -1391,6 +1377,7 @@ def yield_unit_waveforms():
13911377
yield unit_peak_waveform, unit_electrode_waveforms
13921378

13931379
# Spike interface mean and peak waveform extraction from we object
1380+
13941381
elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists():
13951382
we_kilosort = si.load_waveforms(waveforms_folder[0].parent)
13961383
unit_templates = we_kilosort.get_all_templates()
@@ -1618,7 +1605,10 @@ def make(self, key):
16181605
sorter_name = clustering_method.replace(".", "_")
16191606

16201607
# find metric_fp
1621-
for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]:
1608+
for metric_fp in [
1609+
output_dir / "metrics.csv",
1610+
output_dir / sorter_name / "metrics" / "metrics.csv",
1611+
]:
16221612
if metric_fp.exists():
16231613
break
16241614
else:

element_array_ephys/spike_sorting/si_preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from spikeinterface import preprocessing
33

44

5-
def catGT(recording):
5+
def CatGT(recording):
66
recording = si.preprocessing.phase_shift(recording)
77
recording = si.preprocessing.common_reference(
88
recording, operator="median", reference="global"

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def make(self, key):
127127
stream_names, stream_ids = si.extractors.get_neo_streams(
128128
acq_software, folder_path=data_dir
129129
)
130-
si_recording: si.BaseRecording = si_extractor[acq_software](
130+
si_recording: si.BaseRecording = si_extractor(
131131
folder_path=data_dir, stream_name=stream_names[0]
132132
)
133133

@@ -184,7 +184,7 @@ def make(self, key):
184184
output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
185185
sorter_name = clustering_method.replace(".", "_")
186186
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
187-
si_recording: si.BaseRecording = si.load_extractor(recording_file)
187+
si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir)
188188

189189
# Run sorting
190190
# Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
@@ -241,8 +241,8 @@ def make(self, key):
241241
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
242242
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
243243

244-
si_recording: si.BaseRecording = si.load_extractor(recording_file)
245-
si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file)
244+
si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir)
245+
si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir)
246246

247247
# Extract waveforms
248248
we: si.WaveformExtractor = si.extract_waveforms(

0 commit comments

Comments
 (0)