Skip to content

Commit c873acf

Browse files
committed
WIP: QC dashboard
1 parent 678cd95 commit c873acf

File tree

5 files changed

+346
-6
lines changed

5 files changed

+346
-6
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def activate(
3434
ephys_schema_name (str): A string containing the name of the ephys schema.
3535
probe_schema_name (str): A string containing the name of the probe scehma.
3636
create_schema (bool): If True, schema will be created in the database.
37-
create_tables (bool): If True, tables related to the schema will be created in the database.
3837
linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
3938
4039
Dependencies:

element_array_ephys/ephys_report.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import datetime
33
import datajoint as dj
44
import typing as T
5-
import json
5+
6+
from element_interface.utils import insert1_skip_full_duplicates
67

78
schema = dj.schema()
89

@@ -136,6 +137,89 @@ def make(self, key):
136137
)
137138

138139

140+
@schema
141+
class QualityMetricCutoffs(dj.Lookup):
142+
definition = """
143+
cutoffs_id : smallint
144+
---
145+
amplitude_cutoff_maximum=null : float # Defualt null, no cutoff applied
146+
presence_ratio_minimum=null : float # Defualt null, no cutoff applied
147+
isi_violations_maximum=null : float # Defualt null, no cutoff applied
148+
"""
149+
150+
contents = [(0, None, None, None), (1, 0.1, 0.9, 0.5)]
151+
152+
def insert_new_cutoffs(
153+
cls,
154+
cutoffs_id: int,
155+
amplitude_cutoff_maximum: float,
156+
presence_ratio_minimum: float,
157+
isi_violations_maximum: float,
158+
):
159+
insert1_skip_full_duplicates( # depends on element-interface/pull/43
160+
cls,
161+
dict(
162+
cutoffs_id=cutoffs_id,
163+
amplitude_cutoff_maximum=amplitude_cutoff_maximum,
164+
presence_ratio_minimum=presence_ratio_minimum,
165+
isi_violations_maximum=isi_violations_maximum,
166+
),
167+
)
168+
169+
170+
@schema
171+
class QualityMetricSet(dj.Manual):
172+
definition = """
173+
-> ephys.QualityMetrics
174+
-> QualityMetricCutoffs
175+
"""
176+
177+
178+
@schema
179+
class QualityMetricReport(dj.Computed):
180+
definition = """
181+
-> QualityMetricSet
182+
"""
183+
184+
class Cluster(dj.Part):
185+
definition = """
186+
-> master
187+
---
188+
firing_rate_plot : longblob
189+
presence_ratio_plot : longblob
190+
amp_cutoff_plot : longblob
191+
isi_violation_plot : longblob
192+
snr_plot : longblob
193+
iso_dist_plot : longblob
194+
d_prime_plot : longblob
195+
nn_hit_plot : longblob
196+
"""
197+
198+
def make(self, key):
199+
from .plotting.qc import QualityMetricFigs
200+
201+
cutoffs = (QualityMetricCutoffs & key).fetch1()
202+
qc_key = ephys.QualityMetrics & key
203+
qc_figs = QualityMetricFigs(qc_key, **cutoffs)
204+
205+
# ADD SAVEFIG?
206+
207+
self.insert1(key)
208+
self.Cluster.insert1(
209+
dict(
210+
**key,
211+
firing_rate_plot=qc_figs.firing_rate_plot().to_plotly_json(),
212+
presence_ratio_plot=qc_figs.presence_ratio_plot().to_plotly_json(),
213+
amp_cutoff_plot=qc_figs.amp_cutoff_plot().to_plotly_json(),
214+
isi_violation_plot=qc_figs.isi_violation_plot().to_plotly_json(),
215+
snr_plot=qc_figs.snr_plot().to_plotly_json(),
216+
iso_dist_plot=qc_figs.iso_dist_plot().to_plotly_json(),
217+
d_prime_plot=qc_figs.d_prime_plot().to_plotly_json(),
218+
nn_hit_plot=qc_figs.nn_hit_plot().to_plotly_json(),
219+
)
220+
)
221+
222+
139223
def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path:
140224
if root_dir is None:
141225
root_dir = pathlib.Path().absolute()

