Skip to content

Commit e9df92d

Browse files
committed
Update _check.py
1 parent bc14ef9 commit e9df92d

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

bayesml/_check.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,13 @@ def pos_float_vecs(val,val_name,exception_class):
199199
return val
200200
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)"))
201201

202-
def float_vec_sum_1(val,val_name,exception_class,ndim=1,sum_axis=0):
202+
def float_vec_sum_1(val,val_name,exception_class):
203203
if type(val) is np.ndarray:
204-
sum_val = np.sum(val, axis=sum_axis)
205-
if np.issubdtype(val.dtype,np.integer) and val.ndim == ndim and abs(sum_val.sum() - np.prod(sum_val.shape)) <= _EPSILON:
204+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
206205
return val.astype(float)
207-
if np.issubdtype(val.dtype,np.floating) and val.ndim == ndim and abs(sum_val.sum() - np.prod(sum_val.shape)) <= _EPSILON:
206+
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
208207
return val
209-
raise(exception_class(val_name + f" must be a {ndim}-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
208+
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
210209

211210
def float_vecs_sum_1(val,val_name,exception_class):
212211
if type(val) is np.ndarray:

0 commit comments

Comments
 (0)