Skip to content

Commit 2a4b166

Browse files
Merge pull request #190 from ttngu207/dev_spikeinterface_v101
Dev spikeinterface v101
2 parents 403d1df + 51e2ced commit 2a4b166

File tree

5 files changed

+180
-197
lines changed

5 files changed

+180
-197
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [0.4.0] - 2024-05-28
7+
8+
+ Add - support for SpikeInterface version >= 0.101.0 (updated API)
9+
+ Add - feature for memoization of spike sorting results (prevent duplicated runs)
10+
11+
612
## [0.3.4] - 2024-03-22
713

814
+ Add - pytest

element_array_ephys/ephys_no_curation.py

Lines changed: 105 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,98 +1037,70 @@ def make(self, key):
10371037

10381038
# Get sorter method and create output directory.
10391039
sorter_name = clustering_method.replace(".", "_")
1040-
si_waveform_dir = output_dir / sorter_name / "waveform"
1041-
si_sorting_dir = output_dir / sorter_name / "spike_sorting"
1040+
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
10421041

1043-
if si_waveform_dir.exists(): # Read from spikeinterface outputs
1044-
we: si.WaveformExtractor = si.load_waveforms(
1045-
si_waveform_dir, with_recording=False
1046-
)
1047-
si_sorting: si.sorters.BaseSorter = si.load_extractor(
1048-
si_sorting_dir / "si_sorting.pkl", base_folder=output_dir
1049-
)
1042+
if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
1043+
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
1044+
si_sorting = sorting_analyzer.sorting
10501045

1051-
unit_peak_channel: dict[int, int] = si.get_template_extremum_channel(
1052-
we, outputs="index"
1053-
) # {unit: peak_channel_id}
1046+
# Find representative channel for each unit
1047+
unit_peak_channel: dict[int, np.ndarray] = (
1048+
si.ChannelSparsity.from_best_channels(
1049+
sorting_analyzer, 1, peak_sign="neg"
1050+
).unit_id_to_channel_indices
1051+
)
1052+
unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()}
10541053

10551054
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
10561055
# {unit: spike_count}
10571056

1058-
spikes = si_sorting.to_spike_vector()
1059-
10601057
# reorder channel2electrode_map according to recording channel ids
10611058
channel2electrode_map = {
10621059
chn_idx: channel2electrode_map[chn_idx]
1063-
for chn_idx in we.channel_ids_to_indices(we.channel_ids)
1060+
for chn_idx in sorting_analyzer.channel_ids_to_indices(
1061+
sorting_analyzer.channel_ids
1062+
)
10641063
}
10651064

