Skip to content

Commit 5e6ee0c

Browse files
Add docs new method + change score base
1 parent b54dfdd commit 5e6ee0c

File tree

14 files changed

+520
-160
lines changed

14 files changed

+520
-160
lines changed

adapt/base.py

Lines changed: 104 additions & 117 deletions
Large diffs are not rendered by default.

adapt/feature_based/_adda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def pretrain_step(self, data):
178178
ys_pred = tf.reshape(ys_pred, tf.shape(ys))
179179

180180
# Compute the loss value
181-
loss = self.task_loss_(ys, ys_pred)
181+
loss = tf.reduce_mean(self.task_loss_(ys, ys_pred))
182182
task_loss = loss + sum(self.task_.losses)
183183
enc_loss = loss + sum(self.encoder_src_.losses)
184184

adapt/feature_based/_coral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def transform(self, X, domain="tgt"):
185185
Input data.
186186
187187
domain : str (default="tgt")
188-
Choose between ``"source", "src"`` and
188+
Choose between ``"source", "src"`` or
189189
``"target", "tgt"`` feature embedding.
190190
191191
Returns

adapt/feature_based/_fmmd.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class fMMD(BaseAdaptEstimator):
9292
"""
9393
fMMD : feature Selection with MMD
9494
95-
LDM selects input features inorder to minimize the
95+
LDM selects input features in order to minimize the
9696
maximum mean discrepancy (MMD) between the source and
9797
the target data.
9898
@@ -160,7 +160,25 @@ def __init__(self,
160160
super().__init__(**kwargs)
161161

162162

163-
def fit_transform(self, Xs, Xt, **fit_params):
163+
def fit_transform(self, Xs, Xt, **kwargs):
164+
"""
165+
Fit embeddings.
166+
167+
Parameters
168+
----------
169+
Xs : array
170+
Input source data.
171+
172+
Xt : array
173+
Input target data.
174+
175+
kwargs : key, value argument
176+
Not used, present here for adapt consistency.
177+
178+
Returns
179+
-------
180+
Xs_emb : embedded source data
181+
"""
164182
Xs = check_array(Xs)
165183
Xt = check_array(Xt)
166184
set_random_seed(self.random_state)
@@ -217,5 +235,18 @@ def F(x=None, z=None):
217235

218236

219237
def transform(self, X):
238+
"""
239+
Return the projection of X on the selected featues.
240+
241+
Parameters
242+
----------
243+
X : array
244+
Input data.
245+
246+
Returns
247+
-------
248+
X_emb : array
249+
Embeddings of X.
250+
"""
220251
X = check_array(X)
221252
return X[:, self.selected_features_]

adapt/feature_based/_mcd.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,27 @@ def train_step(self, data):
223223
return logs
224224

225225

226+
def predict_avg(self, X):
227+
"""
228+
Return the average predictions between
229+
task_ and discriminator_ networks.
230+
231+
Parameters
232+
----------
233+
X : array
234+
Input data
235+
236+
Returns
237+
-------
238+
y_avg : array
239+
Average predictions
240+
"""
241+
ypt = self.task_.predict(self.transform(X))
242+
ypd = self.discriminator_.predict(self.transform(X))
243+
yp_avg = 0.5 * (ypt+ypd)
244+
return yp_avg
245+
246+
226247
def _initialize_networks(self):
227248
if self.encoder is None:
228249
self.encoder_ = get_default_encoder(name="encoder")

adapt/feature_based/_sa.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,25 @@ def __init__(self,
4949
super().__init__(**kwargs)
5050

5151

52-
def fit_transform(self, Xs, Xt, **fit_params):
52+
def fit_transform(self, Xs, Xt, **kwargs):
53+
"""
54+
Fit embeddings.
55+
56+
Parameters
57+
----------
58+
Xs : array
59+
Input source data.
60+
61+
Xt : array
62+
Input target data.
63+
64+
kwargs : key, value argument
65+
Not used, present here for adapt consistency.
66+
67+
Returns
68+
-------
69+
Xs_emb : embedded source data
70+
"""
5371
Xs = check_array(Xs)
5472
Xt = check_array(Xt)
5573
set_random_seed(self.random_state)
@@ -67,6 +85,29 @@ def fit_transform(self, Xs, Xt, **fit_params):
6785

6886

6987
def transform(self, X, domain="tgt"):
88+
"""
89+
Project X in the target subspace.
90+
91+
The paramter ``domain`` specify if X should
92+
be considered as source or target data. As the
93+
transformation is assymetric, the source transformation
94+
should be applied on source data and the target
95+
transformation on target data.
96+
97+
Parameters
98+
----------
99+
X : array
100+
Input data.
101+
102+
domain : str (default="tgt")
103+
Choose between ``"source", "src"`` or
104+
``"target", "tgt"`` feature embedding.
105+
106+
Returns
107+
-------
108+
X_emb : array
109+
Embeddings of X.
110+
"""
70111
X = check_array(X)
71112

72113
if domain in ["tgt", "target"]:

adapt/instance_based/_ldm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ def __init__(self,
6666

6767

6868
def fit_weights(self, Xs, Xt, **kwargs):
69+
"""
70+
Fit importance weighting.
71+
72+
Parameters
73+
----------
74+
Xs : array
75+
Input source data.
76+
77+
Xt : array
78+
Input target data.
79+
80+
kwargs : key, value argument
81+
Not used, present here for adapt consistency.
82+
83+
Returns
84+
-------
85+
weights_ : sample weights
86+
"""
6987
Xs = check_array(Xs)
7088
Xt = check_array(Xt)
7189
set_random_seed(self.random_state)

adapt/instance_based/_nnw.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,24 @@ def __init__(self,
100100

101101

102102
def fit_weights(self, Xs, Xt, **kwargs):
103+
"""
104+
Fit importance weighting.
105+
106+
Parameters
107+
----------
108+
Xs : array
109+
Input source data.
110+
111+
Xt : array
112+
Input target data.
113+
114+
kwargs : key, value argument
115+
Not used, present here for adapt consistency.
116+
117+
Returns
118+
-------
119+
weights_ : sample weights
120+
"""
103121
Xs = check_array(Xs)
104122
Xt = check_array(Xt)
105123
set_random_seed(self.random_state)

0 commit comments

Comments
 (0)