Skip to content

Commit 36c0863

Browse files
Merge pull request #33 from yuta-nakahara/develop-check
Develop shape_consistency
2 parents d9e6699 + 69d052a commit 36c0863

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

bayesml/_check.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4+
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
46
import numpy as np
57

68
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -100,6 +102,22 @@ def pos_def_sym_mat(val,val_name,exception_class):
100102
pass
101103
raise(exception_class(val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray."))
102104

105+
def sym_mats(val,val_name,exception_class):
106+
if type(val) is np.ndarray:
107+
if val.ndim >= 2 and val.shape[-1] == val.shape[-2]:
108+
if np.allclose(val, np.swapaxes(val,-1,-2)):
109+
return val
110+
raise(exception_class(val_name + " must be a symmetric 2-dimensional numpy.ndarray."))
111+
112+
def pos_def_sym_mats(val,val_name,exception_class):
113+
sym_mats(val,val_name,exception_class)
114+
try:
115+
np.linalg.cholesky(val)
116+
return val
117+
except np.linalg.LinAlgError:
118+
pass
119+
raise(exception_class(val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray."))
120+
103121
def float_(val,val_name,exception_class):
104122
if np.issubdtype(type(val),np.floating):
105123
return val
@@ -163,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
163181
return val
164182
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
165183

184+
def float_vecs_sum_1(val,val_name,exception_class):
185+
if type(val) is np.ndarray:
186+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
187+
return val.astype(float)
188+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
189+
return val
190+
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))
191+
166192
def int_(val,val_name,exception_class):
167193
if np.issubdtype(type(val),np.integer):
168194
return val
@@ -189,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
189215
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
190216
return val
191217
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
218+
219+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
220+
if val != correct:
221+
message = (f"{val_name} must coincide with {correct_name}: "
222+
+ f"{val_name} = {val}, {correct_name} = {correct}")
223+
raise(exception_class(message))

0 commit comments

Comments
 (0)