@@ -98,8 +98,8 @@ def set_params(
98
98
if a_mat is not None :
99
99
_check .float_vecs_sum_1 (a_mat , "a_mat" , ParameterFormatError )
100
100
_check .shape_consistency (
101
- a_mat .shape [- 1 ], "a_mat.shape[-1]" ,
102
- self .c_num_classes , "self.c_num_classes" ,
101
+ a_mat .shape [- 1 ],"a_mat.shape[-1]" ,
102
+ self .c_num_classes ,"self.c_num_classes" ,
103
103
ParameterFormatError
104
104
)
105
105
self .a_mat [:] = a_mat
@@ -184,19 +184,15 @@ def get_h_params(self):
184
184
'h_nus' :self .h_nus ,
185
185
'h_w_mats' :self .h_w_mats }
186
186
187
- # まだ実装しなくてよい
188
187
def gen_params (self ):
189
188
pass
190
189
191
- # まだ実装しなくてよい
192
190
def gen_sample (self ):
193
191
pass
194
192
195
- # まだ実装しなくてよい
196
193
def save_sample (self ):
197
194
pass
198
195
199
- # まだ実装しなくてよい
200
196
def visualize_model (self ):
201
197
pass
202
198
@@ -207,12 +203,12 @@ def __init__(
207
203
c_num_classes ,
208
204
c_degree ,
209
205
* ,
206
+ h0_eta_vec = None ,
207
+ h0_zeta_vecs = None ,
210
208
h0_m_vecs = None ,
211
209
h0_kappas = None ,
212
210
h0_nus = None ,
213
211
h0_w_mats = None ,
214
- h0_eta_vec = None ,
215
- h0_zeta_vecs = None ,
216
212
seed = None
217
213
):
218
214
# constants
@@ -221,22 +217,22 @@ def __init__(
221
217
self .rng = np .random .default_rng (seed )
222
218
223
219
# h0_params
220
+ self .h0_eta_vec = np .ones (self .c_num_classes ) / 2.0
221
+ self .h0_zeta_vecs = np .ones ([self .c_num_classes ,self .c_num_classes ]) / 2.0
224
222
self .h0_m_vecs = np .zeros ([self .c_num_classes ,self .c_degree ])
225
223
self .h0_kappas = np .ones ([self .c_num_classes ])
226
224
self .h0_nus = np .ones (self .c_num_classes ) * self .c_degree
227
225
self .h0_w_mats = np .tile (np .identity (self .c_degree ),[self .c_num_classes ,1 ,1 ])
228
226
self .h0_w_mats_inv = np .linalg .inv (self .h0_w_mats )
229
- self .h0_eta_vec = np .ones (self .c_num_classes ) / 2.0
230
- self .h0_zeta_vecs = np .ones ([self .c_num_classes ,self .c_num_classes ]) / 2.0
231
227
232
228
# hn_params
229
+ self .hn_eta_vec = np .empty (self .c_num_classes )
230
+ self .hn_zeta_vecs = np .empty ([self .c_num_classes ,self .c_num_classes ])
233
231
self .hn_m_vecs = np .empty ([self .c_num_classes ,self .c_degree ])
234
232
self .hn_kappas = np .empty ([self .c_num_classes ])
235
233
self .hn_nus = np .empty (self .c_num_classes )
236
234
self .hn_w_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
237
235
self .hn_w_mats_inv = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
238
- self .hn_eta_vec = np .empty (self .c_num_classes )
239
- self .hn_zeta_vecs = np .empty ([self .c_num_classes ,self .c_num_classes ])
240
236
241
237
# p_params
242
238
self .p_mu_vecs = np .empty ([self .c_num_classes ,self .c_degree ])
@@ -245,118 +241,158 @@ def __init__(
245
241
self .p_lambda_mats_inv = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
246
242
247
243
self .set_h0_params (
244
+ h0_eta_vec ,
245
+ h0_zeta_vecs ,
248
246
h0_m_vecs ,
249
247
h0_kappas ,
250
248
h0_nus ,
251
249
h0_w_mats ,
252
- h0_eta_vec ,
253
- h0_zeta_vecs ,
254
250
)
255
251
256
252
def set_h0_params (
257
253
self ,
254
+ h0_eta_vec = None ,
255
+ h0_zeta_vecs = None ,
258
256
h0_m_vecs = None ,
259
257
h0_kappas = None ,
260
258
h0_nus = None ,
261
259
h0_w_mats = None ,
262
- h0_eta_vec = None ,
263
- h0_zeta_vecs = None ,
264
260
):
265
- # Noneでない入力について,以下をチェックする.
266
- # * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
267
- # * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
268
- # 例
269
- # if h0_m_vecs is not None:
270
- # _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
271
- # if h0_m_vecs.shape[-1] != self.degree:
272
- # raise(ParameterFormatError(
273
- # "h0_m_vecs.shape[-1] must coincide with self.degree:"
274
- # +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
275
- # self.h0_m_vecs[:] = h0_m_vecs
276
-
277
- # 最後にreset_hn_params()を呼ぶようにする
261
+ if h0_eta_vec is not None :
262
+ _check .pos_floats (h0_eta_vec ,'h0_eta_vec' ,ParameterFormatError )
263
+ self .h0_eta_vec [:] = h0_eta_vec
264
+
265
+ if h0_zeta_vecs is not None :
266
+ _check .pos_floats (h0_zeta_vecs , 'h0_zeta_vecs' , ParameterFormatError )
267
+ self .h0_zeta_vecs [:] = h0_zeta_vecs
268
+
269
+ if h0_m_vecs is not None :
270
+ _check .float_vecs (h0_m_vecs , "h0_m_vecs" , ParameterFormatError )
271
+ _check .shape_consistency (
272
+ h0_m_vecs .shape [- 1 ],"h0_m_vecs.shape[-1]" ,
273
+ self .c_degree ,"self.c_degree" ,
274
+ ParameterFormatError
275
+ )
276
+ self .h0_m_vecs [:] = h0_m_vecs
277
+
278
+ if h0_kappas is not None :
279
+ _check .pos_floats (h0_kappas , "h0_kappas" , ParameterFormatError )
280
+ self .h0_kappas [:] = h0_kappas
281
+
282
+ if h0_nus is not None :
283
+ _check .floats (h0_nus , "h0_nus" , ParameterFormatError )
284
+ if np .all (h0_nus <= self .c_degree - 1 ):
285
+ raise (ParameterFormatError (
286
+ "All the values in h0_nus must be greater than self.c_degree - 1: "
287
+ + f"self.c_degree = { self .c_degree } , h0_nus = { h0_nus } " ))
288
+ self .h0_nus [:] = h0_nus
289
+
290
+ if h0_w_mats is not None :
291
+ _check .pos_def_sym_mats (h0_w_mats ,'h0_w_mats' ,ParameterFormatError )
292
+ _check .shape_consistency (
293
+ h0_w_mats .shape [- 1 ],"h0_w_mats.shape[-1] and h0_w_mats.shape[-2]" ,
294
+ self .c_degree ,"self.c_degree" ,
295
+ ParameterFormatError
296
+ )
297
+ self .h0_w_mats [:] = h0_w_mats
298
+
299
+ self .h0_w_mats_inv [:] = np .linalg .inv (self .h0_w_mats )
300
+
278
301
self .reset_hn_params ()
279
302
280
303
def get_h0_params (self ):
281
- # h0_paramsを辞書として返す関数.
282
- # 要素の順番はset_h_paramsの引数の順にそろえる.
283
- pass
304
+ return {'h0_eta_vec' :self .h0_eta_vec ,
305
+ 'h0_zeta_vecs' :self .h0_zeta_vecs ,
306
+ 'h0_m_vecs' :self .h0_m_vecs ,
307
+ 'h0_kappas' :self .h0_kappas ,
308
+ 'h0_nus' :self .h0_nus ,
309
+ 'h0_w_mats' :self .h0_w_mats }
284
310
285
311
def set_hn_params (
286
312
self ,
313
+ hn_eta_vec = None ,
314
+ hn_zeta_vecs = None ,
287
315
hn_m_vecs = None ,
288
316
hn_kappas = None ,
289
317
hn_nus = None ,
290
318
hn_w_mats = None ,
291
- hn_eta_vec = None ,
292
- hn_zeta_vecs = None ,
293
319
):
294
- # Noneでない入力について,以下をチェックする.
295
- # * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
296
- # * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
297
- # 例
298
- # if h0_m_vecs is not None:
299
- # _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
300
- # if h0_m_vecs.shape[-1] != self.degree:
301
- # raise(ParameterFormatError(
302
- # "h0_m_vecs.shape[-1] must coincide with self.degree:"
303
- # +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
304
- # self.h0_m_vecs[:] = h0_m_vecs
305
-
306
- # 最後にcalc_pred_dist()を呼ぶようにする
320
+ if hn_eta_vec is not None :
321
+ _check .pos_floats (hn_eta_vec ,'hn_eta_vec' ,ParameterFormatError )
322
+ self .hn_eta_vec [:] = hn_eta_vec
323
+
324
+ if hn_zeta_vecs is not None :
325
+ _check .pos_floats (hn_zeta_vecs , 'hn_zeta_vecs' , ParameterFormatError )
326
+ self .hn_zeta_vecs [:] = hn_zeta_vecs
327
+
328
+ if hn_m_vecs is not None :
329
+ _check .float_vecs (hn_m_vecs , "hn_m_vecs" , ParameterFormatError )
330
+ _check .shape_consistency (
331
+ hn_m_vecs .shape [- 1 ],"hn_m_vecs.shape[-1]" ,
332
+ self .c_degree ,"self.c_degree" ,
333
+ ParameterFormatError
334
+ )
335
+ self .hn_m_vecs [:] = hn_m_vecs
336
+
337
+ if hn_kappas is not None :
338
+ _check .pos_floats (hn_kappas , "hn_kappas" , ParameterFormatError )
339
+ self .hn_kappas [:] = hn_kappas
340
+
341
+ if hn_nus is not None :
342
+ _check .floats (hn_nus , "hn_nus" , ParameterFormatError )
343
+ if np .all (hn_nus <= self .c_degree - 1 ):
344
+ raise (ParameterFormatError (
345
+ "All the values in hn_nus must be greater than self.c_degree - 1: "
346
+ + f"self.c_degree = { self .c_degree } , hn_nus = { hn_nus } " ))
347
+ self .hn_nus [:] = hn_nus
348
+
349
+ if hn_w_mats is not None :
350
+ _check .pos_def_sym_mats (hn_w_mats ,'hn_w_mats' ,ParameterFormatError )
351
+ _check .shape_consistency (
352
+ hn_w_mats .shape [- 1 ],"hn_w_mats.shape[-1] and hn_w_mats.shape[-2]" ,
353
+ self .c_degree ,"self.c_degree" ,
354
+ ParameterFormatError
355
+ )
356
+ self .hn_w_mats [:] = hn_w_mats
357
+
358
+ self .hn_w_mats_inv [:] = np .linalg .inv (self .hn_w_mats )
359
+
307
360
self .calc_pred_dist ()
308
361
309
362
def get_hn_params (self ):
310
- # hn_paramsを辞書として返す関数.
311
- # 要素の順番はset_h_paramsの引数の順にそろえる.
312
- pass
363
+ return {'hn_eta_vec' :self .hn_eta_vec ,
364
+ 'hn_zeta_vecs' :self .hn_zeta_vecs ,
365
+ 'hn_m_vecs' :self .hn_m_vecs ,
366
+ 'hn_kappas' :self .hn_kappas ,
367
+ 'hn_nus' :self .hn_nus ,
368
+ 'hn_w_mats' :self .hn_w_mats }
313
369
314
370
def reset_hn_params (self ):
315
- # h0_paramsの値をhn_paramsの値にそのままコピーする
316
- # 配列サイズを揃えてあるので,簡単に書けるはず.
317
- # 例
318
- # self.hn_alpha_vec[:] = self.h0_alpha_vec
319
- # self.hn_m_vecs[:] = self.h0_m_vecs
320
- # self.hn_kappas[:] = self.h0_kappas
321
-
322
- # 最後にcalc_pred_distを呼ぶ.
323
- self .calc_pred_dist ()
371
+ self .set_hn_params (* self .get_h0_params ().values ())
324
372
325
373
def overwrite_h0_params (self ):
326
- # hn_paramsの値をh0_paramsの値にそのままコピーする
327
- # 配列サイズを揃えてあるので,簡単に書けるはず.
328
- # 例
329
- # self.h0_alpha_vec[:] = self.hn_alpha_vec
330
- # self.h0_m_vecs[:] = self.hn_m_vecs
331
- # self.h0_kappas[:] = self.hn_kappas
332
-
333
- # 最後にcalc_pred_distを呼ぶ.
334
- self .calc_pred_dist ()
374
+ self .set_h0_params (* self .get_hn_params ().values ())
335
375
336
- # まだ実装しなくてよい
337
376
def update_posterior ():
338
377
pass
339
378
340
- # まだ実装しなくてよい
341
379
def estimate_params (self ,loss = "squared" ):
342
380
pass
343
381
344
- # まだ実装しなくてよい
345
382
def visualize_posterior (self ):
346
383
pass
347
384
348
385
def get_p_params (self ):
349
- # p_paramsを辞書として返す関数.
350
- pass
386
+ return {'p_mu_vecs' :self .p_mu_vecs ,
387
+ 'p_nus' :self .p_nus ,
388
+ 'p_lambda_mats' :self .p_lambda_mats ,
389
+ 'p_lambda_mats_inv' :self .p_lambda_mats_inv }
351
390
352
- # まだ実装しなくてよい
353
391
def calc_pred_dist (self ):
354
392
pass
355
393
356
- # まだ実装しなくてよい
357
394
def make_prediction (self ,loss = "squared" ):
358
395
pass
359
396
360
- # まだ実装しなくてよい
361
397
def pred_and_update (self ,x ,loss = "squared" ):
362
398
pass
0 commit comments