Skip to content

Commit b77e0a8

Browse files
committed
revise h_params of exponential
1 parent 19b4fb3 commit b77e0a8

File tree

2 files changed

+23
-74
lines changed

2 files changed

+23
-74
lines changed

bayesml/exponential/_exponential.py

Lines changed: 21 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -36,37 +36,19 @@ def __init__(self,*,lambda_=1.0,h_alpha=1.0,h_beta=1.0,seed=None):
3636
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
3737
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
3838
self.rng = np.random.default_rng(seed)
39-
self._H_PARAM_KEYS = {'h_alpha','h_beta'}
40-
self._H0_PARAM_KEYS = {'h0_alpha','h0_beta'}
41-
self._HN_PARAM_KEYS = {'hn_alpha','hn_beta'}
4239

43-
def set_h_params(self,**kwargs):
40+
def set_h_params(self,h_alpha,h_beta):
4441
"""Set the hyperparameters of the prior distribution.
4542
4643
Parameters
4744
----------
48-
**kwargs
49-
a python dictionary {'h_alpha':float, 'h_beta':float} or
50-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
51-
They are obtained by ``get_h_params()`` of GenModel,
52-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
45+
h_alpha : float
46+
a positive real number
47+
h_beta : float
48+
a positive real number
5349
"""
54-
if kwargs.keys() == self._H_PARAM_KEYS:
55-
self.h_alpha = _check.pos_float(kwargs['h_alpha'],'h_alpha',ParameterFormatError)
56-
self.h_beta = _check.pos_float(kwargs['h_beta'],'h_beta',ParameterFormatError)
57-
elif kwargs.keys() == self._H0_PARAM_KEYS:
58-
self.h_alpha = _check.pos_float(kwargs['h0_alpha'],'h_alpha',ParameterFormatError)
59-
self.h_beta = _check.pos_float(kwargs['h0_beta'],'h_beta',ParameterFormatError)
60-
elif kwargs.keys() == self._HN_PARAM_KEYS:
61-
self.h_alpha = _check.pos_float(kwargs['hn_alpha'],'h_alpha',ParameterFormatError)
62-
self.h_beta = _check.pos_float(kwargs['hn_beta'],'h_beta',ParameterFormatError)
63-
else:
64-
raise(ParameterFormatError(
65-
"The input of this function must be a python dictionary with keys:"
66-
+str(self._H_PARAM_KEYS)+" or "
67-
+str(self._H0_PARAM_KEYS)+" or "
68-
+str(self._HN_PARAM_KEYS)+".")
69-
)
50+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
51+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
7052

7153
def get_h_params(self):
7254
"""Get the hyperparameters of the prior distribution.
@@ -212,37 +194,19 @@ def __init__(self,h0_alpha=2.0, h0_beta=1.0):
212194
self.hn_beta = self.h0_beta
213195
self.p_kappa = self.hn_alpha
214196
self.p_lambda = self.hn_beta
215-
self._H_PARAM_KEYS = {'h_alpha', 'h_beta'}
216-
self._H0_PARAM_KEYS = {'h0_alpha', 'h0_beta'}
217-
self._HN_PARAM_KEYS = {'hn_alpha', 'hn_beta'}
218197

219-
def set_h0_params(self,**kwargs):
198+
def set_h0_params(self,h0_alpha, h0_beta):
220199
"""Set initial values of the hyperparameter of the posterior distribution.
221200
222201
Parameters
223202
----------
224-
**kwargs
225-
a python dictionary {'h_alpha':float, 'h_beta':float} or
226-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
227-
They are obtained by ``get_h_params()`` of GenModel,
228-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
203+
h0_alpha : float
204+
a positive real number
205+
h0_beta : float
206+
a positibe real number
229207
"""
230-
if kwargs.keys() == self._H_PARAM_KEYS:
231-
self.h0_alpha = _check.pos_float(kwargs['h_alpha'],'h0_alpha',ParameterFormatError)
232-
self.h0_beta = _check.pos_float(kwargs['h_beta'],'h0_beta',ParameterFormatError)
233-
elif kwargs.keys() == self._H0_PARAM_KEYS:
234-
self.h0_alpha = _check.pos_float(kwargs['h0_alpha'],'h0_alpha',ParameterFormatError)
235-
self.h0_beta = _check.pos_float(kwargs['h0_beta'],'h0_beta',ParameterFormatError)
236-
elif kwargs.keys() == self._HN_PARAM_KEYS:
237-
self.h0_alpha = _check.pos_float(kwargs['hn_alpha'],'h0_alpha',ParameterFormatError)
238-
self.h0_beta = _check.pos_float(kwargs['hn_beta'],'h0_beta',ParameterFormatError)
239-
else:
240-
raise(ParameterFormatError(
241-
"The input of this function must be a python dictionary with keys:"
242-
+str(self._H_PARAM_KEYS)+" or "
243-
+str(self._H0_PARAM_KEYS)+" or "
244-
+str(self._HN_PARAM_KEYS)+".")
245-
)
208+
self.h0_alpha = _check.pos_float(h0_alpha, 'h0_alpha', ParameterFormatError)
209+
self.h0_beta = _check.pos_float(h0_beta, 'h0_beta', ParameterFormatError)
246210
self.reset_hn_params()
247211

248212
def get_h0_params(self):
@@ -256,33 +220,18 @@ def get_h0_params(self):
256220
"""
257221
return {"h0_alpha":self.h0_alpha, "h0_beta": self.h0_beta}
258222

