1
1
# Code Author
2
2
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
3
3
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4
+ # Yasushi Esaki <esakiful@gmail.com>
5
+ # Jun Nishikawa <jun.b.nishikawa@gmail.com>
4
6
import numpy as np
5
7
6
8
_EPSILON = np .sqrt (np .finfo (np .float64 ).eps )
@@ -100,6 +102,22 @@ def pos_def_sym_mat(val,val_name,exception_class):
100
102
pass
101
103
raise (exception_class (val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray." ))
102
104
105
+ def sym_mats (val ,val_name ,exception_class ):
106
+ if type (val ) is np .ndarray :
107
+ if val .ndim >= 2 and val .shape [- 1 ] == val .shape [- 2 ]:
108
+ if np .allclose (val , np .swapaxes (val ,- 1 ,- 2 )):
109
+ return val
110
+ raise (exception_class (val_name + " must be a symmetric 2-dimensional numpy.ndarray." ))
111
+
112
+ def pos_def_sym_mats (val ,val_name ,exception_class ):
113
+ sym_mats (val ,val_name ,exception_class )
114
+ try :
115
+ np .linalg .cholesky (val )
116
+ return val
117
+ except np .linalg .LinAlgError :
118
+ pass
119
+ raise (exception_class (val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray." ))
120
+
103
121
def float_ (val ,val_name ,exception_class ):
104
122
if np .issubdtype (type (val ),np .floating ):
105
123
return val
@@ -163,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
163
181
return val
164
182
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1." ))
165
183
184
+ def float_vecs_sum_1 (val ,val_name ,exception_class ):
185
+ if type (val ) is np .ndarray :
186
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
187
+ return val .astype (float )
188
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
189
+ return val
190
+ raise (exception_class (val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1." ))
191
+
166
192
def int_ (val ,val_name ,exception_class ):
167
193
if np .issubdtype (type (val ),np .integer ):
168
194
return val
@@ -189,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
189
215
if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (val >= 0 ) and np .all (val .sum (axis = - 1 )== 1 ):
190
216
return val
191
217
raise (exception_class (val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors." ))
218
+
219
+ def shape_consistency (val : int , val_name : str , correct : int , correct_name : str , exception_class ):
220
+ if val != correct :
221
+ message = (f"{ val_name } must coincide with { correct_name } : "
222
+ + f"{ val_name } = { val } , { correct_name } = { correct } " )
223
+ raise (exception_class (message ))
0 commit comments