Skip to content

Commit eea283c

Browse files
Add test tradaboost
1 parent af5fc37 commit eea283c

File tree

5 files changed

+78
-22
lines changed

5 files changed

+78
-22
lines changed

adapt/instance_based/_tradaboost.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def fit(self, X, y, Xt=None, yt=None,
235235
Xt, yt = self._get_target_data(Xt, yt)
236236
Xt, yt = check_arrays(Xt, yt, accept_sparse=True)
237237

238-
if isinstance(self, TrAdaBoost):
238+
if not isinstance(self, TrAdaBoostR2) and isinstance(self.estimator, BaseEstimator):
239239
self.label_encoder_ = LabelEncoder()
240240
ys = self.label_encoder_.fit_transform(ys)
241241
yt = self.label_encoder_.transform(yt)
@@ -454,7 +454,10 @@ def predict(self, X):
454454
predictions.append(y_pred)
455455
predictions = np.stack(predictions, -1)
456456
weighted_vote = predictions.dot(weights).argmax(1)
457-
return self.label_encoder_.inverse_transform(weighted_vote)
457+
if hasattr(self, "label_encoder_"):
458+
return self.label_encoder_.inverse_transform(weighted_vote)
459+
else:
460+
return weighted_vote
458461

459462

460463
def predict_weights(self, domain="src"):
@@ -951,7 +954,7 @@ def func(x):
951954
def _cross_val_score(self, Xs, ys, Xt, yt,
952955
sample_weight_src, sample_weight_tgt,
953956
**fit_params):
954-
if len(Xt) >= self.cv:
957+
if Xt.shape[0] >= self.cv:
955958
cv = self.cv
956959
else:
957960
cv = Xt.shape[0]

src_docs/_static/css/custom.css

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,19 @@ img.map-adapt {
251251
#selecting-the-right-domain-adaptation-model {
252252
padding-bottom: 600px;
253253
}
254+
255+
256+
blockquote {
257+
border-left: 5px solid #D3D3D3;
258+
padding: 0 1em;
259+
}
260+
261+
262+
div.alert.alert-block.alert-info {
263+
background: #e7f2fa;
264+
padding: 12px;
265+
margin-bottom: 12px;
266+
border-top-color: #6ab0de;
267+
border-top-width: 12px;
268+
border-top-style: solid;
269+
}

src_docs/_templates/layout.html

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
{%- if pathto("examples/Regression") == "#" %}{% set Regression = "" %}{% else %}{% set Regression = "#" %}{% endif %}
88
{%- if pathto("examples/sample_bias") == "#" %}{% set sample_bias = "" %}{% else %}{% set sample_bias = "#" %}{% endif %}
99
{%- if pathto("examples/Multi_fidelity") == "#" %}{% set Multi_fidelity = "" %}{% else %}{% set Multi_fidelity = "#" %}{% endif %}
10-
{%- if pathto("examples/Rotation") == "#" %}{% set Rotation = "" %}{% else %}{% set Rotation = "#" %}{% endif %}
10+
{%- if pathto("examples/Rotation") == "#" %}{% set Rotation = "" %}{% else %}{% set Rotation = "#" %}{% endif %}
11+
{%- if pathto("examples/tradaboost_experiments") == "#" %}{% set tradaboost_experiments = "" %}{% else %}{% set tradaboost_experiments = "#" %}{% endif %}
1112

1213
{% block menu %}
1314
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
@@ -150,5 +151,12 @@
150151
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/Multi_fidelity") }}{{ Multi_fidelity }}{{ "RegularTransferNN" }}">RegularTransferNN</a></li>
151152
</ul>
152153
</li>
154+
<li class="toctree-l1"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}">TrAdaBoost Experiments</a><ul>
155+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}{{ tradaboost_experiments }}{{ "Mushrooms" }}">Mushrooms</a></li>
156+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}{{ tradaboost_experiments }}{{ "20-NewsGroup" }}">20-NewsGroup</a></li>
157+
</ul>
158+
</li>
159+
160+
153161
</ul>
154162
{% endblock %}

