Skip to content

Commit 11f4bf7

Browse files
committed
Update bernoulli
1 parent a7520a3 commit 11f4bf7

File tree

2 files changed

+161
-55
lines changed

2 files changed

+161
-55
lines changed

bayesml/bernoulli/_bernoulli.py

Lines changed: 84 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,44 @@ class GenModel(base.Generative):
2828
by default None
2929
"""
3030
def __init__(self,*,theta=0.5,h_alpha=0.5,h_beta=0.5,seed=None):
31-
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
32-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
33-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
3431
self.rng = np.random.default_rng(seed)
3532

36-
def set_h_params(self,h_alpha,h_beta):
33+
# params
34+
self.theta = 0.5
35+
36+
# h_params
37+
self.h_alpha = 0.5
38+
self.h_beta = 0.5
39+
40+
self.set_params(theta)
41+
self.set_h_params(h_alpha,h_beta)
42+
43+
def get_constants(self):
44+
"""Get constants of GenModel.
45+
46+
This model does not have any constants.
47+
Therefore, this function returns an emtpy dict ``{}``.
48+
49+
Returns
50+
-------
51+
constants : an empty dict
52+
"""
53+
return {}
54+
55+
def set_h_params(self,h_alpha=None,h_beta=None):
3756
"""Set the hyperparameters of the prior distribution.
3857
3958
Parameters
4059
----------
41-
h_alpha : float
42-
a positive real number
60+
h_alpha : float, optional
61+
a positive real number, bydefault None
4362
h_beta : float, optional
44-
a positibe real number
63+
a positibe real number, bydefault None
4564
"""
46-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
47-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
65+
if h_alpha is not None:
66+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
67+
if h_beta is not None:
68+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
4869
return self
4970

5071
def get_h_params(self):
@@ -66,15 +87,16 @@ def gen_params(self):
6687
self.theta = self.rng.beta(self.h_alpha,self.h_beta)
6788
return self
6889

69-
def set_params(self,theta):
90+
def set_params(self,theta=None):
7091
"""Set the parameter of the sthocastic data generative model.
7192
7293
Parameters
7394
----------
74-
theta : float
75-
a real number :math:`\theta \in [0, 1]`
95+
theta : float, optional
96+
a real number :math:`\theta \in [0, 1]`, by default None.
7697
"""
77-
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
98+
if theta is not None:
99+
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
78100
return self
79101

80102
def get_params(self):
@@ -187,24 +209,45 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
187209
a real number :math:`\theta_\mathrm{p} \in [0, 1]`
188210
"""
189211
def __init__(self,h0_alpha=0.5,h0_beta=0.5):
190-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
191-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
192-
self.hn_alpha = self.h0_alpha
193-
self.hn_beta = self.h0_beta
194-
self.p_theta = self.hn_alpha / (self.hn_alpha + self.hn_beta)
212+
# h0_params
213+
self.h0_alpha = 0.5
214+
self.h0_beta = 0.5
215+
216+
# hn_params
217+
self.hn_alpha = 0.5
218+
self.hn_beta = 0.5
219+
220+
# p_params
221+
self.p_theta = 0.5
222+
223+
self.set_h0_params(h0_alpha,h0_beta)
195224

196-
def set_h0_params(self,h0_alpha,h0_beta):
225+
def get_constants(self):
226+
"""Get constants of GenModel.
227+
228+
This model does not have any constants.
229+
Therefore, this function returns an emtpy dict ``{}``.
230+
231+
Returns
232+
-------
233+
constants : an empty dict
234+
"""
235+
return {}
236+
237+
def set_h0_params(self,h0_alpha=None,h0_beta=None):
197238
"""Set initial values of the hyperparameter of the posterior distribution.
198239
199240
Parameters
200241
----------
201-
h0_alpha : float
202-
a positive real number
203-
h0_beta : float
204-
a positibe real number
242+
h0_alpha : float, optional
243+
a positive real number, by default None
244+
h0_beta : float, optionanl
245+
a positibe real number, by default None
205246
"""
206-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
207-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
247+
if h0_alpha is not None:
248+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
249+
if h0_beta is not None:
250+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
208251
self.reset_hn_params()
209252
return self
210253

