Skip to content

Commit 3e9b675

Browse files
author
Thinh Nguyen
authored
Merge pull request #103 from JaerongA/no_curation_plot
add ephys_report schema for data visualizations
2 parents db75e4d + d0c6797 commit 3e9b675

File tree

13 files changed

+2963
-1366
lines changed

13 files changed

+2963
-1366
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,7 @@ dj_local_conf_old.json
118118
docker-compose.y*ml
119119

120120
# include
121-
!docs/docker-compose.yaml
121+
!docs/docker-compose.yaml
122+
123+
# vscode settings
124+
.vscode

element_array_ephys/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import os
44

55

6-
dj.config['enable_python_native_blobs'] = True
6+
dj.config["enable_python_native_blobs"] = True
77

88

99
def get_logger(name):
1010
log = logging.getLogger(name)
11-
log.setLevel(os.getenv('LOGLEVEL', 'INFO'))
11+
log.setLevel(os.getenv("LOGLEVEL", "INFO"))
1212
return log
1313

14+
1415
from . import ephys_acute as ephys

element_array_ephys/ephys_acute.py

Lines changed: 626 additions & 365 deletions
Large diffs are not rendered by default.

element_array_ephys/ephys_chronic.py

Lines changed: 576 additions & 333 deletions
Large diffs are not rendered by default.

element_array_ephys/ephys_no_curation.py

Lines changed: 591 additions & 339 deletions
Large diffs are not rendered by default.

element_array_ephys/ephys_precluster.py

Lines changed: 499 additions & 290 deletions
Large diffs are not rendered by default.

element_array_ephys/ephys_report.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import pathlib
2+
import datetime
3+
import datajoint as dj
4+
import typing as T
5+
import json
6+
7+
schema = dj.schema()
8+
9+
ephys = None
10+
11+
12+
def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True):
13+
"""
14+
activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None)
15+
:param schema_name: schema name on the database server to activate the `ephys_report` schema
16+
:param ephys_schema_name: schema name of the activated ephys element for which this ephys_report schema will be downstream from
17+
:param create_schema: when True (default), create schema in the database if it does not yet exist.
18+
:param create_tables: when True (default), create tables in the database if they do not yet exist.
19+
(The "activation" of this ephys_report module should be evoked by one of the ephys modules only)
20+
"""
21+
global ephys
22+
ephys = dj.create_virtual_module("ephys", ephys_schema_name)
23+
schema.activate(
24+
schema_name,
25+
create_schema=create_schema,
26+
create_tables=create_tables,
27+
add_objects=ephys.__dict__,
28+
)
29+
30+
31+
@schema
32+
class ProbeLevelReport(dj.Computed):
33+
definition = """
34+
-> ephys.CuratedClustering
35+
shank : tinyint unsigned
36+
---
37+
drift_map_plot: attach
38+
"""
39+
40+
def make(self, key):
41+
42+
from . import probe
43+
from .plotting.probe_level import plot_driftmap
44+
45+
save_dir = _make_save_dir()
46+
47+
units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'"
48+
49+
shanks = set((probe.ProbeType.Electrode & units).fetch("shank"))
50+
51+
for shank_no in shanks:
52+
53+
table = (
54+
units
55+
* ephys.ProbeInsertion.proj()
56+
* probe.ProbeType.Electrode.proj("shank")
57+
& {"shank": shank_no}
58+
)
59+
60+
spike_times, spike_depths = table.fetch(
61+
"spike_times", "spike_depths", order_by="unit"
62+
)
63+
64+
# Get the figure
65+
fig = plot_driftmap(spike_times, spike_depths, colormap="gist_heat_r")
66+
fig_prefix = (
67+
"-".join(
68+
[
69+
v.strftime("%Y%m%d%H%M%S")
70+
if isinstance(v, datetime.datetime)
71+
else str(v)
72+
for v in key.values()
73+
]
74+
)
75+
+ f"-{shank_no}"
76+
)
77+
78+
# Save fig and insert
79+
fig_dict = _save_figs(
80+
figs=(fig,),
81+
fig_names=("drift_map_plot",),
82+
save_dir=save_dir,
83+
fig_prefix=fig_prefix,
84+
extension=".png",
85+
)
86+
87+
self.insert1({**key, **fig_dict, "shank": shank_no})
88+
89+
90+
@schema
91+
class UnitLevelReport(dj.Computed):
92+
definition = """
93+
-> ephys.CuratedClustering.Unit
94+
---
95+
cluster_quality_label : varchar(100)
96+
waveform_plotly : longblob
97+
autocorrelogram_plotly : longblob
98+
depth_waveform_plotly : longblob
99+
"""
100+
101+
def make(self, key):
102+
103+
from .plotting.unit_level import (
104+
plot_waveform,
105+
plot_correlogram,
106+
plot_depth_waveforms,
107+
)
108+
109+
sampling_rate = (ephys.EphysRecording & key).fetch1(
110+
"sampling_rate"
111+
) / 1e3 # in kHz
112+
113+
peak_electrode_waveform, spike_times, cluster_quality_label = (
114+
(ephys.CuratedClustering.Unit & key) * ephys.WaveformSet.PeakWaveform
115+
).fetch1("peak_electrode_waveform", "spike_times", "cluster_quality_label")
116+
117+
# Get the figure
118+
waveform_fig = plot_waveform(
119+
waveform=peak_electrode_waveform, sampling_rate=sampling_rate
120+
)
121+
122+
correlogram_fig = plot_correlogram(
123+
spike_times=spike_times, bin_size=0.001, window_size=1
124+
)
125+
126+
depth_waveform_fig = plot_depth_waveforms(ephys, unit_key=key, y_range=60)
127+
128+
self.insert1(
129+
{
130+
**key,
131+
"cluster_quality_label": cluster_quality_label,
132+
"waveform_plotly": waveform_fig.to_plotly_json(),
133+
"autocorrelogram_plotly": correlogram_fig.to_plotly_json(),
134+
"depth_waveform_plotly": depth_waveform_fig.to_plotly_json(),
135+
}
136+
)
137+
138+
139+
def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path:
140+
if root_dir is None:
141+
root_dir = pathlib.Path().absolute()
142+
save_dir = root_dir / "ephys_figures"
143+
save_dir.mkdir(parents=True, exist_ok=True)
144+
return save_dir
145+
146+
147+
def _save_figs(
148+
figs, fig_names, save_dir, fig_prefix, extension=".png"
149+
) -> T.Dict[str, pathlib.Path]:
150+
fig_dict = {}
151+
for fig, fig_name in zip(figs, fig_names):
152+
fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension)
153+
fig.tight_layout()
154+
fig.savefig(fig_filepath)
155+
fig_dict[fig_name] = fig_filepath.as_posix()
156+
157+
return fig_dict

