Skip to content

Commit 6e8bc2c

Browse files
committed
test file corrections
1 parent 3709b97 commit 6e8bc2c

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

adapt/parameter_based/_transfer_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def __init__(self,
10621062
if not isinstance(estimator, RandomForestClassifier):
10631063
raise ValueError("`estimator` argument must be a ``RandomForestClassifier`` instance, got %s."%str(type(estimator)))
10641064

1065-
if not hasattr(estimator, ".estimators_"):
1065+
if not hasattr(estimator, "estimators_"):
10661066
raise ValueError("`estimator` argument has no ``estimators_`` attribute, "
10671067
"please call `fit` on `estimator`.")
10681068

tests/test_treeutils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,21 @@
5050
clf_source_dt.fit(Xs, ys)
5151
clf_source_rf.fit(Xs, ys)
5252

53+
Nkmin = sum(yt == 0 )
54+
root_source_values = clf_source_dt.tree_.value[0].reshape(-1)
55+
props_s = root_source_values
56+
props_s = props_s / sum(props_s)
57+
props_t = np.zeros(props_s.size)
58+
for k in range(props_s.size):
59+
props_t[k] = np.sum(yt == k) / yt.size
60+
61+
coeffs = np.divide(props_t, props_s)
62+
5363
def test_depth():
5464
ut.depth_tree(clf_source_dt)
5565
ut.depth_rf(clf_source_rf)
5666
ut.depth(clf_source_dt,node_test)
57-
ut.depth_array(clf_source_dt,np.arange(clf_source_dt.tree_.node_counte_count))
67+
ut.depth_array(clf_source_dt,np.arange(clf_source_dt.tree_.node_count))
5868

5969
def test_rules():
6070

@@ -96,8 +106,8 @@ def test_splits():
96106
ut.new_random_split(np.ones(s)/s,coh_splits)
97107

98108
def test_error():
99-
e = ut.error(clf_source_dt,node_test)
100-
le = ut.leaf_error(clf_source_dt,node_test)
109+
e = ut.error(clf_source_dt.tree_,node_test)
110+
le = ut.leaf_error(clf_source_dt.tree_,node_test)
101111
return e,le
102112

103113
def test_distribution():
@@ -110,8 +120,8 @@ def test_distribution():
110120
ut.compute_Q_children_target(Xs,ys,phi,threshold,classes_test)
111121

112122
def test_pruning_risk():
113-
ut.compute_LLR_estimates_homog(clf_source_dt)
114-
ut.contain_leaf_to_not_prune(clf_source_dt)
123+
ut.compute_LLR_estimates_homog(clf_source_dt,Nkmin=Nkmin)
124+
ut.contain_leaf_to_not_prune(clf_source_dt,Nkmin=Nkmin,coeffs=coeffs)
115125

116126
def test_divergence_computation():
117127
phi = clf_source_dt.tree_.feature[0]

0 commit comments

Comments
 (0)