Skip to content

Commit 9cd63b8

Browse files
committed
revise h_params of poisson
1 parent 465eb80 commit 9cd63b8

File tree

2 files changed

+23
-74
lines changed

2 files changed

+23
-74
lines changed

bayesml/poisson/_poisson.py

Lines changed: 21 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -37,37 +37,19 @@ def __init__(self,*,lambda_=1.0,h_alpha=1.0,h_beta=1.0,seed=None):
3737
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
3838
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
3939
self.rng = np.random.default_rng(seed)
40-
self._H_PARAM_KEYS = {'h_alpha','h_beta'}
41-
self._H0_PARAM_KEYS = {'h0_alpha','h0_beta'}
42-
self._HN_PARAM_KEYS = {'hn_alpha','hn_beta'}
4340

44-
def set_h_params(self,**kwargs):
41+
def set_h_params(self,h_alpha,h_beta):
4542
"""Set the hyperparameters of the prior distribution.
4643
4744
Parameters
4845
----------
49-
**kwargs
50-
a python dictionary {'h_alpha':float, 'h_beta':float} or
51-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
52-
They are obtained by ``get_h_params()`` of GenModel,
53-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
46+
h_alpha : float
47+
a positive real number
48+
h_beta : float
49+
a positibe real number
5450
"""
55-
if kwargs.keys() == self._H_PARAM_KEYS:
56-
self.h_alpha = _check.pos_float(kwargs['h_alpha'],'h_alpha',ParameterFormatError)
57-
self.h_beta = _check.pos_float(kwargs['h_beta'],'h_beta',ParameterFormatError)
58-
elif kwargs.keys() == self._H0_PARAM_KEYS:
59-
self.h_alpha = _check.pos_float(kwargs['h0_alpha'],'h_alpha',ParameterFormatError)
60-
self.h_beta = _check.pos_float(kwargs['h0_beta'],'h_beta',ParameterFormatError)
61-
elif kwargs.keys() == self._HN_PARAM_KEYS:
62-
self.h_alpha = _check.pos_float(kwargs['hn_alpha'],'h_alpha',ParameterFormatError)
63-
self.h_beta = _check.pos_float(kwargs['hn_beta'],'h_beta',ParameterFormatError)
64-
else:
65-
raise(ParameterFormatError(
66-
"The input of this function must be a python dictionary with keys:"
67-
+str(self._H_PARAM_KEYS)+" or "
68-
+str(self._H0_PARAM_KEYS)+" or "
69-
+str(self._HN_PARAM_KEYS)+".")
70-
)
51+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
52+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
7153

7254
def get_h_params(self):
7355
"""Get the hyperparameters of the prior distribution.
@@ -205,37 +187,19 @@ def __init__(self,h0_alpha=1.0,h0_beta=1.0):
205187
self.hn_beta = self.h0_beta
206188
self.p_r = self.hn_alpha
207189
self.p_theta = 1.0 / (1.0+self.hn_beta)
208-
self._H_PARAM_KEYS = {'h_alpha','h_beta'}
209-
self._H0_PARAM_KEYS = {'h0_alpha','h0_beta'}
210-
self._HN_PARAM_KEYS = {'hn_alpha','hn_beta'}
211190

212-
def set_h0_params(self,**kwargs):
191+
def set_h0_params(self,h0_alpha,h0_beta):
213192
"""Set initial values of the hyperparameter of the posterior distribution.
214193
215194
Parameters
216195
----------
217-
**kwargs
218-
a python dictionary {'h_alpha':float, 'h_beta':float} or
219-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
220-
They are obtained by ``get_h_params()`` of GenModel,
221-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
196+
h0_alpha : float
197+
a positive real number
198+
h0_beta : float
199+
a positibe real number
222200
"""
223-
if kwargs.keys() == self._H_PARAM_KEYS:
224-
self.h0_alpha = _check.pos_float(kwargs['h_alpha'],'h0_alpha',ParameterFormatError)
225-
self.h0_beta = _check.pos_float(kwargs['h_beta'],'h0_beta',ParameterFormatError)
226-
elif kwargs.keys() == self._H0_PARAM_KEYS:
227-
self.h0_alpha = _check.pos_float(kwargs['h0_alpha'],'h0_alpha',ParameterFormatError)
228-
self.h0_beta = _check.pos_float(kwargs['h0_beta'],'h0_beta',ParameterFormatError)
229-
elif kwargs.keys() == self._HN_PARAM_KEYS:
230-
self.h0_alpha = _check.pos_float(kwargs['hn_alpha'],'h0_alpha',ParameterFormatError)
231-
self.h0_beta = _check.pos_float(kwargs['hn_beta'],'h0_beta',ParameterFormatError)
232-
else:
233-
raise(ParameterFormatError(
234-
"The input of this function must be a python dictionary with keys:"
235-
+str(self._H_PARAM_KEYS)+" or "
236-
+str(self._H0_PARAM_KEYS)+" or "
237-
+str(self._HN_PARAM_KEYS)+".")
238-
)
201+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
202+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
239203
self.reset_hn_params()
240204

241205
def get_h0_params(self):
@@ -249,33 +213,18 @@ def get_h0_params(self):
249213
"""
250214
return {"h0_alpha":self.h0_alpha, "h0_beta":self.h0_beta}
251215

252-
def set_hn_params(self,**kwargs):
216+
def set_hn_params(self,hn_alpha,hn_beta):
253217
"""Set updated values of the hyperparameter of the posterior distribution.
254218
255219
Parameters
256220
----------
257-
**kwargs
258-
a python dictionary {'h_alpha':float, 'h_beta':float} or
259-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
260-
They are obtained by ``get_h_params()`` of GenModel,
261-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
221+
hn_alpha : float
222+
a positive real number
223+
hn_beta : float
224+
a positibe real number
262225
"""
263-
if kwargs.keys() == self._H_PARAM_KEYS:
264-
self.hn_alpha = _check.pos_float(kwargs['h_alpha'],'hn_alpha',ParameterFormatError)
265-
self.hn_beta = _check.pos_float(kwargs['h_beta'],'hn_beta',ParameterFormatError)
266-
elif kwargs.keys() == self._H0_PARAM_KEYS:
267-
self.hn_alpha = _check.pos_float(kwargs['h0_alpha'],'hn_alpha',ParameterFormatError)
268-
self.hn_beta = _check.pos_float(kwargs['h0_beta'],'hn_beta',ParameterFormatError)
269-
elif kwargs.keys() == self._HN_PARAM_KEYS:
270-
self.hn_alpha = _check.pos_float(kwargs['hn_alpha'],'hn_alpha',ParameterFormatError)
271-
self.hn_beta = _check.pos_float(kwargs['hn_beta'],'hn_beta',ParameterFormatError)
272-
else:
273-
raise(ParameterFormatError(
274-
"The input of this function must be a python dictionary with keys:"
275-
+str(self._H_PARAM_KEYS)+" or "
276-
+str(self._H0_PARAM_KEYS)+" or "
277-
+str(self._HN_PARAM_KEYS)+".")
278-
)
226+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
227+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
279228
self.calc_pred_dist()
280229

281230
def get_hn_params(self):

doc/devdoc/examples/h_params_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from bayesml import categorical as bayesml_model
1+
from bayesml import poisson as bayesml_model
22
import numpy as np
33

4-
h0_params = {'h0_alpha_vec':np.ones(5)}
4+
h0_params = {'h0_alpha':2,'h0_beta':3}
55

66
print('Gen to Learn 1')
77
model = bayesml_model.GenModel()

0 commit comments

Comments
 (0)