Skip to content

Commit 1efbd7b

Browse files
Add tests instancebased and base
1 parent 0693d5e commit 1efbd7b

File tree

7 files changed

+198
-17
lines changed

7 files changed

+198
-17
lines changed

adapt/base.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def _save_validation_data(self, Xs, Xt):
363363
size = min(self.val_sample_size, len(Xt))
364364
tgt_index = np.random.choice(len(Xt), size, replace=False)
365365
self.Xt_ = Xt[tgt_index]
366+
self.src_index_ = src_index
366367
else:
367368
self.Xs_ = Xs
368369
self.Xt_ = Xt
@@ -780,9 +781,20 @@ def _get_config_keras_model(self, model):
780781
config = model.get_config()
781782
klass = model.__class__
782783

783-
return dict(weights=weights,
784-
config=config,
785-
klass=klass)
784+
config = dict(weights=weights,
785+
config=config,
786+
klass=klass)
787+
788+
if hasattr(model, "loss"):
789+
config["loss"] = model.loss
790+
791+
if hasattr(model, "optimizer"):
792+
try:
793+
config["optimizer_klass"] = model.optimizer.__class__
794+
config["optimizer_config"] = model.optimizer.get_config()
795+
except:
796+
pass
797+
return config
786798

787799

788800
def _from_config_keras_model(self, dict_):
@@ -794,6 +806,16 @@ def _from_config_keras_model(self, dict_):
794806

795807
if weights is not None:
796808
model.set_weights(weights)
809+
810+
if "loss" in dict_ and "optimizer_klass" in dict_:
811+
loss = dict_["loss"]
812+
optimizer = dict_["optimizer_klass"].from_config(
813+
dict_["optimizer_config"])
814+
try:
815+
model.compile(loss=loss, optimizer=optimizer)
816+
except:
817+
print("Unable to compile model")
818+
797819
return model
798820

799821

@@ -1414,14 +1436,6 @@ def _initialize_weights(self, shape_X):
14141436
X_enc = self.encoder_(np.zeros((1,) + shape_X))
14151437
if hasattr(self, "discriminator_"):
14161438
self.discriminator_(X_enc)
1417-
1418-
1419-
def _check_pretrain(self, params):
1420-
self.pretrain = False
1421-
for param in params:
1422-
if "pretrain__" in param:
1423-
self.pretrain = True
1424-
break
14251439

14261440

14271441
def _unpack_data(self, data):

