@@ -45,15 +45,15 @@ def __init__(
45
45
self .lambda_mats = np .tile (np .identity (self .c_degree ),[self .c_num_classes ,1 ,1 ])
46
46
47
47
# h_params
48
+ self .h_eta_vec = np .ones (self .c_num_classes ) / 2.0
49
+ self .h_zeta_vecs = np .ones ([self .c_num_classes ,self .c_num_classes ]) / 2.0
48
50
self .h_m_vecs = np .zeros ([self .c_num_classes ,self .c_degree ])
49
51
self .h_kappas = np .ones ([self .c_num_classes ])
50
52
self .h_nus = np .ones (self .c_num_classes ) * self .c_degree
51
53
self .h_w_mats = np .tile (np .identity (self .c_degree ),[self .c_num_classes ,1 ,1 ])
52
- self .h_eta_vec = np .ones (self .c_num_classes ) / 2.0
53
- self .h_zeta_vecs = np .ones ([self .c_num_classes ,self .c_num_classes ]) / 2.0
54
54
55
55
self .set_params (pi_vec ,a_mat ,mu_vecs ,lambda_mats )
56
- self .set_h_params (h_m_vecs ,h_kappas ,h_nus ,h_w_mats , h_eta_vec , h_zeta_vecs )
56
+ self .set_h_params (h_eta_vec , h_zeta_vecs , h_m_vecs ,h_kappas ,h_nus ,h_w_mats )
57
57
58
58
def set_params (
59
59
self ,
@@ -69,115 +69,95 @@ def set_params(
69
69
pi_vec : numpy.ndarray
70
70
a real vector in :math:`[0, 1]^K`. The sum of its elements must be 1.
71
71
a_mat : numpy.ndarray
72
- a real matrix in :math:`[0, 1]^{KxK}`. The sum of each column elements must be 1.
72
+ a real matrix in :math:`[0, 1]^{KxK}`. The sum of each row elements must be 1.
73
73
mu_vecs : numpy.ndarray
74
74
vectors of real numbers
75
75
lambda_mats : numpy.ndarray
76
76
positive definite symetric matrices
77
77
"""
78
-
79
- # [Check values]
80
- tmp_pi_vec = None if pi_vec is None else _check .float_vec_sum_1 (pi_vec , "pi_vec" , ParameterFormatError )
81
- tmp_a_mat = None if a_mat is None else _check .float_vec_sum_1 (a_mat , "a_mat" , ParameterFormatError , ndim = 2 , sum_axis = 1 )
82
- tmp_mu_vecs = None if mu_vecs is None else _check .float_vecs (mu_vecs , "mu_vecs" , ParameterFormatError )
83
- tmp_lambda_mats = None if lambda_mats is None else _check .float_vecs (lambda_mats , "lambda_mats" , ParameterFormatError )
84
-
85
- # [Dimension consistency]
86
- if tmp_pi_vec is not None :
78
+ if pi_vec is not None :
79
+ _check .float_vec_sum_1 (pi_vec , "pi_vec" , ParameterFormatError )
87
80
_check .shape_consistency (
88
- val = tmp_pi_vec , val_name = "pi_vec" ,
89
- correct = [(self .c_num_classes ,)], correct_name = "(self.c_num_classes,)" ,
90
- exception_class = ParameterFormatError
91
- )
92
- self .pi_vec [:] = tmp_pi_vec
93
- if tmp_a_mat is not None :
81
+ pi_vec .shape [0 ],"pi_vec.shape[0]" ,
82
+ self .c_num_classes ,"self.c_num_classes" ,
83
+ ParameterFormatError
84
+ )
85
+ self .pi_vec [:] = pi_vec
86
+
87
+ if a_mat is not None :
88
+ _check .float_vecs_sum_1 (a_mat , "a_mat" , ParameterFormatError )
94
89
_check .shape_consistency (
95
- val = tmp_a_mat , val_name = "a_mat" ,
96
- correct = [(self .c_num_classes , self .c_num_classes )], correct_name = "(self.c_num_classes, self.c_num_classes)" ,
97
- exception_class = ParameterFormatError
98
- )
99
- self .a_mat [:] = tmp_a_mat
100
- if tmp_mu_vecs is not None :
90
+ a_mat .shape [- 1 ], "a_mat.shape[-1]" ,
91
+ self .c_num_classes , "self.c_num_classes" ,
92
+ ParameterFormatError
93
+ )
94
+ self .a_mat [:] = a_mat
95
+
96
+ if mu_vecs is not None :
97
+ _check .float_vecs (mu_vecs , "mu_vecs" , ParameterFormatError )
101
98
_check .shape_consistency (
102
- val = tmp_mu_vecs , val_name = "mu_vecs" ,
103
- correct = [(self .c_degree ,), (self .c_num_classes , self .c_degree )],
104
- correct_name = "(self.c_degree,), (self.c_num_classes, self.c_degree)" ,
105
- exception_class = ParameterFormatError
106
- )
107
- self .mu_vecs [:] = tmp_mu_vecs
108
- if tmp_lambda_mats is not None :
99
+ mu_vecs .shape [- 1 ],"mu_vecs.shape[-1]" ,
100
+ self .c_degree ,"self.c_degree" ,
101
+ ParameterFormatError
102
+ )
103
+ self .mu_vecs [:] = mu_vecs
104
+
105
+ if lambda_mats is not None :
106
+ _check .pos_def_sym_mats (lambda_mats ,'lambda_mats' ,ParameterFormatError )
109
107
_check .shape_consistency (
110
- val = tmp_lambda_mats , val_name = "lambda_mats" ,
111
- correct = [(self .c_degree ,self .c_degree ), (self .c_num_classes ,self .c_degree ,self .c_degree )],
112
- correct_name = "(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)" ,
113
- exception_class = ParameterFormatError
114
- )
115
- self .lambda_mats [:] = tmp_lambda_mats
108
+ lambda_mats .shape [- 1 ],"lambda_mats.shape[-1] and lambda_mats.shape[-2]" ,
109
+ self .c_degree ,"self.c_degree" ,
110
+ ParameterFormatError
111
+ )
112
+ self .lambda_mats [:] = lambda_mats
116
113
117
114
def set_h_params (
118
115
self ,
116
+ h_eta_vec = None ,
117
+ h_zeta_vecs = None ,
119
118
h_m_vecs = None ,
120
119
h_kappas = None ,
121
120
h_nus = None ,
122
121
h_w_mats = None ,
123
- h_eta_vec = None ,
124
- h_zeta_vecs = None
125
122
):
126
- # Noneでない入力について,以下をチェックする.
127
- # * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
128
- # * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
129
- # 例
130
- # if h0_m_vecs is not None:
131
- # _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
132
- # if h0_m_vecs.shape[-1] != self.degree:
133
- # raise(ParameterFormatError(
134
- # "h0_m_vecs.shape[-1] must coincide with self.degree:"
135
- # +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
136
- # self.h0_m_vecs[:] = h0_m_vecs
137
- pass
138
123
139
- # [Check values]
140
- tmp_h_m_vecs = None if h_m_vecs is None else _check .float_vecs (h_m_vecs , "h_m_vecs" , ParameterFormatError )
141
- tmp_h_kappas = None if h_kappas is None else np .array (_check .pos_floats (h_kappas , "h_kappas" , ParameterFormatError ))
142
- tmp_h_nus = None if h_nus is None else np .array (_check .floats (h_nus , "h_nus" , ParameterFormatError ))
143
- tmp_h_w_mats = None if h_w_mats is None else _check .pos_float_vecs (h_w_mats , "h_w_mats" , ParameterFormatError )
144
- tmp_h_eta_vec = None if h_eta_vec is None else _check .pos_float_vec (h_eta_vec , "h_eta_vec" , ParameterFormatError )
145
- tmp_h_zeta_vecs = None if h_zeta_vecs is None else _check .pos_float_vecs (h_zeta_vecs , "h_zeta_vecs" , ParameterFormatError )
124
+ if h_eta_vec is not None :
125
+ _check .pos_floats (h_eta_vec ,'h_eta_vec' ,ParameterFormatError )
126
+ self .h_eta_vec [:] = h_eta_vec
146
127
147
- # [Dimension consistency]
148
- if tmp_h_m_vecs is not None :
149
- _check .shape_consistency (
150
- val = tmp_h_m_vecs , val_name = "h_m_vecs" ,
151
- correct = [(self .c_num_classes ,self .c_degree ), (self .c_degree ,)],
152
- correct_name = "(self.c_num_classes,self.c_degree), (self.c_degree,)" ,
153
- exception_class = ParameterFormatError
154
- )
155
- self .h_m_vecs [:] = tmp_h_m_vecs
156
- if tmp_h_kappas is not None :
157
- _check .shape_consistency (
158
- val = tmp_h_kappas , val_name = "h_kappas" ,
159
- correct = [(self .c_num_classes ,), ()], correct_name = "(self.c_num_classes,), ()" ,
160
- exception_class = ParameterFormatError
161
- )
162
- self .h_kappas [:] = tmp_h_kappas
163
- if tmp_h_nus is not None :
164
- if not np .all (tmp_h_nus > self .c_degree - 1 ):
165
- raise (ParameterFormatError ("The all values in h_nus must be greater than self.c_degree." ))
128
+ if h_zeta_vecs is not None :
129
+ _check .pos_floats (h_zeta_vecs , 'h_zeta_vecs' , ParameterFormatError )
130
+ self .h_zeta_vecs [:] = h_zeta_vecs
131
+
132
+ if h_m_vecs is not None :
133
+ _check .float_vecs (h_m_vecs , "h_m_vecs" , ParameterFormatError )
166
134
_check .shape_consistency (
167
- val = tmp_h_nus , val_name = "h_nus" ,
168
- correct = [(self .c_num_classes ,), ()],
169
- correct_name = "(self.c_num_classes,), ()" ,
170
- exception_class = ParameterFormatError
171
- )
172
- self .h_nus [:] = tmp_h_nus
173
- if tmp_h_w_mats is not None :
135
+ h_m_vecs .shape [- 1 ],"h_m_vecs.shape[-1]" ,
136
+ self .c_degree ,"self.c_degree" ,
137
+ ParameterFormatError
138
+ )
139
+ self .h_m_vecs [:] = h_m_vecs
140
+
141
+ if h_kappas is not None :
142
+ _check .pos_floats (h_kappas , "h_kappas" , ParameterFormatError )
143
+ self .h_kappas [:] = h_kappas
144
+
145
+ if h_nus is not None :
146
+ _check .floats (h_nus , "h_nus" , ParameterFormatError )
147
+ if np .all (h_nus <= self .c_degree - 1 ):
148
+ raise (ParameterFormatError (
149
+ "All the values in h_nus must be greater than self.c_degree - 1: "
150
+ + f"self.c_degree = { self .c_degree } , h_nus = { h_nus } " ))
151
+ self .h_nus [:] = h_nus
152
+
153
+ if h_w_mats is not None :
154
+ _check .pos_def_sym_mats (h_w_mats ,'h_w_mats' ,ParameterFormatError )
174
155
_check .shape_consistency (
175
- val = tmp_h_w_mats .shape , val_name = "h_w_mats" ,
176
- correct = [(self .c_num_classes ,self .c_degree ,self .c_degree ), (self .c_degree ,self .c_degree )],
177
- correct_name = "(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)" ,
178
- exception_class = ParameterFormatError
179
- )
180
- self .h_w_mats [:] = tmp_h_w_mats
156
+ h_w_mats .shape [- 1 ],"h_w_mats.shape[-1] and h_w_mats.shape[-2]" ,
157
+ self .c_degree ,"self.c_degree" ,
158
+ ParameterFormatError
159
+ )
160
+ self .h_w_mats [:] = h_w_mats
181
161
182
162
def get_params (self ):
183
163
# paramsを辞書として返す関数.
0 commit comments