Skip to content

Commit d9c75c8

Browse files
committed
docs: 📝 name change & add docstring
1 parent 8a65635 commit d9c75c8

File tree

3 files changed

+60
-51
lines changed

3 files changed

+60
-51
lines changed

element_array_ephys/ephys_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def make(self, key):
102102

103103
from .plotting.unit_level import (
104104
plot_waveform,
105-
plot_correlogram,
105+
plot_auto_correlogram,
106106
plot_depth_waveforms,
107107
)
108108

@@ -119,7 +119,7 @@ def make(self, key):
119119
waveform=peak_electrode_waveform, sampling_rate=sampling_rate
120120
)
121121

122-
correlogram_fig = plot_correlogram(
122+
correlogram_fig = plot_auto_correlogram(
123123
spike_times=spike_times, bin_size=0.001, window_size=1
124124
)
125125

element_array_ephys/plotting/acorr.py renamed to element_array_ephys/plotting/corr.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,12 @@
1-
"""autocorrelation code adapted from ibllib
2-
https://github.com/int-brain-lab/ibllib
1+
"""Code adapted from International Brain Laboratory, T. (2021). ibllib [Computer software]. https://github.com/int-brain-lab/ibllib
32
"""
43

54
import numpy as np
65

76

8-
def _index_of(arr, lookup):
9-
"""Replace scalars in an array by their indices in a lookup table.
7+
def _index_of(arr: np.ndarray, lookup: np.ndarray):
8+
"""Replace scalars in an array by their indices in a lookup table."""
109

11-
Implicitly assume that:
12-
13-
* All elements of arr and lookup are non-negative integers.
14-
* All elements or arr belong to lookup.
15-
16-
This is not checked for performance reasons.
17-
18-
"""
19-
# Equivalent of np.digitize(arr, lookup) - 1, but much faster.
20-
# TODO: assertions to disable in production for performance reasons.
21-
# TODO: np.searchsorted(lookup, arr) is faster on small arrays with large
22-
# values
2310
lookup = np.asarray(lookup, dtype=np.int32)
2411
m = (lookup.max() if len(lookup) else 0) + 1
2512
tmp = np.zeros(m + 1, dtype=int)
@@ -32,7 +19,9 @@ def _index_of(arr, lookup):
3219

3320
def _increment(arr, indices):
3421
"""Increment some indices in a 1D vector of non-negative integers.
35-
Repeated indices are taken into account."""
22+
Repeated indices are taken into account.
23+
"""
24+
3625
bbins = np.bincount(indices)
3726
arr[: len(bbins)] += bbins
3827
return arr
@@ -63,24 +52,22 @@ def _symmetrize_correlograms(correlograms):
6352
return np.dstack((sym, correlograms))
6453

6554

66-
def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):
55+
def xcorr(
56+
spike_times: np.ndarray,
57+
spike_clusters: np.ndarray,
58+
bin_size: float,
59+
window_size: int,
60+
) -> np.ndarray:
6761
"""Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`.
6862
69-
Parameters
70-
----------
71-
72-
:param spike_times: Spike times in seconds.
73-
:type spike_times: array-like
74-
:param spike_clusters: Spike-cluster mapping.
75-
:type spike_clusters: array-like
76-
:param bin_size: Size of the bin, in seconds.
77-
:type bin_size: float
78-
:param window_size: Size of the window, in seconds.
79-
:type window_size: float
80-
81-
Returns an `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
82-
cross-correlograms.
63+
Args:
64+
spike_times (np.ndarray): Spike times in seconds.
65+
spike_clusters (np.ndarray): Spike-cluster mapping.
66+
bin_size (float): Size of the time bin in seconds.
67+
window_size (int): Size of the correlogram window in seconds.
8368
69+
Returns:
70+
np.ndarray: cross-correlogram array
8471
"""
8572
assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing."
8673
assert spike_times.ndim == 1
@@ -140,21 +127,16 @@ def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):
140127
return _symmetrize_correlograms(correlograms)
141128

142129

143-
def acorr(spike_times, bin_size=None, window_size=None):
144-
"""Compute the auto-correlogram of a neuron.
145-
146-
Parameters
147-
----------
148-
149-
:param spike_times: Spike times in seconds.
150-
:type spike_times: array-like
151-
:param bin_size: Size of the bin, in seconds.
152-
:type bin_size: float
153-
:param window_size: Size of the window, in seconds.
154-
:type window_size: float
130+
def acorr(spike_times: np.ndarray, bin_size: float, window_size: int) -> np.ndarray:
131+
"""Compute the auto-correlogram of a unit.
155132
156-
Returns an `(winsize_samples,)` array with the auto-correlogram.
133+
Args:
134+
spike_times (np.ndarray): Spike times in seconds.
135+
bin_size (float, optional): Size of the time bin in seconds.
136+
window_size (int, optional): Size of the correlogram window in seconds.
157137
138+
Returns:
139+
np.ndarray: auto-correlogram array (winsize_samples,)
158140
"""
159141
xc = xcorr(
160142
spike_times,

element_array_ephys/plotting/unit_level.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
import numpy as np
44
import pandas as pd
55
import plotly.graph_objs as go
6-
6+
from typing import Any
77
from .. import probe
88

99

1010
def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
11+
"""Plot unit waveform.
12+
13+
Args:
14+
waveform (np.ndarray): Amplitude of a spike waveform in μV.
15+
sampling_rate (float): Sampling rate in kHz.
1116
17+
Returns:
18+
go.Figure: Plotly figure object.
19+
"""
1220
waveform_df = pd.DataFrame(data={"waveform": waveform})
1321
waveform_df["timestamp"] = waveform_df.index / sampling_rate
1422

@@ -33,11 +41,20 @@ def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
3341
return fig
3442

3543

36-
def plot_correlogram(
44+
def plot_auto_correlogram(
3745
spike_times: np.ndarray, bin_size: float = 0.001, window_size: int = 1
3846
) -> go.Figure:
47+
"""Plot the auto-correlogram of a unit.
48+
49+
Args:
50+
spike_times (np.ndarray): Spike timestamps in seconds
51+
bin_size (float, optional): Size of the time bin (lag) in seconds. Defaults to 0.001.
52+
window_size (int, optional): Size of the correlogram window in seconds. Defaults to 1.
3953
40-
from .acorr import acorr
54+
Returns:
55+
go.Figure: Plotly figure object.
56+
"""
57+
from .corr import acorr
4158

4259
correlogram = acorr(
4360
spike_times=spike_times, bin_size=bin_size, window_size=window_size
@@ -76,9 +93,19 @@ def plot_correlogram(
7693

7794
def plot_depth_waveforms(
7895
ephys: Module,
79-
unit_key: dict,
96+
unit_key: dict[str, Any],
8097
y_range: float = 60,
8198
) -> go.Figure:
99+
"""Plot waveforms
100+
101+
Args:
102+
ephys (Module): Imported ephys module.
103+
unit_key (dict[str, Any]): Key dictionary from ephys.CuratedClustering.Unit table.
104+
y_range (float, optional): Vertical range to show waveforms relative to the peak waveform in μm. Defaults to 60.
105+
106+
Returns:
107+
go.Figure: Plotly figure object.
108+
"""
82109

83110
sampling_rate = (ephys.EphysRecording & unit_key).fetch1(
84111
"sampling_rate"

0 commit comments

Comments
 (0)