Skip to content

Commit 276b81a

Browse files
committed
revise hparams of bernoulli
1 parent d77f342 commit 276b81a

File tree

3 files changed

+61
-99
lines changed

3 files changed

+61
-99
lines changed

bayesml/base.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_h_params(self):
1717
def save_h_params(self,filename):
1818
"""Save the hyperparameters using python ``pickle`` module.
1919
20-
They are saved as a python dictionary obtained by ``get_h_params()``.
20+
They are saved as a python dictionary obtained by ``GenModel.get_h_params()``.
2121
2222
Parameters
2323
----------
@@ -43,8 +43,8 @@ def load_h_params(self,filename):
4343
filename : str
4444
The filename (including a extention like .pkl) to be loaded.
4545
It must be a pickled python dictionary obtained by
46-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
47-
or ``save_hn_params()`` of LearnModel.
46+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
47+
or ``LearnModel.save_hn_params()``.
4848
4949
Warnings
5050
--------
@@ -57,13 +57,13 @@ def load_h_params(self,filename):
5757
with open(filename, 'rb') as f:
5858
tmp_h_params = pickle.load(f)
5959
if type(tmp_h_params) is dict:
60-
self.set_h_params(**tmp_h_params)
60+
self.set_h_params(*tmp_h_params.values())
6161
return
6262

6363
raise(ParameterFormatError(
64-
filename+" must be a pickled python dictionary with "
65-
+str(self.get_h_params().keys())
66-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
64+
filename+" must be a pickled python dictionary obtained by "
65+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
66+
+'or ``LearnModel.save_hn_params()``.')
6767
)
6868

6969
@abstractmethod
@@ -81,7 +81,7 @@ def get_params(self):
8181
def save_params(self,filename):
8282
"""Save the parameters using python ``pickle`` module.
8383
84-
They are saved as a pickled python dictionary obtained by ``get_params()``.
84+
They are saved as a pickled python dictionary obtained by ``GenModel.get_params()``.
8585
8686
Parameters
8787
----------
@@ -106,7 +106,7 @@ def load_params(self,filename):
106106
----------
107107
filename : str
108108
The filename (including a extention like .pkl) to be loaded.
109-
It must be a pickled python dictionary with keys obtained by ``get_params().keys()``.
109+
It must be a pickled python dictionary obtained by ``GenModel.save_params()``.
110110
111111
Warnings
112112
--------
@@ -119,11 +119,10 @@ def load_params(self,filename):
119119
with open(filename, 'rb') as f:
120120
params = pickle.load(f)
121121
if type(params) is dict:
122-
if params.keys() == self.get_params().keys():
123-
self.set_params(**params)
124-
return
122+
self.set_params(*params.values())
123+
return
125124

126-
raise(ParameterFormatError(filename+" must be a pickled python dictionary with "+str(self.get_params().keys())))
125+
raise(ParameterFormatError(filename+" must be a pickled python dictionary obtained by ``GenModel.save_params()``"))
127126

128127
@abstractmethod
129128
def gen_sample(self):
@@ -149,7 +148,7 @@ def get_h0_params(self):
149148
def save_h0_params(self,filename):
150149
"""Save the hyperparameters using python ``pickle`` module.
151150
152-
They are saved as a pickled python dictionary obtained by ``get_h0_params()``.
151+
They are saved as a pickled python dictionary obtained by ``LearnModel.get_h0_params()``.
153152
154153
Parameters
155154
----------
@@ -175,8 +174,8 @@ def load_h0_params(self,filename):
175174
filename : str
176175
The filename (including a extention like .pkl) to be loaded.
177176
It must be a pickled python dictionary obtained by
178-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
179-
or ``save_hn_params()`` of LearnModel.
177+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
178+
or ``LearnModel.save_hn_params()``.
180179
181180
Warnings
182181
--------
@@ -189,13 +188,13 @@ def load_h0_params(self,filename):
189188
with open(filename, 'rb') as f:
190189
tmp_h_params = pickle.load(f)
191190
if type(tmp_h_params) is dict:
192-
self.set_h0_params(**tmp_h_params)
191+
self.set_h0_params(*tmp_h_params.values())
193192
return
194193

195194
raise(ParameterFormatError(
196-
filename+" must be a pickled python dictionary with "
197-
+str(self.get_h0_params().keys())
198-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
195+
filename+" must be a pickled python dictionary obtained by "
196+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
197+
+'or ``LearnModel.save_hn_params()``.')
199198
)
200199

201200
@abstractmethod
@@ -209,7 +208,7 @@ def get_hn_params(self):
209208
def save_hn_params(self,filename):
210209
"""Save the hyperparameters using python ``pickle`` module.
211210
212-
They are saved as a pickled python dictionary obtained by ``get_hn_params()``.
211+
They are saved as a pickled python dictionary obtained by ``LearnModel.get_hn_params()``.
213212
214213
Parameters
215214
----------
@@ -235,8 +234,8 @@ def load_hn_params(self,filename):
235234
filename : str
236235
The filename (including a extention like .pkl) to be loaded.
237236
It must be a pickled python dictionary obtained by
238-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
239-
or ``save_hn_params()`` of LearnModel.
237+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
238+
or ``LearnModel.save_hn_params()``.
240239
241240
Warnings
242241
--------
@@ -249,13 +248,13 @@ def load_hn_params(self,filename):
249248
with open(filename, 'rb') as f:
250249
tmp_h_params = pickle.load(f)
251250
if type(tmp_h_params) is dict:
252-
self.set_hn_params(**tmp_h_params)
251+
self.set_hn_params(*tmp_h_params.values())
253252
return
254253

