Skip to content

Commit 1a1b18f

Browse files
committed
feat: replace output_folder with folder when calling run_sorter, use default value for peak_sign
1 parent b459709 commit 1a1b18f

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,10 +1045,13 @@ def make(self, key):
10451045
# Find representative channel for each unit
10461046
unit_peak_channel: dict[int, np.ndarray] = (
10471047
si.ChannelSparsity.from_best_channels(
1048-
sorting_analyzer, 1, peak_sign="both"
1048+
sorting_analyzer,
1049+
1,
10491050
).unit_id_to_channel_indices
10501051
)
1051-
unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()}
1052+
unit_peak_channel: dict[int, int] = {
1053+
u: chn[0] for u, chn in unit_peak_channel.items()
1054+
}
10521055

10531056
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
10541057
# {unit: spike_count}
@@ -1084,7 +1087,9 @@ def make(self, key):
10841087
)
10851088
unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
10861089
_, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
1087-
spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True)
1090+
spike_times = si_sorting.get_unit_spike_train(
1091+
unit_id, return_times=True
1092+
)
10881093

10891094
assert len(spike_times) == len(spike_sites) == len(spike_depths)
10901095

@@ -1243,13 +1248,13 @@ def make(self, key):
12431248
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
12441249
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
12451250
import spikeinterface as si
1246-
1251+
12471252
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
12481253

12491254
# Find representative channel for each unit
12501255
unit_peak_channel: dict[int, np.ndarray] = (
12511256
si.ChannelSparsity.from_best_channels(
1252-
sorting_analyzer, 1, peak_sign="both"
1257+
sorting_analyzer, 1
12531258
).unit_id_to_channel_indices
12541259
) # {unit: peak_channel_index}
12551260
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
@@ -1272,7 +1277,9 @@ def yield_unit_waveforms():
12721277
)
12731278
unit_peak_waveform = {
12741279
**unit,
1275-
"peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]],
1280+
"peak_electrode_waveform": unit_waveforms[
1281+
:, unit_peak_channel[unit["unit"]]
1282+
],
12761283
}
12771284

12781285
unit_electrode_waveforms = [
@@ -1495,7 +1502,7 @@ def make(self, key):
14951502
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
14961503
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
14971504
import spikeinterface as si
1498-
1505+
14991506
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
15001507
qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
15011508
template_metrics = sorting_analyzer.get_extension(

element_array_ephys/spike_sorting/si_spike_sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _run_sorter():
205205
si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
206206
sorter_name=sorter_name,
207207
recording=si_recording,
208-
output_folder=sorting_output_dir,
208+
folder=sorting_output_dir,
209209
remove_existing_folder=True,
210210
verbose=True,
211211
docker_image=sorter_name not in si.sorters.installed_sorters(),

0 commit comments

Comments
 (0)