@@ -36,37 +36,19 @@ def __init__(self,*,lambda_=1.0,h_alpha=1.0,h_beta=1.0,seed=None):
36
36
self .h_alpha = _check .pos_float (h_alpha ,'h_alpha' ,ParameterFormatError )
37
37
self .h_beta = _check .pos_float (h_beta ,'h_beta' ,ParameterFormatError )
38
38
self .rng = np .random .default_rng (seed )
39
- self ._H_PARAM_KEYS = {'h_alpha' ,'h_beta' }
40
- self ._H0_PARAM_KEYS = {'h0_alpha' ,'h0_beta' }
41
- self ._HN_PARAM_KEYS = {'hn_alpha' ,'hn_beta' }
42
39
43
- def set_h_params (self ,** kwargs ):
40
+ def set_h_params (self ,h_alpha , h_beta ):
44
41
"""Set the hyperparameters of the prior distribution.
45
42
46
43
Parameters
47
44
----------
48
- **kwargs
49
- a python dictionary {'h_alpha':float, 'h_beta':float} or
50
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
51
- They are obtained by ``get_h_params()`` of GenModel,
52
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
45
+ h_alpha : float
46
+ a positive real number
47
+ h_beta : float
48
+ a positive real number
53
49
"""
54
- if kwargs .keys () == self ._H_PARAM_KEYS :
55
- self .h_alpha = _check .pos_float (kwargs ['h_alpha' ],'h_alpha' ,ParameterFormatError )
56
- self .h_beta = _check .pos_float (kwargs ['h_beta' ],'h_beta' ,ParameterFormatError )
57
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
58
- self .h_alpha = _check .pos_float (kwargs ['h0_alpha' ],'h_alpha' ,ParameterFormatError )
59
- self .h_beta = _check .pos_float (kwargs ['h0_beta' ],'h_beta' ,ParameterFormatError )
60
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
61
- self .h_alpha = _check .pos_float (kwargs ['hn_alpha' ],'h_alpha' ,ParameterFormatError )
62
- self .h_beta = _check .pos_float (kwargs ['hn_beta' ],'h_beta' ,ParameterFormatError )
63
- else :
64
- raise (ParameterFormatError (
65
- "The input of this function must be a python dictionary with keys:"
66
- + str (self ._H_PARAM_KEYS )+ " or "
67
- + str (self ._H0_PARAM_KEYS )+ " or "
68
- + str (self ._HN_PARAM_KEYS )+ "." )
69
- )
50
+ self .h_alpha = _check .pos_float (h_alpha ,'h_alpha' ,ParameterFormatError )
51
+ self .h_beta = _check .pos_float (h_beta ,'h_beta' ,ParameterFormatError )
70
52
71
53
def get_h_params (self ):
72
54
"""Get the hyperparameters of the prior distribution.
@@ -212,37 +194,19 @@ def __init__(self,h0_alpha=2.0, h0_beta=1.0):
212
194
self .hn_beta = self .h0_beta
213
195
self .p_kappa = self .hn_alpha
214
196
self .p_lambda = self .hn_beta
215
- self ._H_PARAM_KEYS = {'h_alpha' , 'h_beta' }
216
- self ._H0_PARAM_KEYS = {'h0_alpha' , 'h0_beta' }
217
- self ._HN_PARAM_KEYS = {'hn_alpha' , 'hn_beta' }
218
197
219
- def set_h0_params (self ,** kwargs ):
198
+ def set_h0_params (self ,h0_alpha , h0_beta ):
220
199
"""Set initial values of the hyperparameter of the posterior distribution.
221
200
222
201
Parameters
223
202
----------
224
- **kwargs
225
- a python dictionary {'h_alpha':float, 'h_beta':float} or
226
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
227
- They are obtained by ``get_h_params()`` of GenModel,
228
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
203
+ h0_alpha : float
204
+ a positive real number
205
+ h0_beta : float
206
+ a positibe real number
229
207
"""
230
- if kwargs .keys () == self ._H_PARAM_KEYS :
231
- self .h0_alpha = _check .pos_float (kwargs ['h_alpha' ],'h0_alpha' ,ParameterFormatError )
232
- self .h0_beta = _check .pos_float (kwargs ['h_beta' ],'h0_beta' ,ParameterFormatError )
233
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
234
- self .h0_alpha = _check .pos_float (kwargs ['h0_alpha' ],'h0_alpha' ,ParameterFormatError )
235
- self .h0_beta = _check .pos_float (kwargs ['h0_beta' ],'h0_beta' ,ParameterFormatError )
236
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
237
- self .h0_alpha = _check .pos_float (kwargs ['hn_alpha' ],'h0_alpha' ,ParameterFormatError )
238
- self .h0_beta = _check .pos_float (kwargs ['hn_beta' ],'h0_beta' ,ParameterFormatError )
239
- else :
240
- raise (ParameterFormatError (
241
- "The input of this function must be a python dictionary with keys:"
242
- + str (self ._H_PARAM_KEYS )+ " or "
243
- + str (self ._H0_PARAM_KEYS )+ " or "
244
- + str (self ._HN_PARAM_KEYS )+ "." )
245
- )
208
+ self .h0_alpha = _check .pos_float (h0_alpha , 'h0_alpha' , ParameterFormatError )
209
+ self .h0_beta = _check .pos_float (h0_beta , 'h0_beta' , ParameterFormatError )
246
210
self .reset_hn_params ()
247
211
248
212
def get_h0_params (self ):
@@ -256,33 +220,18 @@ def get_h0_params(self):
256
220
"""
257
221
return {"h0_alpha" :self .h0_alpha , "h0_beta" : self .h0_beta }
258
222
259
- def set_hn_params (self , ** kwargs ):
223
+ def set_hn_params (self ,hn_alpha , hn_beta ):
260
224
"""Set updated values of the hyperparameter of the posterior distribution.
261
225
262
226
Parameters
263
227
----------
264
- **kwargs
265
- a python dictionary {'h_alpha':float, 'h_beta':float} or
266
- {'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
267
- They are obtained by ``get_h_params()`` of GenModel,
268
- ``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
228
+ hn_alpha : float
229
+ a positive real number
230
+ hn_beta : float
231
+ a positibe real number
269
232
"""
270
- if kwargs .keys () == self ._H_PARAM_KEYS :
271
- self .hn_alpha = _check .pos_float (kwargs ['h_alpha' ], 'hn_alpha' , ParameterFormatError )
272
- self .hn_beta = _check .pos_float (kwargs ['h_beta' ], 'hn_beta' , ParameterFormatError )
273
- elif kwargs .keys () == self ._H0_PARAM_KEYS :
274
- self .hn_alpha = _check .pos_float (kwargs ['h0_alpha' ], 'hn_alpha' , ParameterFormatError )
275
- self .hn_beta = _check .pos_float (kwargs ['h0_beta' ], 'hn_beta' , ParameterFormatError )
276
- elif kwargs .keys () == self ._HN_PARAM_KEYS :
277
- self .hn_alpha = _check .pos_float (kwargs ['hn_alpha' ], 'hn_alpha' , ParameterFormatError )
278
- self .hn_beta = _check .pos_float (kwargs ['hn_beta' ], 'hn_beta' , ParameterFormatError )
279
- else :
280
- raise (ParameterFormatError (
281
- "The input of this function must be a python dictionary with keys:"
282
- + str (self ._H_PARAM_KEYS ) + " or "
283
- + str (self ._H0_PARAM_KEYS ) + " or "
284
- + str (self ._HN_PARAM_KEYS ) + "." )
285
- )
233
+ self .hn_alpha = _check .pos_float (hn_alpha , 'hn_alpha' , ParameterFormatError )
234
+ self .hn_beta = _check .pos_float (hn_beta , 'hn_beta' , ParameterFormatError )
286
235
self .calc_pred_dist ()
287
236
288
237
def get_hn_params (self ):
0 commit comments