Skip to content

Commit c09204f

Browse files
committed
add set_params
1 parent 8a03f2c commit c09204f

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
4+
from email import message
45
import warnings
56
import numpy as np
67
from scipy.stats import multivariate_normal as ss_multivariate_normal
@@ -61,18 +62,64 @@ def set_params(
6162
mu_vecs=None,
6263
lambda_mats=None
6364
):
64-
# Noneでない入力について,以下をチェックする.
65-
# * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
66-
# * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
67-
# 例
68-
# if h0_m_vecs is not None:
69-
# _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
70-
# if h0_m_vecs.shape[-1] != self.degree:
71-
# raise(ParameterFormatError(
72-
# "h0_m_vecs.shape[-1] must coincide with self.degree:"
73-
# +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
74-
# self.h0_m_vecs[:] = h0_m_vecs
75-
pass
65+
"""Set the parameter of the sthocastic data generative model.
66+
67+
Parameters
68+
----------
69+
pi_vec : numpy.ndarray
70+
a real vector in :math:`[0, 1]^K`. The sum of its elements must be 1.
71+
a_mat : numpy.ndarray
72+
a real matrix in :math:`[0, 1]^{KxK}`. The sum of each column elements must be 1.
73+
mu_vecs : numpy.ndarray
74+
vectors of real numbers
75+
lambda_mats : numpy.ndarray
76+
positive definite symetric matrices
77+
"""
78+
79+
# [Dimension value consistency]
80+
if pi_vec is not None:
81+
if pi_vec.shape[0] != self.c_num_classes:
82+
raise(ParameterFormatError(
83+
"pi_vec.shape[0] must coincide with self.c_num_classes:"
84+
+f"pi_vec.shape[0]={pi_vec.shape[0]}, self.c_num_classes={self.c_num_classes}"))
85+
self.pi_vec = _check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError)
86+
if a_mat is not None:
87+
if a_mat.shape[0] != self.c_num_classes or a_mat.shape[1] != self.c_num_classes:
88+
raise(ParameterFormatError(
89+
"a_mat.shape[0] and a_mat.shape[1] must coincide with self.c_num_classes:"
90+
+f"a_mat.shape[0]={a_mat.shape[0]}, a_mat.shape[1]={a_mat.shape[1]}, self.c_num_classes={self.c_num_classes}"))
91+
self.a_mat = _check.float_vec_sum_1(a_mat, "a_mat", ParameterFormatError, ndim=2, sum_axis=1)
92+
if mu_vecs is not None:
93+
message = ""
94+
if mu_vecs.shape[0] != self.c_degree:
95+
message += "mu_vecs.shape[0] must coincide with self.c_degree:"
96+
+f"mu_vecs.shape[0]={mu_vecs.shape[0]}, self.c_degree={self.c_degree}"
97+
if len(mu_vecs) == 2:
98+
if mu_vecs.shape[1] != self.c_num_classes:
99+
message += "mu_vecs.shape[1] must coincide with self.c_num_classes:"
100+
+f"mu_vecs.shape[1]={mu_vecs.shape[1]}, self.c_num_classes={self.c_num_classes}"
101+
if message != "":
102+
raise(ParameterFormatError(message))
103+
self.mu_vecs = _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError)
104+
if len(mu_vecs) == 1:
105+
self.mu_vecs = np.broadcast_to(np.reshape(self.mu_vecs, (self.c_degree, 1)), (self.c_degree, self.c_num_classes))
106+
if lambda_mats is not None:
107+
message = ""
108+
if lambda_mats.shape[:2] != (self.c_degree, self.c_degree)
109+
message += "lambda_mats.shape[:2] must coincide with (self.c_degree, self.c_degree):"
110+
+f"lambda_mats.shape[:2]={lambda_mats.shape[:2]}, (self.c_degree, self.c_degree)={(self.c_degree, self.c_degree)}"
111+
if len(lambda_mats) == 3:
112+
if lambda_mats.shape[2] != self.c_num_classes:
113+
message += "lambda_mats.shape[2] must coincide with self.c_num_classes:"
114+
+f"lambda_mats.shape[2]={lambda_mats.shape[2]}, self.c_num_classes={self.c_num_classes}"
115+
if message != "":
116+
raise(ParameterFormatError(message))
117+
self.lambda_mats = _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError)
118+
if len(self.lambda_mats) == 2:
119+
self.lambda_mats = np.broadcast_to(
120+
np.reshape(self.lambda_mats, (self.c_degree,self.c_degree,1)),
121+
(self.c_degree,self.c_degree,self.c_num_classes)
122+
)
76123

77124
def set_h_params(
78125
self,

0 commit comments

Comments
 (0)