@@ -21,16 +21,16 @@ class GenModel(base.Generative):
21
21
Parameters
22
22
----------
23
23
lambda_ : float, optional
24
- a positive real number, 1.0 by default.
24
+ a positive real number, by default 1.0 .
25
25
h_alpha : float, optional
26
- a positive real number, 1.0 by default.
26
+ a positive real number, by default 1.0 .
27
27
h_beta : float, optional
28
- a positive real number, 1.0 by default.
28
+ a positive real number, by default 1.0 .
29
29
seed : {None, int}, optional
30
30
A seed to initialize numpy.random.default_rng(),
31
31
by default None
32
32
"""
33
- def __init__ (self ,* , lambda_ = 1.0 ,h_alpha = 1.0 ,h_beta = 1.0 ,seed = None ):
33
+ def __init__ (self ,lambda_ = 1.0 ,h_alpha = 1.0 ,h_beta = 1.0 ,seed = None ):
34
34
self .rng = np .random .default_rng (seed )
35
35
36
36
# params
@@ -61,9 +61,9 @@ def set_h_params(self,h_alpha=None,h_beta=None):
61
61
Parameters
62
62
----------
63
63
h_alpha : float, optional
64
- a positive real number, None by default.
64
+ a positive real number, by default None .
65
65
h_beta : float, optional
66
- a positive real number, None by default.
66
+ a positive real number, by default None .
67
67
"""
68
68
if h_alpha is not None :
69
69
self .h_alpha = _check .pos_float (h_alpha ,'h_alpha' ,ParameterFormatError )
@@ -96,7 +96,7 @@ def set_params(self,lambda_=None):
96
96
Parameters
97
97
----------
98
98
lambda_ : float, optional
99
- a positive real number, None by default.
99
+ a positive real number, by default None .
100
100
"""
101
101
if lambda_ is not None :
102
102
self .lambda_ = _check .pos_float (lambda_ , 'lambda_' , ParameterFormatError )
@@ -153,9 +153,9 @@ def visualize_model(self,sample_size=100,hist_bins=10):
153
153
Parameters
154
154
----------
155
155
sample_size : int, optional
156
- A positive integer, 100 by default.
156
+ A positive integer, by default 100 .
157
157
hist_bins : float, optional
158
- A positive float, 10 by default.
158
+ A positive float, by default 10 .
159
159
160
160
Examples
161
161
--------
@@ -194,9 +194,9 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
194
194
Parameters
195
195
----------
196
196
h0_alpha : float, optional
197
- a positive real number, 1.0 by default.
197
+ a positive real number, by default 1.0 .
198
198
h0_beta : float, optional
199
- a positibe real number, 1.0 by default.
199
+ a positibe real number, by default 1.0 .
200
200
201
201
Attributes
202
202
----------
@@ -242,9 +242,9 @@ def set_h0_params(self,h0_alpha=None,h0_beta=None):
242
242
Parameters
243
243
----------
244
244
h0_alpha : float, optional
245
- a positive real number, None by default.
245
+ a positive real number, by default None .
246
246
h0_beta : float, optional
247
- a positibe real number, None by default.
247
+ a positibe real number, by default None .
248
248
"""
249
249
if h0_alpha is not None :
250
250
self .h0_alpha = _check .pos_float (h0_alpha , 'h0_alpha' , ParameterFormatError )
@@ -270,9 +270,9 @@ def set_hn_params(self,hn_alpha=None,hn_beta=None):
270
270
Parameters
271
271
----------
272
272
hn_alpha : float, optional
273
- a positive real number, None by default.
273
+ a positive real number, by default None .
274
274
hn_beta : float, optional
275
- a positibe real number, None by default.
275
+ a positibe real number, by default None .
276
276
"""
277
277
if hn_alpha is not None :
278
278
self .hn_alpha = _check .pos_float (hn_alpha , 'hn_alpha' , ParameterFormatError )
@@ -301,8 +301,17 @@ def update_posterior(self,x):
301
301
All the elements must be positive real numbers.
302
302
"""
303
303
_check .pos_floats (x , 'x' , DataFormatError )
304
+ try :
305
+ self .hn_alpha += x .size
306
+ except :
307
+ self .hn_alpha += 1
308
+ self .hn_beta += np .sum (x )
309
+ return self
310
+
311
+ def _update_posterior (self ,x ):
312
+ """Update opsterior without input check."""
304
313
self .hn_alpha += x .size
305
- self .hn_beta += x .sum ()
314
+ self .hn_beta += np .sum (x )
306
315
return self
307
316
308
317
def estimate_params (self ,loss = "squared" ,dict_out = False ):
@@ -311,10 +320,10 @@ def estimate_params(self,loss="squared",dict_out=False):
311
320
Parameters
312
321
----------
313
322
loss : str, optional
314
- Loss function underlying the Bayes risk function, \" squared\" by default .
323
+ Loss function underlying the Bayes risk function, by default \" squared\" .
315
324
This function supports \" squared\" , \" 0-1\" , \" abs\" , and \" KL\" .
316
325
dict_out : bool, optional
317
- If ``True``, output will be a dict, ``False`` by default .
326
+ If ``True``, output will be a dict, by default ``False``.
318
327
319
328
Returns
320
329
-------
@@ -361,7 +370,7 @@ def estimate_interval(self,credibility=0.95):
361
370
Parameters
362
371
----------
363
372
credibility : float, optional
364
- A posterior probability that the interval conitans the paramter, 0.95 by default.
373
+ A posterior probability that the interval conitans the paramter, by default 0.95 .
365
374
366
375
Returns
367
376
-------
@@ -417,7 +426,7 @@ def make_prediction(self,loss="squared"):
417
426
Parameters
418
427
----------
419
428
loss : str, optional
420
- Loss function underlying the Bayes risk function, \" squared\" by default .
429
+ Loss function underlying the Bayes risk function, by default \" squared\" .
421
430
This function supports \" squared\" , \" 0-1\" , \" abs\" , and \" KL\" .
422
431
423
432
Returns
@@ -451,7 +460,7 @@ def pred_and_update(self,x,loss="squared"):
451
460
x : float
452
461
a positive real number
453
462
loss : str, optional
454
- Loss function underlying the Bayes risk function, \" squared\" by default .
463
+ Loss function underlying the Bayes risk function, by default \" squared\" .
455
464
This function supports \" squared\" , \" 0-1\" , \" abs\" , and \" KL\" .
456
465
457
466
Returns
0 commit comments