Skip to content

Commit cb97d6d

Browse files
Merge pull request #19 from yuta-nakahara/develop-revise_hparams
Revise hparams
2 parents d77f342 + 1aaa6fc commit cb97d6d

File tree

13 files changed

+327
-790
lines changed

13 files changed

+327
-790
lines changed

bayesml/autoregressive/_autoregressive.py

Lines changed: 39 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -114,44 +114,25 @@ def __init__(
114114
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
115115
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
116116
self.rng = np.random.default_rng(seed)
117-
self._H_PARAM_KEYS = {'h_mu_vec','h_lambda_mat','h_alpha','h_beta'}
118-
self._H0_PARAM_KEYS = {'h0_mu_vec','h0_lambda_mat','h0_alpha','h0_beta'}
119-
self._HN_PARAM_KEYS = {'hn_mu_vec','hn_lambda_mat','hn_alpha','hn_beta'}
120117

121-
def set_h_params(self,**kwargs):
118+
def set_h_params(self,h_mu_vec,h_lambda_mat,h_alpha,h_beta):
122119
"""Set the hyperparameters of the prior distribution.
123120
124121
Parameters
125122
----------
126-
**kwargs
127-
a python dictionary {'h_mu_vec':ndarray, 'h_lambda_mat':ndarray, 'h_alpha':float, 'h_beta':float} or
128-
{'h0_mu_vec':ndarray, 'h0_lambda_mat':ndarray, 'h0_alpha':float, 'h0_beta':float}
129-
or {'hn_mu_vec':ndarray, 'hn_lambda_mat':ndarray, 'hn_alpha':float, 'hn_beta':float}
130-
They are obtained by ``get_h_params()`` of GenModel,
131-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
123+
h_mu_vec : numpy ndarray
124+
a vector of real numbers
125+
h_lambda_mat : numpy ndarray
126+
a positibe definate matrix
127+
h_alpha : float
128+
a positive real number
129+
h_beta : float
130+
a positibe real number
132131
"""
133-
if kwargs.keys() == self._H_PARAM_KEYS:
134-
self.h_mu_vec = _check.float_vec(kwargs['h_mu_vec'],'h_mu_vec',ParameterFormatError)
135-
self.h_lambda_mat = _check.pos_def_sym_mat(kwargs['h_lambda_mat'],'h_lambda_mat',ParameterFormatError)
136-
self.h_alpha = _check.pos_float(kwargs['h_alpha'],'h_alpha',ParameterFormatError)
137-
self.h_beta = _check.pos_float(kwargs['h_beta'],'h_beta',ParameterFormatError)
138-
elif kwargs.keys() == self._H0_PARAM_KEYS:
139-
self.h_mu_vec = _check.float_vec(kwargs['h0_mu_vec'],'h_mu_vec',ParameterFormatError)
140-
self.h_lambda_mat = _check.pos_def_sym_mat(kwargs['h0_lambda_mat'],'h_lambda_mat',ParameterFormatError)
141-
self.h_alpha = _check.pos_float(kwargs['h0_alpha'],'h_alpha',ParameterFormatError)
142-
self.h_beta = _check.pos_float(kwargs['h0_beta'],'h_beta',ParameterFormatError)
143-
elif kwargs.keys() == self._HN_PARAM_KEYS:
144-
self.h_mu_vec = _check.float_vec(kwargs['hn_mu_vec'],'h_mu_vec',ParameterFormatError)
145-
self.h_lambda_mat = _check.pos_def_sym_mat(kwargs['hn_lambda_mat'],'h_lambda_mat',ParameterFormatError)
146-
self.h_alpha = _check.pos_float(kwargs['hn_alpha'],'h_alpha',ParameterFormatError)
147-
self.h_beta = _check.pos_float(kwargs['hn_beta'],'h_beta',ParameterFormatError)
148-
else:
149-
raise(ParameterFormatError(
150-
"The input of this function must be a python dictionary with keys:"
151-
+str(self._H_PARAM_KEYS)+" or "
152-
+str(self._H0_PARAM_KEYS)+" or "
153-
+str(self._HN_PARAM_KEYS)+".")
154-
)
132+
self.h_mu_vec = _check.float_vec(h_mu_vec,'h_mu_vec',ParameterFormatError)
133+
self.h_lambda_mat = _check.pos_def_sym_mat(h_lambda_mat,'h_lambda_mat',ParameterFormatError)
134+
self.h_alpha = _check.pos_float(h_alpha,'h_alpha',ParameterFormatError)
135+
self.h_beta = _check.pos_float(h_beta,'h_beta',ParameterFormatError)
155136

156137
if (self.h_mu_vec.shape[0] != self.h_lambda_mat.shape[0]):
157138
raise(ParameterFormatError(
@@ -397,47 +378,27 @@ def __init__(
397378
self.p_lambda = self.hn_alpha / self.hn_beta / (1.0 + _explanatory_vec @ np.linalg.solve(self.hn_lambda_mat,_explanatory_vec))
398379
self.p_nu = 2.0 * self.hn_alpha
399380

400-
self._H_PARAM_KEYS = {'h_mu_vec','h_lambda_mat','h_alpha','h_beta'}
401-
self._H0_PARAM_KEYS = {'h0_mu_vec','h0_lambda_mat','h0_alpha','h0_beta'}
402-
self._HN_PARAM_KEYS = {'hn_mu_vec','hn_lambda_mat','hn_alpha','hn_beta'}
403-
404-
def set_h0_params(self,**kwargs):
381+
def set_h0_params(self,h0_mu_vec,h0_lambda_mat,h0_alpha,h0_beta):
405382
"""Set initial values of the hyperparameter of the posterior distribution.
406383
407384
Note that the parameters of the predictive distribution are also calculated from
408385
``self.h0_mu_vec``, ``slef.h0_lambda_mat``, ``self.h0_alpha`` and ``self.h0_beta``.
409386
410387
Parameters
411388
----------
412-
**kwargs
413-
a python dictionary {'h_mu_vec':ndarray, 'h_lambda_mat':ndarray, 'h_alpha':float, 'h_beta':float} or
414-
{'h0_mu_vec':ndarray, 'h0_lambda_mat':ndarray, 'h0_alpha':float, 'h0_beta':float}
415-
or {'hn_mu_vec':ndarray, 'hn_lambda_mat':ndarray, 'hn_alpha':float, 'hn_beta':float}
416-
They are obtained by ``get_h_params()`` of GenModel,
417-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
389+
h0_mu_vec : numpy ndarray
390+
a vector of real numbers
391+
h0_lambda_mat : numpy ndarray
392+
a positibe definate matrix
393+
h0_alpha : float
394+
a positive real number
395+
h0_beta : float
396+
a positibe real number
418397
"""
419-
if kwargs.keys() == self._H_PARAM_KEYS:
420-
self.h0_mu_vec = _check.float_vec(kwargs['h_mu_vec'],'h0_mu_vec',ParameterFormatError)
421-
self.h0_lambda_mat = _check.pos_def_sym_mat(kwargs['h_lambda_mat'],'h0_lambda_mat',ParameterFormatError)
422-
self.h0_alpha = _check.pos_float(kwargs['h_alpha'],'h0_alpha',ParameterFormatError)
423-
self.h0_beta = _check.pos_float(kwargs['h_beta'],'h0_beta',ParameterFormatError)
424-
elif kwargs.keys() == self._H0_PARAM_KEYS:
425-
self.h0_mu_vec = _check.float_vec(kwargs['h0_mu_vec'],'h0_mu_vec',ParameterFormatError)
426-
self.h0_lambda_mat = _check.pos_def_sym_mat(kwargs['h0_lambda_mat'],'h0_lambda_mat',ParameterFormatError)
427-
self.h0_alpha = _check.pos_float(kwargs['h0_alpha'],'h0_alpha',ParameterFormatError)
428-
self.h0_beta = _check.pos_float(kwargs['h0_beta'],'h0_beta',ParameterFormatError)
429-
elif kwargs.keys() == self._HN_PARAM_KEYS:
430-
self.h0_mu_vec = _check.float_vec(kwargs['hn_mu_vec'],'h0_mu_vec',ParameterFormatError)
431-
self.h0_lambda_mat = _check.pos_def_sym_mat(kwargs['hn_lambda_mat'],'h0_lambda_mat',ParameterFormatError)
432-
self.h0_alpha = _check.pos_float(kwargs['hn_alpha'],'h0_alpha',ParameterFormatError)
433-
self.h0_beta = _check.pos_float(kwargs['hn_beta'],'h0_beta',ParameterFormatError)
434-
else:
435-
raise(ParameterFormatError(
436-
"The input of this function must be a python dictionary with keys:"
437-
+str(self._H_PARAM_KEYS)+" or "
438-
+str(self._H0_PARAM_KEYS)+" or "
439-
+str(self._HN_PARAM_KEYS)+".")
440-
)
398+
self.h0_mu_vec = _check.float_vec(h0_mu_vec,'h0_mu_vec',ParameterFormatError)
399+
self.h0_lambda_mat = _check.pos_def_sym_mat(h0_lambda_mat,'h0_lambda_mat',ParameterFormatError)
400+
self.h0_alpha = _check.pos_float(h0_alpha,'h0_alpha',ParameterFormatError)
401+
self.h0_beta = _check.pos_float(h0_beta,'h0_beta',ParameterFormatError)
441402

442403
self.degree = self.h0_mu_vec.shape[0]-1
443404
if (self.h0_mu_vec.shape[0] != self.h0_lambda_mat.shape[0]):
@@ -472,43 +433,27 @@ def get_hn_params(self):
472433
"""
473434
return {"hn_mu_vec":self.hn_mu_vec, "hn_lambda_mat":self.hn_lambda_mat, "hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
474435

475-
def set_hn_params(self,**kwargs):
436+
def set_hn_params(self,hn_mu_vec,hn_lambda_mat,hn_alpha,hn_beta):
476437
"""Set updated values of the hyperparameter of the posterior distribution.
477438
478439
Note that the parameters of the predictive distribution are also calculated from
479440
``self.hn_mu_vec``, ``slef.hn_lambda_mat``, ``self.hn_alpha`` and ``self.hn_beta``.
480441
481442
Parameters
482443
----------
483-
**kwargs
484-
a python dictionary {'h_mu_vec':ndarray, 'h_lambda_mat':ndarray, 'h_alpha':float, 'h_beta':float} or
485-
{'h0_mu_vec':ndarray, 'h0_lambda_mat':ndarray, 'h0_alpha':float, 'h0_beta':float}
486-
or {'hn_mu_vec':ndarray, 'hn_lambda_mat':ndarray, 'hn_alpha':float, 'hn_beta':float}
487-
They are obtained by ``get_h_params()`` of GenModel,
488-
``get_h0_params`` of LearnModel or ``get_hn_params`` of LearnModel.
444+
hn_mu_vec : numpy ndarray
445+
a vector of real numbers
446+
hn_lambda_mat : numpy ndarray
447+
a positibe definate matrix
448+
hn_alpha : float
449+
a positive real number
450+
hn_beta : float
451+
a positibe real number
489452
"""
490-
if kwargs.keys() == self._H_PARAM_KEYS:
491-
self.hn_mu_vec = _check.float_vec(kwargs['h_mu_vec'],'hn_mu_vec',ParameterFormatError)
492-
self.hn_lambda_mat = _check.pos_def_sym_mat(kwargs['h_lambda_mat'],'hn_lambda_mat',ParameterFormatError)
493-
self.hn_alpha = _check.pos_float(kwargs['h_alpha'],'hn_alpha',ParameterFormatError)
494-
self.hn_beta = _check.pos_float(kwargs['h_beta'],'hn_beta',ParameterFormatError)
495-
elif kwargs.keys() == self._H0_PARAM_KEYS:
496-
self.hn_mu_vec = _check.float_vec(kwargs['h0_mu_vec'],'hn_mu_vec',ParameterFormatError)
497-
self.hn_lambda_mat = _check.pos_def_sym_mat(kwargs['h0_lambda_mat'],'hn_lambda_mat',ParameterFormatError)
498-
self.hn_alpha = _check.pos_float(kwargs['h0_alpha'],'hn_alpha',ParameterFormatError)
499-
self.hn_beta = _check.pos_float(kwargs['h0_beta'],'hn_beta',ParameterFormatError)
500-
elif kwargs.keys() == self._HN_PARAM_KEYS:
501-
self.hn_mu_vec = _check.float_vec(kwargs['hn_mu_vec'],'hn_mu_vec',ParameterFormatError)
502-
self.hn_lambda_mat = _check.pos_def_sym_mat(kwargs['hn_lambda_mat'],'hn_lambda_mat',ParameterFormatError)
503-
self.hn_alpha = _check.pos_float(kwargs['hn_alpha'],'hn_alpha',ParameterFormatError)
504-
self.hn_beta = _check.pos_float(kwargs['hn_beta'],'hn_beta',ParameterFormatError)
505-
else:
506-
raise(ParameterFormatError(
507-
"The input of this function must be a python dictionary with keys:"
508-
+str(self._H_PARAM_KEYS)+" or "
509-
+str(self._H0_PARAM_KEYS)+" or "
510-
+str(self._HN_PARAM_KEYS)+".")
511-
)
453+
self.hn_mu_vec = _check.float_vec(hn_mu_vec,'hn_mu_vec',ParameterFormatError)
454+
self.hn_lambda_mat = _check.pos_def_sym_mat(hn_lambda_mat,'hn_lambda_mat',ParameterFormatError)
455+
self.hn_alpha = _check.pos_float(hn_alpha,'hn_alpha',ParameterFormatError)
456+
self.hn_beta = _check.pos_float(hn_beta,'hn_beta',ParameterFormatError)
512457

513458
self.degree = self.hn_mu_vec.shape[0]-1
514459
if (self.hn_mu_vec.shape[0] != self.hn_lambda_mat.shape[0]):

bayesml/base.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_h_params(self):
1717
def save_h_params(self,filename):
1818
"""Save the hyperparameters using python ``pickle`` module.
1919
20-
They are saved as a python dictionary obtained by ``get_h_params()``.
20+
They are saved as a python dictionary obtained by ``GenModel.get_h_params()``.
2121
2222
Parameters
2323
----------
@@ -43,8 +43,8 @@ def load_h_params(self,filename):
4343
filename : str
4444
The filename (including a extention like .pkl) to be loaded.
4545
It must be a pickled python dictionary obtained by
46-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
47-
or ``save_hn_params()`` of LearnModel.
46+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
47+
or ``LearnModel.save_hn_params()``.
4848
4949
Warnings
5050
--------
@@ -57,13 +57,13 @@ def load_h_params(self,filename):
5757
with open(filename, 'rb') as f:
5858
tmp_h_params = pickle.load(f)
5959
if type(tmp_h_params) is dict:
60-
self.set_h_params(**tmp_h_params)
60+
self.set_h_params(*tmp_h_params.values())
6161
return
6262

6363
raise(ParameterFormatError(
64-
filename+" must be a pickled python dictionary with "
65-
+str(self.get_h_params().keys())
66-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
64+
filename+" must be a pickled python dictionary obtained by "
65+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
66+
+'or ``LearnModel.save_hn_params()``.')
6767
)
6868

6969
@abstractmethod
@@ -81,7 +81,7 @@ def get_params(self):
8181
def save_params(self,filename):
8282
"""Save the parameters using python ``pickle`` module.
8383
84-
They are saved as a pickled python dictionary obtained by ``get_params()``.
84+
They are saved as a pickled python dictionary obtained by ``GenModel.get_params()``.
8585
8686
Parameters
8787
----------
@@ -106,7 +106,7 @@ def load_params(self,filename):
106106
----------
107107
filename : str
108108
The filename (including a extention like .pkl) to be loaded.
109-
It must be a pickled python dictionary with keys obtained by ``get_params().keys()``.
109+
It must be a pickled python dictionary obtained by ``GenModel.save_params()``.
110110
111111
Warnings
112112
--------
@@ -119,11 +119,10 @@ def load_params(self,filename):
119119
with open(filename, 'rb') as f:
120120
params = pickle.load(f)
121121
if type(params) is dict:
122-
if params.keys() == self.get_params().keys():
123-
self.set_params(**params)
124-
return
122+
self.set_params(*params.values())
123+
return
125124

126-
raise(ParameterFormatError(filename+" must be a pickled python dictionary with "+str(self.get_params().keys())))
125+
raise(ParameterFormatError(filename+" must be a pickled python dictionary obtained by ``GenModel.save_params()``"))
127126

128127
@abstractmethod
129128
def gen_sample(self):
@@ -149,7 +148,7 @@ def get_h0_params(self):
149148
def save_h0_params(self,filename):
150149
"""Save the hyperparameters using python ``pickle`` module.
151150
152-
They are saved as a pickled python dictionary obtained by ``get_h0_params()``.
151+
They are saved as a pickled python dictionary obtained by ``LearnModel.get_h0_params()``.
153152
154153
Parameters
155154
----------
@@ -175,8 +174,8 @@ def load_h0_params(self,filename):
175174
filename : str
176175
The filename (including a extention like .pkl) to be loaded.
177176
It must be a pickled python dictionary obtained by
178-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
179-
or ``save_hn_params()`` of LearnModel.
177+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
178+
or ``LearnModel.save_hn_params()``.
180179
181180
Warnings
182181
--------
@@ -189,13 +188,13 @@ def load_h0_params(self,filename):
189188
with open(filename, 'rb') as f:
190189
tmp_h_params = pickle.load(f)
191190
if type(tmp_h_params) is dict:
192-
self.set_h0_params(**tmp_h_params)
191+
self.set_h0_params(*tmp_h_params.values())
193192
return
194193

195194
raise(ParameterFormatError(
196-
filename+" must be a pickled python dictionary with "
197-
+str(self.get_h0_params().keys())
198-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
195+
filename+" must be a pickled python dictionary obtained by "
196+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
197+
+'or ``LearnModel.save_hn_params()``.')
199198
)
200199

201200
@abstractmethod
@@ -209,7 +208,7 @@ def get_hn_params(self):
209208
def save_hn_params(self,filename):
210209
"""Save the hyperparameters using python ``pickle`` module.
211210
212-
They are saved as a pickled python dictionary obtained by ``get_hn_params()``.
211+
They are saved as a pickled python dictionary obtained by ``LearnModel.get_hn_params()``.
213212
214213
Parameters
215214
----------
@@ -235,8 +234,8 @@ def load_hn_params(self,filename):
235234
filename : str
236235
The filename (including a extention like .pkl) to be loaded.
237236
It must be a pickled python dictionary obtained by
238-
``save_h_params()`` of GenModel, ``save_h0_params()`` of LearnModel
239-
or ``save_hn_params()`` of LearnModel.
237+
``GenModel.save_h_params()``, ``LearnModel.save_h0_params()``
238+
or ``LearnModel.save_hn_params()``.
240239
241240
Warnings
242241
--------
@@ -249,13 +248,13 @@ def load_hn_params(self,filename):
249248
with open(filename, 'rb') as f:
250249
tmp_h_params = pickle.load(f)
251250
if type(tmp_h_params) is dict:
252-
self.set_hn_params(**tmp_h_params)
251+
self.set_hn_params(*tmp_h_params.values())
253252
return
254253

255254
raise(ParameterFormatError(
256-
filename+" must be a pickled python dictionary with "
257-
+str(self.get_hn_params().keys())
258-
+", where `h_`, `h0_`, and `hn_` can be replaced to each other.")
255+
filename+" must be a pickled python dictionary obtained by "
256+
+"``GenModel.save_h_params()``, ``LearnModel.save_h0_params()`` "
257+
+'or ``LearnModel.save_hn_params()``.')
259258
)
260259

261260
@abstractmethod

0 commit comments

Comments
 (0)