Skip to content

Commit c428e47

Browse files
committed
Handle case where pc_features does not exist
1 parent d53f7a9 commit c428e47

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

element_array_ephys/ephys_precluster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,8 @@ def make(self, key):
667667
# -- Spike-sites and Spike-depths --
668668
spike_sites = np.array([channel2electrodes[s]['electrode']
669669
for s in kilosort_dataset.data['spike_sites']])
670-
spike_depths = kilosort_dataset.data['spike_depths']
670+
spike_depths = kilosort_dataset.data['spike_depths'] \
671+
if 'pc_features' in kilosort_dataset.data else None
671672

672673
# -- Insert unit, label, peak-chn
673674
units = []
@@ -685,7 +686,7 @@ def make(self, key):
685686
'spike_times': unit_spike_times,
686687
'spike_count': spike_count,
687688
'spike_sites': spike_sites[kilosort_dataset.data['spike_clusters'] == unit],
688-
'spike_depths': spike_depths[kilosort_dataset.data['spike_clusters'] == unit]})
689+
'spike_depths': spike_depths[kilosort_dataset.data['spike_clusters'] == unit] if spike_depths else None})
689690

690691
self.insert1(key)
691692
self.Unit.insert([{**key, **u} for u in units])

element_array_ephys/readers/kilosort.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,20 @@ def get_best_channel(self, unit):
127127

128128
def extract_spike_depths(self):
129129
""" Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m """
130-
ycoords = self.data['channel_positions'][:, 1]
131-
pc_features = self.data['pc_features'][:, 0, :] # 1st PC only
132-
pc_features = np.where(pc_features < 0, 0, pc_features)
133-
134-
# ---- compute center of mass of these features (spike depths) ----
135-
136-
# which channels for each spike?
137-
spk_feature_ind = self.data['pc_feature_ind'][self.data['spike_templates'], :]
138-
# ycoords of those channels?
139-
spk_feature_ycoord = ycoords[spk_feature_ind]
140-
# center of mass is sum(coords.*features)/sum(features)
141-
self._data['spike_depths'] = (np.sum(spk_feature_ycoord * pc_features**2, axis=1)
142-
/ np.sum(pc_features**2, axis=1))
130+
if 'pc_features' in self.data:
131+
ycoords = self.data['channel_positions'][:, 1]
132+
pc_features = self.data['pc_features'][:, 0, :] # 1st PC only
133+
pc_features = np.where(pc_features < 0, 0, pc_features)
134+
135+
# ---- compute center of mass of these features (spike depths) ----
136+
137+
# which channels for each spike?
138+
spk_feature_ind = self.data['pc_feature_ind'][self.data['spike_templates'], :]
139+
# ycoords of those channels?
140+
spk_feature_ycoord = ycoords[spk_feature_ind]
141+
# center of mass is sum(coords.*features)/sum(features)
142+
self._data['spike_depths'] = (np.sum(spk_feature_ycoord * pc_features**2, axis=1)
143+
/ np.sum(pc_features**2, axis=1))
143144

144145
# ---- extract spike sites ----
145146
max_site_ind = np.argmax(np.abs(self.data['templates']).max(axis=1), axis=1)

0 commit comments

Comments
 (0)