259-
def set_hn_params(self, **kwargs):
223+
def set_hn_params(self,hn_alpha, hn_beta):
260224
"""Set updated values of the hyperparameter of the posterior distribution.
261225
262226
Parameters
263227
----------
264-
**kwargs
265-
a python dictionary {'h_alpha':float, 'h_beta':float} or
266-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
267-
They are obtained by ``get_h_params()`` of GenModel,
268-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
228+
hn_alpha : float
229+
a positive real number
230+
hn_beta : float
231+
a positibe real number
269232
"""
270-
if kwargs.keys() == self._H_PARAM_KEYS:
271-
self.hn_alpha = _check.pos_float(kwargs['h_alpha'], 'hn_alpha', ParameterFormatError)
272-
self.hn_beta = _check.pos_float(kwargs['h_beta'], 'hn_beta', ParameterFormatError)
273-
elif kwargs.keys() == self._H0_PARAM_KEYS:
274-
self.hn_alpha = _check.pos_float(kwargs['h0_alpha'], 'hn_alpha', ParameterFormatError)
275-
self.hn_beta = _check.pos_float(kwargs['h0_beta'], 'hn_beta', ParameterFormatError)
276-
elif kwargs.keys() == self._HN_PARAM_KEYS:
277-
self.hn_alpha = _check.pos_float(kwargs['hn_alpha'], 'hn_alpha', ParameterFormatError)
278-
self.hn_beta = _check.pos_float(kwargs['hn_beta'], 'hn_beta', ParameterFormatError)
279-
else:
280-
raise (ParameterFormatError(
281-
"The input of this function must be a python dictionary with keys:"
282-
+ str(self._H_PARAM_KEYS) + " or "
283-
+ str(self._H0_PARAM_KEYS) + " or "
284-
+ str(self._HN_PARAM_KEYS) + ".")
285-
)
233+
self.hn_alpha = _check.pos_float(hn_alpha, 'hn_alpha', ParameterFormatError)
234+
self.hn_beta = _check.pos_float(hn_beta, 'hn_beta', ParameterFormatError)
286235
self.calc_pred_dist()
287236

288237
def get_hn_params(self):

doc/devdoc/examples/h_params_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from bayesml import multivariate_normal as bayesml_model
1+
from bayesml import exponential as bayesml_model
22
import numpy as np
33

4-
h0_params = {'h0_m_vec':np.ones(2),'h0_kappa':3,'h0_nu':2,'h0_w_mat':np.eye(2)*2}
4+
h0_params = {'h0_alpha':2,'h0_beta':3}
55

66
print('Gen to Learn 1')
77
model = bayesml_model.GenModel()

0 commit comments

Comments
 (0)