Skip to content

Commit 4c0682c

Browse files
add test interpolation
1 parent dc04b60 commit 4c0682c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/test_wdgrl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tensorflow.keras.optimizers import Adam
1010

1111
from adapt.feature_based import WDGRL
12+
from adapt.feature_based._wdgrl import _Interpolation
1213

1314
Xs = np.concatenate((
1415
np.linspace(0, 1, 100).reshape(-1, 1),
@@ -51,6 +52,19 @@ def _get_task(input_shape=(1,), output_shape=(1,)):
5152
return model
5253

5354

55+
def test_interpolation():
56+
np.random.seed(0)
57+
tf.random.set_seed(0)
58+
59+
zeros = tf.identity(np.zeros((3, 1), dtype=np.float32))
60+
ones= tf.identity(np.ones((3, 1), dtype=np.float32))
61+
62+
inter, dist = _Interpolation().call([zeros, ones])
63+
assert np.all(np.round(dist, 3) == np.round(inter, 3))
64+
assert np.all(inter >= zeros)
65+
assert np.all(inter <= ones)
66+
67+
5468
def test_fit_lambda_zero():
5569
tf.random.set_seed(1)
5670
np.random.seed(1)

0 commit comments

Comments
 (0)