Skip to content

Commit 7ee15bb

Browse files
Merge pull request #41 from yuta-nakahara/develop-contexttree-GenModel
Develop contexttree gen model
2 parents 0109468 + d89d42b commit 7ee15bb

14 files changed

+4796
-42
lines changed

bayesml/_check.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4+
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
46
import numpy as np
57

68
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -179,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
179181
return val
180182
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
181183

184+
def float_vecs_sum_1(val,val_name,exception_class):
185+
if type(val) is np.ndarray:
186+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
187+
return val.astype(float)
188+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
189+
return val
190+
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))
191+
182192
def int_(val,val_name,exception_class):
183193
if np.issubdtype(type(val),np.integer):
184194
return val
@@ -205,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
205215
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
206216
return val
207217
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
218+
219+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
220+
if val != correct:
221+
message = (f"{val_name} must coincide with {correct_name}: "
222+
+ f"{val_name} = {val}, {correct_name} = {correct}")
223+
raise(exception_class(message))

bayesml/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,21 @@ def load_hn_params(self,filename):
257257
+'or ``LearnModel.save_hn_params()``.')
258258
)
259259

260-
@abstractmethod
261260
def reset_hn_params(self):
262-
pass
263-
264-
@abstractmethod
261+
"""Reset the hyperparameters of the posterior distribution to their initial values.
262+
263+
They are reset to the output of `self.get_h0_params()`.
264+
Note that the parameters of the predictive distribution are also calculated from them.
265+
"""
266+
self.set_hn_params(*self.get_h0_params().values())
267+
265268
def overwrite_h0_params(self):
266-
pass
269+
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
270+
271+
They are overwitten by the output of `self.get_hn_params()`.
272+
Note that the parameters of the predictive distribution are also calculated from them.
273+
"""
274+
self.set_h0_params(*self.get_hn_params().values())
267275

268276
@abstractmethod
269277
def update_posterior(self):

bayesml/contexttree/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._contexttree import GenModel
2+
# from ._contexttree import LearnModel
3+
4+
__all__ = ["GenModel"]#, "LearnModel"]

0 commit comments

Comments
 (0)