src_docs/examples/tradaboost_experiments.ipynb

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,32 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "8ea3c629-ffd7-48f6-9ff1-96a43d120c9f",
6+
"metadata": {},
7+
"source": [
8+
"# Reproduction of the TrAdaBoost experiments"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "3aedc081-d6af-45dc-a9e1-bcd71e83f90b",
14+
"metadata": {},
15+
"source": [
16+
"<div class=\"btn btn-notebook\" role=\"button\">\n",
17+
" <img src=\"../_static/images/github_logo_32px.png\"> [View on GitHub](https://github.com/adapt-python/notebooks/blob/d0364973c642ea4880756cef4e9f2ee8bb5e8495/Two_moons.ipynb)\n",
18+
"</div>"
19+
]
20+
},
321
{
422
"cell_type": "markdown",
523
"id": "a22504a0-5ff7-498e-bc35-a6c101926204",
624
"metadata": {},
725
"source": [
8-
"# Reproduction of the TrAdaBoost experiments\n",
9-
"\n",
1026
"The purpose of this example is to reproduce the results obtained in the paper [Boosting for Transfer Learning (2007)](https://cse.hkust.edu.hk/~qyang/Docs/2007/tradaboost.pdf). In this work, the authors developed a transfer algorithm called TrAdaBoost dedicated for [supervised domain adaptation](https://adapt-python.github.io/adapt/map.html). You can find more details about this algorithm [here](https://adapt-python.github.io/adapt/generated/adapt.instance_based.TrAdaBoost.html). The goal of this algorithm is to combine a source dataset with many labeled instances to a target dataset with few labels in order to learn a good model on the target domain.\n",
1127
"\n",
1228
"We try to reproduce the two following exepriments:\n",
29+
"\n",
1330
"- Mushrooms\n",
1431
"- 20newsgroups\n",
1532
"\n"
@@ -314,8 +331,7 @@
314331
"metadata": {},
315332
"source": [
316333
"<div class=\"alert alert-block alert-info\">\n",
317-
"<b>Note:</b> When looking at the number of instances in each category of the *stalk-shape* attribute, it seems that the authors inversed the source data set with the target one in the text above. Indeed, when looking at Table 1 in the paper, the number of source instances should be 4608 which corresponds to the <b>tapering</b> class and not the <b>enlarging</b> one.</div>\n",
318-
"\n"
334+
"**Note:** When looking at the number of instances in each category of the *stalk-shape* attribute, it seems that the authors inversed the source data set with the target one in the text above. Indeed, when looking at Table 1 in the paper, the number of source instances should be 4608 which corresponds to the **tapering** class and not the **enlarging** one.</div>"
319335
]
320336
},
321337
{
@@ -552,7 +568,7 @@
552568
"id": "babd1cce-c39e-4516-9a10-9d7e9f00f190",
553569
"metadata": {},
554570
"source": [
555-
"## 20 NewsGroup experiments"
571+
"## 20 NewsGroup"
556572
]
557573
},
558574
{
@@ -641,17 +657,6 @@
641657
"We conduct the three proposed experiments \"rec vs talk\", \"rec vs sci\" and \"sci vs talk\". We set the number of TrAdaBoost estimators to 10 instead of 100. We found that using 100 estimators give poor results for TrAdaBoost."
642658
]
643659
},
644-
{
645-
"cell_type": "code",
646-
"execution_count": 26,
647-
"id": "1d47ec68-f638-42aa-9c11-c37caa61fe14",
648-
"metadata": {},
649-
"outputs": [],
650-
"source": [
651-
"# source_sci = ['sci.crypt', 'sci.electronics']\n",
652-
"# target_sci = ['sci.med', 'sci.space']"
653-
]
654-
},
655660
{
656661
"cell_type": "markdown",
657662
"id": "054fa3be-c83e-4c64-a3ff-58c51ea397fe",

tests/test_tradaboost.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import copy
66
import numpy as np
7-
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
7+
import scipy
8+
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, RidgeClassifier
89
from sklearn.metrics import r2_score, accuracy_score
910
import tensorflow as tf
1011

@@ -184,4 +185,27 @@ def test_tradaboost_lr():
184185
model.fit(Xs, ys_classif)
185186
err2 = model.estimator_errors_
186187

187-
assert np.sum(err1) > 10 * np.sum(err2)
188+
assert np.sum(err1) > 5 * np.sum(err2)
189+
190+
191+
def test_tradaboost_sparse_matrix():
192+
X = scipy.sparse.csr_matrix(np.eye(200))
193+
y = np.random.randn(100)
194+
yc = np.random.choice(["e", "p"], 100)
195+
Xt = X[:100]
196+
Xs = X[100:]
197+
198+
model = TrAdaBoost(RidgeClassifier(), Xt=Xt[:10], yt=yc[:10])
199+
model.fit(Xs, yc)
200+
model.score(Xt, yc)
201+
model.predict(Xs)
202+
203+
model = TrAdaBoostR2(Ridge(), Xt=Xt[:10], yt=y[:10])
204+
model.fit(Xs, y)
205+
model.score(Xt, y)
206+
model.predict(Xs)
207+
208+
model = TwoStageTrAdaBoostR2(Ridge(), Xt=Xt[:10], yt=y[:10], n_estimators=3)
209+
model.fit(Xs, y)
210+
model.score(Xt, y)
211+
model.predict(Xs)

0 commit comments

Comments
 (0)