Skip to content

Commit 0b3479b

Browse files
Merge pull request #35 from atiqm/master
feat: Add TransferForestClassifier
2 parents 666aed6 + 4fdcf23 commit 0b3479b

File tree

8 files changed

+836
-346
lines changed

8 files changed

+836
-346
lines changed

.github/workflows/check-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
run: |
1919
sudo apt install pandoc
2020
python -m pip install --upgrade pip
21-
pip install sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
21+
pip install jinja2==3.0.3 sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
2222
- name: Install adapt dependencies
2323
run: |
2424
python -m pip install --upgrade pip

.github/workflows/publish-doc-to-remote.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
run: |
1919
sudo apt install pandoc
2020
python -m pip install --upgrade pip
21-
pip install sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
21+
pip install jinja2==3.0.3 sphinx numpydoc nbsphinx sphinx_gallery sphinx_rtd_theme ipython
2222
- name: Install adapt dependencies
2323
run: |
2424
python -m pip install --upgrade pip

adapt/_tree_utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
import copy
22
import numpy as np
33

4+
5+
def _bootstrap_(size,class_wise=False,y=None):
6+
if class_wise:
7+
if y is None:
8+
print("Error : need labels to apply class wise bootstrap.")
9+
else:
10+
inds = []
11+
oob_inds = []
12+
classes_ = set(y)
13+
ind_classes_ = np.zeros(len(classes_),dtype=object)
14+
15+
for j,c in enumerate(classes_):
16+
ind_classes_[j] = np.where(y==c)[0]
17+
s = ind_classes_[j].size
18+
inds += list(np.random.choice(ind_classes_[j], s, replace=True))
19+
oob_inds += list(set(ind_classes_[j]) - set(inds))
20+
21+
inds,oob_inds = np.array(inds),np.array(oob_inds)
22+
else:
23+
inds = np.random.choice(np.arange(size), size, replace=True)
24+
oob_inds = set(np.arange(size)) - set(inds)
25+
oob_inds = np.array(list(oob_inds))
26+
27+
return inds, oob_inds
28+
429
def depth_tree(dt,node=0):
530

631
if dt.tree_.feature[node] == -2:
@@ -464,8 +489,7 @@ def coherent_new_split(phi,th,rule):
464489
return 0,1
465490
else:
466491
return 1,0
467-
468-
492+
469493
def all_coherent_splits(rule,all_splits):
470494

471495
inds = np.zeros(all_splits.shape[0],dtype=bool)
@@ -534,4 +558,3 @@ def bounds_rule(rule,n_features):
534558

535559
return bound_infs,bound_sups
536560

537-

adapt/parameter_based/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN
66
from ._finetuning import FineTuning
77
from ._transfer_tree import TransferTreeClassifier
8+
from ._transfer_tree import TransferForestClassifier
89

910
__all__ = ["RegularTransferLR",
1011
"RegularTransferLC",
1112
"RegularTransferNN",
1213
"FineTuning",
13-
"TransferTreeClassifier"]
14+
"TransferTreeClassifier",
15+
"TransferForestClassifier"]

0 commit comments

Comments
 (0)