Skip to content

Commit 10e1e0e

Browse files
Merge pull request #12 from antoinedemathelin/master
feat: Add WDGRL method
2 parents 7e8c597 + cecb995 commit 10e1e0e

File tree

7 files changed

+362
-1
lines changed

7 files changed

+362
-1
lines changed

adapt/feature_based/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
from ._deepcoral import DeepCORAL
1212
from ._mcd import MCD
1313
from ._mdd import MDD
14+
from ._wdgrl import WDGRL
1415

15-
__all__ = ["FE", "CORAL", "DeepCORAL", "ADDA", "DANN", "mSDA", "MCD", "MDD", "BaseDeepFeature"]
16+
__all__ = ["FE", "CORAL", "DeepCORAL", "ADDA", "DANN", "mSDA", "MCD", "MDD", "WDGRL", "BaseDeepFeature"]

adapt/feature_based/_wdgrl.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
WDGRL
3+
"""
4+
5+
import numpy as np
6+
import tensorflow as tf
7+
from tensorflow.keras import Model, Sequential
8+
from tensorflow.keras.layers import Layer, subtract
9+
from tensorflow.keras.optimizers import Adam
10+
import tensorflow.keras.backend as K
11+
12+
from adapt.utils import (GradientHandler,
13+
check_arrays)
14+
from adapt.feature_based import BaseDeepFeature
15+
16+
EPS = K.epsilon()
17+
18+
19+
class _Interpolation(Layer):
20+
"""
21+
Layer that produces interpolates points between
22+
two entries, with the distance of the interpolation
23+
to the first entry.
24+
"""
25+
26+
def call(self, inputs):
27+
Xs = inputs[0]
28+
Xt = inputs[1]
29+
batch_size = tf.shape(Xs)[0]
30+
dim = tf.shape(Xs)[1:]
31+
alphas = tf.random.uniform([batch_size]+[1]*len(dim))
32+
tiled_shape = tf.concat(([1], dim), 0)
33+
tiled_alphas = tf.tile(alphas, tiled_shape)
34+
differences = Xt - Xs
35+
interpolates = Xs + tiled_alphas * differences
36+
distances = K.sqrt(K.mean(K.square(tiled_alphas * differences),
37+
axis=[i for i in range(1, len(dim))]) + EPS)
38+
return interpolates, distances
39+
40+
41+
class WDGRL(BaseDeepFeature):
42+
"""
43+
WDGRL (Wasserstein Distance Guided Representation Learning) is an
44+
unsupervised domain adaptation method on the model of the
45+
:ref:`DANN <adapt.feature_based.DANN>`. In WDGRL the discriminator
46+
is used to approximate the Wasserstein distance between the
47+
source and target encoded distributions in the spirit of WGAN.
48+
49+
The optimization formulation is the following:
50+
51+
.. math::
52+
53+
\min_{\phi, F} & \; \mathcal{L}_{task}(F(\phi(X_S)), y_S) +
54+
\lambda \\left(D(\phi(X_S)) - D(\phi(X_T)) \\right) \\\\
55+
\max_{D} & \; \\left(D(\phi(X_S)) - D(\phi(X_T)) \\right) -
56+
\\gamma (||\\nabla D(\\alpha \phi(X_S) + (1- \\alpha) \phi(X_T))||_2 - 1)^2
57+
58+
Where:
59+
60+
- :math:`(X_S, y_S), (X_T)` are respectively the labeled source data
61+
and the unlabeled target data.
62+
- :math:`\phi, F, D` are respectively the **encoder**, the **task**
63+
and the **discriminator** networks
64+
- :math:`\lambda` is the trade-off parameter.
65+
- :math:`\\gamma` is the gradient penalty parameter.
66+
67+
.. figure:: ../_static/images/wdgrl.png
68+
:align: center
69+
70+
WDGRL architecture (source: [1])
71+
72+
Parameters
73+
----------
74+
encoder : tensorflow Model (default=None)
75+
Encoder netwok. If ``None``, a shallow network with 10
76+
neurons and ReLU activation is used as encoder network.
77+
78+
task : tensorflow Model (default=None)
79+
Task netwok. If ``None``, a two layers network with 10
80+
neurons per layer and ReLU activation is used as task network.
81+
82+
discriminator : tensorflow Model (default=None)
83+
Discriminator netwok. If ``None``, a two layers network with 10
84+
neurons per layer and ReLU activation is used as discriminator
85+
network. Note that the output shape of the discriminator should
86+
be ``(None, 1)``.
87+
88+
lambda_ : float or None (default=1)
89+
Trade-off parameter. This parameter gives the trade-off
90+
for the encoder between learning the task and matching
91+
the source and target distribution. If `lambda_`is small
92+
the encoder will focus on the task. If `lambda_=0`, WDGRL
93+
is equivalent to a "source only" method.
94+
95+
gamma : float (default=1.)
96+
Gradient penalization parameter. To well approximate the
97+
Wasserstein, the `discriminator`should be 1-Lipschitz.
98+
This constraint is imposed by the gradient penalty term
99+
of the optimization. The good value `gamma` to use is
100+
not easy to find. One can check through the metrics that
101+
the gradient penalty term is in the same order than the
102+
"disc loss". If `gamma=0`, no penalty is given on the
103+
discriminator gradient.
104+
105+
loss : string or tensorflow loss (default="mse")
106+
Loss function used for the task.
107+
108+
metrics : dict or list of string or tensorflow metrics (default=None)
109+
Metrics given to the model. If a list is provided,
110+
metrics are used on both ``task`` and ``discriminator``
111+
outputs. To give seperated metrics, please provide a
112+
dict of metrics list with ``"task"`` and ``"disc"`` as keys.
113+
114+
optimizer : string or tensorflow optimizer (default=None)
115+
Optimizer of the model. If ``None``, the
116+
optimizer is set to tf.keras.optimizers.Adam(0.001)
117+
118+
copy : boolean (default=True)
119+
Whether to make a copy of ``encoder``, ``task`` and
120+
``discriminator`` or not.
121+
122+
random_state : int (default=None)
123+
Seed of random generator.
124+
125+
Attributes
126+
----------
127+
encoder_ : tensorflow Model
128+
encoder network.
129+
130+
task_ : tensorflow Model
131+
task network.
132+
133+
discriminator_ : tensorflow Model
134+
discriminator network.
135+
136+
model_ : tensorflow Model
137+
Fitted model: the union of ``encoder_``,
138+
``task_`` and ``discriminator_`` networks.
139+
140+
history_ : dict
141+
history of the losses and metrics across the epochs.
142+
If ``yt`` is given in ``fit`` method, target metrics
143+
and losses are recorded too.
144+
145+
Examples
146+
--------
147+
>>> import numpy as np
148+
>>> from adapt.feature_based import WDGRL
149+
>>> np.random.seed(0)
150+
>>> Xs = np.concatenate((np.random.random((100, 1)),
151+
... np.zeros((100, 1))), 1)
152+
>>> Xt = np.concatenate((np.random.random((100, 1)),
153+
... np.ones((100, 1))), 1)
154+
>>> ys = 0.2 * Xs[:, 0]
155+
>>> yt = 0.2 * Xt[:, 0]
156+
>>> model = WDGRL(lambda_=0., random_state=0)
157+
>>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
158+
>>> model.history_["task_t"][-1]
159+
0.0223...
160+
>>> model = WDGRL(lambda_=1, random_state=0)
161+
>>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
162+
>>> model.history_["task_t"][-1]
163+
0.0044...
164+
165+
See also
166+
--------
167+
DANN
168+
ADDA
169+
DeepCORAL
170+
171+
References
172+
----------
173+
.. [1] `[1] <https://arxiv.org/pdf/1707.01217.pdf>`_ Shen, J., Qu, Y., Zhang, W., \
174+
and Yu, Y. Wasserstein distance guided representation learning for domain adaptation. \
175+
In AAAI, 2018.
176+
"""
177+
def __init__(self,
178+
encoder=None,
179+
task=None,
180+
discriminator=None,
181+
lambda_=1.,
182+
gamma=1.,
183+
loss="mse",
184+
metrics=None,
185+
optimizer=None,
186+
copy=True,
187+
random_state=None):
188+
189+
self.lambda_ = lambda_
190+
self.gamma = gamma
191+
super().__init__(encoder, task, discriminator,
192+
loss, metrics, optimizer, copy,
193+
random_state)
194+
195+
196+
def create_model(self, inputs_Xs, inputs_Xt):
197+
198+
encoded_src = self.encoder_(inputs_Xs)
199+
encoded_tgt = self.encoder_(inputs_Xt)
200+
task_src = self.task_(encoded_src)
201+
task_tgt = self.task_(encoded_tgt)
202+
203+
flip = GradientHandler(-self.lambda_, name="flip")
204+
no_grad = GradientHandler(0, name="no_grad")
205+
206+
disc_src = flip(encoded_src)
207+
disc_src = self.discriminator_(disc_src)
208+
disc_tgt = flip(encoded_tgt)
209+
disc_tgt = self.discriminator_(disc_tgt)
210+
211+
encoded_src_no_grad = no_grad(encoded_src)
212+
encoded_tgt_no_grad = no_grad(encoded_tgt)
213+
214+
interpolates, distances = _Interpolation()([encoded_src_no_grad, encoded_tgt_no_grad])
215+
disc_grad = K.abs(
216+
subtract([self.discriminator_(interpolates), self.discriminator_(encoded_src_no_grad)])
217+
)
218+
disc_grad /= distances
219+
220+
outputs = dict(task_src=task_src,
221+
task_tgt=task_tgt,
222+
disc_src=disc_src,
223+
disc_tgt=disc_tgt,
224+
disc_grad=disc_grad)
225+
return outputs
226+
227+
228+
def get_loss(self, inputs_ys,
229+
task_src, task_tgt,
230+
disc_src, disc_tgt,
231+
disc_grad):
232+
233+
loss_task = self.loss_(inputs_ys, task_src)
234+
loss_disc = K.mean(disc_src) - K.mean(disc_tgt)
235+
gradient_penalty = K.mean(K.square(disc_grad-1.))
236+
237+
loss = K.mean(loss_task) - K.mean(loss_disc) + self.gamma * K.mean(gradient_penalty)
238+
return loss
239+
240+
241+
def get_metrics(self, inputs_ys, inputs_yt,
242+
task_src, task_tgt,
243+
disc_src, disc_tgt, disc_grad):
244+
metrics = {}
245+
246+
task_s = self.loss_(inputs_ys, task_src)
247+
disc = K.mean(disc_src) - K.mean(disc_tgt)
248+
grad_pen = K.square(disc_grad-1.)
249+
250+
metrics["task_s"] = K.mean(task_s)
251+
metrics["disc"] = K.mean(disc)
252+
metrics["grad_pen"] = self.gamma * K.mean(grad_pen)
253+
254+
if inputs_yt is not None:
255+
task_t = self.loss_(inputs_yt, task_tgt)
256+
metrics["task_t"] = K.mean(task_t)
257+
258+
names_task, names_disc = self._get_metric_names()
259+
260+
for metric, name in zip(self.metrics_task_, names_task):
261+
metrics[name + "_s"] = metric(inputs_ys, task_src)
262+
if inputs_yt is not None:
263+
metrics[name + "_t"] = metric(inputs_yt, task_tgt)
264+
return metrics

src_docs/_static/images/wdgrl.png

73.9 KB
Loading

src_docs/_templates/layout.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.feature_based.mSDA") }}">mSDA</a></li>
2828
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.feature_based.MCD") }}">MCD</a></li>
2929
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.feature_based.MDD") }}">MDD</a></li>
30+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.feature_based.WDGRL") }}">WDGRL</a></li>
3031
</ul>
3132
</li>
3233
<li class="toctree-l1"><a class="reference internal" href="{{ pathto("contents") }}{{ contents }}{{ "adapt-instance-based" }}">Instance-based</a><ul>

src_docs/contents.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ and **target** distributions. The **task** is then learned in this **encoded fea
6464
feature_based.mSDA
6565
feature_based.MCD
6666
feature_based.MDD
67+
feature_based.WDGRL
6768

6869

6970
.. _adapt.instance_based:

src_docs/gallery/WDGRL.rst

Whitespace-only changes.

tests/test_wdgrl.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Test functions for dann module.
3+
"""
4+
5+
import numpy as np
6+
import tensorflow as tf
7+
from tensorflow.keras import Sequential, Model
8+
from tensorflow.keras.layers import Dense
9+
from tensorflow.keras.optimizers import Adam
10+
11+
from adapt.feature_based import WDGRL
12+
from adapt.feature_based._wdgrl import _Interpolation
13+
14+
Xs = np.concatenate((
15+
np.linspace(0, 1, 100).reshape(-1, 1),
16+
np.zeros((100, 1))
17+
), axis=1)
18+
Xt = np.concatenate((
19+
np.linspace(0, 1, 100).reshape(-1, 1),
20+
np.ones((100, 1))
21+
), axis=1)
22+
ys = 0.2 * Xs[:, 0].ravel()
23+
yt = 0.2 * Xt[:, 0].ravel()
24+
25+
26+
def _get_encoder(input_shape=Xs.shape[1:]):
27+
model = Sequential()
28+
model.add(Dense(1, input_shape=input_shape,
29+
kernel_initializer="ones",
30+
use_bias=False))
31+
model.compile(loss="mse", optimizer="adam")
32+
return model
33+
34+
35+
def _get_discriminator(input_shape=(1,)):
36+
model = Sequential()
37+
model.add(Dense(10,
38+
input_shape=input_shape,
39+
activation="relu"))
40+
model.add(Dense(1,
41+
activation=None))
42+
model.compile(loss="mse", optimizer="adam")
43+
return model
44+
45+
46+
def _get_task(input_shape=(1,), output_shape=(1,)):
47+
model = Sequential()
48+
model.add(Dense(np.prod(output_shape),
49+
use_bias=False,
50+
input_shape=input_shape))
51+
model.compile(loss="mse", optimizer=Adam(0.1))
52+
return model
53+
54+
55+
def test_interpolation():
56+
np.random.seed(0)
57+
tf.random.set_seed(0)
58+
59+
zeros = tf.identity(np.zeros((3, 1), dtype=np.float32))
60+
ones= tf.identity(np.ones((3, 1), dtype=np.float32))
61+
62+
inter, dist = _Interpolation().call([zeros, ones])
63+
assert np.all(np.round(dist, 3) == np.round(inter, 3))
64+
assert np.all(inter >= zeros)
65+
assert np.all(inter <= ones)
66+
67+
68+
def test_fit_lambda_zero():
69+
tf.random.set_seed(1)
70+
np.random.seed(1)
71+
model = WDGRL(_get_encoder(), _get_task(), _get_discriminator(),
72+
lambda_=0, loss="mse", optimizer=Adam(0.01), metrics=["mse"],
73+
random_state=0)
74+
model.fit(Xs, ys, Xt, yt,
75+
epochs=300, verbose=0)
76+
assert isinstance(model.model_, Model)
77+
assert model.encoder_.get_weights()[0][1][0] == 1.0
78+
assert np.sum(np.abs(model.predict(Xs).ravel() - ys)) < 0.01
79+
assert np.sum(np.abs(model.predict(Xt).ravel() - yt)) > 10
80+
81+
82+
def test_fit_lambda_one():
83+
tf.random.set_seed(1)
84+
np.random.seed(1)
85+
model = WDGRL(_get_encoder(), _get_task(), _get_discriminator(),
86+
lambda_=1, gamma=0, loss="mse", optimizer=Adam(0.01),
87+
metrics=["mse"], random_state=0)
88+
model.fit(Xs, ys, Xt, yt,
89+
epochs=300, verbose=0)
90+
assert isinstance(model.model_, Model)
91+
assert np.abs(model.encoder_.get_weights()[0][1][0] /
92+
model.encoder_.get_weights()[0][0][0]) < 0.05
93+
assert np.sum(np.abs(model.predict(Xs).ravel() - ys)) < 2
94+
assert np.sum(np.abs(model.predict(Xt).ravel() - yt)) < 2

0 commit comments

Comments
 (0)