Skip to content

Commit bd174ef

Browse files
committed
little modify and add set_h_params before eta and zeta
1 parent 9cd44c2 commit bd174ef

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

bayesml/_check.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def float_vecs(val,val_name,exception_class):
178178
return val
179179
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1."))
180180

181+
def pos_float_vecs(val,val_name,exception_class):
182+
if type(val) is np.ndarray:
183+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val>0):
184+
return val.astype(float)
185+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(val>0.0):
186+
return val
187+
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)"))
188+
181189
def float_vec_sum_1(val,val_name,exception_class,ndim=1,sum_axis=0):
182190
if type(val) is np.ndarray:
183191
sum_val = np.sum(val, axis=sum_axis)

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,32 +82,32 @@ def set_params(
8282
tmp_mu_vecs = None if mu_vecs is None else _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
8383
tmp_lambda_mats = None if lambda_mats is None else _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError)
8484

85-
# [Dimension value consistency]
85+
# [Dimension consistency]
8686
if tmp_pi_vec is not None:
8787
_check.shape_consistency(
88-
val=pi_vec, val_name="pi_vec",
88+
val=tmp_pi_vec, val_name="pi_vec",
8989
correct=[(self.c_num_classes,)], correct_name="(self.c_num_classes,)",
9090
exception_class=ParameterFormatError
9191
)
9292
self.pi_vec[:] = tmp_pi_vec
9393
if tmp_a_mat is not None:
9494
_check.shape_consistency(
95-
val=a_mat.shape, val_name="a_mat",
95+
val=tmp_a_mat, val_name="a_mat",
9696
correct=[(self.c_num_classes, self.c_num_classes)], correct_name="(self.c_num_classes, self.c_num_classes)",
9797
exception_class=ParameterFormatError
9898
)
9999
self.a_mat[:] = tmp_a_mat
100100
if tmp_mu_vecs is not None:
101101
_check.shape_consistency(
102-
val=mu_vecs.shape, val_name="mu_vecs",
102+
val=tmp_mu_vecs, val_name="mu_vecs",
103103
correct=[(self.c_degree,), (self.c_num_classes, self.c_degree)],
104104
correct_name="(self.c_degree,), (self.c_num_classes, self.c_degree)",
105105
exception_class=ParameterFormatError
106106
)
107107
self.mu_vecs[:] = tmp_mu_vecs
108108
if tmp_lambda_mats is not None:
109109
_check.shape_consistency(
110-
val=lambda_mats.shape, val_name="lambda_mats",
110+
val=tmp_lambda_mats, val_name="lambda_mats",
111111
correct=[(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)],
112112
correct_name="(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)",
113113
exception_class=ParameterFormatError
@@ -136,6 +136,49 @@ def set_h_params(
136136
# self.h0_m_vecs[:] = h0_m_vecs
137137
pass
138138

139+
# [Check values]
140+
tmp_h_m_vecs = None if h_m_vecs is None else _check.float_vecs(h_m_vecs, "h_m_vecs", ParameterFormatError)
141+
tmp_h_kappas = None if h_kappas is None else np.array(_check.pos_floats(h_kappas, "h_kappas", ParameterFormatError))
142+
tmp_h_nus = None if h_nus is None else np.array(_check.floats(h_nus, "h_nus", ParameterFormatError))
143+
tmp_h_w_mats = None if h_w_mats is None else _check.pos_float_vecs(h_w_mats, "h_w_mats", ParameterFormatError)
144+
tmp_h_eta_vec = None if h_eta_vec is None else _check.pos_float_vec(h_eta_vec, "h_eta_vec", ParameterFormatError)
145+
tmp_h_zeta_vecs = None if h_zeta_vecs is None else _check.pos_float_vecs(h_zeta_vecs, "h_zeta_vecs", ParameterFormatError)
146+
147+
# [Dimension consistency]
148+
if tmp_h_m_vecs is not None:
149+
_check.shape_consistency(
150+
val=tmp_h_m_vecs, val_name="h_m_vecs",
151+
correct=[(self.c_num_classes,self.c_degree), (self.c_degree,)],
152+
correct_name="(self.c_num_classes,self.c_degree), (self.c_degree,)",
153+
exception_class=ParameterFormatError
154+
)
155+
self.h_m_vecs[:] = tmp_h_m_vecs
156+
if tmp_h_kappas is not None:
157+
_check.shape_consistency(
158+
val=tmp_h_kappas, val_name="h_kappas",
159+
correct=[(self.c_num_classes,), ()], correct_name="(self.c_num_classes,), ()",
160+
exception_class=ParameterFormatError
161+
)
162+
self.h_kappas[:] = tmp_h_kappas
163+
if tmp_h_nus is not None:
164+
if not np.all(tmp_h_nus > self.c_degree - 1):
165+
raise(ParameterFormatError("The all values in h_nus must be greater than self.c_degree."))
166+
_check.shape_consistency(
167+
val=tmp_h_nus, val_name="h_nus",
168+
correct=[(self.c_num_classes,), ()],
169+
correct_name="(self.c_num_classes,), ()",
170+
exception_class=ParameterFormatError
171+
)
172+
self.h_nus[:] = tmp_h_nus
173+
if tmp_h_w_mats is not None:
174+
_check.shape_consistency(
175+
val=tmp_h_w_mats.shape, val_name="h_w_mats",
176+
correct=[(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)],
177+
correct_name="(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)",
178+
exception_class=ParameterFormatError
179+
)
180+
self.h_w_mats[:] = tmp_h_w_mats
181+
139182
def get_params(self):
140183
# paramsを辞書として返す関数.
141184
# 要素の順番はset_paramsの引数の順にそろえる.

0 commit comments

Comments
 (0)