1
1
# Code Author
2
2
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
3
3
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4
+ # Yasushi Esaki <esakiful@gmail.com>
5
+ # Jun Nishikawa <jun.b.nishikawa@gmail.com>
4
6
import numpy as np
5
7
6
8
_EPSILON = np .sqrt (np .finfo (np .float64 ).eps )
@@ -51,6 +53,12 @@ def nonneg_int_vec(val,val_name,exception_class):
51
53
return val
52
54
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)." ))
53
55
56
+ def nonneg_float_vec (val ,val_name ,exception_class ):
57
+ if type (val ) is np .ndarray :
58
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim == 1 and np .all (val >= 0 ):
59
+ return val
60
+ raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray whose dtype is float. Its values must be non-negative (including 0)." ))
61
+
54
62
def int_of_01 (val ,val_name ,exception_class ):
55
63
if np .issubdtype (type (val ),np .integer ):
56
64
if val == 0 or val == 1 :
@@ -171,13 +179,30 @@ def float_vecs(val,val_name,exception_class):
171
179
return val
172
180
raise (exception_class (val_name + " must be a numpy.ndarray whose ndim >= 1." ))
173
181
174
- def float_vec_sum_1 (val ,val_name ,exception_class ):
182
+ def pos_float_vecs (val ,val_name ,exception_class ):
175
183
if type (val ) is np .ndarray :
176
- if np .issubdtype (val .dtype ,np .integer ) and val .ndim == 1 and abs ( val . sum () - 1. ) <= _EPSILON :
184
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np . all ( val > 0 ) :
177
185
return val .astype (float )
178
- if np .issubdtype (val .dtype ,np .floating ) and val .ndim == 1 and abs (val . sum () - 1. ) <= _EPSILON :
186
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim >= 1 and np . all (val > 0.0 ) :
179
187
return val
180
- raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1." ))
188
+ raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)" ))
189
+
190
+ def float_vec_sum_1 (val ,val_name ,exception_class ,ndim = 1 ,sum_axis = 0 ):
191
+ if type (val ) is np .ndarray :
192
+ sum_val = np .sum (val , axis = sum_axis )
193
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim == ndim and abs (sum_val .sum () - np .prod (sum_val .shape )) <= _EPSILON :
194
+ return val .astype (float )
195
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim == ndim and abs (sum_val .sum () - np .prod (sum_val .shape )) <= _EPSILON :
196
+ return val
197
+ raise (exception_class (val_name + f" must be a { ndim } -dimensional numpy.ndarray, and the sum of its elements must equal to 1." ))
198
+
199
+ def float_vecs_sum_1 (val ,val_name ,exception_class ):
200
+ if type (val ) is np .ndarray :
201
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
202
+ return val .astype (float )
203
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim >= 1 and np .all (np .abs (np .sum (val , axis = - 1 ) - 1. ) <= _EPSILON ):
204
+ return val
205
+ raise (exception_class (val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1." ))
181
206
182
207
def int_ (val ,val_name ,exception_class ):
183
208
if np .issubdtype (type (val ),np .integer ):
@@ -205,3 +230,9 @@ def onehot_vecs(val,val_name,exception_class):
205
230
if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (val >= 0 ) and np .all (val .sum (axis = - 1 )== 1 ):
206
231
return val
207
232
raise (exception_class (val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors." ))
233
+
234
+ def shape_consistency (val : int , val_name : str , correct : int , correct_name : str , exception_class ):
235
+ if val != correct :
236
+ message = (f"{ val_name } must coincide with { correct_name } : "
237
+ + f"{ val_name } = { val } , { correct_name } = { correct } " )
238
+ raise (exception_class (message ))
0 commit comments