@@ -27,9 +27,6 @@ def normal_pdf( x :np.ndarray,
2727 = \dfrac{e^{- \frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^{2}}}
2828 {\sigma \sqrt{2\pi}}$$"""
2929
30- np .e
31- # dev = x.device
32-
3330 mean = np .array (mean ) #if not isinstance(mean, torch.Tensor) else mean
3431 std = np .array (std ) #.to(dev) if not isinstance(std, torch.Tensor) else std
3532
@@ -46,22 +43,22 @@ def sample( x :np.ndarray,
4643 # - samples from x
4744 # - original x min (None = no good numbes in x)
4845 # - original x max (None = no good numbes in x)
49-
46+
5047 # Ignore NaN and Inf.
5148 x = x [ np .isfinite (x ) ]
5249 x_min = x_max = None
5350
54- if x .size :
51+ if x .size :
5552 x_min , x_max = x .min (), x .max ()
56-
53+
5754 # An option to ignore zeros
5855 if not plt0 : x = x [x != 0. ]
5956
6057 if x .size > max_s and max_s > 0 :
6158 rng = np .random .default_rng ( get_config ().plt_seed )
6259 x = rng .choice (x .reshape (- 1 ), max_s ) # Sample with replacement for efficiency
6360
64- return (x , x_min , x_max )
61+ return (x , x_min , x_max )
6562
6663# %% ../nbs/02_repr_plt.ipynb 7
6764def find_xlims ( x_min :Union [float , None ],
@@ -71,7 +68,7 @@ def find_xlims( x_min :Union[float, None],
7168 center :str ):
7269
7370 assert center in ["zero" , "mean" , "range" ]
74-
71+
7572 if x_min is None or x_max is None : return (- 1. , 1 ,)
7673 if x_min == x_max and center == "range" : center = "zero"
7774 if x_mean is None or x_std is None and center == "mean" : center = "zero"
@@ -92,9 +89,9 @@ def find_xlims( x_min :Union[float, None],
9289 # Center the plot around zero
9390 abs_max_value = max (abs (x_min ), abs (x_max ), 1. )
9491 xlim_min , xlim_max = - abs_max_value , abs_max_value
95-
9692
97- # Give some extra space around the
93+
94+ # Give some extra space around the data
9895 xlim_min -= abs (xlim_max - xlim_min ) * 0.02
9996 xlim_max += abs (xlim_max - xlim_min ) * 0.02
10097
@@ -111,7 +108,7 @@ def plot_histogram( x :np.ndarray,
111108 # Adjust the number of bins proportional to the fraction of x axis occupied
112109 # by the histogram
113110 xlims = ax .get_xlim ()
114-
111+
115112 bins = min (bins , 100 )
116113 bins = np .ceil ( bins * ( (x .max ()- x .min ())/ (xlims [1 ]- xlims [0 ]) ) ).astype (int )
117114 bins = max (bins , 10 )
@@ -148,7 +145,7 @@ def plot_sigmas(x_min :Union[float, None],
148145
149146 for s in range (- sigmas , sigmas + 1 ):
150147 x_pos = (x_mean + s * x_std )
151- if xlims [0 ] < x_pos < xlims [1 ]:
148+ if xlims [0 ] < x_pos < xlims [1 ] and ( sigmas <= 20 or not s % 10 ) :
152149 greek = ["-σ" , "μ" , "+σ" ][s + 1 ] if - 1 <= s <= 1 else f"{ s :+} σ"
153150 weight = 'bold' if not s else None
154151 ax .axvline (x_pos , 0 , 1 , c = "black" )
@@ -196,7 +193,7 @@ def plot_str(t_str, ax):
196193
197194# %% ../nbs/02_repr_plt.ipynb 13
198195@config (show_mem_above = np .inf )
199- def fig_plot ( x :np .ndarray , #
196+ def fig_plot ( x :np .ndarray , #
200197 center :str = "zero" , # Center plot on `zero`, `mean`, or `range`
201198 max_s :int = 10000 , # Draw up to this many samples. =0 to draw all
202199 plt0 :Any = True , # Take zero values into account
@@ -210,14 +207,14 @@ def fig_plot( x :np.ndarray, #
210207 # display backend-specific info.
211208 if summary is None : summary = str (lovely (x , color = False ))
212209 orig_numel = x .size
213-
210+
214211 x , x_min , x_max = sample (x , max_s , plt0 )
215212 x_mean , x_std = (x .mean (), x .std (ddof = ddof )) if x .size else (None ,None )
216213
217214
218215 t_str = ""
219216 if x .size != orig_numel :
220- t_str += str (x .size )
217+ t_str += str (x .size )
221218 if not plt0 : t_str += " non-zero"
222219 t_str += f" samples (μ={ pretty_str (x_mean )} , σ={ pretty_str (x_std )} ) of "
223220 t_str += summary
@@ -235,15 +232,15 @@ def fig_plot( x :np.ndarray, #
235232 ax .set_xlim (* xlims )
236233 plot_histogram (x , ax )
237234 plot_pdf (x_mean , x_std , ax )
238-
235+
239236 # Add extra space to make sure the labels clear the histogram
240237 ylim = ax .get_ylim ()
241238 ax .set_ylim ( ylim [0 ], ylim [1 ]* 1.3 )
242239
243240 plot_sigmas (x_min , x_max , x_mean , x_std , ax )
244241 plot_minmax (x_min , x_max , ax )
245242 plot_str (t_str , ax )
246-
243+
247244 ax .set_yticks ([])
248245
249246 if show : plt .show ()
@@ -253,9 +250,9 @@ def fig_plot( x :np.ndarray, #
253250
254251# %% ../nbs/02_repr_plt.ipynb 14
255252# This is here for the monkey-patched tensor use case.
256- # Gives the ability to call both .plt and .plt(ax=ax).
253+ # Gives the ability to call both .plt and .plt(ax=ax).
257254
258- class PlotProxy ():
255+ class PlotProxy ():
259256 """Flexible `PIL.Image.Image` wrapper"""
260257
261258 def __init__ (self , x :np .ndarray ):
@@ -276,7 +273,7 @@ def __call__( self,
276273 self .params .update ( { k :v for
277274 k ,v in locals ().items ()
278275 if k != "self" and v is not None } )
279-
276+
280277 _ = self .fig # Trigger figure generation
281278 return self
282279
0 commit comments