@@ -82,32 +82,32 @@ def set_params(
82
82
tmp_mu_vecs = None if mu_vecs is None else _check .float_vecs (mu_vecs , "mu_vecs" , ParameterFormatError )
83
83
tmp_lambda_mats = None if lambda_mats is None else _check .float_vecs (lambda_mats , "lambda_mats" , ParameterFormatError )
84
84
85
- # [Dimension value consistency]
85
+ # [Dimension consistency]
86
86
if tmp_pi_vec is not None :
87
87
_check .shape_consistency (
88
- val = pi_vec , val_name = "pi_vec" ,
88
+ val = tmp_pi_vec , val_name = "pi_vec" ,
89
89
correct = [(self .c_num_classes ,)], correct_name = "(self.c_num_classes,)" ,
90
90
exception_class = ParameterFormatError
91
91
)
92
92
self .pi_vec [:] = tmp_pi_vec
93
93
if tmp_a_mat is not None :
94
94
_check .shape_consistency (
95
- val = a_mat . shape , val_name = "a_mat" ,
95
+ val = tmp_a_mat , val_name = "a_mat" ,
96
96
correct = [(self .c_num_classes , self .c_num_classes )], correct_name = "(self.c_num_classes, self.c_num_classes)" ,
97
97
exception_class = ParameterFormatError
98
98
)
99
99
self .a_mat [:] = tmp_a_mat
100
100
if tmp_mu_vecs is not None :
101
101
_check .shape_consistency (
102
- val = mu_vecs . shape , val_name = "mu_vecs" ,
102
+ val = tmp_mu_vecs , val_name = "mu_vecs" ,
103
103
correct = [(self .c_degree ,), (self .c_num_classes , self .c_degree )],
104
104
correct_name = "(self.c_degree,), (self.c_num_classes, self.c_degree)" ,
105
105
exception_class = ParameterFormatError
106
106
)
107
107
self .mu_vecs [:] = tmp_mu_vecs
108
108
if tmp_lambda_mats is not None :
109
109
_check .shape_consistency (
110
- val = lambda_mats . shape , val_name = "lambda_mats" ,
110
+ val = tmp_lambda_mats , val_name = "lambda_mats" ,
111
111
correct = [(self .c_degree ,self .c_degree ), (self .c_num_classes ,self .c_degree ,self .c_degree )],
112
112
correct_name = "(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)" ,
113
113
exception_class = ParameterFormatError
@@ -136,6 +136,49 @@ def set_h_params(
136
136
# self.h0_m_vecs[:] = h0_m_vecs
137
137
pass
138
138
139
+ # [Check values]
140
+ tmp_h_m_vecs = None if h_m_vecs is None else _check .float_vecs (h_m_vecs , "h_m_vecs" , ParameterFormatError )
141
+ tmp_h_kappas = None if h_kappas is None else np .array (_check .pos_floats (h_kappas , "h_kappas" , ParameterFormatError ))
142
+ tmp_h_nus = None if h_nus is None else np .array (_check .floats (h_nus , "h_nus" , ParameterFormatError ))
143
+ tmp_h_w_mats = None if h_w_mats is None else _check .pos_float_vecs (h_w_mats , "h_w_mats" , ParameterFormatError )
144
+ tmp_h_eta_vec = None if h_eta_vec is None else _check .pos_float_vec (h_eta_vec , "h_eta_vec" , ParameterFormatError )
145
+ tmp_h_zeta_vecs = None if h_zeta_vecs is None else _check .pos_float_vecs (h_zeta_vecs , "h_zeta_vecs" , ParameterFormatError )
146
+
147
+ # [Dimension consistency]
148
+ if tmp_h_m_vecs is not None :
149
+ _check .shape_consistency (
150
+ val = tmp_h_m_vecs , val_name = "h_m_vecs" ,
151
+ correct = [(self .c_num_classes ,self .c_degree ), (self .c_degree ,)],
152
+ correct_name = "(self.c_num_classes,self.c_degree), (self.c_degree,)" ,
153
+ exception_class = ParameterFormatError
154
+ )
155
+ self .h_m_vecs [:] = tmp_h_m_vecs
156
+ if tmp_h_kappas is not None :
157
+ _check .shape_consistency (
158
+ val = tmp_h_kappas , val_name = "h_kappas" ,
159
+ correct = [(self .c_num_classes ,), ()], correct_name = "(self.c_num_classes,), ()" ,
160
+ exception_class = ParameterFormatError
161
+ )
162
+ self .h_kappas [:] = tmp_h_kappas
163
+ if tmp_h_nus is not None :
164
+ if not np .all (tmp_h_nus > self .c_degree - 1 ):
165
+ raise (ParameterFormatError ("The all values in h_nus must be greater than self.c_degree." ))
166
+ _check .shape_consistency (
167
+ val = tmp_h_nus , val_name = "h_nus" ,
168
+ correct = [(self .c_num_classes ,), ()],
169
+ correct_name = "(self.c_num_classes,), ()" ,
170
+ exception_class = ParameterFormatError
171
+ )
172
+ self .h_nus [:] = tmp_h_nus
173
+ if tmp_h_w_mats is not None :
174
+ _check .shape_consistency (
175
+ val = tmp_h_w_mats .shape , val_name = "h_w_mats" ,
176
+ correct = [(self .c_num_classes ,self .c_degree ,self .c_degree ), (self .c_degree ,self .c_degree )],
177
+ correct_name = "(self.c_num_classes,self.c_degree,self.c_degree), (self.c_degree,self.c_degree)" ,
178
+ exception_class = ParameterFormatError
179
+ )
180
+ self .h_w_mats [:] = tmp_h_w_mats
181
+
139
182
def get_params (self ):
140
183
# paramsを辞書として返す関数.
141
184
# 要素の順番はset_paramsの引数の順にそろえる.
0 commit comments