Skip to content

Commit 3581099

Browse files
add test multisource and cce mdd
1 parent 93ad881 commit 3581099

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tests/test_base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,4 +211,28 @@ def gens():
211211
model = BaseAdaptDeep()
212212
model.fit(dataset, Xt=Xt, validation_data=dataset.batch(10))
213213
model.predict(tf.data.Dataset.from_tensor_slices(Xs).batch(32))
214-
model.evaluate(dataset.batch(32))
214+
model.evaluate(dataset.batch(32))
215+
216+
217+
def _unpack_data_ms(self, data):
218+
data_src = data[0]
219+
data_tgt = data[1]
220+
Xs = data_src[0][0]
221+
ys = data_src[1][0]
222+
if isinstance(data_tgt, tuple):
223+
Xt = data_tgt[0]
224+
yt = data_tgt[1]
225+
return Xs, Xt, ys, yt
226+
else:
227+
Xt = data_tgt
228+
return Xs, Xt, ys, None
229+
230+
231+
def test_multisource():
232+
np.random.seed(0)
233+
model = BaseAdaptDeep()
234+
model._unpack_data = _unpack_data_ms.__get__(model)
235+
model.fit(Xs, ys, Xt=Xt, domains=np.random.choice(2, len(Xs)))
236+
model.predict(Xs)
237+
model.evaluate(Xs, ys)
238+
assert model.n_sources_ == 2

tests/test_mdd.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,4 @@ def test_cce():
9696
model = MDD(encoder, task, copy=False,
9797
loss="categorical_crossentropy", optimizer=Adam(0.01), metrics=["acc"])
9898
model.fit(Xs, ys_2, Xt, yt_2,
99-
epochs=0, batch_size=34, verbose=0)
100-
assert np.any(model.task_.get_weights()[0] !=
101-
model.discriminator_.get_weights()[0])
102-
assert np.all(model.task_.get_weights()[0] ==
103-
task.get_weights()[0])
99+
epochs=10, batch_size=34, verbose=0)

0 commit comments

Comments
 (0)