10661065
# Get unit id to quality label mapping
1067-
try:
1068-
cluster_quality_label_map = pd.read_csv(
1069-
si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv",
1070-
delimiter="\t",
1066+
cluster_quality_label_map = {
1067+
int(unit_id): (
1068+
si_sorting.get_unit_property(unit_id, "KSLabel")
1069+
if "KSLabel" in si_sorting.get_property_keys()
1070+
else "n.a."
10711071
)
1072-
except FileNotFoundError:
1073-
cluster_quality_label_map = {}
1074-
else:
1075-
cluster_quality_label_map: dict[
1076-
int, str
1077-
] = cluster_quality_label_map.set_index("cluster_id")[
1078-
"KSLabel"
1079-
].to_dict() # {unit: quality_label}
1080-
1081-
# Get electrode where peak unit activity is recorded
1082-
peak_electrode_ind = np.array(
1083-
[
1084-
channel2electrode_map[unit_peak_channel[unit_id]]["electrode"]
1085-
for unit_id in si_sorting.unit_ids
1086-
]
1087-
)
1088-
1089-
# Get channel depth
1090-
channel_depth_ind = np.array(
1091-
[
1092-
we.get_probe().contact_positions[unit_peak_channel[unit_id]][1]
1093-
for unit_id in si_sorting.unit_ids
1094-
]
1095-
)
1096-
1097-
# Assign electrode and depth for each spike
1098-
new_spikes = np.empty(
1099-
spikes.shape,
1100-
spikes.dtype.descr + [("electrode", "<i8"), ("depth", "<i8")],
1101-
)
1102-
1103-
for field in spikes.dtype.names:
1104-
new_spikes[field] = spikes[field]
1105-
del spikes
1072+
for unit_id in si_sorting.unit_ids
1073+
}
11061074

1107-
new_spikes["electrode"] = peak_electrode_ind[new_spikes["unit_index"]]
1108-
new_spikes["depth"] = channel_depth_ind[new_spikes["unit_index"]]
1075+
spike_locations = sorting_analyzer.get_extension("spike_locations")
1076+
spikes_df = pd.DataFrame(spike_locations.spikes)
11091077

11101078
units = []
1111-
1112-
for unit_id in si_sorting.unit_ids:
1079+
for unit_idx, unit_id in enumerate(si_sorting.unit_ids):
11131080
unit_id = int(unit_id)
1081+
unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx]
1082+
spike_sites = np.array(
1083+
[
1084+
channel2electrode_map[chn_idx]["electrode"]
1085+
for chn_idx in unit_spikes_df.channel_index
1086+
]
1087+
)
1088+
unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
1089+
_, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
1090+
spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True)
1091+
1092+
assert len(spike_times) == len(spike_sites) == len(spike_depths)
1093+
11141094
units.append(
11151095
{
11161096
**key,
11171097
**channel2electrode_map[unit_peak_channel[unit_id]],
11181098
"unit": unit_id,
1119-
"cluster_quality_label": cluster_quality_label_map.get(
1120-
unit_id, "n.a."
1121-
),
1122-
"spike_times": si_sorting.get_unit_spike_train(
1123-
unit_id, return_times=True
1124-
),
1099+
"cluster_quality_label": cluster_quality_label_map[unit_id],
1100+
"spike_times": spike_times,
11251101
"spike_count": spike_count_dict[unit_id],
1126-
"spike_sites": new_spikes["electrode"][
1127-
new_spikes["unit_index"] == unit_id
1128-
],
1129-
"spike_depths": new_spikes["depth"][
1130-
new_spikes["unit_index"] == unit_id
1131-
],
1102+
"spike_sites": spike_sites,
1103+
"spike_depths": spike_depths,
11321104
}
11331105
)
11341106
else: # read from kilosort outputs
@@ -1272,33 +1244,38 @@ def make(self, key):
12721244
chn.pop("channel_idx"): chn for chn in channel2electrode_map
12731245
}
12741246

1275-
si_waveform_dir = output_dir / sorter_name / "waveform"
1276-
if si_waveform_dir.exists(): # read from spikeinterface outputs
1277-
we: si.WaveformExtractor = si.load_waveforms(
1278-
si_waveform_dir, with_recording=False
1279-
)
1280-
unit_id_to_peak_channel_map: dict[
1281-
int, np.ndarray
1282-
] = si.ChannelSparsity.from_best_channels(
1283-
we, 1, peak_sign="neg"
1284-
).unit_id_to_channel_indices # {unit: peak_channel_index}
1247+
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1248+
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
1249+
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
1250+
1251+
# Find representative channel for each unit
1252+
unit_peak_channel: dict[int, np.ndarray] = (
1253+
si.ChannelSparsity.from_best_channels(
1254+
sorting_analyzer, 1, peak_sign="neg"
1255+
).unit_id_to_channel_indices
1256+
) # {unit: peak_channel_index}
1257+
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
12851258

12861259
# reorder channel2electrode_map according to recording channel ids
1287-
channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist()
1260+
channel_indices = sorting_analyzer.channel_ids_to_indices(
1261+
sorting_analyzer.channel_ids
1262+
).tolist()
12881263
channel2electrode_map = {
12891264
chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices
12901265
}
12911266

