Skip to content

Commit d066598

Browse files
committed
Revise set_params and set_h_params
1 parent de87b59 commit d066598

File tree

1 file changed

+71
-91
lines changed

1 file changed

+71
-91
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 71 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def __init__(
4545
self.lambda_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
4646

4747
# h_params
48+
self.h_eta_vec = np.ones(self.c_num_classes) / 2.0
49+
self.h_zeta_vecs = np.ones([self.c_num_classes,self.c_num_classes]) / 2.0
4850
self.h_m_vecs = np.zeros([self.c_num_classes,self.c_degree])
4951
self.h_kappas = np.ones([self.c_num_classes])
5052
self.h_nus = np.ones(self.c_num_classes) * self.c_degree
5153
self.h_w_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
52-
self.h_eta_vec = np.ones(self.c_num_classes) / 2.0
53-
self.h_zeta_vecs = np.ones([self.c_num_classes,self.c_num_classes]) / 2.0
5454

5555
self.set_params(pi_vec,a_mat,mu_vecs,lambda_mats)
56-
self.set_h_params(h_m_vecs,h_kappas,h_nus,h_w_mats,h_eta_vec,h_zeta_vecs)
56+
self.set_h_params(h_eta_vec,h_zeta_vecs,h_m_vecs,h_kappas,h_nus,h_w_mats)
5757

5858
def set_params(
5959
self,
@@ -69,115 +69,95 @@ def set_params(
6969
pi_vec : numpy.ndarray
7070
a real vector in :math:`[0, 1]^K`. The sum of its elements must be 1.
7171
a_mat : numpy.ndarray
72-
a real matrix in :math:`[0, 1]^{KxK}`. The sum of each column elements must be 1.
72+
a real matrix in :math:`[0, 1]^{KxK}`. The sum of each row elements must be 1.
7373
mu_vecs : numpy.ndarray
7474
vectors of real numbers
7575
lambda_mats : numpy.ndarray
7676
positive definite symetric matrices
7777
"""
78-
79-
# [Check values]
80-
tmp_pi_vec = None if pi_vec is None else _check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError)
81-
tmp_a_mat = None if a_mat is None else _check.float_vec_sum_1(a_mat, "a_mat", ParameterFormatError, ndim=2, sum_axis=1)
82-
tmp_mu_vecs = None if mu_vecs is None else _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
83-
tmp_lambda_mats = None if lambda_mats is None else _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError)
84-
85-
# [Dimension consistency]
86-
if tmp_pi_vec is not None:
78+
if pi_vec is not None:
79+
_check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError)
8780
_check.shape_consistency(
88-
val=tmp_pi_vec, val_name="pi_vec",
89-
correct=[(self.c_num_classes,)], correct_name="(self.c_num_classes,)",
90-
exception_class=ParameterFormatError
91-
)
92-
self.pi_vec[:] = tmp_pi_vec
93-
if tmp_a_mat is not None:
81+
pi_vec.shape[0],"pi_vec.shape[0]",
82+
self.c_num_classes,"self.c_num_classes",
83+
ParameterFormatError
84+
)
85+
self.pi_vec[:] = pi_vec
86+
87+
if a_mat is not None:
88+
_check.float_vecs_sum_1(a_mat, "a_mat", ParameterFormatError)
9489
_check.shape_consistency(
95-
val=tmp_a_mat, val_name="a_mat",
96-
correct=[(self.c_num_classes, self.c_num_classes)], correct_name="(self.c_num_classes, self.c_num_classes)",
97-
exception_class=ParameterFormatError
98-
)
99-
self.a_mat[:] = tmp_a_mat
100-
if tmp_mu_vecs is not None:
90+
a_mat.shape[-1], "a_mat.shape[-1]",
91+
self.c_num_classes, "self.c_num_classes",
92+
ParameterFormatError
93+
)
94+
self.a_mat[:] = a_mat
95+
96+
if mu_vecs is not None:
97+
_check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
10198
_check.shape_consistency(
102-
val=tmp_mu_vecs, val_name="mu_vecs",
103-
correct=[(self.c_degree,), (self.c_num_classes, self.c_degree)],
104-
correct_name="(self.c_degree,), (self.c_num_classes, self.c_degree)",
105-
exception_class=ParameterFormatError
106-
)
107-
self.mu_vecs[:] = tmp_mu_vecs
108-
if tmp_lambda_mats is not None:
99+
mu_vecs.shape[-1],"mu_vecs.shape[-1]",
100+
self.c_degree,"self.c_degree",
101+
ParameterFormatError
102+
)
103+
self.mu_vecs[:] = mu_vecs
104+
105+
if lambda_mats is not None:
106+
_check.pos_def_sym_mats(lambda_mats,'lambda_mats',ParameterFormatError)
109107
_check.shape_consistency(
110-
val=tmp_lambda_mats, val_name="lambda_mats",
111-
correct=[(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)],
112-
correct_name="(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)",
113-
exception_class=ParameterFormatError
114-
)
115-
self.lambda_mats[:] = tmp_lambda_mats
108+
lambda_mats.shape[-1],"lambda_mats.shape[-1] and lambda_mats.shape[-2]",
109+
self.c_degree,"self.c_degree",
110+
ParameterFormatError
111+
)
112+
self.lambda_mats[:] = lambda_mats
116113

117114
def set_h_params(
118115
self,
116+
h_eta_vec=None,
117+
h_zeta_vecs=None,
119118
h_m_vecs=None,
120119
h_kappas=None,
121120
h_nus=None,
122121
h_w_mats=None,
123-
h_eta_vec=None,
124-
h_zeta_vecs=None
125122
):
126-
# Noneでない入力について,以下をチェックする.
127-
# * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
128-
# * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
129-
# 例
130-
# if h0_m_vecs is not None:
131-
# _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
132-
# if h0_m_vecs.shape[-1] != self.degree:
133-
# raise(ParameterFormatError(
134-
# "h0_m_vecs.shape[-1] must coincide with self.degree:"
135-
# +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
136-
# self.h0_m_vecs[:] = h0_m_vecs
137-
pass
138123

139-
# [Check values]
140-
tmp_h_m_vecs = None if h_m_vecs is None else _check.float_vecs(h_m_vecs, "h_m_vecs", ParameterFormatError)
141-
tmp_h_kappas = None if h_kappas is None else np.array(_check.pos_floats(h_kappas, "h_kappas", ParameterFormatError))
142-
tmp_h_nus = None if h_nus is None else np.array(_check.floats(h_nus, "h_nus", ParameterFormatError))
143-
tmp_h_w_mats = None if h_w_mats is None else _check.pos_float_vecs(h_w_mats, "h_w_mats", ParameterFormatError)
144-
tmp_h_eta_vec = None if h_eta_vec is None else _check.pos_float_vec(h_eta_vec, "h_eta_vec", ParameterFormatError)
145-
tmp_h_zeta_vecs = None if h_zeta_vecs is None else _check.pos_float_vecs(h_zeta_vecs, "h_zeta_vecs", ParameterFormatError)
124+
if h_eta_vec is not None:
125+
_check.pos_floats(h_eta_vec,'h_eta_vec',ParameterFormatError)
126+
self.h_eta_vec[:] = h_eta_vec
146127

147-
# [Dimension consistency]
148-
if tmp_h_m_vecs is not None:
149-
_check.shape_consistency(
150-
val=tmp_h_m_vecs, val_name="h_m_vecs",
151-
correct=[(self.c_num_classes,self.c_degree), (self.c_degree,)],
152-
correct_name="(self.c_num_classes,self.c_degree), (self.c_degree,)",
153-
exception_class=ParameterFormatError
154-
)
155-
self.h_m_vecs[:] = tmp_h_m_vecs
156-
if tmp_h_kappas is not None:
157-
_check.shape_consistency(
158-
val=tmp_h_kappas, val_name="h_kappas",
159-
correct=[(self.c_num_classes,), ()], correct_name="(self.c_num_classes,), ()",
160-
exception_class=ParameterFormatError
161-
)
162-
self.h_kappas[:] = tmp_h_kappas
163-
if tmp_h_nus is not None:
164-
if not np.all(tmp_h_nus > self.c_degree - 1):
165-
raise(ParameterFormatError("The all values in h_nus must be greater than self.c_degree."))
128+
if h_zeta_vecs is not None:
129+
_check.pos_floats(h_zeta_vecs, 'h_zeta_vecs', ParameterFormatError)
130+
self.h_zeta_vecs[:] = h_zeta_vecs
131+
132+
if h_m_vecs is not None:
133+
_check.float_vecs(h_m_vecs, "h_m_vecs", ParameterFormatError)
166134
_check.shape_consistency(
167-
val=tmp_h_nus, val_name="h_nus",
168-
correct=[(self.c_num_classes,), ()],
169-
correct_name="(self.c_num_classes,), ()",
170-
exception_class=ParameterFormatError
171-
)
172-
self.h_nus[:] = tmp_h_nus
173-
if tmp_h_w_mats is not None:
135+
h_m_vecs.shape[-1],"h_m_vecs.shape[-1]",
136+
self.c_degree,"self.c_degree",
137+
ParameterFormatError
138+
)
139+
self.h_m_vecs[:] = h_m_vecs
140+
141+
if h_kappas is not None:
142+
_check.pos_floats(h_kappas, "h_kappas", ParameterFormatError)
143+
self.h_kappas[:] = h_kappas
144+
145+
if h_nus is not None:
146+
_check.floats(h_nus, "h_nus", ParameterFormatError)
147+
if np.all(h_nus <= self.c_degree - 1):
148+
raise(ParameterFormatError(
149+
"All the values in h_nus must be greater than self.c_degree - 1: "
150+
+ f"self.c_degree = {self.c_degree}, h_nus = {h_nus}"))
151+
self.h_nus[:] = h_nus
152+
153+
if h_w_mats is not None:
154+
_check.pos_def_sym_mats(h_w_mats,'h_w_mats',ParameterFormatError)
174155
_check.shape_consistency(
175-
val=tmp_h_w_mats.shape, val_name="h_w_mats",
176-
correct=[(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)],
177-
correct_name="(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)",
178-
exception_class=ParameterFormatError
179-
)
180-
self.h_w_mats[:] = tmp_h_w_mats
156+
h_w_mats.shape[-1],"h_w_mats.shape[-1] and h_w_mats.shape[-2]",
157+
self.c_degree,"self.c_degree",
158+
ParameterFormatError
159+
)
160+
self.h_w_mats[:] = h_w_mats
181161

182162
def get_params(self):
183163
# paramsを辞書として返す関数.

0 commit comments

Comments
 (0)