Skip to content

Commit 3843027

Browse files
committed
Add float_vecs_sum_1
1 parent 0ff39a8 commit 3843027

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

bayesml/_check.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4+
# Yasushi Esaki <esakiful@gmail.com>
45
import numpy as np
56

67
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -179,6 +180,14 @@ def float_vec_sum_1(val,val_name,exception_class):
179180
return val
180181
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
181182

183+
def float_vecs_sum_1(val,val_name,exception_class):
184+
if type(val) is np.ndarray:
185+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
186+
return val.astype(float)
187+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
188+
return val
189+
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))
190+
182191
def int_(val,val_name,exception_class):
183192
if np.issubdtype(type(val),np.integer):
184193
return val

0 commit comments

Comments
 (0)