Skip to content

Commit 27e1320

Browse files
Update tradaboost
1 parent 5e6ee0c commit 27e1320

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

adapt/instance_based/_tradaboost.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,9 @@ def fit(self, X, y, Xt=None, yt=None,
275275
sample_weight_src = sample_weight_src / sum_weights
276276
sample_weight_tgt = sample_weight_tgt / sum_weights
277277

278-
self.estimator_errors_ = np.array(self.estimator_errors_)
279-
self.estimator_weights_ = np.array([
280-
-np.log(err / (1-err) + EPS) + 2*EPS
281-
for err in self.estimator_errors_])
278+
self.estimator_weights_ = [
279+
-np.log(err / (2.-err) + EPS) + 2*EPS
280+
for err in self.estimator_errors_]
282281
return self
283282

284283

@@ -339,10 +338,10 @@ def _boost(self, iboost, Xs, ys, Xt, yt,
339338
sample_weight_tgt.sum())
340339

341340
# For multiclassification and regression error can be greater than 0.5
342-
# if estimator_error > 0.5:
343-
# estimator_error = 0.5
341+
# if estimator_error > 0.49:
342+
# estimator_error = 0.49
344343

345-
beta_t = 2*estimator_error / (2. - estimator_error)
344+
beta_t = estimator_error / (2. - estimator_error)
346345

347346
beta_s = 1. / (1. + np.sqrt(
348347
2. * np.log(len(Xs)) / self.n_estimators
@@ -386,7 +385,8 @@ def predict(self, X):
386385
"""
387386
X = check_array(X)
388387
N = len(self.estimators_)
389-
weights = self.estimator_weights_[int(N/2):]
388+
weights = np.array(self.estimator_weights_)
389+
weights = weights[int(N/2):]
390390
predictions = []
391391
for est in self.estimators_[int(N/2):]:
392392
if isinstance(est, BaseEstimator):
@@ -457,9 +457,9 @@ def score(self, X, y):
457457
X, y = check_arrays(X, y)
458458
yp = self.predict(X)
459459
if isinstance(self, TrAdaBoostR2):
460-
score = r2_score(yp, y)
460+
score = r2_score(y, yp)
461461
else:
462-
score = accuracy_score(yp, y)
462+
score = accuracy_score(y, yp)
463463
return score
464464

465465

@@ -589,7 +589,7 @@ def predict(self, X):
589589
"""
590590
X = check_array(X)
591591
N = len(self.estimators_)
592-
weights = self.estimator_weights_
592+
weights = np.array(self.estimator_weights_)
593593
weights = weights[int(N/2):]
594594
predictions = []
595595
for est in self.estimators_[int(N/2):]:
@@ -811,13 +811,13 @@ def fit(self, X, y, Xt=None, yt=None,
811811
print("Iteration %i - Cross-validation score: %.4f (%.4f)"%
812812
(iboost, np.mean(cv_score), np.std(cv_score)))
813813

814-
self.estimator_errors_.append(cv_score.mean())
815-
816814
sample_weight_src, sample_weight_tgt = self._boost(
817815
iboost, Xs, ys, Xt, yt,
818816
sample_weight_src, sample_weight_tgt,
819817
**fit_params
820818
)
819+
820+
self.estimator_errors_.append(cv_score.mean())
821821

822822
if sample_weight_src is None:
823823
break
@@ -827,7 +827,6 @@ def fit(self, X, y, Xt=None, yt=None,
827827
sample_weight_src = sample_weight_src / sum_weights
828828
sample_weight_tgt = sample_weight_tgt / sum_weights
829829

830-
self.estimator_errors_ = np.array(self.estimator_errors_)
831830
return self
832831

833832

@@ -959,7 +958,7 @@ def predict(self, X):
959958
"""
960959
X = check_array(X)
961960
best_estimator = self.estimators_[
962-
self.estimator_errors_.argmin()]
961+
np.argmin(self.estimator_errors_)]
963962
return best_estimator.predict(X)
964963

965964

@@ -984,7 +983,7 @@ def predict_weights(self, domain="src"):
984983
weights : source sample weights
985984
"""
986985
if hasattr(self, "sample_weights_src_"):
987-
arg = self.estimator_errors_.argmin()
986+
arg = np.argmin(self.estimator_errors_)
988987
if domain in ["src", "source"]:
989988
return self.sample_weights_src_[arg]
990989
elif domain in ["tgt", "target"]:

tests/test_tradaboost.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import copy
66
import numpy as np
7-
from sklearn.linear_model import LinearRegression, LogisticRegression
7+
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
8+
from sklearn.metrics import r2_score, accuracy_score
89
import tensorflow as tf
910

1011
from adapt.instance_based import (TrAdaBoost,
@@ -34,6 +35,8 @@ def test_tradaboost_fit():
3435
solver='lbfgs'),
3536
n_estimators=20)
3637
model.fit(Xs, ys_classif, Xt=Xt[:10], yt=yt_classif[:10])
38+
score = model.score(Xs, ys_classif)
39+
assert score == accuracy_score(ys_classif, model.predict(Xs))
3740
assert len(model.sample_weights_src_[0]) == 100
3841
assert (model.sample_weights_src_[0][:50].sum() ==
3942
model.sample_weights_src_[0][50:].sum())
@@ -58,13 +61,18 @@ def test_tradaboost_fit_keras_model():
5861
model.fit(Xs, np.random.random((100, 2)),
5962
Xt=Xt[:10], yt=np.random.random((10, 2)))
6063

64+
score = model.score(Xs, ys_classif)
65+
assert score == accuracy_score(ys_classif, model.predict(Xs))
66+
6167

6268
def test_tradaboostr2_fit():
6369
np.random.seed(0)
6470
model = TrAdaBoostR2(LinearRegression(fit_intercept=False),
6571
n_estimators=100,
6672
Xt=Xt[:10], yt=yt_reg[:10])
6773
model.fit(Xs, ys_reg)
74+
score = model.score(Xs, ys_reg)
75+
assert score == r2_score(ys_reg, model.predict(Xs))
6876
assert np.abs(model.estimators_[-1].coef_[0] - 1.) < 1
6977
assert np.abs(model.sample_weights_src_[-1][:50].sum() /
7078
model.sample_weights_src_[-1][50:].sum()) > 10
@@ -80,7 +88,9 @@ def test_twostagetradaboostr2_fit():
8088
np.random.seed(0)
8189
model = TwoStageTrAdaBoostR2(LinearRegression(fit_intercept=False),
8290
n_estimators=10)
83-
model.fit(Xs, ys_reg.ravel(), Xt=Xt[:10], yt=yt_reg[:10].ravel())
91+
model.fit(Xs, ys_reg.ravel(), Xt=Xt[:10], yt=yt_reg[:10].ravel())
92+
score = model.score(Xs, ys_reg)
93+
assert score == r2_score(ys_reg, model.predict(Xs))
8494
assert np.abs(model.estimators_[-1].estimators_[-1].coef_[0]
8595
- 1.) < 1
8696
assert np.abs(model.sample_weights_src_[-1][:50].sum() /
@@ -103,3 +113,75 @@ def test_tradaboost_deepcopy():
103113
copy_model = copy.deepcopy(model)
104114
assert np.all(model.predict(Xt) == copy_model.predict(Xt))
105115
assert hex(id(model)) != hex(id(copy_model))
116+
117+
118+
def test_tradaboost_multiclass():
119+
np.random.seed(0)
120+
X = np.random.randn(10, 3)
121+
y = np.random.choice(3, 10)
122+
model = TrAdaBoost(LogisticRegression(penalty='none',
123+
solver='lbfgs'), Xt=X, yt=y,
124+
n_estimators=20)
125+
model.fit(X, y)
126+
yp = model.predict(X)
127+
score = model.score(X, y)
128+
assert set(np.unique(yp)) == set([0,1,2])
129+
assert score == accuracy_score(y, yp)
130+
131+
132+
def test_tradaboost_multireg():
133+
np.random.seed(0)
134+
X = np.random.randn(10, 3)
135+
y = np.random.randn(10, 5)
136+
model = TrAdaBoostR2(LinearRegression(),
137+
Xt=X, yt=y,
138+
n_estimators=20)
139+
model.fit(X, y)
140+
yp = model.predict(X)
141+
score = model.score(X, y)
142+
assert np.all(yp.shape == (10, 5))
143+
assert score == r2_score(y, yp)
144+
145+
model = TwoStageTrAdaBoostR2(LinearRegression(),
146+
Xt=X, yt=y,
147+
n_estimators=3,
148+
n_estimators_fs=3)
149+
model.fit(X, y)
150+
yp = model.predict(X)
151+
score = model.score(X, y)
152+
assert np.all(yp.shape == (10, 5))
153+
assert score == r2_score(y, yp)
154+
155+
156+
def test_tradaboost_above_05():
157+
np.random.seed(0)
158+
X = np.random.randn(10, 3)
159+
y = np.random.randn(10, 5)
160+
model = TrAdaBoostR2(LinearRegression(),
161+
Xt=Xt[:10], yt=yt_reg[:10],
162+
n_estimators=20)
163+
model.fit(Xs, ys_reg)
164+
assert np.any(np.array(model.estimator_errors_)>0.5)
165+
166+
model = TrAdaBoostR2(Ridge(1.),
167+
Xt=Xt[:20], yt=yt_reg[:20],
168+
n_estimators=20)
169+
model.fit(Xs, ys_reg)
170+
assert np.all(np.array(model.estimator_errors_)<0.5)
171+
172+
173+
def test_tradaboost_lr():
174+
np.random.seed(0)
175+
model = TrAdaBoost(LogisticRegression(penalty='none'),
176+
Xt=Xt[:10], yt=yt_classif[:10],
177+
n_estimators=20, lr=.1)
178+
model.fit(Xs, ys_classif)
179+
err1 = model.estimator_errors_
180+
181+
model = TrAdaBoost(LogisticRegression(penalty='none'),
182+
Xt=Xt[:10], yt=yt_classif[:10],
183+
n_estimators=20, lr=2.)
184+
model.fit(Xs, ys_classif)
185+
err2 = model.estimator_errors_
186+
187+
assert np.sum(err1) > 10 * np.sum(err2)

0 commit comments

Comments
 (0)