Skip to content

Commit 535a153

Browse files
Merge pull request #14 from antoinedemathelin/master
feat: Add CDAN
2 parents 7b39f7c + 99f97ab commit 535a153

File tree

16 files changed

+521
-19
lines changed

16 files changed

+521
-19
lines changed

adapt/feature_based/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
from ._mcd import MCD
1313
from ._mdd import MDD
1414
from ._wdgrl import WDGRL
15+
from ._cdan import CDAN
1516

16-
__all__ = ["FE", "CORAL", "DeepCORAL", "ADDA", "DANN", "mSDA", "MCD", "MDD", "WDGRL", "BaseDeepFeature"]
17+
__all__ = ["FE", "CORAL", "DeepCORAL", "ADDA", "DANN", "mSDA", "MCD", "MDD", "WDGRL", "BaseDeepFeature", "CDAN"]

adapt/feature_based/_adda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def create_model(self, inputs_Xs, inputs_Xt):
407407
return outputs
408408

409409

410-
def get_loss(self, inputs_ys, disc_src, disc_tgt,
410+
def get_loss(self, inputs_ys, inputs_yt, disc_src, disc_tgt,
411411
disc_tgt_nograd, task_tgt):
412412

413413
loss_disc = (-K.log(disc_src + EPS)

adapt/feature_based/_cdan.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""
2+
CDAN
3+
"""
4+
5+
import warnings
6+
from copy import deepcopy
7+
8+
import numpy as np
9+
import tensorflow as tf
10+
from tensorflow.keras import Model, Sequential
11+
from tensorflow.keras.layers import Layer, Input, subtract, Dense, Flatten
12+
from tensorflow.keras.callbacks import Callback
13+
from tensorflow.keras.optimizers import Adam
14+
import tensorflow.keras.backend as K
15+
16+
from adapt.feature_based import BaseDeepFeature
17+
from adapt.utils import (GradientHandler,
18+
check_arrays,
19+
check_one_array,
20+
check_network)
21+
22+
23+
EPS = K.epsilon()
24+
25+
26+
def _get_default_classifier():
27+
model = Sequential()
28+
model.add(Flatten())
29+
model.add(Dense(10, activation="relu"))
30+
model.add(Dense(10, activation="relu"))
31+
model.add(Dense(2, activation="softmax"))
32+
return model
33+
34+
35+
class CDAN(BaseDeepFeature):
36+
"""
37+
CDAN (Conditional Adversarial Domain Adaptation) is an
38+
unsupervised domain adaptation method on the model of the
39+
:ref:`DANN <adapt.feature_based.DANN>`. In CDAN the discriminator
40+
is conditioned on the prediction of the task network for
41+
source and target data. This should , in theory, focus the
42+
source-target matching of instances belonging to the same class.
43+
44+
To condition the **discriminator** network on each class, a
45+
multilinear map of shape: ``nb_class * encoder.output_shape[1]``
46+
is given as input. If the shape is too large (>4096), a random
47+
sub-multilinear map of lower dimension is considered.
48+
49+
The optimization formulation of CDAN is the following:
50+
51+
.. math::
52+
53+
\min_{\phi, F} & \; \mathcal{L}_{task}(F(\phi(X_S)), y_S) -
54+
\lambda \\left( \log(1 - D(\phi(X_S) \\bigotimes F(X_S)) + \\\\
55+
\log(D(\phi(X_T) \\bigotimes F(X_T)) \\right) \\\\
56+
\max_{D} & \; \log(1 - D(\phi(X_S) \\bigotimes F(X_S)) + \\\\
57+
\log(D(\phi(X_T) \\bigotimes F(X_T))
58+
59+
Where:
60+
61+
- :math:`(X_S, y_S), (X_T)` are respectively the labeled source data
62+
and the unlabeled target data.
63+
- :math:`\phi, F, D` are respectively the **encoder**, the **task**
64+
and the **discriminator** networks
65+
- :math:`\lambda` is the trade-off parameter.
66+
- :math:`\phi(X_S) \\bigotimes F(X_S)` is the multilinear map between
67+
the encoded sources and the task predictions.
68+
69+
In CDAN+E, an entropy regularization is added to prioritize the
70+
transfer of easy-to-transfer exemples. The optimization formulation
71+
of CDAN+E is the following:
72+
73+
.. math::
74+
75+
\min_{\phi, F} & \; \mathcal{L}_{task}(F(\phi(X_S)), y_S) -
76+
\lambda \\left( \log(1 - W_S D(\phi(X_S) \\bigotimes F(X_S)) + \\\\
77+
W_T \log(D(\phi(X_T) \\bigotimes F(X_T)) \\right) \\\\
78+
\max_{D} & \; \log(1 - W_S D(\phi(X_S) \\bigotimes F(X_S)) + \\\\
79+
W_T \log(D(\phi(X_T) \\bigotimes F(X_T))
80+
81+
Where:
82+
83+
- :math:`W_S = 1+\exp{-\\text{entropy}(F(X_S))}`
84+
- :math:`\\text{entropy}(F(X_S)) = - \sum_{i < C} F(X_S)_i \log(F(X_S)_i)`
85+
with :math:`C` the number of classes.
86+
87+
.. figure:: ../_static/images/cdan.png
88+
:align: center
89+
90+
CDAN architecture (source: [1])
91+
92+
Notes
93+
-----
94+
CDAN is specific for multi-class classification tasks. Be sure to add a
95+
softmax activation at the end of the task network.
96+
97+
Parameters
98+
----------
99+
encoder : tensorflow Model (default=None)
100+
Encoder netwok. If ``None``, a shallow network with 10
101+
neurons and ReLU activation is used as encoder network.
102+
103+
task : tensorflow Model (default=None)
104+
Task netwok. If ``None``, a two layers network with 10
105+
neurons per layer and ReLU activation is used as task network.
106+
``task`` should end with a softmax activation.
107+
108+
discriminator : tensorflow Model (default=None)
109+
Discriminator netwok. If ``None``, a two layers network with 10
110+
neurons per layer and ReLU activation is used as discriminator
111+
network. Note that the output shape of the discriminator should
112+
be ``(None, 1)`` and the input shape:
113+
``(None, encoder.output_shape[1] * nb_class)``.
114+
115+
lambda_ : float or None (default=1)
116+
Trade-off parameter. This parameter gives the trade-off
117+
for the encoder between learning the task and matching
118+
the source and target distribution. If `lambda_`is small
119+
the encoder will focus on the task. If `lambda_=0`, CDAN
120+
is equivalent to a "source only" method.
121+
122+
entropy : boolean (default=True)
123+
Whether to use or not the entropy regularization.
124+
Adding this regularization will prioritize the
125+
``discriminator`` on easy-to-transfer examples.
126+
This, in theory, should make the transfer "safer".
127+
128+
max_features : int (default=4096)
129+
If ``encoder.output_shape[1] * nb_class)`` is higer than
130+
``max_features`` the multilinear map is produced with
131+
considering random sub vectors of the encoder and task outputs.
132+
133+
loss : string or tensorflow loss (default="mse")
134+
Loss function used for the task.
135+
136+
metrics : dict or list of string or tensorflow metrics (default=None)
137+
Metrics given to the model. If a list is provided,
138+
metrics are used on both ``task`` and ``discriminator``
139+
outputs. To give seperated metrics, please provide a
140+
dict of metrics list with ``"task"`` and ``"disc"`` as keys.
141+
142+
optimizer : string or tensorflow optimizer (default=None)
143+
Optimizer of the model. If ``None``, the
144+
optimizer is set to tf.keras.optimizers.Adam(0.001)
145+
146+
copy : boolean (default=True)
147+
Whether to make a copy of ``encoder``, ``task`` and
148+
``discriminator`` or not.
149+
150+
random_state : int (default=None)
151+
Seed of random generator.
152+
153+
Attributes
154+
----------
155+
encoder_ : tensorflow Model
156+
encoder network.
157+
158+
task_ : tensorflow Model
159+
task network.
160+
161+
discriminator_ : tensorflow Model
162+
discriminator network.
163+
164+
model_ : tensorflow Model
165+
Fitted model: the union of ``encoder_``,
166+
``task_`` and ``discriminator_`` networks.
167+
168+
history_ : dict
169+
history of the losses and metrics across the epochs.
170+
If ``yt`` is given in ``fit`` method, target metrics
171+
and losses are recorded too.
172+
173+
See also
174+
--------
175+
DANN
176+
ADDA
177+
WDGRL
178+
179+
References
180+
----------
181+
.. [1] `[1] <https://arxiv.org/pdf/1705.10667.pdf>`_ Long, M., Cao, \
182+
Z., Wang, J., and Jordan, M. I. "Conditional adversarial domain adaptation". \
183+
In NIPS, 2018
184+
"""
185+
def __init__(self,
186+
encoder=None,
187+
task=None,
188+
discriminator=None,
189+
lambda_=1.,
190+
entropy=True,
191+
max_features=4096,
192+
loss="mse",
193+
metrics=None,
194+
optimizer=None,
195+
copy=True,
196+
random_state=None):
197+
198+
self.lambda_ = lambda_
199+
self.entropy = entropy
200+
self.max_features = max_features
201+
202+
if task is None:
203+
task = _get_default_classifier()
204+
super().__init__(encoder, task, discriminator,
205+
loss, metrics, optimizer, copy,
206+
random_state)
207+
208+
209+
def create_model(self, inputs_Xs, inputs_Xt):
210+
encoded_src = self.encoder_(inputs_Xs)
211+
encoded_tgt = self.encoder_(inputs_Xt)
212+
task_src = self.task_(encoded_src)
213+
task_tgt = self.task_(encoded_tgt)
214+
215+
no_grad = GradientHandler(0., name="no_grad")
216+
flip = GradientHandler(-self.lambda_, name="flip")
217+
218+
task_src_nograd = no_grad(task_src)
219+
task_tgt_nograd = no_grad(task_tgt)
220+
221+
if task_src.shape[1] * encoded_src.shape[1] > self.max_features:
222+
self._random_task = tf.random.normal([task_src.shape[1],
223+
self.max_features])
224+
self._random_enc = tf.random.normal([encoded_src.shape[1],
225+
self.max_features])
226+
227+
mapping_task_src = tf.matmul(task_src_nograd, self._random_task)
228+
mapping_enc_src = tf.matmul(encoded_src, self._random_enc)
229+
mapping_src = tf.multiply(mapping_enc_src, mapping_task_src)
230+
mapping_src /= (tf.math.sqrt(tf.cast(self.max_features, tf.float32)) + EPS)
231+
232+
mapping_task_tgt = tf.matmul(task_tgt_nograd, self._random_task)
233+
mapping_enc_tgt = tf.matmul(encoded_tgt, self._random_enc)
234+
mapping_tgt = tf.multiply(mapping_enc_tgt, mapping_task_tgt)
235+
mapping_tgt /= (tf.math.sqrt(tf.cast(self.max_features, tf.float32)) + EPS)
236+
237+
else:
238+
mapping_src = tf.matmul(
239+
tf.expand_dims(encoded_src, 2),
240+
tf.expand_dims(task_src_nograd, 1))
241+
mapping_tgt = tf.matmul(
242+
tf.expand_dims(encoded_tgt, 2),
243+
tf.expand_dims(task_tgt_nograd, 1))
244+
245+
mapping_src = Flatten("channels_first")(mapping_src)
246+
mapping_tgt = Flatten("channels_first")(mapping_tgt)
247+
248+
disc_src = flip(mapping_src)
249+
disc_src = self.discriminator_(disc_src)
250+
disc_tgt = flip(mapping_tgt)
251+
disc_tgt = self.discriminator_(disc_tgt)
252+
253+
outputs = dict(task_src=task_src,
254+
task_tgt=task_tgt,
255+
disc_src=disc_src,
256+
disc_tgt=disc_tgt,
257+
task_src_nograd=task_src_nograd,
258+
task_tgt_nograd=task_tgt_nograd)
259+
return outputs
260+
261+
262+
def get_loss(self, inputs_ys, inputs_yt,
263+
task_src, task_tgt,
264+
disc_src, disc_tgt,
265+
task_src_nograd,
266+
task_tgt_nograd):
267+
268+
loss_task = self.loss_(inputs_ys, task_src)
269+
270+
if self.entropy:
271+
entropy_src = -tf.reduce_sum(task_src_nograd *
272+
tf.math.log(task_src_nograd+EPS),
273+
axis=1, keepdims=True)
274+
entropy_tgt = -tf.reduce_sum(task_tgt_nograd *
275+
tf.math.log(task_tgt_nograd+EPS),
276+
axis=1, keepdims=True)
277+
weight_src = 1.+tf.exp(-entropy_src)
278+
weight_tgt = 1.+tf.exp(-entropy_tgt)
279+
weight_src /= (tf.reduce_mean(weight_src) + EPS)
280+
weight_tgt /= (tf.reduce_mean(weight_tgt) + EPS)
281+
weight_src *= .5
282+
weight_tgt *= .5
283+
284+
assert str(weight_src.shape) == str(disc_src.shape)
285+
assert str(weight_tgt.shape) == str(disc_tgt.shape)
286+
287+
loss_disc = (-tf.math.log(1-weight_src*disc_src + EPS)
288+
-tf.math.log(weight_tgt*disc_tgt + EPS))
289+
else:
290+
loss_disc = (-tf.math.log(1-disc_src + EPS)
291+
-tf.math.log(disc_tgt + EPS))
292+
293+
loss = tf.reduce_mean(loss_task) + tf.reduce_mean(loss_disc)
294+
return loss
295+
296+
297+
def get_metrics(self, inputs_ys, inputs_yt,
298+
task_src, task_tgt,
299+
disc_src, disc_tgt,
300+
task_src_nograd,
301+
task_tgt_nograd):
302+
metrics = {}
303+
304+
task_s = self.loss_(inputs_ys, task_src)
305+
306+
if self.entropy:
307+
entropy_src = -tf.reduce_sum(task_src_nograd *
308+
tf.math.log(task_src_nograd+EPS),
309+
axis=1, keepdims=True)
310+
entropy_tgt = -tf.reduce_sum(task_tgt_nograd *
311+
tf.math.log(task_tgt_nograd+EPS),
312+
axis=1, keepdims=True)
313+
weight_src = 1.+tf.exp(-entropy_src)
314+
weight_tgt = 1.+tf.exp(-entropy_tgt)
315+
weight_src /= (tf.reduce_mean(weight_src) + EPS)
316+
weight_tgt /= (tf.reduce_mean(weight_tgt) + EPS)
317+
weight_src *= .5
318+
weight_tgt *= .5
319+
disc = (-tf.math.log(1-weight_src*disc_src + EPS)
320+
-tf.math.log(weight_tgt*disc_tgt + EPS))
321+
else:
322+
disc = (-tf.math.log(1-disc_src + EPS)
323+
-tf.math.log(disc_tgt + EPS))
324+
325+
metrics["task_s"] = K.mean(task_s)
326+
metrics["disc"] = K.mean(disc)
327+
if inputs_yt is not None:
328+
task_t = self.loss_(inputs_yt, task_tgt)
329+
metrics["task_t"] = K.mean(task_t)
330+
331+
names_task, names_disc = self._get_metric_names()
332+
333+
for metric, name in zip(self.metrics_task_, names_task):
334+
metrics[name + "_s"] = metric(inputs_ys, task_src)
335+
if inputs_yt is not None:
336+
metrics[name + "_t"] = metric(inputs_yt, task_tgt)
337+
338+
for metric, name in zip(self.metrics_disc_, names_disc):
339+
pred = K.concatenate((disc_src, disc_tgt), axis=0)
340+
true = K.concatenate((K.zeros_like(disc_src),
341+
K.ones_like(disc_tgt)), axis=0)
342+
metrics[name] = metric(true, pred)
343+
return metrics
344+
345+
346+
def _initialize_networks(self, shape_Xt):
347+
# Call predict to avoid strange behaviour with
348+
# Sequential model whith unspecified input_shape
349+
zeros_enc_ = self.encoder_.predict(np.zeros((1,) + shape_Xt));
350+
zeros_task_ = self.task_.predict(zeros_enc_);
351+
if zeros_task_.shape[1] * zeros_enc_.shape[1] > self.max_features:
352+
self.discriminator_.predict(np.zeros((1, self.max_features)))
353+
else:
354+
zeros_mapping_ = np.matmul(np.expand_dims(zeros_enc_, 2),
355+
np.expand_dims(zeros_task_, 1))
356+
zeros_mapping_ = np.reshape(zeros_mapping_, (1, -1))
357+
self.discriminator_.predict(zeros_mapping_);
358+
359+
360+
def predict_disc(self, X):
361+
X_enc = self.encoder_.predict(X)
362+
X_task = self.task_.predict(X_enc)
363+
if X_enc.shape[1] * X_task.shape[1] > self.max_features:
364+
X_enc = X_enc.dot(self._random_enc.numpy())
365+
X_task = X_task.dot(self._random_task.numpy())
366+
X_disc = X_enc * X_task
367+
X_disc /= np.sqrt(self.max_features)
368+
else:
369+
X_disc = np.matmul(np.expand_dims(X_enc, 2),
370+
np.expand_dims(X_task, 1))
371+
X_disc = X_disc.transpose([0, 2, 1])
372+
X_disc = X_disc.reshape(-1, X_enc.shape[1] * X_task.shape[1])
373+
y_disc = self.discriminator_.predict(X_disc)
374+
return y_disc

0 commit comments

Comments
 (0)