Skip to content

Commit de87b59

Browse files
committed
Merge branch 'develop-check' into develop-hiddenmarkovnormal-initgetset
2 parents bd174ef + 69d052a commit de87b59

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

bayesml/_check.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
44
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
56
import numpy as np
67

78
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -230,27 +231,8 @@ def onehot_vecs(val,val_name,exception_class):
230231
return val
231232
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
232233

233-
def dim_consistency(value_dict: dict, exception_class):
234-
check_value_dict = {}
235-
for key in value_dict:
236-
if value_dict[key] is not None and value_dict[key] not in check_value_dict.values():
237-
check_value_dict[key] = value_dict[key]
238-
if len(check_value_dict) == 0:
239-
return None
240-
elif len(check_value_dict) > 1:
241-
message = f"The following values must be the same: {list(value_dict.keys())}. "
242-
message += f"The following values are different: {list(check_value_dict.keys())}. "
243-
# print("===== Error =====")
244-
for key in check_value_dict:
245-
# print(f"{key} = {check_value_dict[key]}")
246-
message += f"\n {key} = {check_value_dict[key]}"
247-
else:
248-
return list(check_value_dict.values())[0]
249-
raise(exception_class(message))
250-
251-
def shape_consistency(val, val_name, correct, correct_name, exception_class):
252-
if val.shape not in correct:
253-
message = f"{val_name}.shape must coincide with {correct_name}:"
254-
+f"{val_name}.shape={val.shape}, {correct_name}={correct}"
255-
raise(exception_class(message))
256-
234+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
235+
if val != correct:
236+
message = (f"{val_name} must coincide with {correct_name}: "
237+
+ f"{val_name} = {val}, {correct_name} = {correct}")
238+
raise(exception_class(message))

0 commit comments

Comments
 (0)