Skip to content

Commit 99f97ab

Browse files
Fix get loss in deep
1 parent beee556 commit 99f97ab

File tree

2 files changed

+3
-38
lines changed

2 files changed

+3
-38
lines changed

adapt/feature_based/_cdan.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,7 @@ def get_metrics(self, inputs_ys, inputs_yt,
343343
return metrics
344344

345345

346-
def _build(self, shape_Xs, shape_ys,
347-
shape_Xt, shape_yt):
348-
346+
def _initialize_networks(self, shape_Xt):
349347
# Call predict to avoid strange behaviour with
350348
# Sequential model whith unspecified input_shape
351349
zeros_enc_ = self.encoder_.predict(np.zeros((1,) + shape_Xt));
@@ -357,39 +355,6 @@ def _build(self, shape_Xs, shape_ys,
357355
np.expand_dims(zeros_task_, 1))
358356
zeros_mapping_ = np.reshape(zeros_mapping_, (1, -1))
359357
self.discriminator_.predict(zeros_mapping_);
360-
361-
inputs_Xs = Input(shape_Xs)
362-
inputs_ys = Input(shape_ys)
363-
inputs_Xt = Input(shape_Xt)
364-
365-
if shape_yt is None:
366-
inputs_yt = None
367-
inputs = [inputs_Xs, inputs_ys, inputs_Xt]
368-
else:
369-
inputs_yt = Input(shape_yt)
370-
inputs = [inputs_Xs, inputs_ys,
371-
inputs_Xt, inputs_yt]
372-
373-
outputs = self.create_model(inputs_Xs=inputs_Xs,
374-
inputs_Xt=inputs_Xt)
375-
376-
self.model_ = Model(inputs, outputs)
377-
378-
loss = self.get_loss(inputs_ys=inputs_ys,
379-
**outputs)
380-
metrics = self.get_metrics(inputs_ys=inputs_ys,
381-
inputs_yt=inputs_yt,
382-
**outputs)
383-
384-
self.model_.add_loss(loss)
385-
for k in metrics:
386-
self.model_.add_metric(tf.reduce_mean(metrics[k]),
387-
name=k, aggregation="mean")
388-
389-
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
390-
self.model_.compile(optimizer=self.optimizer)
391-
self.history_ = {}
392-
return self
393358

394359

395360
def predict_disc(self, X):

tests/test_deep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
4545
return dict(task_s=task_s, task_t=task_t,
4646
disc_s=disc_s, disc_t=disc_t)
4747

48-
def get_loss(self, inputs_ys,
48+
def get_loss(self, inputs_ys, inputs_yt,
4949
task_s, task_t,
5050
disc_s, disc_t):
5151

@@ -132,7 +132,7 @@ def test_basedeep_metrics():
132132

133133
def test_basedeep_silent_methods():
134134
model = BaseDeepFeature()
135-
model.get_loss(0)
135+
model.get_loss(0, None)
136136
model.create_model(0, 0)
137137

138138

0 commit comments

Comments
 (0)