Skip to content

Commit d757693

Browse files
Fix coral, kliep, utils
1 parent 5dd2483 commit d757693

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

adapt/feature_based/_coral.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,17 @@ def fit_embeddings(self, Xs, Xt):
196196

197197
self.Cs_ = self.lambda_ * cov_Xs + np.eye(Xs.shape[1])
198198
self.Ct_ = self.lambda_ * cov_Xt + np.eye(Xt.shape[1])
199-
Xs_emb = np.matmul(Xs, linalg.inv(linalg.sqrtm(self.Cs_)))
200-
Xs_emb = np.matmul(Xs_emb, linalg.sqrtm(self.Ct_))
199+
200+
Cs_sqrt_inv = linalg.inv(linalg.sqrtm(self.Cs_))
201+
Ct_sqrt = linalg.sqrtm(self.Ct_)
202+
203+
if np.iscomplexobj(Cs_sqrt_inv):
204+
Cs_sqrt_inv = Cs_sqrt_inv.real
205+
if np.iscomplexobj(Ct_sqrt):
206+
Ct_sqrt = Ct_sqrt.real
207+
208+
Xs_emb = np.matmul(Xs, Cs_sqrt_inv)
209+
Xs_emb = np.matmul(Xs_emb, Ct_sqrt)
201210

202211
if self.verbose:
203212
new_cov_Xs = np.cov(Xs_emb, rowvar=False)
@@ -280,8 +289,16 @@ def predict_features(self, X, domain="tgt"):
280289
if domain in ["tgt", "target"]:
281290
X_emb = X
282291
elif domain in ["src", "source"]:
283-
X_emb = np.matmul(X, linalg.inv(linalg.sqrtm(self.Cs_)))
284-
X_emb = np.matmul(X_emb, linalg.sqrtm(self.Ct_))
292+
Cs_sqrt_inv = linalg.inv(linalg.sqrtm(self.Cs_))
293+
Ct_sqrt = linalg.sqrtm(self.Ct_)
294+
295+
if np.iscomplexobj(Cs_sqrt_inv):
296+
Cs_sqrt_inv = Cs_sqrt_inv.real
297+
if np.iscomplexobj(Ct_sqrt):
298+
Ct_sqrt = Ct_sqrt.real
299+
300+
X_emb = np.matmul(X, Cs_sqrt_inv)
301+
X_emb = np.matmul(X_emb, Ct_sqrt)
285302
else:
286303
raise ValueError("`domain `argument "
287304
"should be `tgt` or `src`, "

adapt/instance_based/_kliep.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def fit_estimator(self, X, y,
288288
else:
289289
if sample_weight is not None:
290290
sample_weight /= (sample_weight.sum() + EPS)
291+
if sample_weight.sum() <= 0:
292+
sample_weight = np.ones(len(sample_weight))
291293
bootstrap_index = np.random.choice(
292294
len(X), size=len(X), replace=True,
293295
p=sample_weight)
@@ -359,7 +361,7 @@ def _fit(self, Xs, Xt, sigma):
359361
alpha += b * ((((1-np.dot(np.transpose(b), alpha)) /
360362
(np.dot(np.transpose(b), b) + EPS))))
361363
alpha = np.maximum(0, alpha)
362-
alpha /= np.dot(np.transpose(b), alpha)
364+
alpha /= (np.dot(np.transpose(b), alpha) + EPS)
363365
objective = np.mean(np.log(np.dot(A, alpha) + EPS))
364366
k += 1
365367

@@ -389,7 +391,7 @@ def _cross_val_jscore(self, Xs, Xt, sigma, cv):
389391
centers,
390392
sigma),
391393
alphas
392-
)
394+
) + EPS
393395
))
394396
cv_scores.append(j_score)
395397
return cv_scores

adapt/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ class GradientHandler(Layer):
364364
"""
365365
def __init__(self, lambda_init=1., name="g_handler"):
366366
super().__init__(name=name)
367+
self.lambda_init=1.
367368
self.lambda_ = tf.Variable(lambda_init,
368369
trainable=False,
369370
dtype="float32")
@@ -384,6 +385,21 @@ def call(self, x):
384385
return _grad_handler(x, self.lambda_)
385386

386387

388+
def get_config(self):
389+
"""
390+
Return config dictionnary.
391+
392+
Returns
393+
-------
394+
dict
395+
"""
396+
config = super().get_config().copy()
397+
config.update({
398+
'lambda_init': self.lambda_init
399+
})
400+
return config
401+
402+
387403
def make_classification_da(n_samples=100,
388404
n_features=2,
389405
random_state=2):

0 commit comments

Comments
 (0)