|
1 | 1 | # Code Author
|
2 | 2 | # Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
|
3 | 3 | # Jun Nishikawa <jun.b.nishikawa@gmail.com>
|
| 4 | +from email import message |
4 | 5 | import warnings
|
5 | 6 | import numpy as np
|
6 | 7 | from scipy.stats import multivariate_normal as ss_multivariate_normal
|
@@ -61,18 +62,64 @@ def set_params(
|
61 | 62 | mu_vecs=None,
|
62 | 63 | lambda_mats=None
|
63 | 64 | ):
|
64 |
| - # Noneでない入力について,以下をチェックする. |
65 |
| - # * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など) |
66 |
| - # * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める |
67 |
| - # 例 |
68 |
| - # if h0_m_vecs is not None: |
69 |
| - # _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError) |
70 |
| - # if h0_m_vecs.shape[-1] != self.degree: |
71 |
| - # raise(ParameterFormatError( |
72 |
| - # "h0_m_vecs.shape[-1] must coincide with self.degree:" |
73 |
| - # +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}")) |
74 |
| - # self.h0_m_vecs[:] = h0_m_vecs |
75 |
| - pass |
| 65 | + """Set the parameter of the sthocastic data generative model. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + pi_vec : numpy.ndarray |
| 70 | + a real vector in :math:`[0, 1]^K`. The sum of its elements must be 1. |
| 71 | + a_mat : numpy.ndarray |
| 72 | + a real matrix in :math:`[0, 1]^{KxK}`. The sum of each column elements must be 1. |
| 73 | + mu_vecs : numpy.ndarray |
| 74 | + vectors of real numbers |
| 75 | + lambda_mats : numpy.ndarray |
| 76 | + positive definite symetric matrices |
| 77 | + """ |
| 78 | + |
| 79 | + # [Dimension value consistency] |
| 80 | + if pi_vec is not None: |
| 81 | + if pi_vec.shape[0] != self.c_num_classes: |
| 82 | + raise(ParameterFormatError( |
| 83 | + "pi_vec.shape[0] must coincide with self.c_num_classes:" |
| 84 | + +f"pi_vec.shape[0]={pi_vec.shape[0]}, self.c_num_classes={self.c_num_classes}")) |
| 85 | + self.pi_vec = _check.float_vec_sum_1(pi_vec, "pi_vec", ParameterFormatError) |
| 86 | + if a_mat is not None: |
| 87 | + if a_mat.shape[0] != self.c_num_classes or a_mat.shape[1] != self.c_num_classes: |
| 88 | + raise(ParameterFormatError( |
| 89 | + "a_mat.shape[0] and a_mat.shape[1] must coincide with self.c_num_classes:" |
| 90 | + +f"a_mat.shape[0]={a_mat.shape[0]}, a_mat.shape[1]={a_mat.shape[1]}, self.c_num_classes={self.c_num_classes}")) |
| 91 | + self.a_mat = _check.float_vec_sum_1(a_mat, "a_mat", ParameterFormatError, ndim=2, sum_axis=1) |
| 92 | + if mu_vecs is not None: |
| 93 | + message = "" |
| 94 | + if mu_vecs.shape[0] != self.c_degree: |
| 95 | + message += "mu_vecs.shape[0] must coincide with self.c_degree:" |
| 96 | + +f"mu_vecs.shape[0]={mu_vecs.shape[0]}, self.c_degree={self.c_degree}" |
| 97 | + if len(mu_vecs) == 2: |
| 98 | + if mu_vecs.shape[1] != self.c_num_classes: |
| 99 | + message += "mu_vecs.shape[1] must coincide with self.c_num_classes:" |
| 100 | + +f"mu_vecs.shape[1]={mu_vecs.shape[1]}, self.c_num_classes={self.c_num_classes}" |
| 101 | + if message != "": |
| 102 | + raise(ParameterFormatError(message)) |
| 103 | + self.mu_vecs = _check.float_vecs(mu_vecs, "mu_vecs", ParameterFormatError) |
| 104 | + if len(mu_vecs) == 1: |
| 105 | + self.mu_vecs = np.broadcast_to(np.reshape(self.mu_vecs, (self.c_degree, 1)), (self.c_degree, self.c_num_classes)) |
| 106 | + if lambda_mats is not None: |
| 107 | + message = "" |
| 108 | + if lambda_mats.shape[:2] != (self.c_degree, self.c_degree) |
| 109 | + message += "lambda_mats.shape[:2] must coincide with (self.c_degree, self.c_degree):" |
| 110 | + +f"lambda_mats.shape[:2]={lambda_mats.shape[:2]}, (self.c_degree, self.c_degree)={(self.c_degree, self.c_degree)}" |
| 111 | + if len(lambda_mats) == 3: |
| 112 | + if lambda_mats.shape[2] != self.c_num_classes: |
| 113 | + message += "lambda_mats.shape[2] must coincide with self.c_num_classes:" |
| 114 | + +f"lambda_mats.shape[2]={lambda_mats.shape[2]}, self.c_num_classes={self.c_num_classes}" |
| 115 | + if message != "": |
| 116 | + raise(ParameterFormatError(message)) |
| 117 | + self.lambda_mats = _check.float_vecs(lambda_mats, "lambda_mats", ParameterFormatError) |
| 118 | + if len(self.lambda_mats) == 2: |
| 119 | + self.lambda_mats = np.broadcast_to( |
| 120 | + np.reshape(self.lambda_mats, (self.c_degree,self.c_degree,1)), |
| 121 | + (self.c_degree,self.c_degree,self.c_num_classes) |
| 122 | + ) |
76 | 123 |
|
77 | 124 | def set_h_params(
|
78 | 125 | self,
|
|
0 commit comments