Skip to content

Commit bf27e4f

Browse files
committed
TransferForestClassifier
1 parent 666aed6 commit bf27e4f

File tree

3 files changed

+694
-348
lines changed

3 files changed

+694
-348
lines changed

adapt/_tree_utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
import copy
22
import numpy as np
33

4+
5+
def _bootstrap_(size,class_wise=False,y=None):
6+
if class_wise:
7+
if y is None:
8+
print("Error : need labels to apply class wise bootstrap.")
9+
else:
10+
inds = []
11+
oob_inds = []
12+
classes_ = set(y)
13+
ind_classes_ = np.zeros(len(classes_),dtype=object)
14+
15+
for j,c in enumerate(classes_):
16+
ind_classes_[j] = np.where(y==c)[0]
17+
s = ind_classes_[j].size
18+
inds += list(np.random.choice(ind_classes_[j], s, replace=True))
19+
oob_inds += list(set(ind_classes_[j]) - set(inds))
20+
21+
inds,oob_inds = np.array(inds),np.array(oob_inds)
22+
else:
23+
inds = np.random.choice(np.arange(size), size, replace=True)
24+
oob_inds = set(np.arange(size)) - set(inds)
25+
oob_inds = np.array(list(oob_inds))
26+
27+
return inds, oob_inds
28+
429
def depth_tree(dt,node=0):
530

631
if dt.tree_.feature[node] == -2:
@@ -464,8 +489,7 @@ def coherent_new_split(phi,th,rule):
464489
return 0,1
465490
else:
466491
return 1,0
467-
468-
492+
469493
def all_coherent_splits(rule,all_splits):
470494

471495
inds = np.zeros(all_splits.shape[0],dtype=bool)
@@ -534,4 +558,3 @@ def bounds_rule(rule,n_features):
534558

535559
return bound_infs,bound_sups
536560

537-

0 commit comments

Comments
 (0)