Skip to content

Commit 2a1aeaa

Browse files
committed
Add shape_consistency
1 parent 3843027 commit 2a1aeaa

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

bayesml/_check.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
44
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
56
import numpy as np
67

78
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -214,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
214215
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
215216
return val
216217
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
218+
219+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
220+
if val != correct:
221+
message = f"{val_name} must coincide with {correct_name}:"
222+
+f"{val_name}={val}, {correct_name}={correct}"
223+
raise(exception_class(message))

0 commit comments

Comments
 (0)