adapt/instance_based/_tradaboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def _boost(self, iboost, Xs, ys, Xt, yt,
804804
ys_pred = estimator.predict(Xs)
805805
yt_pred = estimator.predict(Xt)
806806

807-
if ys_pred.ndim == 1:
807+
if ys_pred.ndim == 1 or ys.ndim == 1:
808808
ys = ys.reshape(-1, 1)
809809
yt = yt.reshape(-1, 1)
810810
ys_pred = ys_pred.reshape(-1, 1)

adapt/instance_based/_wann.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class WANN(BaseAdaptDeep):
5252
5353
history_ : dict
5454
history of the losses and metrics across the epochs.
55+
56+
57+
References
58+
----------
59+
.. [1] `[1] <https://arxiv.org/pdf/2006.08251.pdf>`_ A. de Mathelin, \
60+
G. Richard, F. Deheeger, M. Mougeot and N. Vayatis "Adversarial Weighting \
61+
for Domain Adaptation in Regression". In ICTAI, 2021.
5562
"""
5663

5764
def __init__(self,

tests/test_base.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Test base
33
"""
44

5+
import copy
56
import shutil
67
import numpy as np
78
import pytest
@@ -26,6 +27,30 @@
2627
yt = 0.2 * Xt[:, 0].ravel()
2728

2829

30+
class DummyFeatureBased(BaseAdaptEstimator):
31+
32+
def fit_transform(self, Xs, **kwargs):
33+
return Xs
34+
35+
def transform(self, Xs):
36+
return Xs
37+
38+
39+
class DummyInstanceBased(BaseAdaptEstimator):
40+
41+
def fit_weights(self, Xs, **kwargs):
42+
return np.ones(len(Xs))
43+
44+
def predict_weights(self):
45+
return np.ones(100)
46+
47+
48+
class DummyParameterBased(BaseAdaptEstimator):
49+
50+
def fit(self, Xs, ys):
51+
return self.fit_estimator(Xs, ys)
52+
53+
2954
def test_base_adapt_estimator():
3055
base_adapt = BaseAdaptEstimator(Xt=Xt)
3156
for check in check_estimator(base_adapt, generate_only=True):
@@ -37,7 +62,78 @@ def test_base_adapt_estimator():
3762
else:
3863
raise
3964

40-
65+
66+
def test_base_adapt_score():
67+
model = DummyParameterBased(Xt=Xt, random_state=0)
68+
model.fit(Xs, ys)
69+
model.score(Xt, yt)
70+
71+
model = DummyFeatureBased(Xt=Xt, random_state=0)
72+
model.fit(Xs, ys)
73+
s1 = model.score(Xt, yt)
74+
s2 = model._score_adapt(Xs, Xt)
75+
assert s1 == s2
76+
77+
model = DummyInstanceBased(Xt=Xt, random_state=0)
78+
model.fit(Xs, ys)
79+
model.score(Xt, yt)
80+
s1 = model.score(Xt, yt)
81+
s2 = model._score_adapt(Xs, Xt)
82+
assert s1 == s2
83+
84+
85+
def test_base_adapt_val_sample_size():
86+
model = DummyFeatureBased(Xt=Xt, random_state=0, val_sample_size=10)
87+
model.fit(Xs, ys)
88+
model.score(Xt, yt)
89+
assert len(model.Xs_) == 10
90+
assert len(model.Xt_) == 10
91+
assert np.all(model.Xs_ == Xs[model.src_index_])
92+
93+
94+
def test_base_adapt_keras_estimator():
95+
est = Sequential()
96+
est.add(Dense(1, input_shape=Xs.shape[1:]))
97+
est.compile(loss="mse", optimizer=Adam(0.01))
98+
model = BaseAdaptEstimator(est, Xt=Xt)
99+
model.fit(Xs, ys)
100+
assert model.estimator_.loss == "mse"
101+
assert isinstance(model.estimator_.optimizer, Adam)
102+
assert model.estimator_.optimizer.learning_rate == 0.01
103+
104+
model = BaseAdaptEstimator(est, Xt=Xt, loss="mae",
105+
optimizer=Adam(0.01, beta_1=0.5),
106+
learning_rate=0.1)
107+
model.fit(Xs, ys)
108+
assert model.estimator_.loss == "mae"
109+
assert isinstance(model.estimator_.optimizer, Adam)
110+
assert model.estimator_.optimizer.learning_rate == 0.1
111+
assert model.estimator_.optimizer.beta_1 == 0.5
112+
113+
model = BaseAdaptEstimator(est, Xt=Xt, optimizer="sgd")
114+
model.fit(Xs, ys)
115+
assert not isinstance(model.estimator_.optimizer, Adam)
116+
117+
est = Sequential()
118+
est.add(Dense(1, input_shape=Xs.shape[1:]))
119+
model = BaseAdaptEstimator(est, Xt=Xt, loss="mae",
120+
optimizer=Adam(0.01, beta_1=0.5),
121+
learning_rate=0.1)
122+
model.fit(Xs, ys)
123+
assert model.estimator_.loss == "mae"
124+
assert isinstance(model.estimator_.optimizer, Adam)
125+
assert model.estimator_.optimizer.learning_rate == 0.1
126+
assert model.estimator_.optimizer.beta_1 == 0.5
127+
128+
s1 = model.score_estimator(Xt[:10], yt[:10])
129+
s2 = model.estimator_.evaluate(Xt[:10], yt[:10])
130+
assert s1 == s2
131+
132+
copy_model = copy.deepcopy(model)
133+
assert s1 == copy_model.score_estimator(Xt[:10], yt[:10])
134+
assert hex(id(model)) != hex(id(copy_model))
135+
136+
41137
def test_base_adapt_deep():
42138
model = BaseAdaptDeep(Xt=Xt, loss="mse", optimizer=Adam(), learning_rate=0.1)
43139
model.fit(Xs, ys)
@@ -52,8 +148,8 @@ def test_base_adapt_deep():
52148
assert hasattr(model, "encoder_")
53149
assert hasattr(model, "task_")
54150
assert hasattr(model, "discriminator_")
55-
56-
151+
152+
57153
def test_base_adapt_deep():
58154
model = BaseAdaptDeep(Xt=Xt, loss="mse",
59155
epochs=2,
@@ -63,6 +159,7 @@ def test_base_adapt_deep():
63159
model.fit(Xs, ys)
64160
yp = model.predict(Xt)
65161
score = model.score(Xt, yt)
162+
score_est = model.score_estimator(Xt, yt)
66163
X_enc = model.transform(Xs)
67164
ypt = model.predict_task(Xt)
68165
ypd = model.predict_disc(Xt)
@@ -71,6 +168,7 @@ def test_base_adapt_deep():
71168
new_model.fit(Xs, ys)
72169
yp2 = new_model.predict(Xt)
73170
score2 = new_model.score(Xt, yt)
171+
score_est2 = new_model.score_estimator(Xt, yt)
74172
X_enc2 = new_model.transform(Xs)
75173
ypt2 = new_model.predict_task(Xt)
76174
ypd2 = new_model.predict_disc(Xt)
@@ -88,6 +186,7 @@ def test_base_adapt_deep():
88186

89187
assert np.all(yp == yp2)
90188
assert score == score2
189+
assert score_est == score_est2
91190
assert np.all(ypt == ypt2)
92191
assert np.all(ypd == ypd2)
93192
assert np.all(X_enc == X_enc2)

tests/test_kliep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def test_fit():
4747
assert model.weights_[:50].sum() > 90
4848
assert model.weights_[50:].sum() < 0.5
4949
assert np.abs(model.predict(Xt) - yt).sum() < 20
50+
assert np.all(model.weights_ == model.predict_weights())
51+
assert np.all(model.weights_ == model.predict_weights(Xs))
5052

5153

5254
def test_fit_estimator_bootstrap_index():

tests/test_kmm.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,24 @@ def test_fit():
3232
assert np.abs(model.estimator_.coef_[0][0] - 0.2) < 1
3333
assert model.weights_[:50].sum() > 50
3434
assert model.weights_[50:].sum() < 0.1
35-
assert np.abs(model.predict(Xt) - yt).sum() < 10
35+
assert np.abs(model.predict(Xt) - yt).sum() < 10
36+
assert np.all(model.weights_ == model.predict_weights())
37+
38+
39+
def test_tol():
40+
np.random.seed(0)
41+
model = KMM(LinearRegression(fit_intercept=False), gamma=1., tol=0.1)
42+
model.fit(Xs, ys, Xt=Xt)
43+
assert np.abs(model.estimator_.coef_[0][0] - 0.2) > 5
44+
45+
46+
def test_batch():
47+
np.random.seed(0)
48+
model = KMM(LinearRegression(fit_intercept=False), gamma=1.,
49+
max_size=35)
50+
model.fit(Xs, ys, Xt=Xt)
51+
assert np.abs(model.estimator_.coef_[0][0] - 0.2) < 1
52+
assert model.weights_[:50].sum() > 50
53+
assert model.weights_[50:].sum() < 0.1
54+
assert np.abs(model.predict(Xt) - yt).sum() < 10
55+
assert np.all(model.weights_ == model.predict_weights())

tests/test_tradaboost.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
Test functions for tradaboost module.
33
"""
44

5+
import copy
56
import numpy as np
67
from sklearn.linear_model import LinearRegression, LogisticRegression
8+
import tensorflow as tf
79

810
from adapt.instance_based import (TrAdaBoost,
911
TrAdaBoostR2,
@@ -40,6 +42,23 @@ def test_tradaboost_fit():
4042
assert (model.predict(Xt).ravel() == yt_classif).sum() > 90
4143

4244

45+
def test_tradaboost_fit_keras_model():
46+
np.random.seed(0)
47+
est = tf.keras.Sequential()
48+
est.add(tf.keras.layers.Dense(1, activation="sigmoid"))
49+
est.compile(loss="bce", optimizer="adam")
50+
model = TrAdaBoost(est, n_estimators=2, random_state=0)
51+
model.fit(Xs, ys_classif, Xt=Xt[:10], yt=yt_classif[:10])
52+
yp = model.predict(Xt)
53+
54+
est = tf.keras.Sequential()
55+
est.add(tf.keras.layers.Dense(2, activation="softmax"))
56+
est.compile(loss="mse", optimizer="adam")
57+
model = TrAdaBoost(est, n_estimators=2, random_state=0)
58+
model.fit(Xs, np.random.random((100, 2)),
59+
Xt=Xt[:10], yt=np.random.random((10, 2)))
60+
61+
4362
def test_tradaboostr2_fit():
4463
np.random.seed(0)
4564
model = TrAdaBoostR2(LinearRegression(fit_intercept=False),
@@ -51,16 +70,36 @@ def test_tradaboostr2_fit():
5170
model.sample_weights_src_[-1][50:].sum()) > 10
5271
assert model.sample_weights_tgt_[-1].sum() > 0.7
5372
assert np.abs(model.predict(Xt) - yt_reg).sum() < 1
73+
assert np.all(model.predict_weights(domain="src") ==
74+
model.sample_weights_src_[-1])
75+
assert np.all(model.predict_weights(domain="tgt") ==
76+
model.sample_weights_tgt_[-1])
5477

5578

5679
def test_twostagetradaboostr2_fit():
5780
np.random.seed(0)
5881
model = TwoStageTrAdaBoostR2(LinearRegression(fit_intercept=False),
5982
n_estimators=10)
60-
model.fit(Xs, ys_reg, Xt=Xt[:10], yt=yt_reg[:10])
83+
model.fit(Xs, ys_reg.ravel(), Xt=Xt[:10], yt=yt_reg[:10].ravel())
6184
assert np.abs(model.estimators_[-1].estimators_[-1].coef_[0]
6285
- 1.) < 1
6386
assert np.abs(model.sample_weights_src_[-1][:50].sum() /
6487
model.sample_weights_src_[-1][50:].sum()) > 10
6588
assert model.sample_weights_tgt_[-1].sum() > 0.7
6689
assert np.abs(model.predict(Xt) - yt_reg).sum() < 1
90+
argmin = np.argmin(model.estimator_errors_)
91+
assert np.all(model.predict_weights(domain="src") ==
92+
model.sample_weights_src_[argmin])
93+
assert np.all(model.predict_weights(domain="tgt") ==
94+
model.sample_weights_tgt_[argmin])
95+
96+
97+
def test_tradaboost_deepcopy():
98+
np.random.seed(0)
99+
model = TrAdaBoost(LogisticRegression(penalty='none',
100+
solver='lbfgs'),
101+
n_estimators=20)
102+
model.fit(Xs, ys_classif, Xt=Xt[:10], yt=yt_classif[:10])
103+
copy_model = copy.deepcopy(model)
104+
assert np.all(model.predict(Xt) == copy_model.predict(Xt))
105+
assert hex(id(model)) != hex(id(copy_model))

0 commit comments

Comments
 (0)