@@ -42,32 +42,32 @@ def __init__(
42
42
self ._ephys = ephys
43
43
self ._key = key
44
44
self ._scale = scale
45
- self ._plots = {}
45
+ self ._plots = {} # Empty default to defer set to dict property below
46
46
self ._fig_width = fig_width
47
47
self ._amplitude_cutoff_max = amplitude_cutoff_maximum
48
48
self ._presence_ratio_min = presence_ratio_minimum
49
49
self ._isi_violations_max = isi_violations_maximum
50
50
self ._dark_mode = dark_mode
51
- self ._units = pd .DataFrame ()
51
+ self ._units = pd .DataFrame () # Empty default
52
52
self ._x_fmt = dict (showgrid = False , zeroline = False , linewidth = 2 , ticks = "outside" )
53
53
self ._y_fmt = dict (showgrid = False , linewidth = 0 , zeroline = True , visible = False )
54
- self ._no_data_text = "No data available"
55
- self ._null_series = pd .Series (np .nan )
54
+ self ._no_data_text = "No data available" # What to show when no data in table
55
+ self ._null_series = pd .Series (np .nan ) # What to substitute when no data
56
56
57
57
@property
58
58
def key (self ) -> dict :
59
59
"""Key in ephys.QualityMetrics table"""
60
60
return self ._key
61
61
62
- @key .setter
62
+ @key .setter # Allows `cls.property = new_item` notation
63
63
def key (self , key : dict ):
64
64
"""Use class_instance.key = your_key to reset key"""
65
65
if key not in self ._ephys .QualityMetrics .fetch ("KEY" ):
66
- # if not already key, check if unquely identifies entry
66
+ # If not already full key, check if unquely identifies entry
67
67
key = (self ._ephys .QualityMetrics & key ).fetch1 ("KEY" )
68
68
self ._key = key
69
69
70
- @key .deleter
70
+ @key .deleter # Allows `del cls.property` to clear key
71
71
def key (self ):
72
72
"""Use del class_instance.key to clear key"""
73
73
logger .info ("Cleared key" )
@@ -103,15 +103,17 @@ def cutoffs(self, add_to_tables: bool = False, **cutoff_kwargs):
103
103
"isi_violations_maximum" , self ._isi_violations_max
104
104
)
105
105
_ = self .units
106
+
106
107
if add_to_tables :
107
108
ephys_report .QualityMetricCutoffs .insert_new_cutoffs (** cutoff_kwargs )
109
+ logger .info ("Added cutoffs to QualityMetricCutoffs table" )
108
110
109
111
@property
110
112
def units (self ) -> pd .DataFrame :
111
113
"""Pandas dataframe of QC metrics"""
112
114
if not self ._key :
113
- logger .info ("No key set" )
114
115
return self ._null_series
116
+
115
117
if self ._units .empty :
116
118
restrictions = ["TRUE" ]
117
119
if self ._amplitude_cutoff_max :
@@ -120,14 +122,15 @@ def units(self) -> pd.DataFrame:
120
122
restrictions .append (f"presence_ratio > { self ._presence_ratio_min } " )
121
123
if self ._isi_violations_max :
122
124
restrictions .append (f"isi_violation < { self ._isi_violations_max } " )
123
- " AND " .join (restrictions )
125
+ " AND " .join (restrictions ) # Build restriction from cutoffs
124
126
return (
125
127
self ._ephys .QualityMetrics
126
128
* self ._ephys .QualityMetrics .Cluster
127
129
* self ._ephys .QualityMetrics .Waveform
128
130
& self ._key
129
131
& restrictions
130
132
).fetch (format = "frame" )
133
+
131
134
return self ._units
132
135
133
136
def _format_fig (
@@ -145,12 +148,13 @@ class init, 1.
145
148
Returns:
146
149
go.Figure: Formatted figure
147
150
"""
148
-
149
151
if not fig :
150
152
fig = go .Figure ()
151
153
if not scale :
152
154
scale = self ._scale
155
+
153
156
width = self ._fig_width * scale
157
+
154
158
return fig .update_layout (
155
159
template = "plotly_dark" if self ._dark_mode else "simple_white" ,
156
160
width = width ,
@@ -160,14 +164,15 @@ class init, 1.
160
164
)
161
165
162
166
def _empty_fig (
163
- self , annotation = "Select a key to visualize QC metrics" , scale = None
167
+ self , text = "Select a key to visualize QC metrics" , scale = None
164
168
) -> go .Figure :
165
169
"""Return figure object for when no key is provided"""
166
170
if not scale :
167
171
scale = self ._scale
172
+
168
173
return (
169
174
self ._format_fig (scale = scale )
170
- .add_annotation (text = annotation , showarrow = False )
175
+ .add_annotation (text = text , showarrow = False )
171
176
.update_layout (xaxis = self ._y_fmt , yaxis = self ._y_fmt )
172
177
)
173
178
@@ -196,10 +201,14 @@ class initialization.
196
201
scale = self ._scale
197
202
if not fig :
198
203
fig = self ._format_fig (scale = scale )
199
- # if data.isnull().all():
200
- histogram , histogram_bins = np .histogram (data , bins = bins , density = True )
201
204
202
- fig .add_trace (
205
+ if not data .isnull ().all ():
206
+ histogram , histogram_bins = np .histogram (data , bins = bins , density = True )
207
+ else :
208
+ # To quiet divide by zero error when no data
209
+ histogram , histogram_bins = np .ndarray (0 ), np .ndarray (0 )
210
+
211
+ return fig .add_trace (
203
212
go .Scatter (
204
213
x = histogram_bins [:- 1 ],
205
214
y = gaussian_filter1d (histogram , 1 ), # TODO: remove smoothing
@@ -209,7 +218,6 @@ class initialization.
209
218
),
210
219
** trace_kwargs ,
211
220
)
212
- return fig
213
221
214
222
def get_single_fig (self , fig_name : str , scale : float = None ) -> go .Figure :
215
223
"""Return a single figure of the plots listed in the plot_list property
@@ -224,7 +232,6 @@ def get_single_fig(self, fig_name: str, scale: float = None) -> go.Figure:
224
232
"""
225
233
if not self ._key :
226
234
return self ._empty_fig ()
227
-
228
235
if not scale :
229
236
scale = self ._scale
230
237
@@ -234,7 +241,7 @@ def get_single_fig(self, fig_name: str, scale: float = None) -> go.Figure:
234
241
vline = fig_dict .get ("vline" , None )
235
242
236
243
if data .isnull ().all ():
237
- return self ._empty_fig (annotation = self ._no_data_text )
244
+ return self ._empty_fig (text = self ._no_data_text )
238
245
239
246
fig = (
240
247
self ._plot_metric (data = data , bins = bins , scale = scale )
@@ -265,11 +272,11 @@ def get_grid(self, n_columns: int = 4, scale: float = 1.0) -> go.Figure:
265
272
266
273
if not self ._key :
267
274
return self ._empty_fig ()
268
-
269
- n_rows = int (np .ceil (len (self .plots ) / n_columns ))
270
275
if not scale :
271
276
scale = self ._scale
272
277
278
+ n_rows = int (np .ceil (len (self .plots ) / n_columns ))
279
+
273
280
fig = self ._format_fig (
274
281
fig = make_subplots (
275
282
rows = n_rows ,
@@ -280,12 +287,12 @@ def get_grid(self, n_columns: int = 4, scale: float = 1.0) -> go.Figure:
280
287
),
281
288
scale = scale ,
282
289
ratio = (n_columns / n_rows ),
283
- ).update_layout (
290
+ ).update_layout ( # Global title
284
291
title = dict (text = "Histograms of Quality Metrics" , xanchor = "center" , x = 0.5 ),
285
292
font = dict (size = 12 * scale ),
286
293
)
287
294
288
- for idx , plot in enumerate (self ._plots .values ()):
295
+ for idx , plot in enumerate (self ._plots .values ()): # Each subplot
289
296
this_row = int (np .floor (idx / n_columns ) + 1 )
290
297
this_col = idx % n_columns + 1
291
298
data = plot .get ("data" , self ._null_series )
@@ -302,7 +309,7 @@ def get_grid(self, n_columns: int = 4, scale: float = 1.0) -> go.Figure:
302
309
),
303
310
]
304
311
)
305
- fig = self ._plot_metric ( # still need to plot so vlines y_value works right
312
+ fig = self ._plot_metric ( # still need to plot empty to cal y_vals min/max
306
313
data = data ,
307
314
bins = plot ["bins" ],
308
315
fig = fig ,
@@ -317,12 +324,11 @@ def get_grid(self, n_columns: int = 4, scale: float = 1.0) -> go.Figure:
317
324
)
318
325
if vline :
319
326
y_vals = fig .to_dict ()["data" ][idx ]["y" ]
320
- # y_vals = plot["data"]
321
- fig .add_shape (
327
+ fig .add_shape ( # Add overlay WRT whole fig
322
328
go .layout .Shape (
323
329
type = "line" ,
324
330
yref = "paper" ,
325
- xref = "x" ,
331
+ xref = "x" , # relative to subplot x
326
332
x0 = vline ,
327
333
y0 = min (y_vals ),
328
334
x1 = vline ,
@@ -332,14 +338,13 @@ def get_grid(self, n_columns: int = 4, scale: float = 1.0) -> go.Figure:
332
338
row = this_row ,
333
339
col = this_col ,
334
340
)
335
- fig .update_xaxes (** self ._x_fmt )
336
- fig .update_yaxes (** self ._y_fmt )
337
- return fig
341
+
342
+ return fig .update_xaxes (** self ._x_fmt ).update_yaxes (** self ._y_fmt )
338
343
339
344
@property
340
345
def plot_list (self ):
341
346
"""List of plots that can be rendered inidividually by name or as grid"""
342
- if not self .plots :
347
+ if not self ._plots :
343
348
_ = self .plots
344
349
return [plot for plot in self ._plots ]
345
350
0 commit comments