Skip to content

Commit 107fa79

Browse files
Merge pull request #52 from yuta-nakahara/develop-update_basic_models
Develop update basic models
2 parents a7520a3 + 90d3ab8 commit 107fa79

File tree

9 files changed

+1344
-1370
lines changed

9 files changed

+1344
-1370
lines changed

bayesml/autoregressive/_autoregressive.py

Lines changed: 252 additions & 286 deletions
Large diffs are not rendered by default.

bayesml/bernoulli/_bernoulli.py

Lines changed: 97 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
import numpy as np
77
from scipy.stats import beta as ss_beta
8-
# from scipy.stats import betabino as ss_betabinom
8+
# from scipy.stats import betabinom as ss_betabinom
99
import matplotlib.pyplot as plt
1010

1111
from .. import base
@@ -18,33 +18,54 @@ class GenModel(base.Generative):
1818
Parameters
1919
----------
2020
theta : float, optional
21-
a real number in :math:`[0, 1]`, by default 0.5
21+
a real number in :math:`[0, 1]`, by default 0.5.
2222
h_alpha : float, optional
23-
a positive real number, by default 0.5
23+
a positive real number, by default 0.5.
2424
h_beta : float, optional
25-
a positibe real number, by default 0.5
25+
a positive real number, by default 0.5.
2626
seed : {None, int}, optional
2727
A seed to initialize numpy.random.default_rng(),
28-
by default None
28+
by default None.
2929
"""
30-
def __init__(self,*,theta=0.5,h_alpha=0.5,h_beta=0.5,seed=None):
31-
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
32-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
33-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
30+
def __init__(self,theta=0.5,h_alpha=0.5,h_beta=0.5,seed=None):
3431
self.rng = np.random.default_rng(seed)
3532

36-
def set_h_params(self,h_alpha,h_beta):
33+
# params
34+
self.theta = 0.5
35+
36+
# h_params
37+
self.h_alpha = 0.5
38+
self.h_beta = 0.5
39+
40+
self.set_params(theta)
41+
self.set_h_params(h_alpha,h_beta)
42+
43+
def get_constants(self):
44+
"""Get constants of GenModel.
45+
46+
This model does not have any constants.
47+
Therefore, this function returns an emtpy dict ``{}``.
48+
49+
Returns
50+
-------
51+
constants : an empty dict
52+
"""
53+
return {}
54+
55+
def set_h_params(self,h_alpha=None,h_beta=None):
3756
"""Set the hyperparameters of the prior distribution.
3857
3958
Parameters
4059
----------
41-
h_alpha : float
42-
a positive real number
60+
h_alpha : float, optional
61+
a positive real number, by default None.
4362
h_beta : float, optional
44-
a positibe real number
63+
a positive real number, by default None.
4564
"""
46-
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
47-
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
65+
if h_alpha is not None:
66+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
67+
if h_beta is not None:
68+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
4869
return self
4970

5071
def get_h_params(self):
@@ -66,15 +87,16 @@ def gen_params(self):
6687
self.theta = self.rng.beta(self.h_alpha,self.h_beta)
6788
return self
6889

69-
def set_params(self,theta):
90+
def set_params(self,theta=None):
7091
"""Set the parameter of the sthocastic data generative model.
7192
7293
Parameters
7394
----------
74-
theta : float
75-
a real number :math:`\theta \in [0, 1]`
95+
theta : float, optional
96+
a real number :math:`\theta \in [0, 1]`, by default None.
7697
"""
77-
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
98+
if theta is not None:
99+
self.theta = _check.float_in_closed01(theta,'theta',ParameterFormatError)
78100
return self
79101

80102
def get_params(self):
@@ -128,9 +150,9 @@ def visualize_model(self,sample_size=20,sample_num=5):
128150
Parameters
129151
----------
130152
sample_size : int, optional
131-
A positive integer, by default 20
153+
A positive integer, by default 20.
132154
sample_num : int, optional
133-
A positive integer, by default 5
155+
A positive integer, by default 5.
134156
135157
Examples
136158
--------
@@ -173,38 +195,59 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
173195
Parameters
174196
----------
175197
h0_alpha : float, optional
176-
a positive real number, by default 0.5
198+
a positive real number, by default 0.5.
177199
h0_beta : float, optional
178-
a positibe real number, by default 0.5
200+
a positive real number, by default 0.5.
179201
180202
Attributes
181203
----------
182204
hn_alpha : float
183205
a positive real number
184206
hn_beta : float
185-
a positibe real number
207+
a positive real number
186208
p_theta : float
187209
a real number :math:`\theta_\mathrm{p} \in [0, 1]`
188210
"""
189211
def __init__(self,h0_alpha=0.5,h0_beta=0.5):
190-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
191-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
192-
self.hn_alpha = self.h0_alpha
193-
self.hn_beta = self.h0_beta
194-
self.p_theta = self.hn_alpha / (self.hn_alpha + self.hn_beta)
212+
# h0_params
213+
self.h0_alpha = 0.5
214+
self.h0_beta = 0.5
215+
216+
# hn_params
217+
self.hn_alpha = 0.5
218+
self.hn_beta = 0.5
219+
220+
# p_params
221+
self.p_theta = 0.5
222+
223+
self.set_h0_params(h0_alpha,h0_beta)
195224

