Skip to content

Commit 465eb80

Browse files
committed
revise h_params of categorical
1 parent 276b81a commit 465eb80

File tree

2 files changed

+60
-77
lines changed

2 files changed

+60
-77
lines changed

bayesml/categorical/_categorical.py

Lines changed: 19 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -81,34 +81,16 @@ def __init__(
8181
+" if two or more of them are specified."))
8282

8383
self.rng = np.random.default_rng(seed)
84-
self._H_PARAM_KEYS = {'h_alpha_vec'}
85-
self._H0_PARAM_KEYS = {'h0_alpha_vec'}
86-
self._HN_PARAM_KEYS = {'hn_alpha_vec'}
8784

88-
def set_h_params(self,**kwargs):
85+
def set_h_params(self,h_alpha_vec):
8986
"""Set the hyperparameters of the prior distribution.
9087
9188
Parameters
9289
----------
93-
**kwargs
94-
a python dictionary {'h_alpha_vec':ndarray},
95-
{'h0_alpha_vec':ndarray}, or {'hn_alpha_vec':ndarray}.
96-
They are obtained by ``get_h_params()`` of GenModel,
97-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
90+
h_alpha_vec : numpy ndarray
91+
a vector of positive real numbers
9892
"""
99-
if kwargs.keys() == self._H_PARAM_KEYS:
100-
self.h_alpha_vec = _check.pos_float_vec(kwargs['h_alpha_vec'],'h_alpha_vec',ParameterFormatError)
101-
elif kwargs.keys() == self._H0_PARAM_KEYS:
102-
self.h_alpha_vec = _check.pos_float_vec(kwargs['h0_alpha_vec'],'h_alpha_vec',ParameterFormatError)
103-
elif kwargs.keys() == self._HN_PARAM_KEYS:
104-
self.h_alpha_vec = _check.pos_float_vec(kwargs['hn_alpha_vec'],'h_alpha_vec',ParameterFormatError)
105-
else:
106-
raise(ParameterFormatError(
107-
"The input of this function must be a python dictionary with keys:"
108-
+str(self._H_PARAM_KEYS)+" or "
109-
+str(self._H0_PARAM_KEYS)+" or "
110-
+str(self._HN_PARAM_KEYS)+".")
111-
)
93+
self.h_alpha_vec = _check.pos_float_vec(h_alpha_vec,'h_alpha_vec',ParameterFormatError)
11294

11395
self.degree = self.h_alpha_vec.shape[0]
11496
if self.degree != self.theta_vec.shape[0]:
@@ -264,6 +246,13 @@ class LearnModel(base.Posterior, base.PredictiveMixin):
264246
degree is assumed to be 3.
265247
h0_alpha_vec : numpy.ndarray, optional
266248
a vector of positive real numbers, by default [1/2, 1/2, ... , 1/2]
249+
250+
Attributes
251+
----------
252+
hn_alpha_vec : numpy.ndarray
253+
a vector of positive real numbers
254+
p_theta_vec : numpy.ndarray
255+
a real vector in :math:`[0, 1]^d`
267256
"""
268257
def __init__(self, degree=None, h0_alpha_vec=None):
269258
if degree is not None:
@@ -289,35 +278,15 @@ def __init__(self, degree=None, h0_alpha_vec=None):
289278
self.hn_alpha_vec = np.copy(self.h0_alpha_vec)
290279
self.p_theta_vec = self.hn_alpha_vec / self.hn_alpha_vec.sum()
291280

292-
self._H_PARAM_KEYS = {'h_alpha_vec'}
293-
self._H0_PARAM_KEYS = {'h0_alpha_vec'}
294-
self._HN_PARAM_KEYS = {'hn_alpha_vec'}
295-
296-
def set_h0_params(self,**kwargs):
281+
def set_h0_params(self,h0_alpha_vec):
297282
"""Set the hyperparameters of the prior distribution.
298283
299284
Parameters
300285
----------
301-
**kwargs
302-
a python dictionary {'h_alpha_vec':ndarray},
303-
{'h0_alpha_vec':ndarray}, or {'hn_alpha_vec':ndarray}.
304-
They are obtained by ``get_h_params()`` of GenModel,
305-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
286+
h0_alpha_vec : numpy.ndarray
287+
a vector of positive real numbers
306288
"""
307-
if kwargs.keys() == self._H_PARAM_KEYS:
308-
self.h0_alpha_vec = _check.pos_float_vec(kwargs['h_alpha_vec'],'h0_alpha_vec',ParameterFormatError)
309-
elif kwargs.keys() == self._H0_PARAM_KEYS:
310-
self.h0_alpha_vec = _check.pos_float_vec(kwargs['h0_alpha_vec'],'h0_alpha_vec',ParameterFormatError)
311-
elif kwargs.keys() == self._HN_PARAM_KEYS:
312-
self.h0_alpha_vec = _check.pos_float_vec(kwargs['hn_alpha_vec'],'h0_alpha_vec',ParameterFormatError)
313-
else:
314-
raise(ParameterFormatError(
315-
"The input of this function must be a python dictionary with keys:"
316-
+str(self._H_PARAM_KEYS)+" or "
317-
+str(self._H0_PARAM_KEYS)+" or "
318-
+str(self._HN_PARAM_KEYS)+".")
319-
)
320-
289+
self.h0_alpha_vec = _check.pos_float_vec(h0_alpha_vec,'h0_alpha_vec',ParameterFormatError)
321290
self.degree = self.h0_alpha_vec.shape[0]
322291
self.reset_hn_params()
323292

@@ -331,31 +300,15 @@ def get_h0_params(self):
331300
"""
332301
return {"h0_alpha_vec": self.h0_alpha_vec}
333302

