@@ -343,9 +343,7 @@ def get_metrics(self, inputs_ys, inputs_yt,
343
343
return metrics
344
344
345
345
346
- def _build (self , shape_Xs , shape_ys ,
347
- shape_Xt , shape_yt ):
348
-
346
+ def _initialize_networks (self , shape_Xt ):
349
347
# Call predict to avoid strange behaviour with
350
348
# Sequential model whith unspecified input_shape
351
349
zeros_enc_ = self .encoder_ .predict (np .zeros ((1 ,) + shape_Xt ));
@@ -357,39 +355,6 @@ def _build(self, shape_Xs, shape_ys,
357
355
np .expand_dims (zeros_task_ , 1 ))
358
356
zeros_mapping_ = np .reshape (zeros_mapping_ , (1 , - 1 ))
359
357
self .discriminator_ .predict (zeros_mapping_ );
360
-
361
- inputs_Xs = Input (shape_Xs )
362
- inputs_ys = Input (shape_ys )
363
- inputs_Xt = Input (shape_Xt )
364
-
365
- if shape_yt is None :
366
- inputs_yt = None
367
- inputs = [inputs_Xs , inputs_ys , inputs_Xt ]
368
- else :
369
- inputs_yt = Input (shape_yt )
370
- inputs = [inputs_Xs , inputs_ys ,
371
- inputs_Xt , inputs_yt ]
372
-
373
- outputs = self .create_model (inputs_Xs = inputs_Xs ,
374
- inputs_Xt = inputs_Xt )
375
-
376
- self .model_ = Model (inputs , outputs )
377
-
378
- loss = self .get_loss (inputs_ys = inputs_ys ,
379
- ** outputs )
380
- metrics = self .get_metrics (inputs_ys = inputs_ys ,
381
- inputs_yt = inputs_yt ,
382
- ** outputs )
383
-
384
- self .model_ .add_loss (loss )
385
- for k in metrics :
386
- self .model_ .add_metric (tf .reduce_mean (metrics [k ]),
387
- name = k , aggregation = "mean" )
388
-
389
- tf .compat .v1 .logging .set_verbosity (tf .compat .v1 .logging .ERROR )
390
- self .model_ .compile (optimizer = self .optimizer )
391
- self .history_ = {}
392
- return self
393
358
394
359
395
360
def predict_disc (self , X ):
0 commit comments