3
3
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4
4
import numpy as np
5
5
6
- FLOATS = list ({'float128' ,'float64' ,'float32' ,'float16' } & set (dir (np )) | {float })
7
- INTS = list ({'int64' ,'int32' ,'int16' ,'int8' } & set (dir (np )) | {int })
6
+ _EPSILON = np .sqrt (np .finfo (np .float64 ).eps )
8
7
9
8
def float_in_closed01 (val ,val_name ,exception_class ):
10
- if type (val ) in FLOATS :
9
+ if np . issubdtype ( type (val ), np . floating ) :
11
10
if val >= 0.0 and val <= 1.0 :
12
11
return val
13
- if type (val ) in INTS :
12
+ if np . issubdtype ( type (val ), np . integer ) :
14
13
if val >= 0.0 and val <= 1.0 :
15
14
return float (val )
16
15
raise (exception_class (val_name + " must be in [0,1]." ))
17
16
18
17
def pos_float (val ,val_name ,exception_class ):
19
- if type (val ) in FLOATS :
18
+ if np . issubdtype ( type (val ), np . floating ) :
20
19
if val > 0.0 :
21
20
return val
22
- if type (val ) in INTS :
21
+ if np . issubdtype ( type (val ), np . integer ) :
23
22
if val > 0.0 :
24
23
return float (val )
25
24
raise (exception_class (val_name + " must be positive (not including 0.0)." ))
26
25
27
26
def pos_int (val ,val_name ,exception_class ):
28
- if type (val ) in INTS :
27
+ if np . issubdtype ( type (val ), np . integer ) :
29
28
if val > 0 :
30
29
return val
31
30
raise (exception_class (val_name + " must be int. Its value must be positive (not including 0)." ))
32
31
33
32
def nonneg_int (val ,val_name ,exception_class ):
34
- if type (val ) in INTS :
33
+ if np . issubdtype ( type (val ), np . integer ) :
35
34
if val >= 0 :
36
35
return val
37
36
raise (exception_class (val_name + " must be int. Its value must be non-negative (including 0)." ))
@@ -42,18 +41,18 @@ def nonneg_ints(val,val_name,exception_class):
42
41
except :
43
42
pass
44
43
if type (val ) is np .ndarray :
45
- if val .dtype in INTS and np .all (val >= 0 ):
44
+ if np . issubdtype ( val .dtype , np . integer ) and np .all (val >= 0 ):
46
45
return val
47
46
raise (exception_class (val_name + " must be int or a numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)." ))
48
47
49
48
def nonneg_int_vec (val ,val_name ,exception_class ):
50
49
if type (val ) is np .ndarray :
51
- if val .dtype in INTS and val .ndim == 1 and np .all (val >= 0 ):
50
+ if np . issubdtype ( val .dtype , np . integer ) and val .ndim == 1 and np .all (val >= 0 ):
52
51
return val
53
52
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)." ))
54
53
55
54
def int_of_01 (val ,val_name ,exception_class ):
56
- if type (val ) in INTS :
55
+ if np . issubdtype ( type (val ), np . integer ) :
57
56
if val == 0 or val == 1 :
58
57
return val
59
58
raise (exception_class (val_name + " must be int. Its value must be 0 or 1." ))
@@ -64,23 +63,23 @@ def ints_of_01(val,val_name,exception_class):
64
63
except :
65
64
pass
66
65
if type (val ) is np .ndarray :
67
- if val .dtype in INTS and np .all (val >= 0 ) and np .all (val <= 1 ):
66
+ if np . issubdtype ( val .dtype , np . integer ) and np .all (val >= 0 ) and np .all (val <= 1 ):
68
67
return val
69
68
raise (exception_class (val_name + " must be int or a numpy.ndarray whose dtype is int. Its values must be 0 or 1." ))
70
69
71
70
def int_vec_of_01 (val ,val_name ,exception_class ):
72
71
if type (val ) is np .ndarray :
73
- if val .dtype in INTS and val .ndim == 1 and np .all (val >= 0 ) and np .all (val <= 1 ):
72
+ if np . issubdtype ( val .dtype , np . integer ) and val .ndim == 1 and np .all (val >= 0 ) and np .all (val <= 1 ):
74
73
return val
75
74
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be 0 or 1." ))
76
75
77
76
def scalar (val ,val_name ,exception_class ):
78
- if type (val ) in INTS or type (val ) in FLOATS :
77
+ if np . issubdtype ( type (val ), np . integer ) or np . issubdtype ( type (val ), np . floating ) :
79
78
return val
80
79
raise (exception_class (val_name + " must be a scalar." ))
81
80
82
81
def pos_scalar (val ,val_name ,exception_class ):
83
- if type (val ) in INTS or type (val ) in FLOATS :
82
+ if np . issubdtype ( type (val ), np . integer ) or np . issubdtype ( type (val ), np . floating ) :
84
83
if val > 0.0 :
85
84
return val
86
85
raise (exception_class (val_name + " must be a positive scalar." ))
@@ -102,9 +101,9 @@ def pos_def_sym_mat(val,val_name,exception_class):
102
101
raise (exception_class (val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray." ))
103
102
104
103
def float_ (val ,val_name ,exception_class ):
105
- if type (val ) in FLOATS :
104
+ if np . issubdtype ( type (val ), np . floating ) :
106
105
return val
107
- if type (val ) in INTS :
106
+ if np . issubdtype ( type (val ), np . integer ) :
108
107
return float (val )
109
108
raise (exception_class (val_name + " must be a scalar." ))
110
109
@@ -114,9 +113,9 @@ def floats(val,val_name,exception_class):
114
113
except :
115
114
pass
116
115
if type (val ) is np .ndarray :
117
- if val .dtype in INTS :
116
+ if np . issubdtype ( val .dtype , np . integer ) :
118
117
return val .astype (float )
119
- if val .dtype in FLOATS :
118
+ if np . issubdtype ( val .dtype , np . floating ) :
120
119
return val
121
120
raise (exception_class (val_name + " must be float or a numpy.ndarray." ))
122
121
@@ -126,33 +125,67 @@ def pos_floats(val,val_name,exception_class):
126
125
except :
127
126
pass
128
127
if type (val ) is np .ndarray :
129
- if val .dtype in INTS and np .all (val > 0 ):
128
+ if np . issubdtype ( val .dtype , np . integer ) and np .all (val > 0 ):
130
129
return val .astype (float )
131
- if val .dtype in FLOATS and np .all (val > 0.0 ):
130
+ if np . issubdtype ( val .dtype , np . floating ) and np .all (val > 0.0 ):
132
131
return val
133
132
raise (exception_class (val_name + " must be float or a numpy.ndarray. Its values must be positive (not including 0)" ))
134
133
135
134
def float_vec (val ,val_name ,exception_class ):
136
135
if type (val ) is np .ndarray :
137
- if val .dtype in INTS and val .ndim == 1 :
136
+ if np . issubdtype ( val .dtype , np . integer ) and val .ndim == 1 :
138
137
return val .astype (float )
139
- if val .dtype in FLOATS and val .ndim == 1 :
138
+ if np . issubdtype ( val .dtype , np . floating ) and val .ndim == 1 :
140
139
return val
141
140
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray." ))
142
141
143
142
def pos_float_vec (val ,val_name ,exception_class ):
144
143
if type (val ) is np .ndarray :
145
- if val .dtype in INTS and val .ndim == 1 and np .all (val > 0 ):
144
+ if np . issubdtype ( val .dtype , np . integer ) and val .ndim == 1 and np .all (val > 0 ):
146
145
return val .astype (float )
147
- if val .dtype in FLOATS and val .ndim == 1 and np .all (val > 0.0 ):
146
+ if np . issubdtype ( val .dtype , np . floating ) and val .ndim == 1 and np .all (val > 0.0 ):
148
147
return val
149
148
raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)" ))
150
149
151
150
def float_vecs (val ,val_name ,exception_class ):
152
151
if type (val ) is np .ndarray :
153
- if val .dtype in INTS and val .ndim >= 1 :
152
+ if np . issubdtype ( val .dtype , np . integer ) and val .ndim >= 1 :
154
153
return val .astype (float )
155
- if val .dtype in FLOATS and val .ndim >= 1 :
154
+ if np . issubdtype ( val .dtype , np . floating ) and val .ndim >= 1 :
156
155
return val
157
156
raise (exception_class (val_name + " must be a numpy.ndarray whose ndim >= 1." ))
158
157
158
+ def float_vec_sum_1 (val ,val_name ,exception_class ):
159
+ if type (val ) is np .ndarray :
160
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim == 1 and abs (val .sum () - 1. ) <= _EPSILON :
161
+ return val .astype (float )
162
+ if np .issubdtype (val .dtype ,np .floating ) and val .ndim == 1 and abs (val .sum () - 1. ) <= _EPSILON :
163
+ return val
164
+ raise (exception_class (val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1." ))
165
+
166
+ def int_ (val ,val_name ,exception_class ):
167
+ if np .issubdtype (type (val ),np .integer ):
168
+ return val
169
+ raise (exception_class (val_name + " must be an integer." ))
170
+
171
+ def ints (val ,val_name ,exception_class ):
172
+ try :
173
+ return int_ (val ,val_name ,exception_class )
174
+ except :
175
+ pass
176
+ if type (val ) is np .ndarray :
177
+ if np .issubdtype (val .dtype ,np .integer ):
178
+ return val
179
+ raise (exception_class (val_name + " must be int or a numpy.ndarray whose dtype is int." ))
180
+
181
+ def onehot_vec (val ,val_name ,exception_class ):
182
+ if type (val ) is np .ndarray :
183
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim == 1 and np .all (val >= 0 ) and val .sum ()== 1 :
184
+ return val
185
+ raise (exception_class (val_name + " must be a one-hot vector (1-dimensional ndarray) whose dtype must be int." ))
186
+
187
+ def onehot_vecs (val ,val_name ,exception_class ):
188
+ if type (val ) is np .ndarray :
189
+ if np .issubdtype (val .dtype ,np .integer ) and val .ndim >= 1 and np .all (val >= 0 ) and np .all (val .sum (axis = - 1 )== 1 ):
190
+ return val
191
+ raise (exception_class (val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors." ))
0 commit comments