element_array_ephys/plotting/qc.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import numpy as np
2+
import pandas as pd
3+
import plotly.graph_objs as go
4+
from scipy.ndimage import gaussian_filter1d
5+
from .. import ephys_no_curation as ephys
6+
7+
8+
class QualityMetricFigs(object):
9+
def __init__(self, key=None, **kwargs) -> None:
10+
self._key = key
11+
self._amplitude_cutoff_max = kwargs.get("amplitude_cutoff_maximum", None)
12+
self._presence_ratio_min = kwargs.get("presence_ratio_minimum", None)
13+
self._isi_violations_max = kwargs.get("isi_violations_maximum", None)
14+
self._units = pd.DataFrame()
15+
16+
@property
17+
def units(self):
18+
assert self._key, "Must use key when retrieving units for QC figures"
19+
if self._units.empty:
20+
restrictions = ["TRUE"]
21+
if self._amplitude_cutoff_max:
22+
restrictions.append(f"amplitude_cutoff < {self._amplitude_cutoff_max}")
23+
if self._presence_ratio_min:
24+
restrictions.append(f"presence_ratio > {self._presence_ratio_min}")
25+
if self._isi_violations_max:
26+
restrictions.append(f"isi_violation < {self._isi_violations_max}")
27+
" AND ".join(restrictions)
28+
return (
29+
ephys.QualityMetrics
30+
* ephys.QualityMetrics.Cluster
31+
* ephys.QualityMetrics.Waveform
32+
& self._key
33+
& restrictions
34+
).fetch(format="frame")
35+
return self._units
36+
37+
def _plot_metric(
38+
self,
39+
data,
40+
bins,
41+
x_axis_label,
42+
scale=1,
43+
vline=None,
44+
):
45+
fig = go.Figure()
46+
fig.update_layout(
47+
xaxis_title=x_axis_label,
48+
template="plotly_dark", # "simple_white",
49+
width=350 * scale,
50+
height=350 * scale,
51+
margin=dict(l=20 * scale, r=20 * scale, t=20 * scale, b=20 * scale),
52+
xaxis=dict(showgrid=False, zeroline=False, linewidth=2, ticks="outside"),
53+
yaxis=dict(showgrid=False, linewidth=0, zeroline=True, visible=False),
54+
)
55+
if data.isnull().all():
56+
return fig.add_annotation(text="No data available", showarrow=False)
57+
58+
histogram, histogram_bins = np.histogram(data, bins=bins, density=True)
59+
60+
fig.add_trace(
61+
go.Scatter(
62+
x=histogram_bins[:-1],
63+
y=gaussian_filter1d(histogram, 1), # TODO: remove smoothing
64+
mode="lines",
65+
line=dict(color="rgb(0, 160, 223)", width=2 * scale), # DataJoint Blue
66+
hovertemplate="%{x:.2f}<br>%{y:.2f}<extra></extra>",
67+
)
68+
)
69+
70+
if vline:
71+
fig.add_vline(x=vline, line_width=2 * scale, line_dash="dash")
72+
73+
return fig
74+
75+
def empty_fig(self): # TODO: Remove before submission?
76+
return self._plot_metric(
77+
pd.Series(["nan"]), np.linspace(0, 0, 0), "This fig left blank"
78+
)
79+
80+
def firing_rate_plot(self):
81+
return self._plot_metric(
82+
np.log10(self.units["firing_rate"]),
83+
np.linspace(-3, 2, 100), # If linear, use np.linspace(0, 50, 100)
84+
"log<sub>10</sub> firing rate (Hz)",
85+
)
86+
87+
def presence_ratio_plot(self):
88+
return self._plot_metric(
89+
self.units["presence_ratio"],
90+
np.linspace(0, 1, 100),
91+
"Presence ratio",
92+
vline=0.9,
93+
)
94+
95+
def amp_cutoff_plot(self):
96+
return self._plot_metric(
97+
self.units["amplitude_cutoff"],
98+
np.linspace(0, 0.5, 200),
99+
"Amplitude cutoff",
100+
vline=0.1,
101+
)
102+
103+
def isi_violation_plot(self):
104+
return self._plot_metric(
105+
np.log10(self.units["isi_violation"] + 1e-5), # Offset b/c log(0)
106+
np.linspace(-6, 2.5, 100), # If linear np.linspace(0, 10, 200)
107+
"log<sub>10</sub> ISI violations",
108+
vline=np.log10(0.5),
109+
)
110+
111+
def snr_plot(self):
112+
return self._plot_metric(self.units["snr"], np.linspace(0, 10, 100), "SNR")
113+
114+
def iso_dist_plot(self):
115+
return self._plot_metric(
116+
self.units["isolation_distance"],
117+
np.linspace(0, 170, 50),
118+
"Isolation distance",
119+
)
120+
121+
def d_prime_plot(self):
122+
return self._plot_metric(
123+
self.units["d_prime"], np.linspace(0, 15, 50), "d-prime"
124+
)
125+
126+
def nn_hit_plot(self):
127+
return self._plot_metric(
128+
self.units["nn_hit_rate"],
129+
np.linspace(0, 1, 100),
130+
"Nearest-neighbors hit rate",
131+
)

element_array_ephys/plotting/unit_level.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def plot_waveform(waveform: np.ndarray, sampling_rate: float) -> go.Figure:
2626
x=waveform_df["timestamp"],
2727
y=waveform_df["waveform"],
2828
mode="lines",
29-
line=dict(color="rgb(51, 76.5, 204)", width=2),
29+
line=dict(color="rgb(0, 160, 223)", width=2), # DataJoint Blue
3030
hovertemplate="%{y:.2f} μV<br>" + "%{x:.2f} ms<extra></extra>",
3131
)
3232
)

