Skip to content

Commit 561df39

Browse files
author
Thinh Nguyen
authored
Merge pull request #115 from iamamutt/main
Remove ibllib dependency
2 parents 85a1f0a + dd6e215 commit 561df39

File tree

7 files changed

+277
-105
lines changed

7 files changed

+277
-105
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,4 @@ docker-compose.y*ml
122122

123123
# vscode settings
124124
.vscode
125+
*.code-workspace

element_array_ephys/ephys_no_curation.py

Lines changed: 91 additions & 96 deletions
Large diffs are not rendered by default.

element_array_ephys/ephys_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def make(self, key):
102102

103103
from .plotting.unit_level import (
104104
plot_waveform,
105-
plot_correlogram,
105+
plot_auto_correlogram,
106106
plot_depth_waveforms,
107107
)
108108

@@ -119,7 +119,7 @@ def make(self, key):
119119
waveform=peak_electrode_waveform, sampling_rate=sampling_rate
120120
)
121121

122-
correlogram_fig = plot_correlogram(
122+
correlogram_fig = plot_auto_correlogram(
123123
spike_times=spike_times, bin_size=0.001, window_size=1
124124
)
125125

element_array_ephys/plotting/corr.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Code adapted from International Brain Laboratory, T. (2021). ibllib [Computer software]. https://github.com/int-brain-lab/ibllib
2+
"""
3+
4+
import numpy as np
5+
6+
7+
def _index_of(arr: np.ndarray, lookup: np.ndarray):
8+
"""Replace scalars in an array by their indices in a lookup table."""
9+
10+
lookup = np.asarray(lookup, dtype=np.int32)
11+
m = (lookup.max() if len(lookup) else 0) + 1
12+
tmp = np.zeros(m + 1, dtype=int)
13+
# Ensure that -1 values are kept.
14+
tmp[-1] = -1
15+
if len(lookup):
16+
tmp[lookup] = np.arange(len(lookup))
17+
return tmp[arr]
18+
19+
20+
def _increment(arr, indices):
21+
"""Increment some indices in a 1D vector of non-negative integers.
22+
Repeated indices are taken into account.
23+
"""
24+
25+
bbins = np.bincount(indices)
26+
arr[: len(bbins)] += bbins
27+
return arr
28+
29+
30+
def _diff_shifted(arr, steps=1):
31+
return arr[steps:] - arr[: len(arr) - steps]
32+
33+
34+
def _create_correlograms_array(n_clusters, winsize_bins):
35+
return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), dtype=np.int32)
36+
37+
38+
def _symmetrize_correlograms(correlograms):
39+
"""Return the symmetrized version of the CCG arrays."""
40+
41+
n_clusters, _, n_bins = correlograms.shape
42+
assert n_clusters == _
43+
44+
# We symmetrize c[i, j, 0].
45+
# This is necessary because the algorithm in correlograms()
46+
# is sensitive to the order of identical spikes.
47+
correlograms[..., 0] = np.maximum(correlograms[..., 0], correlograms[..., 0].T)
48+
49+
sym = correlograms[..., 1:][..., ::-1]
50+
sym = np.transpose(sym, (1, 0, 2))
51+
52+
return np.dstack((sym, correlograms))
53+
54+
55+
def xcorr(
56+
spike_times: np.ndarray,
57+
spike_clusters: np.ndarray,
58+
bin_size: float,
59+
window_size: int,
60+
) -> np.ndarray:
61+
"""Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`.
62+
63+
Args:
64+
spike_times (np.ndarray): Spike times in seconds.
65+
spike_clusters (np.ndarray): Spike-cluster mapping.
66+
bin_size (float): Size of the time bin in seconds.
67+
window_size (int): Size of the correlogram window in seconds.
68+
69+
Returns:
70+
np.ndarray: cross-correlogram array
71+
"""
72+
assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing."
73+
assert spike_times.ndim == 1
74+
assert spike_times.shape == spike_clusters.shape
75+
76+
# Find `binsize`.
77+
bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds
78+
79+
# Find `winsize_bins`.
80+
window_size = np.clip(window_size, 1e-5, 1e5) # in seconds
81+
winsize_bins = 2 * int(0.5 * window_size / bin_size) + 1
82+
83+
# Take the cluster order into account.
84+
clusters = np.unique(spike_clusters)
85+
n_clusters = len(clusters)
86+
87+
# Like spike_clusters, but with 0..n_clusters-1 indices.
88+
spike_clusters_i = _index_of(spike_clusters, clusters)
89+
90+
# Shift between the two copies of the spike trains.
91+
shift = 1
92+
93+
# At a given shift, the mask precises which spikes have matching spikes
94+
# within the correlogram time window.
95+
mask = np.ones_like(spike_times, dtype=bool)
96+
97+
correlograms = _create_correlograms_array(n_clusters, winsize_bins)
98+
99+
# The loop continues as long as there is at least one spike with
100+
# a matching spike.
101+
while mask[:-shift].any():
102+
# Interval between spike i and spike i+shift.
103+
spike_diff = _diff_shifted(spike_times, shift)
104+
105+
# Binarize the delays between spike i and spike i+shift.
106+
spike_diff_b = np.round(spike_diff / bin_size).astype(np.int64)
107+
108+
# Spikes with no matching spikes are masked.
109+
mask[:-shift][spike_diff_b > (winsize_bins / 2)] = False
110+
111+
# Cache the masked spike delays.
112+
m = mask[:-shift].copy()
113+
d = spike_diff_b[m]
114+
115+
# Find the indices in the raveled correlograms array that need
116+
# to be incremented, taking into account the spike clusters.
117+
indices = np.ravel_multi_index(
118+
(spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
119+
correlograms.shape,
120+
)
121+
122+
# Increment the matching spikes in the correlograms array.
123+
_increment(correlograms.ravel(), indices)
124+
125+
shift += 1
126+
127+
return _symmetrize_correlograms(correlograms)
128+
129+
130+
def acorr(spike_times: np.ndarray, bin_size: float, window_size: int) -> np.ndarray:
131+
"""Compute the auto-correlogram of a unit.
132+
133+
Args:
134+
spike_times (np.ndarray): Spike times in seconds.
135+
bin_size (float, optional): Size of the time bin in seconds.
136+
window_size (int, optional): Size of the correlogram window in seconds.
137+
138+
Returns:
139+
np.ndarray: auto-correlogram array (winsize_samples,)
140+
"""
141+
xc = xcorr(
142+
spike_times,
143+
np.zeros_like(spike_times, dtype=np.int32),
144+
bin_size=bin_size,
145+
window_size=window_size,
146+
)
147+
return xc[0, 0, :]

element_array_ephys/plotting/probe_level.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import numpy as np
21
import matplotlib
32
import matplotlib.pyplot as plt
3+
import numpy as np
44
import seaborn as sns
55

66

element_array_ephys/plotting/unit_level.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
from .. import probe
21
from modulefinder import Module
2+
33
import numpy as np
44
import pandas as pd
55
import plotly.graph_objs as go
6+
from typing import Any
7+
from .. import probe
68

79

810
def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
11+
"""Plot unit waveform.
12+
13+
Args:
14+
waveform (np.ndarray): Amplitude of a spike waveform in μV.
15+
sampling_rate (float): Sampling rate in kHz.
916
17+
Returns:
18+
go.Figure: Plotly figure object.
19+
"""
1020
waveform_df = pd.DataFrame(data={"waveform": waveform})
1121
waveform_df["timestamp"] = waveform_df.index / sampling_rate
1222

@@ -31,11 +41,20 @@ def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
3141
return fig
3242

3343

34-
def plot_correlogram(
44+
def plot_auto_correlogram(
3545
spike_times: np.ndarray, bin_size: float = 0.001, window_size: int = 1
3646
) -> go.Figure:
47+
"""Plot the auto-correlogram of a unit.
3748
38-
from brainbox.singlecell import acorr
49+
Args:
50+
spike_times (np.ndarray): Spike timestamps in seconds
51+
bin_size (float, optional): Size of the time bin (lag) in seconds. Defaults to 0.001.
52+
window_size (int, optional): Size of the correlogram window in seconds. Defaults to 1.
53+
54+
Returns:
55+
go.Figure: Plotly figure object.
56+
"""
57+
from .corr import acorr
3958

4059
correlogram = acorr(
4160
spike_times=spike_times, bin_size=bin_size, window_size=window_size
@@ -74,9 +93,19 @@ def plot_correlogram(
7493

7594
def plot_depth_waveforms(
7695
ephys: Module,
77-
unit_key: dict,
96+
unit_key: dict[str, Any],
7897
y_range: float = 60,
7998
) -> go.Figure:
99+
"""Plot waveforms
100+
101+
Args:
102+
ephys (Module): Imported ephys module.
103+
unit_key (dict[str, Any]): Key dictionary from ephys.CuratedClustering.Unit table.
104+
y_range (float, optional): Vertical range to show waveforms relative to the peak waveform in μm. Defaults to 60.
105+
106+
Returns:
107+
go.Figure: Plotly figure object.
108+
"""
80109

81110
sampling_rate = (ephys.EphysRecording & unit_key).fetch1(
82111
"sampling_rate"

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ pyopenephys @ git+https://github.com/datajoint-company/pyopenephys.git
33
openpyxl
44
pynwb>=2.0.0
55
element-interface @ git+https://github.com/datajoint/element-interface.git
6-
ibllib @ git+https://github.com/int-brain-lab/ibllib.git
7-
plotly==5.9.0
6+
plotly
7+
seaborn

0 commit comments

Comments
 (0)