@@ -108,15 +108,13 @@ def test_transfer_tree():
108
108
transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
109
109
transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ])
110
110
111
- # WARNING! Error Raised with this test
112
111
if method == 'ser_no_ext' :
113
- pass
114
112
#decision tree
115
- # transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
116
- # transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_ext_on_cl=True,cl_no_red =[0],ext_cond=True)
113
+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
114
+ transferred_dt ._ser (Xt , yt ,node = 0 ,original_ser = False ,no_ext_on_cl = True ,cl_no_ext = [0 ],ext_cond = True )
117
115
#random forest
118
- # transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
119
- # transferred_rf._ser_rf(Xt, yt,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
116
+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
117
+ transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_ext_on_cl = True ,cl_no_ext = [0 ],ext_cond = True )
120
118
if method == 'ser_nr_lambda' :
121
119
#decision tree
122
120
transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
@@ -159,7 +157,7 @@ def test_transfer_tree():
159
157
root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
160
158
#random forest
161
159
transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
162
- transferred_rf ._strut (Xt , yt ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
160
+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
163
161
leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
164
162
root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
165
163
if method == 'strut_lambda_np' :
@@ -170,7 +168,7 @@ def test_transfer_tree():
170
168
root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
171
169
#random forest
172
170
transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
173
- transferred_rf ._strut (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
171
+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
174
172
leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
175
173
root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
176
174
if method == 'strut_lambda_np2' :
@@ -179,12 +177,11 @@ def test_transfer_tree():
179
177
transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
180
178
leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
181
179
root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
182
- # Warning! Error Raised because `strut` not in TransferForest
183
180
#random forest
184
- # transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
185
- # transferred_rf._strut (Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
186
- # leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=True,
187
- # root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
181
+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
182
+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
183
+ leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = True ,
184
+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
188
185
189
186
score = transferred_dt .estimator .score (Xt_test , yt_test )
190
187
#score = clf_transfer.score(Xt_test, yt_test)
0 commit comments