element_array_ephys/plotting/widget.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import pathlib
2+
from skimage import io
13
from modulefinder import Module
24
import ipywidgets as widgets
3-
import pathlib
5+
from ipywidgets import widgets as wg
46
from IPython.display import display
5-
from .. import ephys_report
7+
from plotly.io import from_json
8+
from plotly.subplots import make_subplots
69
import plotly.graph_objs as go
710
import plotly.express as px
8-
from skimage import io
11+
12+
from .. import ephys_report
913

1014

1115
def main(ephys: Module) -> widgets:
@@ -123,3 +127,125 @@ def plot_unit_widget(unit):
123127
unit_widget = widgets.interactive(plot_unit_widget, unit=unit_dropdown_wg)
124128

125129
return widgets.VBox([probe_widget, unit_widget])
130+
131+
132+
def qc_widget(ephys: Module) -> widgets:
133+
from .qc import QualityMetricFigs
134+
135+
title_button = wg.Button(
136+
description="Ephys Quality Control Metrics",
137+
button_style="info",
138+
layout=wg.Layout(
139+
height="auto", width="auto", grid_area="title_button", border="solid"
140+
),
141+
style=wg.ButtonStyle(button_color="#00a0df"),
142+
disabled=True,
143+
)
144+
145+
cluster_dropdown = wg.Dropdown(
146+
options=ephys.QualityMetrics.fetch("KEY"),
147+
description="Clusters:",
148+
description_tooltip='Press "Load" to visualize the clusters identified.',
149+
disabled=False,
150+
layout=wg.Layout(
151+
width="95%",
152+
display="flex",
153+
flex_flow="row",
154+
justify_content="space-between",
155+
grid_area="cluster_dropdown",
156+
),
157+
style={"description_width": "80px"},
158+
)
159+
160+
cutoff_dropdown = wg.Dropdown(
161+
options=ephys_report.QualityMetricCutoffs.fetch("KEY"),
162+
description="Cutoffs:",
163+
description_tooltip='Press "Load" to visualize the clusters identified.',
164+
disabled=False,
165+
layout=wg.Layout(
166+
width="95%",
167+
display="flex",
168+
flex_flow="row",
169+
justify_content="space-between",
170+
grid_area="cutoff_dropdown",
171+
),
172+
style={"description_width": "80px"},
173+
)
174+
175+
fig = make_subplots(
176+
rows=1,
177+
cols=2,
178+
shared_yaxes=False,
179+
horizontal_spacing=0.01,
180+
vertical_spacing=0,
181+
column_titles=["Firing", "Title2"],
182+
)
183+
fwg = go.FigureWidget(fig)
184+
185+
# figure_output = wg.VBox(
186+
# [QualityMetricFigs.empty_fig()],
187+
# layout=wg.Layout(width="95%", grid_area="figure_output"),
188+
# )
189+
# figure_output.add_class("box_style")
190+
191+
load_button = wg.Button(
192+
description="Load",
193+
tooltip="Load figures.",
194+
layout=wg.Layout(width="auto", grid_area="load_button"),
195+
)
196+
197+
def response(change, usedb=False): # TODO: Accept cutoff vals?
198+
global firing_rate_plot
199+
if usedb:
200+
if cluster_dropdown.value not in ephys_report.QualityMetricReport():
201+
ephys_report.QualityMetricReport.populate(cluster_dropdown.value)
202+
203+
firing_rate_plot = from_json(
204+
(ephys_report.QualityMetricReport & cluster_dropdown.value).fetch1(
205+
"firing_rate_plot"
206+
)
207+
)
208+
209+
presence_ratio_plot = from_json(
210+
(ephys_report.QualityMetricReport & cluster_dropdown.value).fetch1(
211+
"presence_ratio_plot"
212+
)
213+
)
214+
215+
else:
216+
qc_figs = QualityMetricFigs(cluster_dropdown)
217+
firing_rate_plot = qc_figs.empty_fig()
218+
presence_ratio_plot = qc_figs.empty_fig()
219+
220+
with fwg.batch_update():
221+
fwg.data[0] = firing_rate_plot
222+
fwg.data[1] = presence_ratio_plot
223+
224+
figure_output = wg.VBox(
225+
[fwg], layout=wg.Layout(width="95%", grid_area="figure_output")
226+
)
227+
figure_output.add_class("box_style")
228+
229+
load_button.on_click(response)
230+
231+
main_container = wg.GridBox(
232+
children=[
233+
title_button,
234+
cluster_dropdown,
235+
cutoff_dropdown,
236+
load_button,
237+
figure_output,
238+
],
239+
layout=wg.Layout(
240+
grid_template_areas="""
241+
"title_button title_button title_button"
242+
"cluster_dropdown . load_button"
243+
"cutoff_dropdown . load_button"
244+
"figure_output figure_output figure_output"
245+
"""
246+
),
247+
grid_template_rows="auto auto auto auto",
248+
grid_template_columns="auto auto auto",
249+
)
250+
251+
return main_container

0 commit comments

Comments
 (0)