Skip to content

Commit 9cd44c2

Browse files
committed
modify set_params and add shape_consistency
1 parent 571a7c0 commit 9cd44c2

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

bayesml/_check.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,10 @@ def dim_consistency(value_dict: dict, exception_class):
239239
else:
240240
return list(check_value_dict.values())[0]
241241
raise(exception_class(message))
242+
243+
def shape_consistency(val, val_name, correct, correct_name, exception_class):
244+
if val.shape not in correct:
245+
message = f"{val_name}.shape must coincide with {correct_name}:"
246+
+f"{val_name}.shape={val.shape}, {correct_name}={correct}"
247+
raise(exception_class(message))
248+

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -76,50 +76,43 @@ def set_params(
7676
positive definite symetric matrices
7777
"""
7878

79+
# [Check values]
80+
tmp_pi_vec = None if pi_vec is None else _check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError)
81+
tmp_a_mat = None if a_mat is None else _check.float_vec_sum_1(a_mat, "a_mat", ParameterFormatError, ndim=2, sum_axis=1)
82+
tmp_mu_vecs = None if mu_vecs is None else _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
83+
tmp_lambda_mats = None if lambda_mats is None else _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError)
84+
7985
# [Dimension value consistency]
80-
if pi_vec is not None:
81-
if pi_vec.shape[0] != self.c_num_classes:
82-
raise(ParameterFormatError(
83-
"pi_vec.shape[0] must coincide with self.c_num_classes:"
84-
+f"pi_vec.shape[0]={pi_vec.shape[0]}, self.c_num_classes={self.c_num_classes}"))
85-
self.pi_vec = _check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError)
86-
if a_mat is not None:
87-
if a_mat.shape[0] != self.c_num_classes or a_mat.shape[1] != self.c_num_classes:
88-
raise(ParameterFormatError(
89-
"a_mat.shape[0] and a_mat.shape[1] must coincide with self.c_num_classes:"
90-
+f"a_mat.shape[0]={a_mat.shape[0]}, a_mat.shape[1]={a_mat.shape[1]}, self.c_num_classes={self.c_num_classes}"))
91-
self.a_mat = _check.float_vec_sum_1(a_mat, "a_mat", ParameterFormatError, ndim=2, sum_axis=1)
92-
if mu_vecs is not None:
93-
message = ""
94-
if mu_vecs.shape[0] != self.c_degree:
95-
message += "mu_vecs.shape[0] must coincide with self.c_degree:"
96-
+f"mu_vecs.shape[0]={mu_vecs.shape[0]}, self.c_degree={self.c_degree}"
97-
if len(mu_vecs) == 2:
98-
if mu_vecs.shape[1] != self.c_num_classes:
99-
message += "mu_vecs.shape[1] must coincide with self.c_num_classes:"
100-
+f"mu_vecs.shape[1]={mu_vecs.shape[1]}, self.c_num_classes={self.c_num_classes}"
101-
if message != "":
102-
raise(ParameterFormatError(message))
103-
self.mu_vecs = _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
104-
if len(mu_vecs) == 1:
105-
self.mu_vecs = np.broadcast_to(np.reshape(self.mu_vecs, (self.c_degree, 1)), (self.c_degree, self.c_num_classes))
106-
if lambda_mats is not None:
107-
message = ""
108-
if lambda_mats.shape[:2] != (self.c_degree, self.c_degree):
109-
message += "lambda_mats.shape[:2] must coincide with (self.c_degree, self.c_degree):"
110-
+f"lambda_mats.shape[:2]={lambda_mats.shape[:2]}, (self.c_degree, self.c_degree)={(self.c_degree, self.c_degree)}"
111-
if len(lambda_mats) == 3:
112-
if lambda_mats.shape[2] != self.c_num_classes:
113-
message += "lambda_mats.shape[2] must coincide with self.c_num_classes:"
114-
+f"lambda_mats.shape[2]={lambda_mats.shape[2]}, self.c_num_classes={self.c_num_classes}"
115-
if message != "":
116-
raise(ParameterFormatError(message))
117-
self.lambda_mats = _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError)
118-
if len(self.lambda_mats) == 2:
119-
self.lambda_mats = np.broadcast_to(
120-
np.reshape(self.lambda_mats, (self.c_degree,self.c_degree,1)),
121-
(self.c_degree,self.c_degree,self.c_num_classes)
122-
)
86+
if tmp_pi_vec is not None:
87+
_check.shape_consistency(
88+
val=pi_vec, val_name="pi_vec",
89+
correct=[(self.c_num_classes,)], correct_name="(self.c_num_classes,)",
90+
exception_class=ParameterFormatError
91+
)
92+
self.pi_vec[:] = tmp_pi_vec
93+
if tmp_a_mat is not None:
94+
_check.shape_consistency(
95+
val=a_mat.shape, val_name="a_mat",
96+
correct=[(self.c_num_classes, self.c_num_classes)], correct_name="(self.c_num_classes, self.c_num_classes)",
97+
exception_class=ParameterFormatError
98+
)
99+
self.a_mat[:] = tmp_a_mat
100+
if tmp_mu_vecs is not None:
101+
_check.shape_consistency(
102+
val=mu_vecs.shape, val_name="mu_vecs",
103+
correct=[(self.c_degree,), (self.c_num_classes, self.c_degree)],
104+
correct_name="(self.c_degree,), (self.c_num_classes, self.c_degree)",
105+
exception_class=ParameterFormatError
106+
)
107+
self.mu_vecs[:] = tmp_mu_vecs
108+
if tmp_lambda_mats is not None:
109+
_check.shape_consistency(
110+
val=lambda_mats.shape, val_name="lambda_mats",
111+
correct=[(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)],
112+
correct_name="(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)",
113+
exception_class=ParameterFormatError
114+
)
115+
self.lambda_mats[:] = tmp_lambda_mats
123116

124117
def set_h_params(
125118
self,

0 commit comments

Comments
 (0)