334-
def set_hn_params(self,**kwargs):
303+
def set_hn_params(self,hn_alpha_vec):
335304
"""Set updated values of the hyperparameter of the posterior distribution.
336305
337306
Parameters
338307
----------
339-
**kwargs
340-
a python dictionary {'h_alpha_vec':ndarray},
341-
{'h0_alpha_vec':ndarray}, or {'hn_alpha_vec':ndarray}.
342-
They are obtained by ``get_h_params()`` of GenModel,
343-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
308+
hn_alpha_vec : numpy.ndarray
309+
a vector of positive real numbers
344310
"""
345-
if kwargs.keys() == self._H_PARAM_KEYS:
346-
self.hn_alpha_vec = _check.pos_float_vec(kwargs['h_alpha_vec'],'hn_alpha_vec',ParameterFormatError)
347-
elif kwargs.keys() == self._H0_PARAM_KEYS:
348-
self.hn_alpha_vec = _check.pos_float_vec(kwargs['h0_alpha_vec'],'hn_alpha_vec',ParameterFormatError)
349-
elif kwargs.keys() == self._HN_PARAM_KEYS:
350-
self.hn_alpha_vec = _check.pos_float_vec(kwargs['hn_alpha_vec'],'hn_alpha_vec',ParameterFormatError)
351-
else:
352-
raise(ParameterFormatError(
353-
"The input of this function must be a python dictionary with keys:"
354-
+str(self._H_PARAM_KEYS)+" or "
355-
+str(self._H0_PARAM_KEYS)+" or "
356-
+str(self._HN_PARAM_KEYS)+".")
357-
)
358-
311+
self.hn_alpha_vec = _check.pos_float_vec(hn_alpha_vec,'hn_alpha_vec',ParameterFormatError)
359312
self.degree = self.hn_alpha_vec.shape[0]
360313
self.calc_pred_dist()
361314

doc/devdoc/examples/h_params_test.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
1-
from bayesml import bernoulli
1+
from bayesml import categorical as bayesml_model
2+
import numpy as np
23

3-
model = bernoulli.GenModel(h_alpha=0.5,h_beta=0.3)
4-
params = model.get_params()
5-
model.save_params('tmp.pkl')
6-
print(params)
4+
h0_params = {'h0_alpha_vec':np.ones(5)}
75

8-
model_2 = bernoulli.GenModel(theta=0.9,h_alpha=1.0,h_beta=1.0)
9-
print(model_2.get_params())
10-
# model_2.set_h_params(*h_params.values())
11-
# print(model_2.get_h_params())
6+
print('Gen to Learn 1')
7+
model = bayesml_model.GenModel()
8+
print(model.get_h_params())
9+
model.save_h_params('tmp.pkl')
1210

13-
model_2.load_params('tmp.pkl')
14-
print(model_2.get_params())
11+
model_2 = bayesml_model.LearnModel(**h0_params)
12+
print(model_2.get_h0_params())
13+
model_2.load_h0_params('tmp.pkl')
14+
print(model_2.get_h0_params())
15+
16+
print('Gen to Learn 2')
17+
model = bayesml_model.GenModel()
18+
print(model.get_h_params())
19+
model.save_h_params('tmp.pkl')
20+
21+
model_2 = bayesml_model.LearnModel(**h0_params)
22+
print(model_2.get_hn_params())
23+
model_2.load_hn_params('tmp.pkl')
24+
print(model_2.get_hn_params())
25+
26+
print('Learn to Gen 1')
27+
model_2 = bayesml_model.LearnModel(**h0_params)
28+
print(model_2.get_h0_params())
29+
model_2.save_h0_params('tmp.pkl')
30+
31+
model = bayesml_model.GenModel()
32+
print(model.get_h_params())
33+
model.load_h_params('tmp.pkl')
34+
print(model.get_h_params())
35+
36+
print('Learn to Gen 2')
37+
model_2 = bayesml_model.LearnModel(**h0_params)
38+
print(model_2.get_hn_params())
39+
model_2.save_hn_params('tmp.pkl')
40+
41+
model = bayesml_model.GenModel()
42+
print(model.get_h_params())
43+
model.load_h_params('tmp.pkl')
44+
print(model.get_h_params())

0 commit comments

Comments
 (0)