Skip to content

Commit 80826d3

Browse files
committed
Update poisson
1 parent 80d6628 commit 80826d3

File tree

4 files changed

+98
-65
lines changed

4 files changed

+98
-65
lines changed

bayesml/bernoulli/_bernoulli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(self,h0_alpha=0.5,h0_beta=0.5):
223223
self.set_h0_params(h0_alpha,h0_beta)
224224

225225
def get_constants(self):
226-
"""Get constants of GenModel.
226+
"""Get constants of LearnModel.
227227
228228
This model does not have any constants.
229229
Therefore, this function returns an emtpy dict ``{}``.

bayesml/exponential/_exponential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def __init__(self,h0_alpha=1.0, h0_beta=1.0):
225225
self.set_h0_params(h0_alpha,h0_beta)
226226

227227
def get_constants(self):
228-
"""Get constants of GenModel.
228+
"""Get constants of LearnModel.
229229
230230
This model does not have any constants.
231231
Therefore, this function returns an emtpy dict ``{}``.

bayesml/normal/_normal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(self,h0_m=0.0,h0_kappa=1.0,h0_alpha=1.0,h0_beta=1.0):
270270
)
271271

272272
def get_constants(self):
273-
"""Get constants of GenModel.
273+
"""Get constants of LearnModel.
274274
275275
This model does not have any constants.
276276
Therefore, this function returns an emtpy dict ``{}``.
@@ -558,6 +558,7 @@ def pred_and_update(self,x,loss="squared"):
558558
If the loss function is \"KL\", the predictive distribution itself will be returned
559559
as numpy.ndarray.
560560
"""
561+
_check.float_(x,'x',DataFormatError)
561562
self.calc_pred_dist()
562563
prediction = self.make_prediction(loss=loss)
563564
self.update_posterior(x)

bayesml/poisson/_poisson.py

Lines changed: 94 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,45 @@ class GenModel(base.Generative):
3232
by default None
3333
"""
3434

35-
def __init__(self,*,lambda_=1.0,h_alpha=1.0,h_beta=1.0,seed=None):
36-
self.lambda_ = _check.pos_float(lambda_,'lambda_',ParameterFormatError)
37-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
38-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
35+
def __init__(self,lambda_=1.0,h_alpha=1.0,h_beta=1.0,seed=None):
3936
self.rng = np.random.default_rng(seed)
4037

41-
def set_h_params(self,h_alpha,h_beta):
38+
# params
39+
self.lambda_ = 1.0
40+
41+
# h_params
42+
self.h_alpha = 1.0
43+
self.h_beta = 1.0
44+
45+
self.set_params(lambda_)
46+
self.set_h_params(h_alpha,h_beta)
47+
48+
def get_constants(self):
49+
"""Get constants of GenModel.
50+
51+
This model does not have any constants.
52+
Therefore, this function returns an emtpy dict ``{}``.
53+
54+
Returns
55+
-------
56+
constants : an empty dict
57+
"""
58+
return {}
59+
60+
def set_h_params(self,h_alpha=None,h_beta=None):
4261
"""Set the hyperparameters of the prior distribution.
4362
4463
Parameters
4564
----------
46-
h_alpha : float
47-
a positive real number
48-
h_beta : float
49-
a positibe real number
65+
h_alpha : float, optional
66+
a positive real number, by default None
67+
h_beta : float, optional
68+
a positibe real number, by default None
5069
"""
51-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
52-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
70+
if h_alpha is not None:
71+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
72+
if h_beta is not None:
73+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
5374
return self
5475

5576
def get_h_params(self):
@@ -71,15 +92,16 @@ def gen_params(self):
7192
self.lambda_ = self.rng.gamma(shape=self.h_alpha,scale=1.0/self.h_beta)
7293
return self
7394

74-
def set_params(self,lambda_):
95+
def set_params(self,lambda_=None):
7596
"""Set the parameter of the sthocastic data generative model.
7697
7798
Parameters
7899
----------
79-
lambda_ : float
80-
a positive real number
100+
lambda_ : float, optional
101+
a positive real number, by default None
81102
"""
82-
self.lambda_ = _check.pos_float(lambda_,'lambda_',ParameterFormatError)
103+
if lambda_ is not None:
104+
self.lambda_ = _check.pos_float(lambda_,'lambda_',ParameterFormatError)
83105
return self
84106

