@@ -32,37 +32,19 @@ def __init__(self,*,theta=0.5,h_alpha=0.5,h_beta=0.5,seed=None):
32
32
self .h_alpha = _check .pos_float (h_alpha ,'h_alpha' ,ParameterFormatError )
33
33
self .h_beta = _check .pos_float (h_beta ,'h_beta' ,ParameterFormatError )
34
34
self .rng = np .random .default_rng (seed )
35
- self ._H_PARAM_KEYS = {'h_alpha' ,'h_beta' }
36
- self ._H0_PARAM_KEYS = {'h0_alpha' ,'h0_beta' }
37
- self ._HN_PARAM_KEYS = {'hn_alpha' ,'hn_beta' }
38
35
39
- def set_h_params (self ,** kwargs ):
36
+ def set_h_params (self ,h_alpha , h_beta ):
40
37
"""Set the hyperparameters of the prior distribution.
41
38
42
39
Parameters
43
40
----------
44
- **kwargs
45
- a python dictionary {'h_alpha':float, 'h_beta':float} or
46
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
47
- They are obtained by ``get_h_params()`` of GenModel,
48
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
41
+ h_alpha : float
42
+ a positive real number
43
+ h_beta : float, optional
44
+ a positibe real number
49
45
"""
50
- if kwargs .keys () == self ._H_PARAM_KEYS :
51
- self .h_alpha = _check .pos_float (kwargs ['h_alpha' ],'h_alpha' ,ParameterFormatError )
52
- self .h_beta = _check .pos_float (kwargs ['h_beta' ],'h_beta' ,ParameterFormatError )
53
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
54
- self .h_alpha = _check .pos_float (kwargs ['h0_alpha' ],'h_alpha' ,ParameterFormatError )
55
- self .h_beta = _check .pos_float (kwargs ['h0_beta' ],'h_beta' ,ParameterFormatError )
56
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
57
- self .h_alpha = _check .pos_float (kwargs ['hn_alpha' ],'h_alpha' ,ParameterFormatError )
58
- self .h_beta = _check .pos_float (kwargs ['hn_beta' ],'h_beta' ,ParameterFormatError )
59
- else :
60
- raise (ParameterFormatError (
61
- "The input of this function must be a python dictionary with keys:"
62
- + str (self ._H_PARAM_KEYS )+ " or "
63
- + str (self ._H0_PARAM_KEYS )+ " or "
64
- + str (self ._HN_PARAM_KEYS )+ "." )
65
- )
46
+ self .h_alpha = _check .pos_float (h_alpha ,'h_alpha' ,ParameterFormatError )
47
+ self .h_beta = _check .pos_float (h_beta ,'h_beta' ,ParameterFormatError )
66
48
67
49
def get_h_params (self ):
68
50
"""Get the hyperparameters of the prior distribution.
@@ -207,37 +189,19 @@ def __init__(self,h0_alpha=0.5,h0_beta=0.5):
207
189
self .hn_alpha = self .h0_alpha
208
190
self .hn_beta = self .h0_beta
209
191
self .p_theta = self .hn_alpha / (self .hn_alpha + self .hn_beta )
210
- self ._H_PARAM_KEYS = {'h_alpha' ,'h_beta' }
211
- self ._H0_PARAM_KEYS = {'h0_alpha' ,'h0_beta' }
212
- self ._HN_PARAM_KEYS = {'hn_alpha' ,'hn_beta' }
213
192
214
- def set_h0_params (self ,** kwargs ):
193
+ def set_h0_params (self ,h0_alpha , h0_beta ):
215
194
"""Set initial values of the hyperparameter of the posterior distribution.
216
195
217
196
Parameters
218
197
----------
219
- **kwargs
220
- a python dictionary {'h_alpha':float, 'h_beta':float} or
221
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
222
- They are obtained by ``get_h_params()`` of GenModel,
223
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
198
+ h0_alpha : float
199
+ a positive real number
200
+ h0_beta : float
201
+ a positibe real number
224
202
"""
225
- if kwargs .keys () == self ._H_PARAM_KEYS :
226
- self .h0_alpha = _check .pos_float (kwargs ['h_alpha' ],'h0_alpha' ,ParameterFormatError )
227
- self .h0_beta = _check .pos_float (kwargs ['h_beta' ],'h0_beta' ,ParameterFormatError )
228
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
229
- self .h0_alpha = _check .pos_float (kwargs ['h0_alpha' ],'h0_alpha' ,ParameterFormatError )
230
- self .h0_beta = _check .pos_float (kwargs ['h0_beta' ],'h0_beta' ,ParameterFormatError )
231
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
232
- self .h0_alpha = _check .pos_float (kwargs ['hn_alpha' ],'h0_alpha' ,ParameterFormatError )
233
- self .h0_beta = _check .pos_float (kwargs ['hn_beta' ],'h0_beta' ,ParameterFormatError )
234
- else :
235
- raise (ParameterFormatError (
236
- "The input of this function must be a python dictionary with keys:"
237
- + str (self ._H_PARAM_KEYS )+ " or "
238
- + str (self ._H0_PARAM_KEYS )+ " or "
239
- + str (self ._HN_PARAM_KEYS )+ "." )
240
- )
203
+ self .h0_alpha = _check .pos_float (h0_alpha ,'h0_alpha' ,ParameterFormatError )
204
+ self .h0_beta = _check .pos_float (h0_beta ,'h0_beta' ,ParameterFormatError )
241
205
self .reset_hn_params ()
242
206
243
207
def get_h0_params (self ):
@@ -251,33 +215,18 @@ def get_h0_params(self):
251
215
"""
252
216
return {"h0_alpha" :self .h0_alpha , "h0_beta" :self .h0_beta }
253
217
254
- def set_hn_params (self ,** kwargs ):
218
+ def set_hn_params (self ,hn_alpha , hn_beta ):
255
219
"""Set updated values of the hyperparameter of the posterior distribution.
256
220
257
221
Parameters
258
222
----------
259
- **kwargs
260
- a python dictionary {'h_alpha':float, 'h_beta':float} or
261
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
262
- They are obtained by ``get_h_params()`` of GenModel,
263
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
223
+ h0_alpha : float
224
+ a positive real number
225
+ h0_beta : float
226
+ a positibe real number
264
227
"""
265
- if kwargs .keys () == self ._H_PARAM_KEYS :
266
- self .hn_alpha = _check .pos_float (kwargs ['h_alpha' ],'hn_alpha' ,ParameterFormatError )
267
- self .hn_beta = _check .pos_float (kwargs ['h_beta' ],'hn_beta' ,ParameterFormatError )
268
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
269
- self .hn_alpha = _check .pos_float (kwargs ['h0_alpha' ],'hn_alpha' ,ParameterFormatError )
270
- self .hn_beta = _check .pos_float (kwargs ['h0_beta' ],'hn_beta' ,ParameterFormatError )
271
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
272
- self .hn_alpha = _check .pos_float (kwargs ['hn_alpha' ],'hn_alpha' ,ParameterFormatError )
273
- self .hn_beta = _check .pos_float (kwargs ['hn_beta' ],'hn_beta' ,ParameterFormatError )
274
- else :
275
- raise (ParameterFormatError (
276
- "The input of this function must be a python dictionary with keys:"
277
- + str (self ._H_PARAM_KEYS )+ " or "
278
- + str (self ._H0_PARAM_KEYS )+ " or "
279
- + str (self ._HN_PARAM_KEYS )+ "." )
280
- )
228
+ self .hn_alpha = _check .pos_float (hn_alpha ,'hn_alpha' ,ParameterFormatError )
229
+ self .hn_beta = _check .pos_float (hn_beta ,'hn_beta' ,ParameterFormatError )
281
230
self .calc_pred_dist ()
282
231
283
232
def get_hn_params (self ):
0 commit comments