Skip to content

Commit 041359b

Browse files
Merge pull request #16 from antoinedemathelin/master
test: Add test WANN
2 parents f40d0b4 + 7cc2b68 commit 041359b

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

adapt/instance_based/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
from ._kliep import KLIEP
66
from ._kmm import KMM
77
from ._tradaboost import TrAdaBoost, TrAdaBoostR2, TwoStageTrAdaBoostR2
8+
from ._wann import WANN
89

9-
__all__ = ["KLIEP", "KMM", "TrAdaBoost", "TrAdaBoostR2", "TwoStageTrAdaBoostR2"]
10+
__all__ = ["KLIEP", "KMM", "TrAdaBoost", "TrAdaBoostR2", "TwoStageTrAdaBoostR2", "WANN"]

tests/test_wann.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Test functions for wann module.
3+
"""
4+
5+
import numpy as np
6+
from sklearn.linear_model import LinearRegression
7+
from adapt.instance_based import WANN
8+
9+
np.random.seed(0)
10+
Xs = np.concatenate((
11+
np.random.randn(50)*0.1,
12+
np.random.randn(50)*0.1 + 1.,
13+
)).reshape(-1, 1)
14+
Xt = (np.random.randn(100) * 0.1).reshape(-1, 1)
15+
ys = np.array([0.2 * x if x<0.5
16+
else 10 for x in Xs.ravel()]).reshape(-1, 1)
17+
yt = np.array([0.2 * x if x<0.5
18+
else 10 for x in Xt.ravel()]).reshape(-1, 1)
19+
20+
def test_fit():
21+
np.random.seed(0)
22+
model = WANN(random_state=0)
23+
model.fit(Xs, ys, Xt, yt, epochs=200, verbose=0)
24+
assert np.abs(model.predict(Xt) - yt).sum() < 10

0 commit comments

Comments
 (0)