Skip to content

Commit beee556

Browse files
Improve deep
1 parent a0260a5 commit beee556

File tree

10 files changed

+28
-15
lines changed

10 files changed

+28
-15
lines changed

adapt/feature_based/_adda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
407407
return outputs
408408

409409

410-
def get_loss(self, inputs_ys, disc_src, disc_tgt,
410+
def get_loss(self, inputs_ys, inputs_yt, disc_src, disc_tgt,
411411
disc_tgt_nograd, task_tgt):
412412

413413
loss_disc = (-K.log(disc_src + EPS)

adapt/feature_based/_cdan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
259259
return outputs
260260

261261

262-
def get_loss(self, inputs_ys,
262+
def get_loss(self, inputs_ys, inputs_yt,
263263
task_src, task_tgt,
264264
disc_src, disc_tgt,
265265
task_src_nograd,

adapt/feature_based/_dann.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
225225
return outputs
226226

227227

228-
def get_loss(self, inputs_ys,
228+
def get_loss(self, inputs_ys, inputs_yt,
229229
task_src, task_tgt,
230230
disc_src, disc_tgt):
231231

adapt/feature_based/_deep.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,17 @@ def __init__(self,
244244
self.random_state = random_state
245245

246246

247+
def _initialize_networks(self, shape_Xt):
248+
zeros_enc_ = self.encoder_.predict(np.zeros((1,) + shape_Xt));
249+
self.task_.predict(zeros_enc_);
250+
self.discriminator_.predict(zeros_enc_);
251+
252+
247253
def _build(self, shape_Xs, shape_ys,
248254
shape_Xt, shape_yt):
249-
250255
# Call predict to avoid strange behaviour with
251256
# Sequential model whith unspecified input_shape
252-
zeros_enc_ = self.encoder_.predict(np.zeros((1,) + shape_Xt));
253-
self.task_.predict(zeros_enc_);
254-
self.discriminator_.predict(zeros_enc_);
257+
self._initialize_networks(shape_Xt)
255258

256259
inputs_Xs = Input(shape_Xs)
257260
inputs_ys = Input(shape_ys)
@@ -271,6 +274,7 @@ def _build(self, shape_Xs, shape_ys,
271274
self.model_ = Model(inputs, outputs)
272275

273276
loss = self.get_loss(inputs_ys=inputs_ys,
277+
inputs_yt=inputs_yt,
274278
**outputs)
275279
metrics = self.get_metrics(inputs_ys=inputs_ys,
276280
inputs_yt=inputs_yt,
@@ -390,14 +394,17 @@ def create_model(self, inputs_Xs, inputs_Xt):
390394
pass
391395

392396

393-
def get_loss(self, inputs_ys, **ouputs):
397+
def get_loss(self, inputs_ys, inputs_yt, **ouputs):
394398
"""
395399
Get loss.
396400
397401
Parameters
398402
----------
399403
inputs_ys : InputLayer
400404
Input layer for ys entries.
405+
406+
inputs_yt : InputLayer
407+
Input layer for yt entries.
401408
402409
outputs : dict of tf Tensors
403410
Model outputs tensors.

adapt/feature_based/_deepcoral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
209209
return outputs
210210

211211

212-
def get_loss(self, inputs_ys,
212+
def get_loss(self, inputs_ys, inputs_yt,
213213
task_src, task_tgt,
214214
cov_src, cov_tgt):
215215

adapt/feature_based/_mcd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
238238
return outputs
239239

240240

241-
def get_loss(self, inputs_ys, task_src,
241+
def get_loss(self, inputs_ys, inputs_yt, task_src,
242242
task_tgt, task_sec_src, task_sec_tgt):
243243

244244
loss_task = 0.5 * (self.loss_(inputs_ys, task_src) + self.loss_(inputs_ys, task_sec_src))

adapt/feature_based/_mdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
171171
return outputs
172172

173173

174-
def get_loss(self, inputs_ys, task_src,
174+
def get_loss(self, inputs_ys, inputs_yt, task_src,
175175
task_src_nograd, task_tgt_nograd,
176176
task_tgt, disc_src, disc_tgt):
177177

adapt/feature_based/_wdgrl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
225225
return outputs
226226

227227

228-
def get_loss(self, inputs_ys,
228+
def get_loss(self, inputs_ys, inputs_yt,
229229
task_src, task_tgt,
230230
disc_src, disc_tgt,
231231
disc_grad):

adapt/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,17 @@ def get_default_encoder():
306306
return model
307307

308308

309-
def get_default_task():
309+
def get_default_task(activation=None):
310310
"""
311311
Return a tensorflow Model of two hidden layers
312312
with 10 neurons each and relu activations. The
313313
last layer is composed of one neuron with linear
314314
activation.
315+
316+
Parameters
317+
----------
318+
activation : str (default=None)
319+
Final activation
315320
316321
Returns
317322
-------
@@ -321,7 +326,7 @@ def get_default_task():
321326
model.add(Flatten())
322327
model.add(Dense(10, activation="relu"))
323328
model.add(Dense(10, activation="relu"))
324-
model.add(Dense(1, activation=None))
329+
model.add(Dense(1, activation=activation))
325330
return model
326331

327332

tests/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def is_equal_estimator(v1, v2):
5050
elif isinstance(v1, Tree):
5151
pass # TODO create a function to check if two tree are equal
5252
else:
53-
assert v1 == v2
53+
if not "input" in str(v1):
54+
assert v1 == v2
5455
return True
5556

5657

0 commit comments

Comments
 (0)