Skip to content

Commit 31cc3fb

Browse files
Fix same weights in MCD, MDD
1 parent 1aaa976 commit 31cc3fb

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

adapt/feature_based/_mcd.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,23 @@ def _initialize_networks(self):
214214
if self.task is None:
215215
self.discriminator_ = get_default_task(name="discriminator")
216216
else:
217+
# Impose Copy, else undesired behaviour
217218
self.discriminator_ = check_network(self.task,
218-
copy=self.copy,
219-
name="discriminator")
219+
copy=True,
220+
name="discriminator")
221+
222+
223+
def _initialize_weights(self, shape_X):
224+
# Init weights encoder
225+
self(np.zeros((1,) + shape_X))
226+
X_enc = self.encoder_(np.zeros((1,) + shape_X))
227+
self.task_(X_enc)
228+
self.discriminator_(X_enc)
229+
230+
# Add noise to discriminator in order to
231+
# differentiate from task
232+
weights = self.discriminator_.get_weights()
233+
for i in range(len(weights)):
234+
weights[i] += (0.01 * weights[i] *
235+
np.random.standard_normal(weights[i].shape))
236+
self.discriminator_.set_weights(weights)

adapt/feature_based/_mdd.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,22 @@ def _initialize_networks(self):
187187
if self.task is None:
188188
self.discriminator_ = get_default_task(name="discriminator")
189189
else:
190+
# Impose Copy, else undesired behaviour
190191
self.discriminator_ = check_network(self.task,
191-
copy=self.copy,
192+
copy=True,
192193
name="discriminator")
194+
195+
def _initialize_weights(self, shape_X):
196+
# Init weights encoder
197+
self(np.zeros((1,) + shape_X))
198+
X_enc = self.encoder_(np.zeros((1,) + shape_X))
199+
self.task_(X_enc)
200+
self.discriminator_(X_enc)
201+
202+
# Add noise to discriminator in order to
203+
# differentiate from task
204+
weights = self.discriminator_.get_weights()
205+
for i in range(len(weights)):
206+
weights[i] += (0.01 * weights[i] *
207+
np.random.standard_normal(weights[i].shape))
208+
self.discriminator_.set_weights(weights)

tests/test_mdd.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,20 @@ def test_fit():
6363
model.encoder_.get_weights()[0][0][0]) < 0.25
6464
assert np.sum(np.abs(model.predict(Xs).ravel() - ys)) < 0.1
6565
assert np.sum(np.abs(model.predict(Xt).ravel() - yt)) < 5.
66+
67+
68+
def test_not_same_weights():
69+
tf.random.set_seed(0)
70+
np.random.seed(0)
71+
task = _get_task()
72+
encoder = _get_encoder()
73+
X_enc = encoder.predict(Xs)
74+
task.predict(X_enc)
75+
model = MDD(encoder, task, copy=False,
76+
loss="mse", optimizer=Adam(0.01), metrics=["mse"])
77+
model.fit(Xs, ys, Xt, yt,
78+
epochs=0, batch_size=34, verbose=0)
79+
assert np.any(model.task_.get_weights()[0] !=
80+
model.discriminator_.get_weights()[0])
81+
assert np.all(model.task_.get_weights()[0] ==
82+
task.get_weights()[0])

0 commit comments

Comments
 (0)