Skip to content

Commit c613164

Browse files
author
Joseph Burling
committed
refactor(deps): ➖ remove ibllib deps and add acorr func
1 parent 2e63edc commit c613164

File tree

5 files changed

+172
-4
lines changed

5 files changed

+172
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,4 @@ docker-compose.y*ml
122122

123123
# vscode settings
124124
.vscode
125+
*.code-workspace

element_array_ephys/plotting/acorr.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""autocorrelation code adapted from ibllib
2+
https://github.com/int-brain-lab/ibllib
3+
"""
4+
5+
import numpy as np
6+
7+
8+
def _index_of(arr, lookup):
9+
"""Replace scalars in an array by their indices in a lookup table.
10+
11+
Implicitly assume that:
12+
13+
* All elements of arr and lookup are non-negative integers.
14+
* All elements or arr belong to lookup.
15+
16+
This is not checked for performance reasons.
17+
18+
"""
19+
# Equivalent of np.digitize(arr, lookup) - 1, but much faster.
20+
# TODO: assertions to disable in production for performance reasons.
21+
# TODO: np.searchsorted(lookup, arr) is faster on small arrays with large
22+
# values
23+
lookup = np.asarray(lookup, dtype=np.int32)
24+
m = (lookup.max() if len(lookup) else 0) + 1
25+
tmp = np.zeros(m + 1, dtype=int)
26+
# Ensure that -1 values are kept.
27+
tmp[-1] = -1
28+
if len(lookup):
29+
tmp[lookup] = np.arange(len(lookup))
30+
return tmp[arr]
31+
32+
33+
def _increment(arr, indices):
34+
"""Increment some indices in a 1D vector of non-negative integers.
35+
Repeated indices are taken into account."""
36+
bbins = np.bincount(indices)
37+
arr[: len(bbins)] += bbins
38+
return arr
39+
40+
41+
def _diff_shifted(arr, steps=1):
42+
return arr[steps:] - arr[: len(arr) - steps]
43+
44+
45+
def _create_correlograms_array(n_clusters, winsize_bins):
46+
return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), dtype=np.int32)
47+
48+
49+
def _symmetrize_correlograms(correlograms):
50+
"""Return the symmetrized version of the CCG arrays."""
51+
52+
n_clusters, _, n_bins = correlograms.shape
53+
assert n_clusters == _
54+
55+
# We symmetrize c[i, j, 0].
56+
# This is necessary because the algorithm in correlograms()
57+
# is sensitive to the order of identical spikes.
58+
correlograms[..., 0] = np.maximum(correlograms[..., 0], correlograms[..., 0].T)
59+
60+
sym = correlograms[..., 1:][..., ::-1]
61+
sym = np.transpose(sym, (1, 0, 2))
62+
63+
return np.dstack((sym, correlograms))
64+
65+
66+
def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):
67+
"""Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`.
68+
69+
Parameters
70+
----------
71+
72+
:param spike_times: Spike times in seconds.
73+
:type spike_times: array-like
74+
:param spike_clusters: Spike-cluster mapping.
75+
:type spike_clusters: array-like
76+
:param bin_size: Size of the bin, in seconds.
77+
:type bin_size: float
78+
:param window_size: Size of the window, in seconds.
79+
:type window_size: float
80+
81+
Returns an `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
82+
cross-correlograms.
83+
84+
"""
85+
assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing."
86+
assert spike_times.ndim == 1
87+
assert spike_times.shape == spike_clusters.shape
88+
89+
# Find `binsize`.
90+
bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds
91+
92+
# Find `winsize_bins`.
93+
window_size = np.clip(window_size, 1e-5, 1e5) # in seconds
94+
winsize_bins = 2 * int(0.5 * window_size / bin_size) + 1
95+
96+
# Take the cluster order into account.
97+
clusters = np.unique(spike_clusters)
98+
n_clusters = len(clusters)
99+
100+
# Like spike_clusters, but with 0..n_clusters-1 indices.
101+
spike_clusters_i = _index_of(spike_clusters, clusters)
102+
103+
# Shift between the two copies of the spike trains.
104+
shift = 1
105+
106+
# At a given shift, the mask precises which spikes have matching spikes
107+
# within the correlogram time window.
108+
mask = np.ones_like(spike_times, dtype=bool)
109+
110+
correlograms = _create_correlograms_array(n_clusters, winsize_bins)
111+
112+
# The loop continues as long as there is at least one spike with
113+
# a matching spike.
114+
while mask[:-shift].any():
115+
# Interval between spike i and spike i+shift.
116+
spike_diff = _diff_shifted(spike_times, shift)
117+
118+
# Binarize the delays between spike i and spike i+shift.
119+
spike_diff_b = np.round(spike_diff / bin_size).astype(np.int64)
120+
121+
# Spikes with no matching spikes are masked.
122+
mask[:-shift][spike_diff_b > (winsize_bins / 2)] = False
123+
124+
# Cache the masked spike delays.
125+
m = mask[:-shift].copy()
126+
d = spike_diff_b[m]
127+
128+
# Find the indices in the raveled correlograms array that need
129+
# to be incremented, taking into account the spike clusters.
130+
indices = np.ravel_multi_index(
131+
(spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
132+
correlograms.shape,
133+
)
134+
135+
# Increment the matching spikes in the correlograms array.
136+
_increment(correlograms.ravel(), indices)
137+
138+
shift += 1
139+
140+
return _symmetrize_correlograms(correlograms)
141+
142+
143+
def acorr(spike_times, bin_size=None, window_size=None):
144+
"""Compute the auto-correlogram of a neuron.
145+
146+
Parameters
147+
----------
148+
149+
:param spike_times: Spike times in seconds.
150+
:type spike_times: array-like
151+
:param bin_size: Size of the bin, in seconds.
152+
:type bin_size: float
153+
:param window_size: Size of the window, in seconds.
154+
:type window_size: float
155+
156+
Returns an `(winsize_samples,)` array with the auto-correlogram.
157+
158+
"""
159+
xc = xcorr(
160+
spike_times,
161+
np.zeros_like(spike_times, dtype=np.int32),
162+
bin_size=bin_size,
163+
window_size=window_size,
164+
)
165+
return xc[0, 0, :]

element_array_ephys/plotting/probe_level.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import numpy as np
21
import matplotlib
32
import matplotlib.pyplot as plt
3+
import numpy as np
44
import seaborn as sns
55

66

element_array_ephys/plotting/unit_level.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from .. import probe
21
from modulefinder import Module
2+
33
import numpy as np
44
import pandas as pd
55
import plotly.graph_objs as go
66

7+
from .. import probe
8+
79

810
def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
911

@@ -35,7 +37,7 @@ def plot_correlogram(
3537
spike_times: np.ndarray, bin_size: float = 0.001, window_size: int = 1
3638
) -> go.Figure:
3739

38-
from brainbox.singlecell import acorr
40+
from .acorr import acorr
3941

4042
correlogram = acorr(
4143
spike_times=spike_times, bin_size=bin_size, window_size=window_size

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ pyopenephys @ git+https://github.com/datajoint-company/pyopenephys.git
33
openpyxl
44
pynwb>=2.0.0
55
element-interface @ git+https://github.com/datajoint/element-interface.git
6-
ibllib @ git+https://github.com/int-brain-lab/ibllib.git
76
plotly==5.9.0
7+
seaborn

0 commit comments

Comments
 (0)