@@ -224,13 +267,15 @@ def set_hn_params(self,hn_alpha,hn_beta):
224267
225268
Parameters
226269
----------
227-
hn_alpha : float
228-
a positive real number
229-
hn_beta : float
230-
a positibe real number
270+
hn_alpha : float, optional
271+
a positive real number, by default None
272+
hn_beta : float, optional
273+
a positibe real number, by default None
231274
"""
232-
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
233-
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
275+
if hn_alpha is not None:
276+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
277+
if hn_beta is not None:
278+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
234279
self.calc_pred_dist()
235280
return self
236281

@@ -245,28 +290,6 @@ def get_hn_params(self):
245290
"""
246291
return {"hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
247292

248-
def reset_hn_params(self):
249-
"""Reset the hyperparameters of the posterior distribution to their initial values.
250-
251-
They are reset to `self.h0_alpha` and `self.h0_beta`.
252-
Note that the parameters of the predictive distribution are also calculated from `self.h0_alpha` and `self.h0_beta`.
253-
"""
254-
self.hn_alpha = self.h0_alpha
255-
self.hn_beta = self.h0_beta
256-
self.calc_pred_dist()
257-
return self
258-
259-
def overwrite_h0_params(self):
260-
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
261-
262-
They are overwritten by `self.hn_alpha` and `self.hn_beta`.
263-
Note that the parameters of the predictive distribution are also calculated from `self.hn_alpha` and `self.hn_beta`.
264-
"""
265-
self.h0_alpha = self.hn_alpha
266-
self.h0_beta = self.hn_beta
267-
self.calc_pred_dist()
268-
return self
269-
270293
def update_posterior(self,x):
271294
"""Update the hyperparameters of the posterior distribution using traning data.
272295
@@ -276,8 +299,14 @@ def update_posterior(self,x):
276299
All the elements must be 0 or 1.
277300
"""
278301
_check.ints_of_01(x,'x',DataFormatError)
279-
self.hn_alpha += np.sum(x==1)
280-
self.hn_beta += np.sum(x==0)
302+
self.hn_alpha += np.count_nonzero(x==1)
303+
self.hn_beta += np.count_nonzero(x==0)
304+
return self
305+
306+
def _update_posterior(self,x):
307+
"""Update opsterior withou input check."""
308+
self.hn_alpha += np.count_nonzero(x==1)
309+
self.hn_beta += np.count_nonzero(x==0)
281310
return self
282311

283312
def estimate_params(self,loss="squared",dict_out=False):

bayesml/test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from bayesml.bernoulli import GenModel
2+
from bayesml.bernoulli import LearnModel
3+
4+
if __name__ == '__main__':
5+
gen_model = GenModel()
6+
h_params = gen_model.get_h_params()
7+
params = gen_model.get_params()
8+
learn_model = LearnModel()
9+
h0_params = learn_model.get_h0_params()
10+
hn_params = learn_model.get_hn_params()
11+
12+
learn_model.set_h0_params(*h_params.values())
13+
print('ok1')
14+
learn_model.set_h0_params(*h0_params.values())
15+
print('ok2')
16+
learn_model.set_h0_params(*hn_params.values())
17+
print('ok3')
18+
try:
19+
learn_model.set_h0_params(**params)
20+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
21+
except:
22+
print('ok4')
23+
24+
learn_model.set_hn_params(*h_params.values())
25+
print('ok5')
26+
learn_model.set_hn_params(*h0_params.values())
27+
print('ok6')
28+
learn_model.set_hn_params(*hn_params.values())
29+
print('ok7')
30+
try:
31+
learn_model.set_hn_params(**params)
32+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
33+
except:
34+
print('ok8')
35+
36+
gen_model.set_h_params(*h_params.values())
37+
print('ok9')
38+
gen_model.set_h_params(*h0_params.values())
39+
print('ok10')
40+
gen_model.set_h_params(*hn_params.values())
41+
print('ok11')
42+
try:
43+
gen_model.set_h_params(**params)
44+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
45+
except:
46+
print('ok12')
47+
48+
gen_model.set_params(**params)
49+
print('ok13')
50+
try:
51+
gen_model.set_params(*h_params.values())
52+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
53+
except:
54+
print('ok14')
55+
56+
import copy
57+
gen_model = GenModel()
58+
x = gen_model.gen_sample(100)
59+
learn_model = LearnModel()
60+
h0_params = copy.deepcopy(learn_model.get_h0_params())
61+
hn_params = copy.deepcopy(learn_model.get_hn_params())
62+
learn_model.update_posterior(x)
63+
if str(hn_params) != str(learn_model.get_hn_params()):
64+
print('ok15')
65+
else:
66+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
67+
learn_model.reset_hn_params()
68+
if str(hn_params) == str(learn_model.get_hn_params()):
69+
print('ok16')
70+
else:
71+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')
72+
learn_model.update_posterior(x)
73+
learn_model.overwrite_h0_params()
74+
if str(h0_params) != str(learn_model.get_h0_params()):
75+
print('ok17')
76+
else:
77+
print('!!!!!!!!!!!NG!!!!!!!!!!!!')

0 commit comments

Comments
 (0)