85107
def get_params(self):
@@ -175,34 +197,55 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
175197
Attributes
176198
----------
177199
hn_alpha : float
178-
a positive real number, by default 1.0
200+
a positive real number
179201
hn_beta : float
180-
a positibe real number, by default 1.0
202+
a positibe real number
181203
p_r : float
182-
a positive real number, by default 1.0
204+
a positive real number
183205
p_theta : float
184-
a real number in :math:`[0, 1]`, by default 0.5
206+
a real number in :math:`[0, 1]`
185207
"""
186208
def __init__(self,h0_alpha=1.0,h0_beta=1.0):
187-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
188-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
189-
self.hn_alpha = self.h0_alpha
190-
self.hn_beta = self.h0_beta
191-
self.p_r = self.hn_alpha
192-
self.p_theta = 1.0 / (1.0+self.hn_beta)
193-
194-
def set_h0_params(self,h0_alpha,h0_beta):
209+
# h0_params
210+
self.h0_alpha = 1.0
211+
self.h0_beta = 1.0
212+
213+
# hn_params
214+
self.hn_alpha = 1.0
215+
self.hn_beta = 1.0
216+
217+
#p_params
218+
self.p_r = 1.0
219+
self.p_theta = 0.5
220+
221+
self.set_h0_params(h0_alpha,h0_beta)
222+
223+
def get_constants(self):
224+
"""Get constants of LearnModel.
225+
226+
This model does not have any constants.
227+
Therefore, this function returns an emtpy dict ``{}``.
228+
229+
Returns
230+
-------
231+
constants : an empty dict
232+
"""
233+
return {}
234+
235+
def set_h0_params(self,h0_alpha=None,h0_beta=None):
195236
"""Set initial values of the hyperparameter of the posterior distribution.
196237
197238
Parameters
198239
----------
199-
h0_alpha : float
200-
a positive real number
201-
h0_beta : float
202-
a positibe real number
240+
h0_alpha : float, optional
241+
a positive real number, by default None
242+
h0_beta : float, optional
243+
a positibe real number, by default None
203244
"""
204-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
205-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
245+
if h0_alpha is not None:
246+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
247+
if h0_beta is not None:
248+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
206249
self.reset_hn_params()
207250
return self
208251

@@ -217,18 +260,20 @@ def get_h0_params(self):
217260
"""
218261
return {"h0_alpha":self.h0_alpha, "h0_beta":self.h0_beta}
219262

220-
def set_hn_params(self,hn_alpha,hn_beta):
263+
def set_hn_params(self,hn_alpha=None,hn_beta=None):
221264
"""Set updated values of the hyperparameter of the posterior distribution.
222265
223266
Parameters
224267
----------
225-
hn_alpha : float
226-
a positive real number
227-
hn_beta : float
228-
a positibe real number
268+
hn_alpha : float, optional
269+
a positive real number, by default None
270+
hn_beta : float, optional
271+
a positibe real number, by default None
229272
"""
230-
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
231-
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
273+
if hn_alpha is not None:
274+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
275+
if hn_beta is not None:
276+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
232277
self.calc_pred_dist()
233278
return self
234279

@@ -243,28 +288,6 @@ def get_hn_params(self):
243288
"""
244289
return {"hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
245290

246-
def overwrite_h0_params(self):
247-
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
248-
249-
They are overwritten by `self.hn_alpha` and `self.hn_beta`.
250-
Note that the parameters of the predictive distribution are also calculated from `self.hn_alpha` and `self.hn_beta`.
251-
"""
252-
self.h0_alpha = self.hn_alpha
253-
self.h0_beta = self.hn_beta
254-
self.calc_pred_dist()
255-
return self
256-
257-
def reset_hn_params(self):
258-
"""Reset the hyperparameters of the posterior distribution to their initial values.
259-
260-
They are reset to `self.h0_alpha` and `self.h0_beta`.
261-
Note that the parameters of the predictive distribution are also calculated from `self.h0_alpha` and `self.h0_beta`.
262-
"""
263-
self.hn_alpha = self.h0_alpha
264-
self.hn_beta = self.h0_beta
265-
self.calc_pred_dist()
266-
return self
267-
268291
def update_posterior(self,x):
269292
"""Update the hyperparameters of the posterior distribution using traning data.
270293
@@ -275,6 +298,15 @@ def update_posterior(self,x):
275298
"""
276299
_check.nonneg_ints(x,'x',DataFormatError)
277300
self.hn_alpha += np.sum(x)
301+
try:
302+
self.hn_beta += x.size
303+
except:
304+
self.hn_beta += 1
305+
return self
306+
307+
def _update_posterior(self,x):
308+
"""Update opsterior without input check."""
309+
self.hn_alpha += np.sum(x)
278310
self.hn_beta += x.size
279311
return self
280312

0 commit comments

Comments
 (0)