Skip to content

Commit 242818c

Browse files
committed
revise h_params of normal
1 parent 9cd63b8 commit 242818c

File tree

2 files changed

+41
-98
lines changed

2 files changed

+41
-98
lines changed

bayesml/normal/_normal.py

Lines changed: 39 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -47,44 +47,25 @@ def __init__(self,*,mu=0.0,tau=1.0,h_m=0.0,h_kappa=1.0,h_alpha=1.0,h_beta=1.0,se
4747
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
4848
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
4949
self.rng = np.random.default_rng(seed)
50-
self._H_PARAM_KEYS = {'h_m','h_kappa','h_alpha','h_beta'}
51-
self._H0_PARAM_KEYS = {'h0_m','h0_kappa','h0_alpha','h0_beta'}
52-
self._HN_PARAM_KEYS = {'hn_m','hn_kappa','hn_alpha','hn_beta'}
5350

54-
def set_h_params(self,**kwargs):
51+
def set_h_params(self,h_m,h_kappa,h_alpha,h_beta):
5552
"""Set the hyperparameters of the prior distribution.
5653
5754
Parameters
5855
----------
59-
**kwargs
60-
a python dictionary {'h_m':float, 'h_kappa':float, 'h_alpha':float, 'h_beta':float} or
61-
{'h0_m':float, 'h0_kappa':float, 'h0_alpha':float, 'h0_beta':float} or
62-
{'hn_m':float, 'hn_kappa':float, 'hn_alpha':float, 'hn_beta':float}
63-
They are obtained by ``get_h_params()`` of GenModel,
64-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
56+
h_m : float
57+
a real number
58+
h_kappa : float
59+
a positibe real number
60+
h_alpha : float
61+
a positibe real number
62+
h_beta : float
63+
a positibe real number
6564
"""
66-
if kwargs.keys() == self._H_PARAM_KEYS:
67-
self.h_m = _check.float_(kwargs['h_m'],'h_m',ParameterFormatError)
68-
self.h_kappa = _check.pos_float(kwargs['h_kappa'],'h_alpha',ParameterFormatError)
69-
self.h_alpha = _check.pos_float(kwargs['h_alpha'],'h_alpha',ParameterFormatError)
70-
self.h_beta = _check.pos_float(kwargs['h_beta'],'h_beta',ParameterFormatError)
71-
elif kwargs.keys() == self._H0_PARAM_KEYS:
72-
self.h_m = _check.float_(kwargs['h0_m'],'h_m',ParameterFormatError)
73-
self.h_kappa = _check.pos_float(kwargs['h0_kappa'],'h_alpha',ParameterFormatError)
74-
self.h_alpha = _check.pos_float(kwargs['h0_alpha'],'h_alpha',ParameterFormatError)
75-
self.h_beta = _check.pos_float(kwargs['h0_beta'],'h_beta',ParameterFormatError)
76-
elif kwargs.keys() == self._HN_PARAM_KEYS:
77-
self.h_m = _check.float_(kwargs['hn_m'],'h_m',ParameterFormatError)
78-
self.h_kappa = _check.pos_float(kwargs['hn_kappa'],'h_alpha',ParameterFormatError)
79-
self.h_alpha = _check.pos_float(kwargs['hn_alpha'],'h_alpha',ParameterFormatError)
80-
self.h_beta = _check.pos_float(kwargs['hn_beta'],'h_beta',ParameterFormatError)
81-
else:
82-
raise(ParameterFormatError(
83-
"The input of this function must be a python dictionary with keys:"
84-
+str(self._H_PARAM_KEYS)+" or "
85-
+str(self._H0_PARAM_KEYS)+" or "
86-
+str(self._HN_PARAM_KEYS)+".")
87-
)
65+
self.h_m = _check.float_(h_m,'h_m',ParameterFormatError)
66+
self.h_kappa = _check.pos_float(h_kappa,'h_kappa',ParameterFormatError)
67+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
68+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
8869

8970
def get_h_params(self):
9071
"""Get the hyperparameters of the prior distribution.
@@ -251,45 +232,24 @@ def __init__(self,h0_m=0.0,h0_kappa=1.0,h0_alpha=1.0,h0_beta=1.0):
251232
self.p_nu = 2*self.hn_alpha
252233
self.p_lambda = self.hn_kappa / (self.hn_kappa+1) * self.hn_alpha / self.hn_beta
253234

254-
self._H_PARAM_KEYS = {'h_m','h_kappa','h_alpha','h_beta'}
255-
self._H0_PARAM_KEYS = {'h0_m','h0_kappa','h0_alpha','h0_beta'}
256-
self._HN_PARAM_KEYS = {'hn_m','hn_kappa','hn_alpha','hn_beta'}
257-
258-
def set_h0_params(self,**kwargs):
235+
def set_h0_params(self,h0_m,h0_kappa,h0_alpha,h0_beta):
259236
"""Set the hyperparameters of the prior distribution.
260237
261238
Parameters
262239
----------
263-
**kwargs
264-
a python dictionary {'h_m':float, 'h_kappa':float, 'h_alpha':float, 'h_beta':float} or
265-
{'h0_m':float, 'h0_kappa':float, 'h0_alpha':float, 'h0_beta':float} or
266-
{'hn_m':float, 'hn_kappa':float, 'hn_alpha':float, 'hn_beta':float}
267-
They are obtained by ``get_h_params()`` of GenModel,
268-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
240+
h0_m : float
241+
a real number
242+
h0_kappa : float
243+
a positive real number
244+
h0_alpha : float
245+
a positive real number
246+
h0_beta : float
247+
a positive real number
269248
"""
270-
271-
if kwargs.keys() == self._H_PARAM_KEYS:
272-
self.h0_m = _check.float_(kwargs['h_m'],'h0_m',ParameterFormatError)
273-
self.h0_kappa = _check.pos_float(kwargs['h_kappa'],'h0_alpha',ParameterFormatError)
274-
self.h0_alpha = _check.pos_float(kwargs['h_alpha'],'h0_alpha',ParameterFormatError)
275-
self.h0_beta = _check.pos_float(kwargs['h_beta'],'h0_beta',ParameterFormatError)
276-
elif kwargs.keys() == self._H0_PARAM_KEYS:
277-
self.h0_m = _check.float_(kwargs['h0_m'],'h0_m',ParameterFormatError)
278-
self.h0_kappa = _check.pos_float(kwargs['h0_kappa'],'h0_alpha',ParameterFormatError)
279-
self.h0_alpha = _check.pos_float(kwargs['h0_alpha'],'h0_alpha',ParameterFormatError)
280-
self.h0_beta = _check.pos_float(kwargs['h0_beta'],'h0_beta',ParameterFormatError)
281-
elif kwargs.keys() == self._HN_PARAM_KEYS:
282-
self.h0_m = _check.float_(kwargs['hn_m'],'h0_m',ParameterFormatError)
283-
self.h0_kappa = _check.pos_float(kwargs['hn_kappa'],'h0_alpha',ParameterFormatError)
284-
self.h0_alpha = _check.pos_float(kwargs['hn_alpha'],'h0_alpha',ParameterFormatError)
285-
self.h0_beta = _check.pos_float(kwargs['hn_beta'],'h0_beta',ParameterFormatError)
286-
else:
287-
raise(ParameterFormatError(
288-
"The input of this function must be a python dictionary with keys:"
289-
+str(self._H_PARAM_KEYS)+" or "
290-
+str(self._H0_PARAM_KEYS)+" or "
291-
+str(self._HN_PARAM_KEYS)+".")
292-
)
249+
self.h0_m = _check.float_(h0_m,'h0_m',ParameterFormatError)
250+
self.h0_kappa = _check.pos_float(h0_kappa,'h0_kappa',ParameterFormatError)
251+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
252+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
293253
self.reset_hn_params()
294254

295255
def get_h0_params(self):
@@ -305,41 +265,24 @@ def get_h0_params(self):
305265
"""
306266
return {"h0_m":self.h0_m, "h0_kappa":self.h0_kappa, "h0_alpha":self.h0_alpha, "h0_beta":self.h0_beta}
307267

308-
def set_hn_params(self,**kwargs):
268+
def set_hn_params(self,hn_m,hn_kappa,hn_alpha,hn_beta):
309269
"""Set updated values of the hyperparameter of the posterior distribution.
310270
311271
Parameters
312272
----------
313-
**kwargs
314-
a python dictionary {'h_m':float, 'h_kappa':float, 'h_alpha':float, 'h_beta':float} or
315-
{'h0_m':float, 'h0_kappa':float, 'h0_alpha':float, 'h0_beta':float} or
316-
{'hn_m':float, 'hn_kappa':float, 'hn_alpha':float, 'hn_beta':float}
317-
They are obtained by ``get_h_params()`` of GenModel,
318-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
273+
hn_m : float
274+
a real number
275+
hn_kappa : float
276+
a positive real number
277+
hn_alpha : float
278+
a positive real number
279+
hn_beta : float
280+
a positive real number
319281
"""
320-
321-
if kwargs.keys() == self._H_PARAM_KEYS:
322-
self.hn_m = _check.float_(kwargs['h_m'],'hn_m',ParameterFormatError)
323-
self.hn_kappa = _check.pos_float(kwargs['h_kappa'],'hn_alpha',ParameterFormatError)
324-
self.hn_alpha = _check.pos_float(kwargs['h_alpha'],'hn_alpha',ParameterFormatError)
325-
self.hn_beta = _check.pos_float(kwargs['h_beta'],'hn_beta',ParameterFormatError)
326-
elif kwargs.keys() == self._H0_PARAM_KEYS:
327-
self.hn_m = _check.float_(kwargs['h0_m'],'hn_m',ParameterFormatError)
328-
self.hn_kappa = _check.pos_float(kwargs['h0_kappa'],'hn_alpha',ParameterFormatError)
329-
self.hn_alpha = _check.pos_float(kwargs['h0_alpha'],'hn_alpha',ParameterFormatError)
330-
self.hn_beta = _check.pos_float(kwargs['h0_beta'],'hn_beta',ParameterFormatError)
331-
elif kwargs.keys() == self._HN_PARAM_KEYS:
332-
self.hn_m = _check.float_(kwargs['hn_m'],'hn_m',ParameterFormatError)
333-
self.hn_kappa = _check.pos_float(kwargs['hn_kappa'],'hn_alpha',ParameterFormatError)
334-
self.hn_alpha = _check.pos_float(kwargs['hn_alpha'],'hn_alpha',ParameterFormatError)
335-
self.hn_beta = _check.pos_float(kwargs['hn_beta'],'hn_beta',ParameterFormatError)
336-
else:
337-
raise(ParameterFormatError(
338-
"The input of this function must be a python dictionary with keys:"
339-
+str(self._H_PARAM_KEYS)+" or "
340-
+str(self._H0_PARAM_KEYS)+" or "
341-
+str(self._HN_PARAM_KEYS)+".")
342-
)
282+
self.hn_m = _check.float_(hn_m,'hn_m',ParameterFormatError)
283+
self.hn_kappa = _check.pos_float(hn_kappa,'hn_kappa',ParameterFormatError)
284+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
285+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
343286
self.calc_pred_dist()
344287

345288
def get_hn_params(self):

doc/devdoc/examples/h_params_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from bayesml import poisson as bayesml_model
1+
from bayesml import normal as bayesml_model
22
import numpy as np
33

4-
h0_params = {'h0_alpha':2,'h0_beta':3}
4+
h0_params = {'h0_m':2,'h0_kappa':1,'h0_alpha':3,'h0_beta':4}
55

66
print('Gen to Learn 1')
77
model = bayesml_model.GenModel()

0 commit comments

Comments
 (0)