Skip to content

Commit 8bfd21a

Browse files
Merge pull request #30 from yuta-nakahara/develop-hiddenmarkovnormal-initgetset
Develop hiddenmarkovnormal initgetset
2 parents 4b83991 + fd91b6b commit 8bfd21a

File tree

9 files changed

+863
-11
lines changed

9 files changed

+863
-11
lines changed

bayesml/_check.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4+
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
46
import numpy as np
57

68
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -51,6 +53,12 @@ def nonneg_int_vec(val,val_name,exception_class):
5153
return val
5254
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)."))
5355

56+
def nonneg_float_vec(val,val_name,exception_class):
57+
if type(val) is np.ndarray:
58+
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and np.all(val>=0):
59+
return val
60+
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray whose dtype is float. Its values must be non-negative (including 0)."))
61+
5462
def int_of_01(val,val_name,exception_class):
5563
if np.issubdtype(type(val),np.integer):
5664
if val == 0 or val ==1:
@@ -171,13 +179,30 @@ def float_vecs(val,val_name,exception_class):
171179
return val
172180
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1."))
173181

174-
def float_vec_sum_1(val,val_name,exception_class):
182+
def pos_float_vecs(val,val_name,exception_class):
175183
if type(val) is np.ndarray:
176-
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
184+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val>0):
177185
return val.astype(float)
178-
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
186+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(val>0.0):
179187
return val
180-
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
188+
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)"))
189+
190+
def float_vec_sum_1(val,val_name,exception_class,ndim=1,sum_axis=0):
191+
if type(val) is np.ndarray:
192+
sum_val = np.sum(val, axis=sum_axis)
193+
if np.issubdtype(val.dtype,np.integer) and val.ndim == ndim and abs(sum_val.sum() - np.prod(sum_val.shape)) <= _EPSILON:
194+
return val.astype(float)
195+
if np.issubdtype(val.dtype,np.floating) and val.ndim == ndim and abs(sum_val.sum() - np.prod(sum_val.shape)) <= _EPSILON:
196+
return val
197+
raise(exception_class(val_name + f" must be a {ndim}-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
198+
199+
def float_vecs_sum_1(val,val_name,exception_class):
200+
if type(val) is np.ndarray:
201+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
202+
return val.astype(float)
203+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
204+
return val
205+
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))
181206

182207
def int_(val,val_name,exception_class):
183208
if np.issubdtype(type(val),np.integer):
@@ -205,3 +230,9 @@ def onehot_vecs(val,val_name,exception_class):
205230
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
206231
return val
207232
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
233+
234+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
235+
if val != correct:
236+
message = (f"{val_name} must coincide with {correct_name}: "
237+
+ f"{val_name} = {val}, {correct_name} = {correct}")
238+
raise(exception_class(message))

bayesml/autoregressive/_autoregressive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class GenModel(base.Generative):
3030
h_mu_vec : numpy ndarray, optional
3131
a vector of real numbers, by default [0.0, 0.0, ... , 0.0]
3232
h_lambda_mat : numpy ndarray, optional
33-
a positibe definate matrix, by default the identity matrix
33+
a positive definate matrix, by default the identity matrix
3434
h_alpha : float, optional
3535
a positive real number, by default 1.0
3636
h_beta : float, optional
37-
a positibe real number, by default 1.0
37+
a positive real number, by default 1.0
3838
seed : {None, int}, optional
3939
A seed to initialize numpy.random.default_rng(),
4040
by default None

bayesml/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,21 @@ def load_hn_params(self,filename):
257257
+'or ``LearnModel.save_hn_params()``.')
258258
)
259259

260-
@abstractmethod
261260
def reset_hn_params(self):
262-
pass
263-
264-
@abstractmethod
261+
"""Reset the hyperparameters of the posterior distribution to their initial values.
262+
263+
They are reset to the output of `self.get_h0_params()`.
264+
Note that the parameters of the predictive distribution are also calculated from them.
265+
"""
266+
self.set_hn_params(*self.get_h0_params().values())
267+
265268
def overwrite_h0_params(self):
266-
pass
269+
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
270+
271+
They are overwitten by the output of `self.get_hn_params()`.
272+
Note that the parameters of the predictive distribution are also calculated from them.
273+
"""
274+
self.set_h0_params(*self.get_hn_params().values())
267275

268276
@abstractmethod
269277
def update_posterior(self):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Document Author
2+
# Koki Kazama <kokikazama@aoni.waseda.jp>
3+
# Jun Nishikawa <Jun.B.Nishikawa@gmail.com>
4+
5+
from ._hiddenmarkovautoregressive import GenModel
6+
# from ._hiddenmarkovautoregressive import LearnModel
7+
8+
# __all__ = ["GenModel","LearnModel"]
9+
__all__ = ["GenModel"]

0 commit comments

Comments
 (0)