255254
raise(ParameterFormatError(
256-
filename+" must be a pickled python dictionary with "
257-
+str(self.get_hn_params().keys())
258-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
255+
filename+" must be a pickled python dictionary obtained by "
256+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
257+
+'or ``LearnModel.save_hn_params()``.')
259258
)
260259

261260
@abstractmethod

bayesml/bernoulli/_bernoulli.py

Lines changed: 21 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -32,37 +32,19 @@ def __init__(self,*,theta=0.5,h_alpha=0.5,h_beta=0.5,seed=None):
3232
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
3333
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
3434
self.rng = np.random.default_rng(seed)
35-
self._H_PARAM_KEYS = {'h_alpha','h_beta'}
36-
self._H0_PARAM_KEYS = {'h0_alpha','h0_beta'}
37-
self._HN_PARAM_KEYS = {'hn_alpha','hn_beta'}
3835

39-
def set_h_params(self,**kwargs):
36+
def set_h_params(self,h_alpha,h_beta):
4037
"""Set the hyperparameters of the prior distribution.
4138
4239
Parameters
4340
----------
44-
**kwargs
45-
a python dictionary {'h_alpha':float, 'h_beta':float} or
46-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
47-
They are obtained by ``get_h_params()`` of GenModel,
48-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
41+
h_alpha : float
42+
a positive real number
43+
h_beta : float, optional
44+
a positibe real number
4945
"""
50-
if kwargs.keys() == self._H_PARAM_KEYS:
51-
self.h_alpha = _check.pos_float(kwargs['h_alpha'],'h_alpha',ParameterFormatError)
52-
self.h_beta = _check.pos_float(kwargs['h_beta'],'h_beta',ParameterFormatError)
53-
elif kwargs.keys() == self._H0_PARAM_KEYS:
54-
self.h_alpha = _check.pos_float(kwargs['h0_alpha'],'h_alpha',ParameterFormatError)
55-
self.h_beta = _check.pos_float(kwargs['h0_beta'],'h_beta',ParameterFormatError)
56-
elif kwargs.keys() == self._HN_PARAM_KEYS:
57-
self.h_alpha = _check.pos_float(kwargs['hn_alpha'],'h_alpha',ParameterFormatError)
58-
self.h_beta = _check.pos_float(kwargs['hn_beta'],'h_beta',ParameterFormatError)
59-
else:
60-
raise(ParameterFormatError(
61-
"The input of this function must be a python dictionary with keys:"
62-
+str(self._H_PARAM_KEYS)+" or "
63-
+str(self._H0_PARAM_KEYS)+" or "
64-
+str(self._HN_PARAM_KEYS)+".")
65-
)
46+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
47+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
6648

