Skip to content

Commit e319bb7

Browse files
add tests
1 parent d757693 commit e319bb7

File tree

6 files changed

+48
-2
lines changed

6 files changed

+48
-2
lines changed

adapt/feature_based/_mdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class MDD(BaseDeepFeature):
7373
task_ : tensorflow Model
7474
Principal task network.
7575
76-
task_adv_ : tensorflow Model
76+
discriminator_ : tensorflow Model
7777
Adversarial task network.
7878
7979
model_ : tensorflow Model

adapt/instance_based/_kliep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def fit_estimator(self, X, y,
290290
sample_weight /= (sample_weight.sum() + EPS)
291291
if sample_weight.sum() <= 0:
292292
sample_weight = np.ones(len(sample_weight))
293+
sample_weight /= (sample_weight.sum() + EPS)
293294
bootstrap_index = np.random.choice(
294295
len(X), size=len(X), replace=True,
295296
p=sample_weight)

adapt/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ class GradientHandler(Layer):
364364
"""
365365
def __init__(self, lambda_init=1., name="g_handler"):
366366
super().__init__(name=name)
367-
self.lambda_init=1.
367+
self.lambda_init=lambda_init
368368
self.lambda_ = tf.Variable(lambda_init,
369369
trainable=False,
370370
dtype="float32")

tests/test_coral.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
from sklearn.linear_model import LogisticRegression
7+
from scipy import linalg
78
import tensorflow as tf
89
from tensorflow.keras import Sequential, Model
910
from tensorflow.keras.layers import Dense
@@ -62,6 +63,17 @@ def test_fit_coral():
6263
assert np.abs(model.estimator_.coef_[0][0] /
6364
model.estimator_.coef_[0][1] + 0.5) < 0.1
6465
assert (model.predict(Xt) == yt).sum() / len(Xt) >= 0.99
66+
67+
68+
def test_fit_coral_complex():
69+
np.random.seed(0)
70+
model = CORAL(LogisticRegression(), lambda_=10000.)
71+
Xs_ = np.random.randn(10, 100)
72+
Xt_ = np.random.randn(10, 100)
73+
model.fit(Xs_, ys[:10], Xt_)
74+
assert np.iscomplexobj(linalg.inv(linalg.sqrtm(model.Cs_)))
75+
assert np.iscomplexobj(linalg.sqrtm(model.Ct_))
76+
model.predict(Xs_)
6577

6678

6779
def test_fit_deepcoral():

tests/test_kliep.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,21 @@
44

55
import numpy as np
66
from sklearn.linear_model import LinearRegression
7+
from sklearn.base import BaseEstimator
78

89
from adapt.instance_based import KLIEP
910

11+
12+
class DummyEstimator(BaseEstimator):
13+
14+
def __init__(self):
15+
pass
16+
17+
def fit(self, X, y):
18+
self.y = y
19+
return self
20+
21+
1022
np.random.seed(0)
1123
Xs = np.concatenate((
1224
np.random.randn(50)*0.1,
@@ -34,3 +46,21 @@ def test_fit():
3446
assert model.weights_[:50].sum() > 90
3547
assert model.weights_[50:].sum() < 0.5
3648
assert np.abs(model.predict(Xt) - yt).sum() < 20
49+
50+
51+
def test_fit_estimator_bootstrap_index():
52+
np.random.seed(0)
53+
ys_ = np.random.randn(100)
54+
model = KLIEP(DummyEstimator(),
55+
sigmas=[10, 100])
56+
model.fit_estimator(Xs, ys_, sample_weight=np.random.random(len(ys)))
57+
assert len(set(list(model.estimator_.y.ravel())) & set(list(ys_.ravel()))) > 50
58+
59+
60+
def test_fit_estimator_sample_weight_zeros():
61+
np.random.seed(0)
62+
ys_ = np.random.randn(100)
63+
model = KLIEP(DummyEstimator(),
64+
sigmas=[10, 100])
65+
model.fit_estimator(Xs, ys_, sample_weight=np.zeros(len(ys)))
66+
assert len(set(list(model.estimator_.y.ravel())) & set(list(ys_.ravel()))) > 50

tests/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,9 @@ def test_gradienthandler(lambda_):
433433
gradient = tape.gradient(grad_handler(inputs),
434434
inputs)
435435
assert np.all(gradient == lambda_ * np.ones(3))
436+
config = grad_handler.get_config()
437+
assert config['lambda_init'] == lambda_
438+
436439

437440

438441
def test_make_classification_da():

0 commit comments

Comments
 (0)