1
+ """
2
+ Weighting Adversarial Neural Network (WANN)
3
+ """
4
+ from copy import deepcopy
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from tensorflow .keras import Sequential , Model
9
+ from tensorflow .keras .layers import Layer , multiply
10
+ from tensorflow .keras .callbacks import Callback
11
+ from tensorflow .keras .constraints import MaxNorm
12
+
13
+ from adapt .utils import (GradientHandler ,
14
+ check_arrays ,
15
+ check_one_array ,
16
+ check_network ,
17
+ get_default_task )
18
+ from adapt .feature_based import BaseDeepFeature
19
+
20
+
21
+ class StopTraining (Callback ):
22
+
23
+ def on_train_batch_end (self , batch , logs = {}):
24
+ if logs .get ('loss' ) < 0.01 :
25
+ print ("Weights initialization succeeded !" )
26
+ self .model .stop_training = True
27
+
28
+
29
+ class WANN (BaseDeepFeature ):
30
+ """
31
+ WANN: Weighting Adversarial Neural Network is an instance-based domain adaptation
32
+ method suited for regression tasks. It supposes the supervised setting where some
33
+ labeled target data are available.
34
+
35
+ The goal of WANN is to compute a source instances reweighting which correct
36
+ "shifts" between source and target domain. This is done by minimizing the
37
+ Y-discrepancy distance between source and target distributions
38
+
39
+ WANN involves three networks:
40
+ - the weighting network which learns the source weights.
41
+ - the task network which learns the task.
42
+ - the discrepancy network which is used to estimate a distance
43
+ between the reweighted source and target distributions: the Y-discrepancy
44
+
45
+ Parameters
46
+ ----------
47
+ task : tensorflow Model (default=None)
48
+ Task netwok. If ``None``, a two layers network with 10
49
+ neurons per layer and ReLU activation is used as task network.
50
+
51
+ weighter : tensorflow Model (default=None)
52
+ Encoder netwok. If ``None``, a two layers network with 10
53
+ neurons per layer and ReLU activation is used as
54
+ weighter network.
55
+
56
+ C : float (default=1.)
57
+ Clipping constant for the weighting networks
58
+ regularization. Low value of ``C`` produce smoother
59
+ weighting map. If ``C<=0``, No regularization is added.
60
+
61
+ init_weights : bool (default=True)
62
+ If True a pretraining of ``weighter`` is made such
63
+ that all predicted weights start close to one.
64
+
65
+ loss : string or tensorflow loss (default="mse")
66
+ Loss function used for the task.
67
+
68
+ metrics : dict or list of string or tensorflow metrics (default=None)
69
+ Metrics given to the model. If a list is provided,
70
+ metrics are used on both ``task`` and ``discriminator``
71
+ outputs. To give seperated metrics, please provide a
72
+ dict of metrics list with ``"task"`` and ``"disc"`` as keys.
73
+
74
+ optimizer : string or tensorflow optimizer (default=None)
75
+ Optimizer of the model. If ``None``, the
76
+ optimizer is set to tf.keras.optimizers.Adam(0.001)
77
+
78
+ copy : boolean (default=True)
79
+ Whether to make a copy of ``encoder``, ``task`` and
80
+ ``discriminator`` or not.
81
+
82
+ random_state : int (default=None)
83
+ Seed of random generator.
84
+ """
85
+
86
+ def __init__ (self ,
87
+ task = None ,
88
+ weighter = None ,
89
+ C = 1. ,
90
+ init_weights = True ,
91
+ loss = "mse" ,
92
+ metrics = None ,
93
+ optimizer = None ,
94
+ copy = True ,
95
+ random_state = None ):
96
+
97
+ super ().__init__ (weighter , task , None ,
98
+ loss , metrics , optimizer , copy ,
99
+ random_state )
100
+
101
+ self .init_weights = init_weights
102
+ self .init_weights_ = init_weights
103
+ self .C = C
104
+
105
+ if weighter is None :
106
+ self .weighter_ = get_default_task () #activation="relu"
107
+ else :
108
+ self .weighter_ = self .encoder_
109
+
110
+ if self .C > 0. :
111
+ self ._add_regularization ()
112
+
113
+ self .discriminator_ = check_network (self .task_ ,
114
+ copy = True ,
115
+ display_name = "task" ,
116
+ force_copy = True )
117
+ self .discriminator_ ._name = self .discriminator_ ._name + "_2"
118
+
119
+
120
+ def _add_regularization (self ):
121
+ for layer in self .weighter_ .layers :
122
+ if hasattr (self .weighter_ , "kernel_constraint" ):
123
+ self .weighter_ .kernel_constraint = MaxNorm (self .C )
124
+ if hasattr (self .weighter_ , "bias_constraint" ):
125
+ self .weighter_ .bias_constraint = MaxNorm (self .C )
126
+
127
+
128
+ def fit (self , Xs , ys , Xt , yt , ** fit_params ):
129
+ Xs , ys , Xt , yt = check_arrays (Xs , ys , Xt , yt )
130
+
131
+ if self .init_weights_ :
132
+ self ._init_weighter (Xs )
133
+ self .init_weights_ = False
134
+ self ._fit (Xs , ys , Xt , yt , ** fit_params )
135
+ return self
136
+
137
+
138
+ def _init_weighter (self , Xs ):
139
+ self .weighter_ .compile (optimizer = deepcopy (self .optimizer ), loss = "mse" )
140
+ batch_size = 64
141
+ epochs = max (1 , int (64 * 1000 / len (Xs )))
142
+ callback = StopTraining ()
143
+ self .weighter_ .fit (Xs , np .ones (len (Xs )),
144
+ epochs = epochs , batch_size = batch_size ,
145
+ callbacks = [callback ], verbose = 0 )
146
+
147
+
148
+ def _initialize_networks (self , shape_Xt ):
149
+ self .weighter_ .predict (np .zeros ((1 ,) + shape_Xt ));
150
+ self .task_ .predict (np .zeros ((1 ,) + shape_Xt ));
151
+ self .discriminator_ .predict (np .zeros ((1 ,) + shape_Xt ));
152
+
153
+
154
+ def create_model (self , inputs_Xs , inputs_Xt ):
155
+
156
+ Flip = GradientHandler (- 1. )
157
+
158
+ # Get networks output for both source and target
159
+ weights_s = self .weighter_ (inputs_Xs )
160
+ weights_s = tf .math .abs (weights_s )
161
+ task_s = self .task_ (inputs_Xs )
162
+ task_t = self .task_ (inputs_Xt )
163
+ disc_s = self .discriminator_ (inputs_Xs )
164
+ disc_t = self .discriminator_ (inputs_Xt )
165
+
166
+ # Reversal layer at the end of discriminator
167
+ disc_s = Flip (disc_s )
168
+ disc_t = Flip (disc_t )
169
+
170
+ return dict (task_s = task_s , task_t = task_t ,
171
+ disc_s = disc_s , disc_t = disc_t ,
172
+ weights_s = weights_s )
173
+
174
+
175
+ def get_loss (self , inputs_ys , inputs_yt , task_s ,
176
+ task_t , disc_s , disc_t , weights_s ):
177
+
178
+ loss_task_s = self .loss_ (inputs_ys , task_s )
179
+ loss_task_s = multiply ([weights_s , loss_task_s ])
180
+
181
+ loss_disc_s = self .loss_ (inputs_ys , disc_s )
182
+ loss_disc_s = multiply ([weights_s , loss_disc_s ])
183
+
184
+ loss_disc_t = self .loss_ (inputs_yt , disc_t )
185
+
186
+ loss_disc = (tf .reduce_mean (loss_disc_t ) -
187
+ tf .reduce_mean (loss_disc_s ))
188
+
189
+ loss = tf .reduce_mean (loss_task_s ) + loss_disc
190
+ return loss
191
+
192
+
193
+ def get_metrics (self , inputs_ys , inputs_yt , task_s ,
194
+ task_t , disc_s , disc_t , weights_s ):
195
+
196
+ metrics = {}
197
+
198
+ loss_s = self .loss_ (inputs_ys , task_s )
199
+ loss_t = self .loss_ (inputs_yt , task_t )
200
+
201
+ metrics ["task_s" ] = tf .reduce_mean (loss_s )
202
+ metrics ["task_t" ] = tf .reduce_mean (loss_t )
203
+
204
+ names_task , names_disc = self ._get_metric_names ()
205
+
206
+ for metric , name in zip (self .metrics_task_ , names_task ):
207
+ metrics [name + "_s" ] = metric (inputs_ys , task_s )
208
+ metrics [name + "_t" ] = metric (inputs_yt , task_t )
209
+ return metrics
210
+
211
+
212
+ def predict (self , X ):
213
+ """
214
+ Predict method: return the prediction of task network
215
+
216
+ Parameters
217
+ ----------
218
+ X: array
219
+ input data
220
+
221
+ Returns
222
+ -------
223
+ y_pred: array
224
+ prediction of task network
225
+ """
226
+ X = check_one_array (X )
227
+ return self .task_ .predict (X )
228
+
229
+
230
+ def predict_weights (self , X ):
231
+ """
232
+ Return the predictions of weighting network
233
+
234
+ Parameters
235
+ ----------
236
+ X: array
237
+ input data
238
+
239
+ Returns
240
+ -------
241
+ array:
242
+ weights
243
+ """
244
+ return np .abs (self .weighter_ .predict (X ))
245
+
246
+
247
+ def predict_disc (self , X ):
248
+ """
249
+ Return predictions of the discriminator.
250
+
251
+ Parameters
252
+ ----------
253
+ X : array
254
+ input data
255
+
256
+ Returns
257
+ -------
258
+ y_disc : array
259
+ predictions of discriminator network
260
+ """
261
+ X = check_one_array (X )
262
+ return self .discriminator_ .predict (X )
0 commit comments