Skip to content

Commit ab9eefb

Browse files
author
Mounir
committed
modif subfunctions + relabeling method + example udpate
1 parent 1f2a011 commit ab9eefb

File tree

3 files changed

+166
-175
lines changed

3 files changed

+166
-175
lines changed

adapt/parameter_based/decision_trees/example.py

Lines changed: 118 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,27 @@
1515
import transfer_tree as TL
1616

1717
methods = [
18+
'relab',
1819
'ser',
1920
'strut',
2021
'ser_nr',
22+
'ser_nr_lambda',
23+
'strut_nd',
24+
'strut_lambda',
25+
'strut_lambda_np'
2126
# 'strut_hi'
2227
]
2328
labels = [
24-
'SER',
25-
'STRUT',
26-
'SER$^{*}$',
29+
'relab',
30+
'$SER$',
31+
'$STRUT$',
32+
'$SER_{NP}$',
33+
'$SER_{NP}(\lambda)$',
34+
'$STRUT_{ND}$',
35+
'$STRUT(\lambda)$',
36+
'$STRUT_{NP}(\lambda)$'
2737
# 'STRUT$^{*}$',
28-
'STRUT$^{*}$',
38+
#'STRUT$^{*}$',
2939
]
3040

3141
np.random.seed(0)
@@ -76,16 +86,50 @@
7686
#transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt)
7787

7888
for method in methods:
89+
Nkmin = sum(yt == 0 )
90+
root_source_values = clf_source.tree_.value[0].reshape(-1)
91+
props_s = root_source_values
92+
props_s = props_s / sum(props_s)
93+
props_t = np.zeros(props_s.size)
94+
for k in range(props_s.size):
95+
props_t[k] = np.sum(yt == k) / yt.size
96+
97+
coeffs = np.divide(props_t, props_s)
98+
7999
clf_transfer = copy.deepcopy(clf_source)
100+
if method == 'relab':
101+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="")
102+
transferred_dt.fit(Xt,yt)
80103
if method == 'ser':
81-
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt,algo="ser")
82-
transferred_dt._ser(Xt, yt, node=0, original_ser=True)
104+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
105+
transferred_dt.fit(Xt,yt)
106+
#transferred_dt._ser(Xt, yt, node=0, original_ser=True)
83107
#ser.SER(0, clf_transfer, Xt, yt, original_ser=True)
84108
if method == 'ser_nr':
85-
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
109+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
110+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0])
111+
if method == 'ser_nr_lambda':
112+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
113+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],
114+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,
115+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
86116
#ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
87117
if method == 'strut':
88-
transferred_dt._strut(Xt, yt,node=0)
118+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
119+
transferred_dt.fit(Xt,yt)
120+
#transferred_dt._strut(Xt, yt,node=0)
121+
if method == 'strut_nd':
122+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
123+
transferred_dt._strut(Xt, yt,node=0,use_divergence=False)
124+
if method == 'strut_lambda':
125+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
126+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=True,root_source_values=root_source_values,
127+
Nkmin=Nkmin,coeffs=coeffs)
128+
if method == 'strut_lambda_np':
129+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
130+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
131+
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
132+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
89133
#if method == 'strut_hi':
90134
#transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1])
91135
#strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1])
@@ -96,69 +140,69 @@
96140
#clfs.append(clf_transfer)
97141
scores.append(score)
98142

