Skip to content

Commit f40d0b4

Browse files
Merge pull request #15 from antoinedemathelin/master
Add WANN
2 parents 535a153 + f674bc0 commit f40d0b4

File tree

1 file changed

+262
-0
lines changed

1 file changed

+262
-0
lines changed

adapt/instance_based/_wann.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
"""
2+
Weighting Adversarial Neural Network (WANN)
3+
"""
4+
from copy import deepcopy
5+
6+
import numpy as np
7+
import tensorflow as tf
8+
from tensorflow.keras import Sequential, Model
9+
from tensorflow.keras.layers import Layer, multiply
10+
from tensorflow.keras.callbacks import Callback
11+
from tensorflow.keras.constraints import MaxNorm
12+
13+
from adapt.utils import (GradientHandler,
14+
check_arrays,
15+
check_one_array,
16+
check_network,
17+
get_default_task)
18+
from adapt.feature_based import BaseDeepFeature
19+
20+
21+
class StopTraining(Callback):
22+
23+
def on_train_batch_end(self, batch, logs={}):
24+
if logs.get('loss') < 0.01:
25+
print("Weights initialization succeeded !")
26+
self.model.stop_training = True
27+
28+
29+
class WANN(BaseDeepFeature):
30+
"""
31+
WANN: Weighting Adversarial Neural Network is an instance-based domain adaptation
32+
method suited for regression tasks. It supposes the supervised setting where some
33+
labeled target data are available.
34+
35+
The goal of WANN is to compute a source instances reweighting which correct
36+
"shifts" between source and target domain. This is done by minimizing the
37+
Y-discrepancy distance between source and target distributions
38+
39+
WANN involves three networks:
40+
- the weighting network which learns the source weights.
41+
- the task network which learns the task.
42+
- the discrepancy network which is used to estimate a distance
43+
between the reweighted source and target distributions: the Y-discrepancy
44+
45+
Parameters
46+
----------
47+
task : tensorflow Model (default=None)
48+
Task netwok. If ``None``, a two layers network with 10
49+
neurons per layer and ReLU activation is used as task network.
50+
51+
weighter : tensorflow Model (default=None)
52+
Encoder netwok. If ``None``, a two layers network with 10
53+
neurons per layer and ReLU activation is used as
54+
weighter network.
55+
56+
C : float (default=1.)
57+
Clipping constant for the weighting networks
58+
regularization. Low value of ``C`` produce smoother
59+
weighting map. If ``C<=0``, No regularization is added.
60+
61+
init_weights : bool (default=True)
62+
If True a pretraining of ``weighter`` is made such
63+
that all predicted weights start close to one.
64+
65+
loss : string or tensorflow loss (default="mse")
66+
Loss function used for the task.
67+
68+
metrics : dict or list of string or tensorflow metrics (default=None)
69+
Metrics given to the model. If a list is provided,
70+
metrics are used on both ``task`` and ``discriminator``
71+
outputs. To give seperated metrics, please provide a
72+
dict of metrics list with ``"task"`` and ``"disc"`` as keys.
73+
74+
optimizer : string or tensorflow optimizer (default=None)
75+
Optimizer of the model. If ``None``, the
76+
optimizer is set to tf.keras.optimizers.Adam(0.001)
77+
78+
copy : boolean (default=True)
79+
Whether to make a copy of ``encoder``, ``task`` and
80+
``discriminator`` or not.
81+
82+
random_state : int (default=None)
83+
Seed of random generator.
84+
"""
85+
86+
def __init__(self,
87+
task=None,
88+
weighter=None,
89+
C=1.,
90+
init_weights=True,
91+
loss="mse",
92+
metrics=None,
93+
optimizer=None,
94+
copy=True,
95+
random_state=None):
96+
97+
super().__init__(weighter, task, None,
98+
loss, metrics, optimizer, copy,
99+
random_state)
100+
101+
self.init_weights = init_weights
102+
self.init_weights_ = init_weights
103+
self.C = C
104+
105+
if weighter is None:
106+
self.weighter_ = get_default_task() #activation="relu"
107+
else:
108+
self.weighter_ = self.encoder_
109+
110+
if self.C > 0.:
111+
self._add_regularization()
112+
113+
self.discriminator_ = check_network(self.task_,
114+
copy=True,
115+
display_name="task",
116+
force_copy=True)
117+
self.discriminator_._name = self.discriminator_._name + "_2"
118+
119+
120+
def _add_regularization(self):
121+
for layer in self.weighter_.layers:
122+
if hasattr(self.weighter_, "kernel_constraint"):
123+
self.weighter_.kernel_constraint = MaxNorm(self.C)
124+
if hasattr(self.weighter_, "bias_constraint"):
125+
self.weighter_.bias_constraint = MaxNorm(self.C)
126+
127+
128+
def fit(self, Xs, ys, Xt, yt, **fit_params):
129+
Xs, ys, Xt, yt = check_arrays(Xs, ys, Xt, yt)
130+
131+
if self.init_weights_:
132+
self._init_weighter(Xs)
133+
self.init_weights_ = False
134+
self._fit(Xs, ys, Xt, yt, **fit_params)
135+
return self
136+
137+
138+
def _init_weighter(self, Xs):
139+
self.weighter_.compile(optimizer=deepcopy(self.optimizer), loss="mse")
140+
batch_size = 64
141+
epochs = max(1, int(64*1000/len(Xs)))
142+
callback = StopTraining()
143+
self.weighter_.fit(Xs, np.ones(len(Xs)),
144+
epochs=epochs, batch_size=batch_size,
145+
callbacks=[callback], verbose=0)
146+
147+
148+
def _initialize_networks(self, shape_Xt):
149+
self.weighter_.predict(np.zeros((1,) + shape_Xt));
150+
self.task_.predict(np.zeros((1,) + shape_Xt));
151+
self.discriminator_.predict(np.zeros((1,) + shape_Xt));
152+
153+
154+
def create_model(self, inputs_Xs, inputs_Xt):
155+
156+
Flip = GradientHandler(-1.)
157+
158+
# Get networks output for both source and target
159+
weights_s = self.weighter_(inputs_Xs)
160+
weights_s = tf.math.abs(weights_s)
161+
task_s = self.task_(inputs_Xs)
162+
task_t = self.task_(inputs_Xt)
163+
disc_s = self.discriminator_(inputs_Xs)
164+
disc_t = self.discriminator_(inputs_Xt)
165+
166+
# Reversal layer at the end of discriminator
167+
disc_s = Flip(disc_s)
168+
disc_t = Flip(disc_t)
169+
170+
return dict(task_s=task_s, task_t=task_t,
171+
disc_s=disc_s, disc_t=disc_t,
172+
weights_s=weights_s)
173+
174+
175+
def get_loss(self, inputs_ys, inputs_yt, task_s,
176+
task_t, disc_s, disc_t, weights_s):
177+
178+
loss_task_s = self.loss_(inputs_ys, task_s)
179+
loss_task_s = multiply([weights_s, loss_task_s])
180+
181+
loss_disc_s = self.loss_(inputs_ys, disc_s)
182+
loss_disc_s = multiply([weights_s, loss_disc_s])
183+
184+
loss_disc_t = self.loss_(inputs_yt, disc_t)
185+
186+
loss_disc = (tf.reduce_mean(loss_disc_t) -
187+
tf.reduce_mean(loss_disc_s))
188+
189+
loss = tf.reduce_mean(loss_task_s) + loss_disc
190+
return loss
191+
192+
193+
def get_metrics(self, inputs_ys, inputs_yt, task_s,
194+
task_t, disc_s, disc_t, weights_s):
195+
196+
metrics = {}
197+
198+
loss_s = self.loss_(inputs_ys, task_s)
199+
loss_t = self.loss_(inputs_yt, task_t)
200+
201+
metrics["task_s"] = tf.reduce_mean(loss_s)
202+
metrics["task_t"] = tf.reduce_mean(loss_t)
203+
204+
names_task, names_disc = self._get_metric_names()
205+
206+
for metric, name in zip(self.metrics_task_, names_task):
207+
metrics[name + "_s"] = metric(inputs_ys, task_s)
208+
metrics[name + "_t"] = metric(inputs_yt, task_t)
209+
return metrics
210+
211+
212+
def predict(self, X):
213+
"""
214+
Predict method: return the prediction of task network
215+
216+
Parameters
217+
----------
218+
X: array
219+
input data
220+
221+
Returns
222+
-------
223+
y_pred: array
224+
prediction of task network
225+
"""
226+
X = check_one_array(X)
227+
return self.task_.predict(X)
228+
229+
230+
def predict_weights(self, X):
231+
"""
232+
Return the predictions of weighting network
233+
234+
Parameters
235+
----------
236+
X: array
237+
input data
238+
239+
Returns
240+
-------
241+
array:
242+
weights
243+
"""
244+
return np.abs(self.weighter_.predict(X))
245+
246+
247+
def predict_disc(self, X):
248+
"""
249+
Return predictions of the discriminator.
250+
251+
Parameters
252+
----------
253+
X : array
254+
input data
255+
256+
Returns
257+
-------
258+
y_disc : array
259+
predictions of discriminator network
260+
"""
261+
X = check_one_array(X)
262+
return self.discriminator_.predict(X)

0 commit comments

Comments
 (0)