6749
def get_h_params(self):
6850
"""Get the hyperparameters of the prior distribution.
@@ -207,37 +189,19 @@ def __init__(self,h0_alpha=0.5,h0_beta=0.5):
207189
self.hn_alpha = self.h0_alpha
208190
self.hn_beta = self.h0_beta
209191
self.p_theta = self.hn_alpha / (self.hn_alpha + self.hn_beta)
210-
self._H_PARAM_KEYS = {'h_alpha','h_beta'}
211-
self._H0_PARAM_KEYS = {'h0_alpha','h0_beta'}
212-
self._HN_PARAM_KEYS = {'hn_alpha','hn_beta'}
213192

214-
def set_h0_params(self,**kwargs):
193+
def set_h0_params(self,h0_alpha,h0_beta):
215194
"""Set initial values of the hyperparameter of the posterior distribution.
216195
217196
Parameters
218197
----------
219-
**kwargs
220-
a python dictionary {'h_alpha':float, 'h_beta':float} or
221-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
222-
They are obtained by ``get_h_params()`` of GenModel,
223-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
198+
h0_alpha : float
199+
a positive real number
200+
h0_beta : float
201+
a positibe real number
224202
"""
225-
if kwargs.keys() == self._H_PARAM_KEYS:
226-
self.h0_alpha = _check.pos_float(kwargs['h_alpha'],'h0_alpha',ParameterFormatError)
227-
self.h0_beta = _check.pos_float(kwargs['h_beta'],'h0_beta',ParameterFormatError)
228-
elif kwargs.keys() == self._H0_PARAM_KEYS:
229-
self.h0_alpha = _check.pos_float(kwargs['h0_alpha'],'h0_alpha',ParameterFormatError)
230-
self.h0_beta = _check.pos_float(kwargs['h0_beta'],'h0_beta',ParameterFormatError)
231-
elif kwargs.keys() == self._HN_PARAM_KEYS:
232-
self.h0_alpha = _check.pos_float(kwargs['hn_alpha'],'h0_alpha',ParameterFormatError)
233-
self.h0_beta = _check.pos_float(kwargs['hn_beta'],'h0_beta',ParameterFormatError)
234-
else:
235-
raise(ParameterFormatError(
236-
"The input of this function must be a python dictionary with keys:"
237-
+str(self._H_PARAM_KEYS)+" or "
238-
+str(self._H0_PARAM_KEYS)+" or "
239-
+str(self._HN_PARAM_KEYS)+".")
240-
)
203+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
204+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
241205
self.reset_hn_params()
242206

243207
def get_h0_params(self):
@@ -251,33 +215,18 @@ def get_h0_params(self):
251215
"""
252216
return {"h0_alpha":self.h0_alpha, "h0_beta":self.h0_beta}
253217

254-
def set_hn_params(self,**kwargs):
218+
def set_hn_params(self,hn_alpha,hn_beta):
255219
"""Set updated values of the hyperparameter of the posterior distribution.
256220
257221
Parameters
258222
----------
259-
**kwargs
260-
a python dictionary {'h_alpha':float, 'h_beta':float} or
261-
{'h0_alpha':float, 'h0_beta':float} or {'hn_alpha':float, 'hn_beta':float}
262-
They are obtained by ``get_h_params()`` of GenModel,
263-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
223+
h0_alpha : float
224+
a positive real number
225+
h0_beta : float
226+
a positibe real number
264227
"""
265-
if kwargs.keys() == self._H_PARAM_KEYS:
266-
self.hn_alpha = _check.pos_float(kwargs['h_alpha'],'hn_alpha',ParameterFormatError)
267-
self.hn_beta = _check.pos_float(kwargs['h_beta'],'hn_beta',ParameterFormatError)
268-
elif kwargs.keys() == self._H0_PARAM_KEYS:
269-
self.hn_alpha = _check.pos_float(kwargs['h0_alpha'],'hn_alpha',ParameterFormatError)
270-
self.hn_beta = _check.pos_float(kwargs['h0_beta'],'hn_beta',ParameterFormatError)
271-
elif kwargs.keys() == self._HN_PARAM_KEYS:
272-
self.hn_alpha = _check.pos_float(kwargs['hn_alpha'],'hn_alpha',ParameterFormatError)
273-
self.hn_beta = _check.pos_float(kwargs['hn_beta'],'hn_beta',ParameterFormatError)
274-
else:
275-
raise(ParameterFormatError(
276-
"The input of this function must be a python dictionary with keys:"
277-
+str(self._H_PARAM_KEYS)+" or "
278-
+str(self._H0_PARAM_KEYS)+" or "
279-
+str(self._HN_PARAM_KEYS)+".")
280-
)
228+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
229+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
281230
self.calc_pred_dist()
282231

283232
def get_hn_params(self):

doc/devdoc/examples/h_params_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from bayesml import bernoulli
2+
3+
model = bernoulli.GenModel(h_alpha=0.5,h_beta=0.3)
4+
params = model.get_params()
5+
model.save_params('tmp.pkl')
6+
print(params)
7+
8+
model_2 = bernoulli.GenModel(theta=0.9,h_alpha=1.0,h_beta=1.0)
9+
print(model_2.get_params())
10+
# model_2.set_h_params(*h_params.values())
11+
# print(model_2.get_h_params())
12+
13+
model_2.load_params('tmp.pkl')
14+
print(model_2.get_params())

0 commit comments

Comments
 (0)