Skip to content

Commit ea54493

Browse files
Merge pull request #25 from atiqm/ser_strut
feat: first commit ser strut
2 parents f60b909 + ab9eefb commit ea54493

File tree

3 files changed

+1820
-0
lines changed

3 files changed

+1820
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Mar 3 11:23:55 2022
5+
6+
@author: mounir
7+
"""
8+
9+
import sys
10+
import copy
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from sklearn.tree import DecisionTreeClassifier
14+
sys.path.insert(0, '../')
15+
import transfer_tree as TL
16+
17+
methods = [
18+
'relab',
19+
'ser',
20+
'strut',
21+
'ser_nr',
22+
'ser_nr_lambda',
23+
'strut_nd',
24+
'strut_lambda',
25+
'strut_lambda_np'
26+
# 'strut_hi'
27+
]
28+
labels = [
29+
'relab',
30+
'$SER$',
31+
'$STRUT$',
32+
'$SER_{NP}$',
33+
'$SER_{NP}(\lambda)$',
34+
'$STRUT_{ND}$',
35+
'$STRUT(\lambda)$',
36+
'$STRUT_{NP}(\lambda)$'
37+
# 'STRUT$^{*}$',
38+
#'STRUT$^{*}$',
39+
]
40+
41+
np.random.seed(0)
42+
43+
plot_step = 0.01
44+
# Generate training source data
45+
ns = 200
46+
ns_perclass = ns // 2
47+
mean_1 = (1, 1)
48+
var_1 = np.diag([1, 1])
49+
mean_2 = (3, 3)
50+
var_2 = np.diag([2, 2])
51+
Xs = np.r_[np.random.multivariate_normal(mean_1, var_1, size=ns_perclass),
52+
np.random.multivariate_normal(mean_2, var_2, size=ns_perclass)]
53+
ys = np.zeros(ns)
54+
ys[ns_perclass:] = 1
55+
# Generate training target data
56+
nt = 50
57+
# imbalanced
58+
nt_0 = nt // 10
59+
mean_1 = (6, 3)
60+
var_1 = np.diag([4, 1])
61+
mean_2 = (5, 5)
62+
var_2 = np.diag([1, 3])
63+
Xt = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_0),
64+
np.random.multivariate_normal(mean_2, var_2, size=nt - nt_0)]
65+
yt = np.zeros(nt)
66+
yt[nt_0:] = 1
67+
# Generate testing target data
68+
nt_test = 1000
69+
nt_test_perclass = nt_test // 2
70+
Xt_test = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_test_perclass),
71+
np.random.multivariate_normal(mean_2, var_2, size=nt_test_perclass)]
72+
yt_test = np.zeros(nt_test)
73+
yt_test[nt_test_perclass:] = 1
74+
75+
# Source classifier
76+
clf_source = DecisionTreeClassifier(max_depth=None)
77+
clf_source.fit(Xs, ys)
78+
score_src_src = clf_source.score(Xs, ys)
79+
score_src_trgt = clf_source.score(Xt_test, yt_test)
80+
print('Training score Source model: {:.3f}'.format(score_src_src))
81+
print('Testing score Source model: {:.3f}'.format(score_src_trgt))
82+
clfs = []
83+
scores = []
84+
# Transfer with SER
85+
#clf_transfer = copy.deepcopy(clf_source)
86+
#transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt)
87+
88+
for method in methods:
89+
Nkmin = sum(yt == 0 )
90+
root_source_values = clf_source.tree_.value[0].reshape(-1)
91+
props_s = root_source_values
92+
props_s = props_s / sum(props_s)
93+
props_t = np.zeros(props_s.size)
94+
for k in range(props_s.size):
95+
props_t[k] = np.sum(yt == k) / yt.size
96+
97+
coeffs = np.divide(props_t, props_s)
98+
99+
clf_transfer = copy.deepcopy(clf_source)
100+
if method == 'relab':
101+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="")
102+
transferred_dt.fit(Xt,yt)
103+
if method == 'ser':
104+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
105+
transferred_dt.fit(Xt,yt)
106+
#transferred_dt._ser(Xt, yt, node=0, original_ser=True)
107+
#ser.SER(0, clf_transfer, Xt, yt, original_ser=True)
108+
if method == 'ser_nr':
109+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
110+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0])
111+
if method == 'ser_nr_lambda':
112+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="ser")
113+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],
114+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,
115+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
116+
#ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
117+
if method == 'strut':
118+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
119+
transferred_dt.fit(Xt,yt)
120+
#transferred_dt._strut(Xt, yt,node=0)
121+
if method == 'strut_nd':
122+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
123+
transferred_dt._strut(Xt, yt,node=0,use_divergence=False)
124+
if method == 'strut_lambda':
125+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
126+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=True,root_source_values=root_source_values,
127+
Nkmin=Nkmin,coeffs=coeffs)
128+
if method == 'strut_lambda_np':
129+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,algo="strut")
130+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
131+
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
132+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
133+
#if method == 'strut_hi':
134+
#transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1])
135+
#strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1])
136+
score = transferred_dt.estimator.score(Xt_test, yt_test)
137+
#score = clf_transfer.score(Xt_test, yt_test)
138+
print('Testing score transferred model ({}) : {:.3f}'.format(method, score))
139+
clfs.append(transferred_dt.estimator)
140+
#clfs.append(clf_transfer)
141+
scores.append(score)
142+
143+
# Plot decision functions
144+
145+
# Data on which to plot source
146+
x_min, x_max = Xs[:, 0].min() - 1, Xs[:, 0].max() + 1
147+
y_min, y_max = Xs[:, 1].min() - 1, Xs[:, 1].max() + 1
148+
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
149+
np.arange(y_min, y_max, plot_step))
150+
# Plot source model
151+
Z = clf_source.predict(np.c_[xx.ravel(), yy.ravel()])
152+
Z = Z.reshape(xx.shape)
153+
fig, ax = plt.subplots(nrows=1, ncols=len(methods) + 1, figsize=(30, 3))
154+
ax[0].contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
155+
ax[0].scatter(Xs[0, 0], Xs[0, 1],
156+
marker='o',
157+
edgecolor='black',
158+
color='white',
159+
label='source data',
160+
)
161+
ax[0].scatter(Xs[:ns_perclass, 0], Xs[:ns_perclass, 1],
162+
marker='o',
163+
edgecolor='black',
164+
color='blue',
165+
)
166+
ax[0].scatter(Xs[ns_perclass:, 0], Xs[ns_perclass:, 1],
167+
marker='o',
168+
edgecolor='black',
169+
color='red',
170+
)
171+
ax[0].set_title('Model: Source\nAcc on source data: {:.2f}\nAcc on target data: {:.2f}'.format(score_src_src, score_src_trgt),
172+
fontsize=11)
173+
ax[0].legend()
174+
175+
# Data on which to plot target
176+
x_min, x_max = Xt[:, 0].min() - 1, Xt[:, 0].max() + 1
177+
y_min, y_max = Xt[:, 1].min() - 1, Xt[:, 1].max() + 1
178+
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
179+
np.arange(y_min, y_max, plot_step))
180+
# Plot transfer models
181+
for i, (method, label, score) in enumerate(zip(methods, labels, scores)):
182+
clf_transfer = clfs[i]
183+
Z_transfer = clf_transfer.predict(np.c_[xx.ravel(), yy.ravel()])
184+
Z_transfer = Z_transfer.reshape(xx.shape)
185+
ax[i + 1].contourf(xx, yy, Z_transfer, cmap=plt.cm.coolwarm, alpha=0.8)
186+
ax[i + 1].scatter(Xt[0, 0], Xt[0, 1],
187+
marker='o',
188+
edgecolor='black',
189+
color='white',
190+
label='target data',
191+
)
192+
ax[i + 1].scatter(Xt[:nt_0, 0], Xt[:nt_0, 1],
193+
marker='o',
194+
edgecolor='black',
195+
color='blue',
196+
)
197+
ax[i + 1].scatter(Xt[nt_0:, 0], Xt[nt_0:, 1],
198+
marker='o',
199+
edgecolor='black',
200+
color='red',
201+
)
202+
ax[i + 1].set_title('Model: {}\nAcc on target data: {:.2f}'.format(label, score),
203+
fontsize=11)
204+
ax[i + 1].legend()
205+
206+
# fig.savefig('../images/ser_strut.png')
207+
plt.show()
208+

0 commit comments

Comments
 (0)