1
1
# Code Author
2
2
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
3
3
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4
+ # Yasushi Esaki <esakiful@gmail.com>
5
+ # Jun Nishikawa <jun.b.nishikawa@gmail.com>
4
6
import numpy as np
5
7
6
8
_EPSILON = np .sqrt (np .finfo (np .float64 ).eps )
@@ -179,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
179
181
return val
180
182
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1." ))
181
183
184
+ def float_vecs_sum_1 (val ,val_name ,exception_class ):
185
+ if type (val ) is np .ndarray :
186
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
187
+ return val .astype (float )
188
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
189
+ return val
190
+ raise (exception_class (val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1." ))
191
+
182
192
def int_ (val ,val_name ,exception_class ):
183
193
if np .issubdtype (type (val ),np .integer ):
184
194
return val
@@ -205,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
205
215
if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (val >= 0 ) and np .all (val .sum (axis = - 1 )== 1 ):
206
216
return val
207
217
raise (exception_class (val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors." ))
218
+
219
+ def shape_consistency (val : int , val_name : str , correct : int , correct_name : str , exception_class ):
220
+ if val != correct :
221
+ message = (f"{ val_name } must coincide with { correct_name } : "
222
+ + f"{ val_name } = { val } , { correct_name } = { correct } " )
223
+ raise (exception_class (message ))
0 commit comments