1
+ """
2
+ WDGRL
3
+ """
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow .keras import Model , Sequential
8
+ from tensorflow .keras .layers import Layer , subtract
9
+ from tensorflow .keras .optimizers import Adam
10
+ import tensorflow .keras .backend as K
11
+
12
+ from adapt .utils import (GradientHandler ,
13
+ check_arrays )
14
+ from adapt .feature_based import BaseDeepFeature
15
+
16
+ EPS = K .epsilon ()
17
+
18
+
19
+ class _Interpolation (Layer ):
20
+ """
21
+ Layer that produces interpolates points between
22
+ two entries, with the distance of the interpolation
23
+ to the first entry.
24
+ """
25
+
26
+ def call (self , inputs ):
27
+ Xs = inputs [0 ]
28
+ Xt = inputs [1 ]
29
+ batch_size = tf .shape (Xs )[0 ]
30
+ dim = tf .shape (Xs )[1 :]
31
+ alphas = tf .random .uniform ([batch_size ]+ [1 ]* len (dim ))
32
+ tiled_shape = tf .concat (([1 ], dim ), 0 )
33
+ tiled_alphas = tf .tile (alphas , tiled_shape )
34
+ differences = Xt - Xs
35
+ interpolates = Xs + tiled_alphas * differences
36
+ distances = K .sqrt (K .mean (K .square (tiled_alphas * differences ),
37
+ axis = [i for i in range (1 , len (dim ))]) + EPS )
38
+ return interpolates , distances
39
+
40
+
41
+ class WDGRL (BaseDeepFeature ):
42
+ """
43
+ WDGRL (Wasserstein Distance Guided Representation Learning) is an
44
+ unsupervised domain adaptation method on the model of the
45
+ :ref:`DANN <adapt.feature_based.DANN>`. In WDGRL the discriminator
46
+ is used to approximate the Wasserstein distance between the
47
+ source and target encoded distributions in the spirit of WGAN.
48
+
49
+ The optimization formulation is the following:
50
+
51
+ .. math::
52
+
53
+ \min_{\phi, F} & \; \mathcal{L}_{task}(F(\phi(X_S)), y_S) +
54
+ \lambda \\ left(D(\phi(X_S)) - D(\phi(X_T)) \\ right) \\ \\
55
+ \max_{D} & \; \\ left(D(\phi(X_S)) - D(\phi(X_T)) \\ right) -
56
+ \\ gamma (||\\ nabla D(\\ alpha \phi(X_S) + (1- \\ alpha) \phi(X_T))||_2 - 1)^2
57
+
58
+ Where:
59
+
60
+ - :math:`(X_S, y_S), (X_T)` are respectively the labeled source data
61
+ and the unlabeled target data.
62
+ - :math:`\phi, F, D` are respectively the **encoder**, the **task**
63
+ and the **discriminator** networks
64
+ - :math:`\lambda` is the trade-off parameter.
65
+ - :math:`\\ gamma` is the gradient penalty parameter.
66
+
67
+ .. figure:: ../_static/images/wdgrl.png
68
+ :align: center
69
+
70
+ WDGRL architecture (source: [1])
71
+
72
+ Parameters
73
+ ----------
74
+ encoder : tensorflow Model (default=None)
75
+ Encoder netwok. If ``None``, a shallow network with 10
76
+ neurons and ReLU activation is used as encoder network.
77
+
78
+ task : tensorflow Model (default=None)
79
+ Task netwok. If ``None``, a two layers network with 10
80
+ neurons per layer and ReLU activation is used as task network.
81
+
82
+ discriminator : tensorflow Model (default=None)
83
+ Discriminator netwok. If ``None``, a two layers network with 10
84
+ neurons per layer and ReLU activation is used as discriminator
85
+ network. Note that the output shape of the discriminator should
86
+ be ``(None, 1)``.
87
+
88
+ lambda_ : float or None (default=1)
89
+ Trade-off parameter. This parameter gives the trade-off
90
+ for the encoder between learning the task and matching
91
+ the source and target distribution. If `lambda_`is small
92
+ the encoder will focus on the task. If `lambda_=0`, WDGRL
93
+ is equivalent to a "source only" method.
94
+
95
+ gamma : float (default=1.)
96
+ Gradient penalization parameter. To well approximate the
97
+ Wasserstein, the `discriminator`should be 1-Lipschitz.
98
+ This constraint is imposed by the gradient penalty term
99
+ of the optimization. The good value `gamma` to use is
100
+ not easy to find. One can check through the metrics that
101
+ the gradient penalty term is in the same order than the
102
+ "disc loss". If `gamma=0`, no penalty is given on the
103
+ discriminator gradient.
104
+
105
+ loss : string or tensorflow loss (default="mse")
106
+ Loss function used for the task.
107
+
108
+ metrics : dict or list of string or tensorflow metrics (default=None)
109
+ Metrics given to the model. If a list is provided,
110
+ metrics are used on both ``task`` and ``discriminator``
111
+ outputs. To give seperated metrics, please provide a
112
+ dict of metrics list with ``"task"`` and ``"disc"`` as keys.
113
+
114
+ optimizer : string or tensorflow optimizer (default=None)
115
+ Optimizer of the model. If ``None``, the
116
+ optimizer is set to tf.keras.optimizers.Adam(0.001)
117
+
118
+ copy : boolean (default=True)
119
+ Whether to make a copy of ``encoder``, ``task`` and
120
+ ``discriminator`` or not.
121
+
122
+ random_state : int (default=None)
123
+ Seed of random generator.
124
+
125
+ Attributes
126
+ ----------
127
+ encoder_ : tensorflow Model
128
+ encoder network.
129
+
130
+ task_ : tensorflow Model
131
+ task network.
132
+
133
+ discriminator_ : tensorflow Model
134
+ discriminator network.
135
+
136
+ model_ : tensorflow Model
137
+ Fitted model: the union of ``encoder_``,
138
+ ``task_`` and ``discriminator_`` networks.
139
+
140
+ history_ : dict
141
+ history of the losses and metrics across the epochs.
142
+ If ``yt`` is given in ``fit`` method, target metrics
143
+ and losses are recorded too.
144
+
145
+ Examples
146
+ --------
147
+ >>> import numpy as np
148
+ >>> from adapt.feature_based import WDGRL
149
+ >>> np.random.seed(0)
150
+ >>> Xs = np.concatenate((np.random.random((100, 1)),
151
+ ... np.zeros((100, 1))), 1)
152
+ >>> Xt = np.concatenate((np.random.random((100, 1)),
153
+ ... np.ones((100, 1))), 1)
154
+ >>> ys = 0.2 * Xs[:, 0]
155
+ >>> yt = 0.2 * Xt[:, 0]
156
+ >>> model = WDGRL(lambda_=0., random_state=0)
157
+ >>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
158
+ >>> model.history_["task_t"][-1]
159
+ 0.0223...
160
+ >>> model = WDGRL(lambda_=1, random_state=0)
161
+ >>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
162
+ >>> model.history_["task_t"][-1]
163
+ 0.0044...
164
+
165
+ See also
166
+ --------
167
+ DANN
168
+ ADDA
169
+ DeepCORAL
170
+
171
+ References
172
+ ----------
173
+ .. [1] `[1] <https://arxiv.org/pdf/1707.01217.pdf>`_ Shen, J., Qu, Y., Zhang, W., \
174
+ and Yu, Y. Wasserstein distance guided representation learning for domain adaptation. \
175
+ In AAAI, 2018.
176
+ """
177
+ def __init__ (self ,
178
+ encoder = None ,
179
+ task = None ,
180
+ discriminator = None ,
181
+ lambda_ = 1. ,
182
+ gamma = 1. ,
183
+ loss = "mse" ,
184
+ metrics = None ,
185
+ optimizer = None ,
186
+ copy = True ,
187
+ random_state = None ):
188
+
189
+ self .lambda_ = lambda_
190
+ self .gamma = gamma
191
+ super ().__init__ (encoder , task , discriminator ,
192
+ loss , metrics , optimizer , copy ,
193
+ random_state )
194
+
195
+
196
+ def create_model (self , inputs_Xs , inputs_Xt ):
197
+
198
+ encoded_src = self .encoder_ (inputs_Xs )
199
+ encoded_tgt = self .encoder_ (inputs_Xt )
200
+ task_src = self .task_ (encoded_src )
201
+ task_tgt = self .task_ (encoded_tgt )
202
+
203
+ flip = GradientHandler (- self .lambda_ , name = "flip" )
204
+ no_grad = GradientHandler (0 , name = "no_grad" )
205
+
206
+ disc_src = flip (encoded_src )
207
+ disc_src = self .discriminator_ (disc_src )
208
+ disc_tgt = flip (encoded_tgt )
209
+ disc_tgt = self .discriminator_ (disc_tgt )
210
+
211
+ encoded_src_no_grad = no_grad (encoded_src )
212
+ encoded_tgt_no_grad = no_grad (encoded_tgt )
213
+
214
+ interpolates , distances = _Interpolation ()([encoded_src_no_grad , encoded_tgt_no_grad ])
215
+ disc_grad = K .abs (
216
+ subtract ([self .discriminator_ (interpolates ), self .discriminator_ (encoded_src_no_grad )])
217
+ )
218
+ disc_grad /= distances
219
+
220
+ outputs = dict (task_src = task_src ,
221
+ task_tgt = task_tgt ,
222
+ disc_src = disc_src ,
223
+ disc_tgt = disc_tgt ,
224
+ disc_grad = disc_grad )
225
+ return outputs
226
+
227
+
228
+ def get_loss (self , inputs_ys ,
229
+ task_src , task_tgt ,
230
+ disc_src , disc_tgt ,
231
+ disc_grad ):
232
+
233
+ loss_task = self .loss_ (inputs_ys , task_src )
234
+ loss_disc = K .mean (disc_src ) - K .mean (disc_tgt )
235
+ gradient_penalty = K .mean (K .square (disc_grad - 1. ))
236
+
237
+ loss = K .mean (loss_task ) - K .mean (loss_disc ) + self .gamma * K .mean (gradient_penalty )
238
+ return loss
239
+
240
+
241
+ def get_metrics (self , inputs_ys , inputs_yt ,
242
+ task_src , task_tgt ,
243
+ disc_src , disc_tgt , disc_grad ):
244
+ metrics = {}
245
+
246
+ task_s = self .loss_ (inputs_ys , task_src )
247
+ disc = K .mean (disc_src ) - K .mean (disc_tgt )
248
+ grad_pen = K .square (disc_grad - 1. )
249
+
250
+ metrics ["task_s" ] = K .mean (task_s )
251
+ metrics ["disc" ] = K .mean (disc )
252
+ metrics ["grad_pen" ] = self .gamma * K .mean (grad_pen )
253
+
254
+ if inputs_yt is not None :
255
+ task_t = self .loss_ (inputs_yt , task_tgt )
256
+ metrics ["task_t" ] = K .mean (task_t )
257
+
258
+ names_task , names_disc = self ._get_metric_names ()
259
+
260
+ for metric , name in zip (self .metrics_task_ , names_task ):
261
+ metrics [name + "_s" ] = metric (inputs_ys , task_src )
262
+ if inputs_yt is not None :
263
+ metrics [name + "_t" ] = metric (inputs_yt , task_tgt )
264
+ return metrics
0 commit comments