element_array_ephys/plotting/__init__.py

Whitespace-only changes.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import numpy as np
2+
import matplotlib
3+
import matplotlib.pyplot as plt
4+
import seaborn as sns
5+
6+
7+
def plot_raster(units, spike_times) -> matplotlib.figure.Figure:
8+
9+
units = np.arange(1, len(units) + 1)
10+
x = np.hstack(spike_times)
11+
y = np.hstack([np.full_like(s, u) for u, s in zip(units, spike_times)])
12+
fig, ax = plt.subplots(1, 1, figsize=(32, 8), dpi=100)
13+
ax.plot(x, y, "|")
14+
ax.set(
15+
xlabel="Time (s)",
16+
ylabel="Unit",
17+
xlim=[0 - 0.5, x[-1] + 0.5],
18+
ylim=(1, len(units)),
19+
)
20+
sns.despine()
21+
fig.tight_layout()
22+
23+
return fig
24+
25+
26+
def plot_driftmap(
27+
spike_times: np.ndarray, spike_depths: np.ndarray, colormap="gist_heat_r"
28+
) -> matplotlib.figure.Figure:
29+
30+
spike_times = np.hstack(spike_times)
31+
spike_depths = np.hstack(spike_depths)
32+
33+
# Time-depth 2D histogram
34+
time_bin_count = 1000
35+
depth_bin_count = 200
36+
37+
spike_bins = np.linspace(0, spike_times.max(), time_bin_count)
38+
depth_bins = np.linspace(0, np.nanmax(spike_depths), depth_bin_count)
39+
40+
spk_count, spk_edges, depth_edges = np.histogram2d(
41+
spike_times, spike_depths, bins=[spike_bins, depth_bins]
42+
)
43+
spk_rates = spk_count / np.mean(np.diff(spike_bins))
44+
spk_edges = spk_edges[:-1]
45+
depth_edges = depth_edges[:-1]
46+
47+
# Canvas setup
48+
fig = plt.figure(figsize=(12, 5), dpi=200)
49+
grid = plt.GridSpec(15, 12)
50+
51+
ax_cbar = plt.subplot(grid[0, 0:10])
52+
ax_driftmap = plt.subplot(grid[2:, 0:10])
53+
ax_spkcount = plt.subplot(grid[2:, 10:])
54+
55+
# Plot main
56+
im = ax_driftmap.imshow(
57+
spk_rates.T,
58+
aspect="auto",
59+
cmap=colormap,
60+
extent=[spike_bins[0], spike_bins[-1], depth_bins[-1], depth_bins[0]],
61+
)
62+
# Cosmetic
63+
ax_driftmap.invert_yaxis()
64+
ax_driftmap.set(
65+
xlabel="Time (s)",
66+
ylabel="Distance from the probe tip ($\mu$m)",
67+
ylim=[depth_edges[0], depth_edges[-1]],
68+
)
69+
70+
# Colorbar for firing rates
71+
cb = fig.colorbar(im, cax=ax_cbar, orientation="horizontal")
72+
cb.outline.set_visible(False)
73+
cb.ax.xaxis.tick_top()
74+
cb.set_label("Firing rate (Hz)")
75+
cb.ax.xaxis.set_label_position("top")
76+
77+
# Plot spike count
78+
ax_spkcount.plot(spk_count.sum(axis=0) / 10e3, depth_edges, "k")
79+
ax_spkcount.set_xlabel("Spike count (x$10^3$)")
80+
ax_spkcount.set_yticks([])
81+
ax_spkcount.set_ylim(depth_edges[0], depth_edges[-1])
82+
sns.despine()
83+
84+
return fig

0 commit comments

Comments
 (0)