Skip to content

Commit ed1b107

Browse files
Change MDD loss + Add MCD multiple steps encoder
1 parent 31cc3fb commit ed1b107

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

adapt/feature_based/_mcd.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ def train_step(self, data):
141141
Xs = Xs[0]
142142
ys = ys[0]
143143

144+
145+
for _ in range(4):
146+
with tf.GradientTape() as enc_tape:
147+
Xt_enc = self.encoder_(Xt, training=True)
148+
yt_pred = self.task_(Xt_enc, training=True)
149+
yt_disc = self.discriminator_(Xt_enc, training=True)
150+
151+
# Reshape
152+
yt_pred = tf.reshape(yt_pred, tf.shape(ys))
153+
yt_disc = tf.reshape(yt_disc, tf.shape(ys))
154+
155+
discrepancy = tf.reduce_mean(tf.abs(yt_pred - yt_disc))
156+
enc_loss = discrepancy
157+
enc_loss += sum(self.encoder_.losses)
158+
159+
# Compute gradients
160+
trainable_vars_enc = self.encoder_.trainable_variables
161+
gradients_enc = enc_tape.gradient(enc_loss, trainable_vars_enc)
162+
self.optimizer.apply_gradients(zip(gradients_enc, trainable_vars_enc))
163+
144164
# loss
145165
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:
146166
# Forward pass
@@ -174,7 +194,6 @@ def train_step(self, data):
174194
disc_loss += sum(self.discriminator_.losses)
175195
enc_loss += sum(self.encoder_.losses)
176196

177-
178197
# Compute gradients
179198
trainable_vars_task = self.task_.trainable_variables
180199
trainable_vars_enc = self.encoder_.trainable_variables

adapt/feature_based/_mdd.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MDD(BaseAdaptDeep):
2525
2626
Parameters
2727
----------
28-
lambda_ : float (default=1.)
28+
lambda_ : float (default=0.1)
2929
Trade-off parameter
3030
3131
gamma : float (default=4.)
@@ -71,7 +71,7 @@ def __init__(self,
7171
task=None,
7272
Xt=None,
7373
yt=None,
74-
lambda_=1.,
74+
lambda_=0.1,
7575
gamma=4.,
7676
copy=True,
7777
verbose=1,
@@ -124,8 +124,8 @@ def train_step(self, data):
124124
tf.shape(ys_pred)[1])
125125
argmax_tgt = tf.one_hot(tf.math.argmax(yt_pred, -1),
126126
tf.shape(yt_pred)[1])
127-
disc_loss_src = self.task_loss_(argmax_src, ys_disc)
128-
disc_loss_tgt = self.task_loss_(argmax_tgt, yt_disc)
127+
disc_loss_src = -tf.math.log(tf.reduce_sum(argmax_src * ys_disc, 1) + EPS)
128+
disc_loss_tgt = tf.math.log(1. - tf.reduce_sum(argmax_tgt * yt_disc, 1) + EPS)
129129
else:
130130
disc_loss_src = self.task_loss_(ys_pred, ys_disc)
131131
disc_loss_tgt = self.task_loss_(yt_pred, yt_disc)
@@ -168,6 +168,7 @@ def train_step(self, data):
168168
logs = {m.name: m.result() for m in self.metrics}
169169
# disc_metrics = self._get_disc_metrics(ys_disc, yt_disc)
170170
logs.update({"disc_loss": disc_loss})
171+
logs.update({"disc_src": disc_loss_src, "disc_tgt": disc_loss_tgt})
171172
return logs
172173

173174

0 commit comments

Comments
 (0)