99-
## Plot decision functions
100-
#
101-
## Data on which to plot source
102-
#x_min, x_max = Xs[:, 0].min() - 1, Xs[:, 0].max() + 1
103-
#y_min, y_max = Xs[:, 1].min() - 1, Xs[:, 1].max() + 1
104-
#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
105-
# np.arange(y_min, y_max, plot_step))
106-
## Plot source model
107-
#Z = clf_source.predict(np.c_[xx.ravel(), yy.ravel()])
108-
#Z = Z.reshape(xx.shape)
109-
#fig, ax = plt.subplots(nrows=1, ncols=len(methods) + 1, figsize=(13, 3))
110-
#ax[0].contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
111-
#ax[0].scatter(Xs[0, 0], Xs[0, 1],
112-
# marker='o',
113-
# edgecolor='black',
114-
# color='white',
115-
# label='source data',
116-
# )
117-
#ax[0].scatter(Xs[:ns_perclass, 0], Xs[:ns_perclass, 1],
118-
# marker='o',
119-
# edgecolor='black',
120-
# color='blue',
121-
# )
122-
#ax[0].scatter(Xs[ns_perclass:, 0], Xs[ns_perclass:, 1],
123-
# marker='o',
124-
# edgecolor='black',
125-
# color='red',
126-
# )
127-
#ax[0].set_title('Model: Source\nAcc on source data: {:.2f}\nAcc on target data: {:.2f}'.format(score_src_src, score_src_trgt),
128-
# fontsize=11)
129-
#ax[0].legend()
130-
#
131-
## Data on which to plot target
132-
#x_min, x_max = Xt[:, 0].min() - 1, Xt[:, 0].max() + 1
133-
#y_min, y_max = Xt[:, 1].min() - 1, Xt[:, 1].max() + 1
134-
#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
135-
# np.arange(y_min, y_max, plot_step))
136-
## Plot transfer models
137-
#for i, (method, label, score) in enumerate(zip(methods, labels, scores)):
138-
# clf_transfer = clfs[i]
139-
# Z_transfer = clf_transfer.predict(np.c_[xx.ravel(), yy.ravel()])
140-
# Z_transfer = Z_transfer.reshape(xx.shape)
141-
# ax[i + 1].contourf(xx, yy, Z_transfer, cmap=plt.cm.coolwarm, alpha=0.8)
142-
# ax[i + 1].scatter(Xt[0, 0], Xt[0, 1],
143-
# marker='o',
144-
# edgecolor='black',
145-
# color='white',
146-
# label='target data',
147-
# )
148-
# ax[i + 1].scatter(Xt[:nt_0, 0], Xt[:nt_0, 1],
149-
# marker='o',
150-
# edgecolor='black',
151-
# color='blue',
152-
# )
153-
# ax[i + 1].scatter(Xt[nt_0:, 0], Xt[nt_0:, 1],
154-
# marker='o',
155-
# edgecolor='black',
156-
# color='red',
157-
# )
158-
# ax[i + 1].set_title('Model: {}\nAcc on target data: {:.2f}'.format(label, score),
159-
# fontsize=11)
160-
# ax[i + 1].legend()
161-
#
162-
## fig.savefig('../images/ser_strut.png')
163-
#plt.show()
164-
#
143+
# Plot decision functions
144+
145+
# Data on which to plot source
146+
x_min, x_max = Xs[:, 0].min() - 1, Xs[:, 0].max() + 1
147+
y_min, y_max = Xs[:, 1].min() - 1, Xs[:, 1].max() + 1
148+
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
149+
np.arange(y_min, y_max, plot_step))
150+
# Plot source model
151+
Z = clf_source.predict(np.c_[xx.ravel(), yy.ravel()])
152+
Z = Z.reshape(xx.shape)
153+
fig, ax = plt.subplots(nrows=1, ncols=len(methods) + 1, figsize=(30, 3))
154+
ax[0].contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
155+
ax[0].scatter(Xs[0, 0], Xs[0, 1],
156+
marker='o',
157+
edgecolor='black',
158+
color='white',
159+
label='source data',
160+
)
161+
ax[0].scatter(Xs[:ns_perclass, 0], Xs[:ns_perclass, 1],
162+
marker='o',
163+
edgecolor='black',
164+
color='blue',
165+
)
166+
ax[0].scatter(Xs[ns_perclass:, 0], Xs[ns_perclass:, 1],
167+
marker='o',
168+
edgecolor='black',
169+
color='red',
170+
)
171+
ax[0].set_title('Model: Source\nAcc on source data: {:.2f}\nAcc on target data: {:.2f}'.format(score_src_src, score_src_trgt),
172+
fontsize=11)
173+
ax[0].legend()
174+
175+
# Data on which to plot target
176+
x_min, x_max = Xt[:, 0].min() - 1, Xt[:, 0].max() + 1
177+
y_min, y_max = Xt[:, 1].min() - 1, Xt[:, 1].max() + 1
178+
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
179+
np.arange(y_min, y_max, plot_step))
180+
# Plot transfer models
181+
for i, (method, label, score) in enumerate(zip(methods, labels, scores)):
182+
clf_transfer = clfs[i]
183+
Z_transfer = clf_transfer.predict(np.c_[xx.ravel(), yy.ravel()])
184+
Z_transfer = Z_transfer.reshape(xx.shape)
185+
ax[i + 1].contourf(xx, yy, Z_transfer, cmap=plt.cm.coolwarm, alpha=0.8)
186+
ax[i + 1].scatter(Xt[0, 0], Xt[0, 1],
187+
marker='o',
188+
edgecolor='black',
189+
color='white',
190+
label='target data',
191+
)
192+
ax[i + 1].scatter(Xt[:nt_0, 0], Xt[:nt_0, 1],
193+
marker='o',
194+
edgecolor='black',
195+
color='blue',
196+
)
197+
ax[i + 1].scatter(Xt[nt_0:, 0], Xt[nt_0:, 1],
198+
marker='o',
199+
edgecolor='black',
200+
color='red',
201+
)
202+
ax[i + 1].set_title('Model: {}\nAcc on target data: {:.2f}'.format(label, score),
203+
fontsize=11)
204+
ax[i + 1].legend()
205+
206+
# fig.savefig('../images/ser_strut.png')
207+
plt.show()
208+

0 commit comments

Comments
 (0)