Skip to content

Commit 7908382

Browse files
Merge pull request #39 from atiqm/master
fix ser/strut options and parameters
2 parents 5be0904 + af9833c commit 7908382

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

tests/test_transfertree.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,13 @@ def test_transfer_tree():
108108
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
109109
transferred_rf._ser_rf(Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0])
110110

111-
# WARNING! Error Raised with this test
112111
if method == 'ser_no_ext':
113-
pass
114112
#decision tree
115-
#transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
116-
#transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_ext_on_cl=True,cl_no_red=[0],ext_cond=True)
113+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
114+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
117115
#random forest
118-
#transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
119-
#transferred_rf._ser_rf(Xt, yt,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
116+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
117+
transferred_rf._ser_rf(Xt, yt,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
120118
if method == 'ser_nr_lambda':
121119
#decision tree
122120
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
@@ -159,7 +157,7 @@ def test_transfer_tree():
159157
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
160158
#random forest
161159
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
162-
transferred_rf._strut(Xt, yt,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
160+
transferred_rf._strut_rf(Xt, yt,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
163161
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
164162
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
165163
if method == 'strut_lambda_np':
@@ -170,7 +168,7 @@ def test_transfer_tree():
170168
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
171169
#random forest
172170
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
173-
transferred_rf._strut(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
171+
transferred_rf._strut_rf(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
174172
leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=False,
175173
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
176174
if method == 'strut_lambda_np2':
@@ -179,12 +177,11 @@ def test_transfer_tree():
179177
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
180178
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
181179
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
182-
# Warning! Error Raised because `strut` not in TransferForest
183180
#random forest
184-
#transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
185-
#transferred_rf._strut(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
186-
# leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=True,
187-
# root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
181+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
182+
transferred_rf._strut_rf(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
183+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=True,
184+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
188185

189186
score = transferred_dt.estimator.score(Xt_test, yt_test)
190187
#score = clf_transfer.score(Xt_test, yt_test)

0 commit comments

Comments
 (0)