1267+
templates = sorting_analyzer.get_extension("templates")
1268+
12921269
def yield_unit_waveforms():
12931270
for unit in (CuratedClustering.Unit & key).fetch(
12941271
"KEY", order_by="unit"
12951272
):
12961273
# Get mean waveform for this unit from all channels - (sample x channel)
1297-
unit_waveforms = we.get_template(
1298-
unit_id=unit["unit"], mode="average", force_dense=True
1274+
unit_waveforms = templates.get_unit_template(
1275+
unit_id=unit["unit"], operator="average"
12991276
)
13001277
peak_chn_idx = channel_indices.index(
1301-
unit_id_to_peak_channel_map[unit["unit"]][0]
1278+
unit_peak_channel[unit["unit"]]
13021279
)
13031280
unit_peak_waveform = {
13041281
**unit,
@@ -1316,7 +1293,7 @@ def yield_unit_waveforms():
13161293

13171294
yield unit_peak_waveform, unit_electrode_waveforms
13181295

1319-
else: # read from kilosort outputs
1296+
else: # read from kilosort outputs (ecephys pipeline)
13201297
kilosort_dataset = kilosort.Kilosort(output_dir)
13211298

13221299
acq_software, probe_serial_number = (
@@ -1522,43 +1499,54 @@ def make(self, key):
15221499
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
15231500
sorter_name = clustering_method.replace(".", "_")
15241501

1525-
# find metric_fp
1526-
for metric_fp in [
1527-
output_dir / "metrics.csv",
1528-
output_dir / sorter_name / "metrics" / "metrics.csv",
1529-
]:
1530-
if metric_fp.exists():
1531-
break
1532-
else:
1533-
raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
1502+
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1503+
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
1504+
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
1505+
qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
1506+
template_metrics = sorting_analyzer.get_extension(
1507+
"template_metrics"
1508+
).get_data()
1509+
metrics_df = pd.concat([qc_metrics, template_metrics], axis=1)
15341510

1535-
metrics_df = pd.read_csv(metric_fp)
1536-
1537-
# Conform the dataframe to match the table definition
1538-
if "cluster_id" in metrics_df.columns:
1539-
metrics_df.set_index("cluster_id", inplace=True)
1540-
else:
15411511
metrics_df.rename(
1542-
columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
1512+
columns={
1513+
"amplitude_median": "amplitude",
1514+
"isi_violations_ratio": "isi_violation",
1515+
"isi_violations_count": "number_violation",
1516+
"silhouette": "silhouette_score",
1517+
"rp_contamination": "contamination_rate",
1518+
"drift_ptp": "max_drift",
1519+
"drift_mad": "cumulative_drift",
1520+
"half_width": "halfwidth",
1521+
"peak_trough_ratio": "pt_ratio",
1522+
"peak_to_valley": "duration",
1523+
},
1524+
inplace=True,
15431525
)
1544-
metrics_df.set_index("cluster_id", inplace=True)
1545-
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
1546-
metrics_df.columns = metrics_df.columns.str.lower()
1547-
1548-
metrics_df.rename(
1549-
columns={
1550-
"isi_violations_ratio": "isi_violation",
1551-
"isi_violations_count": "number_violation",
1552-
"silhouette": "silhouette_score",
1553-
"rp_contamination": "contamination_rate",
1554-
"drift_ptp": "max_drift",
1555-
"drift_mad": "cumulative_drift",
1556-
"half_width": "halfwidth",
1557-
"peak_trough_ratio": "pt_ratio",
1558-
},
1559-
inplace=True,
1560-
)
1526+
else: # read from kilosort outputs (ecephys pipeline)
1527+
# find metric_fp
1528+
for metric_fp in [
1529+
output_dir / "metrics.csv",
1530+
]:
1531+
if metric_fp.exists():
1532+
break
1533+
else:
1534+
raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
1535+
1536+
metrics_df = pd.read_csv(metric_fp)
15611537

1538+
# Conform the dataframe to match the table definition
1539+
if "cluster_id" in metrics_df.columns:
1540+
metrics_df.set_index("cluster_id", inplace=True)
1541+
else:
1542+
metrics_df.rename(
1543+
columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
1544+
)
1545+
metrics_df.set_index("cluster_id", inplace=True)
1546+
1547+
metrics_df.columns = metrics_df.columns.str.lower()
1548+
1549+
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
15621550
metrics_list = [
15631551
dict(metrics_df.loc[unit_key["unit"]], **unit_key)
15641552
for unit_key in (CuratedClustering.Unit & key).fetch("KEY")

0 commit comments

Comments
 (0)