196-
def set_h0_params(self,h0_alpha,h0_beta):
225+
def get_constants(self):
226+
"""Get constants of LearnModel.
227+
228+
This model does not have any constants.
229+
Therefore, this function returns an emtpy dict ``{}``.
230+
231+
Returns
232+
-------
233+
constants : an empty dict
234+
"""
235+
return {}
236+
237+
def set_h0_params(self,h0_alpha=None,h0_beta=None):
197238
"""Set initial values of the hyperparameter of the posterior distribution.
198239
199240
Parameters
200241
----------
201-
h0_alpha : float
202-
a positive real number
203-
h0_beta : float
204-
a positibe real number
242+
h0_alpha : float, optional
243+
a positive real number, by default None.
244+
h0_beta : float, optionanl
245+
a positive real number, by default None.
205246
"""
206-
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
207-
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
247+
if h0_alpha is not None:
248+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
249+
if h0_beta is not None:
250+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
208251
self.reset_hn_params()
209252
return self
210253

@@ -219,18 +262,20 @@ def get_h0_params(self):
219262
"""
220263
return {"h0_alpha":self.h0_alpha, "h0_beta":self.h0_beta}
221264

222-
def set_hn_params(self,hn_alpha,hn_beta):
265+
def set_hn_params(self,hn_alpha=None,hn_beta=None):
223266
"""Set updated values of the hyperparameter of the posterior distribution.
224267
225268
Parameters
226269
----------
227-
hn_alpha : float
228-
a positive real number
229-
hn_beta : float
230-
a positibe real number
270+
hn_alpha : float, optional
271+
a positive real number, by default None.
272+
hn_beta : float, optional
273+
a positive real number, by default None.
231274
"""
232-
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
233-
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
275+
if hn_alpha is not None:
276+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
277+
if hn_beta is not None:
278+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
234279
self.calc_pred_dist()
235280
return self
236281

@@ -245,28 +290,6 @@ def get_hn_params(self):
245290
"""
246291
return {"hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
247292

248-
def reset_hn_params(self):
249-
"""Reset the hyperparameters of the posterior distribution to their initial values.
250-
251-
They are reset to `self.h0_alpha` and `self.h0_beta`.
252-
Note that the parameters of the predictive distribution are also calculated from `self.h0_alpha` and `self.h0_beta`.
253-
"""
254-
self.hn_alpha = self.h0_alpha
255-
self.hn_beta = self.h0_beta
256-
self.calc_pred_dist()
257-
return self
258-
259-
def overwrite_h0_params(self):
260-
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
261-
262-
They are overwritten by `self.hn_alpha` and `self.hn_beta`.
263-
Note that the parameters of the predictive distribution are also calculated from `self.hn_alpha` and `self.hn_beta`.
264-
"""
265-
self.h0_alpha = self.hn_alpha
266-
self.h0_beta = self.hn_beta
267-
self.calc_pred_dist()
268-
return self
269-
270293
def update_posterior(self,x):
271294
"""Update the hyperparameters of the posterior distribution using traning data.
272295
@@ -276,8 +299,14 @@ def update_posterior(self,x):
276299
All the elements must be 0 or 1.
277300
"""
278301
_check.ints_of_01(x,'x',DataFormatError)
279-
self.hn_alpha += np.sum(x==1)
280-
self.hn_beta += np.sum(x==0)
302+
self.hn_alpha += np.count_nonzero(x==1)
303+
self.hn_beta += np.count_nonzero(x==0)
304+
return self
305+
306+
def _update_posterior(self,x):
307+
"""Update opsterior without input check."""
308+
self.hn_alpha += np.count_nonzero(x==1)
309+
self.hn_beta += np.count_nonzero(x==0)
281310
return self
282311

283312
def estimate_params(self,loss="squared",dict_out=False):
@@ -347,7 +376,7 @@ def estimate_interval(self,credibility=0.95):
347376
Parameters
348377
----------
349378
credibility : float, optional
350-
A posterior probability that the interval conitans the paramter, by default 0.95
379+
A posterior probability that the interval conitans the paramter, by default 0.95.
351380
352381
Returns
353382
-------

0 commit comments

Comments
 (0)