@@ -25,12 +25,12 @@ def __init__(
25
25
a_mat = None ,
26
26
mu_vecs = None ,
27
27
lambda_mats = None ,
28
+ h_eta_vec = None ,
29
+ h_zeta_vecs = None ,
28
30
h_m_vecs = None ,
29
31
h_kappas = None ,
30
32
h_nus = None ,
31
33
h_w_mats = None ,
32
- h_eta_vec = None ,
33
- h_zeta_vecs = None ,
34
34
seed = None
35
35
):
36
36
# constants
@@ -52,8 +52,19 @@ def __init__(
52
52
self .h_nus = np .ones (self .c_num_classes ) * self .c_degree
53
53
self .h_w_mats = np .tile (np .identity (self .c_degree ),[self .c_num_classes ,1 ,1 ])
54
54
55
- self .set_params (pi_vec ,a_mat ,mu_vecs ,lambda_mats )
56
- self .set_h_params (h_eta_vec ,h_zeta_vecs ,h_m_vecs ,h_kappas ,h_nus ,h_w_mats )
55
+ self .set_params (
56
+ pi_vec ,
57
+ a_mat ,
58
+ mu_vecs ,
59
+ lambda_mats )
60
+
61
+ self .set_h_params (
62
+ h_eta_vec ,
63
+ h_zeta_vecs ,
64
+ h_m_vecs ,
65
+ h_kappas ,
66
+ h_nus ,
67
+ h_w_mats )
57
68
58
69
def set_params (
59
70
self ,
@@ -160,14 +171,18 @@ def set_h_params(
160
171
self .h_w_mats [:] = h_w_mats
161
172
162
173
def get_params (self ):
163
- # paramsを辞書として返す関数.
164
- # 要素の順番はset_paramsの引数の順にそろえる.
165
- pass
174
+ return {'pi_vec' :self .pi_vec ,
175
+ 'a_mat' :self .a_mat ,
176
+ 'mu_vecs' :self .mu_vecs ,
177
+ 'lambda_mats' : self .lambda_mats }
166
178
167
179
def get_h_params (self ):
168
- # h_paramsを辞書として返す関数.
169
- # 要素の順番はset_h_paramsの引数の順にそろえる.
170
- pass
180
+ return {'h_eta_vec' :self .h_eta_vec ,
181
+ 'h_zeta_vecs' :self .h_zeta_vecs ,
182
+ 'h_m_vecs' :self .h_m_vecs ,
183
+ 'h_kappas' :self .h_kappas ,
184
+ 'h_nus' :self .h_nus ,
185
+ 'h_w_mats' :self .h_w_mats }
171
186
172
187
# まだ実装しなくてよい
173
188
def gen_params (self ):
0 commit comments