Skip to content

Commit 1f2a011

Browse files
author
Mounir
committed
Solving bugs + paths + example
1 parent d3f3354 commit 1f2a011

File tree

3 files changed

+632
-164
lines changed

3 files changed

+632
-164
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Mar 3 11:23:55 2022
5+
6+
@author: mounir
7+
"""
8+
9+
import sys
10+
import copy
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from sklearn.tree import DecisionTreeClassifier
14+
sys.path.insert(0, '../')
15+
import transfer_tree as TL
16+
17+
methods = [
18+
'ser',
19+
'strut',
20+
'ser_nr',
21+
# 'strut_hi'
22+
]
23+
labels = [
24+
'SER',
25+
'STRUT',
26+
'SER$^{*}$',
27+
# 'STRUT$^{*}$',
28+
'STRUT$^{*}$',
29+
]
30+
31+
np.random.seed(0)
32+
33+
plot_step = 0.01
34+
# Generate training source data
35+
ns = 200
36+
ns_perclass = ns // 2
37+
mean_1 = (1, 1)
38+
var_1 = np.diag([1, 1])
39+
mean_2 = (3, 3)
40+
var_2 = np.diag([2, 2])
41+
Xs = np.r_[np.random.multivariate_normal(mean_1, var_1, size=ns_perclass),
42+
np.random.multivariate_normal(mean_2, var_2, size=ns_perclass)]
43+
ys = np.zeros(ns)
44+
ys[ns_perclass:] = 1
45+
# Generate training target data
46+
nt = 50
47+
# imbalanced
48+
nt_0 = nt // 10
49+
mean_1 = (6, 3)
50+
var_1 = np.diag([4, 1])
51+
mean_2 = (5, 5)
52+
var_2 = np.diag([1, 3])
53+
Xt = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_0),
54+
np.random.multivariate_normal(mean_2, var_2, size=nt - nt_0)]
55+
yt = np.zeros(nt)
56+
yt[nt_0:] = 1
57+
# Generate testing target data
58+
nt_test = 1000
59+
nt_test_perclass = nt_test // 2
60+
Xt_test = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_test_perclass),
61+
np.random.multivariate_normal(mean_2, var_2, size=nt_test_perclass)]
62+
yt_test = np.zeros(nt_test)
63+
yt_test[nt_test_perclass:] = 1
64+
65+
# Source classifier
66+
clf_source = DecisionTreeClassifier(max_depth=None)
67+
clf_source.fit(Xs, ys)
68+
score_src_src = clf_source.score(Xs, ys)
69+
score_src_trgt = clf_source.score(Xt_test, yt_test)
70+
print('Training score Source model: {:.3f}'.format(score_src_src))
71+
print('Testing score Source model: {:.3f}'.format(score_src_trgt))
72+
clfs = []
73+
scores = []
74+
# Transfer with SER
75+
#clf_transfer = copy.deepcopy(clf_source)
76+
#transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt)
77+
78+
for method in methods:
79+
clf_transfer = copy.deepcopy(clf_source)
80+
if method == 'ser':
81+
transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt,algo="ser")
82+
transferred_dt._ser(Xt, yt, node=0, original_ser=True)
83+
#ser.SER(0, clf_transfer, Xt, yt, original_ser=True)
84+
if method == 'ser_nr':
85+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
86+
#ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
87+
if method == 'strut':
88+
transferred_dt._strut(Xt, yt,node=0)
89+
#if method == 'strut_hi':
90+
#transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1])
91+
#strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1])
92+
score = transferred_dt.estimator.score(Xt_test, yt_test)
93+
#score = clf_transfer.score(Xt_test, yt_test)
94+
print('Testing score transferred model ({}) : {:.3f}'.format(method, score))
95+
clfs.append(transferred_dt.estimator)
96+
#clfs.append(clf_transfer)
97+
scores.append(score)
98+
99+
## Plot decision functions
100+
#
101+
## Data on which to plot source
102+
#x_min, x_max = Xs[:, 0].min() - 1, Xs[:, 0].max() + 1
103+
#y_min, y_max = Xs[:, 1].min() - 1, Xs[:, 1].max() + 1
104+
#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
105+
# np.arange(y_min, y_max, plot_step))
106+
## Plot source model
107+
#Z = clf_source.predict(np.c_[xx.ravel(), yy.ravel()])
108+
#Z = Z.reshape(xx.shape)
109+
#fig, ax = plt.subplots(nrows=1, ncols=len(methods) + 1, figsize=(13, 3))
110+
#ax[0].contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
111+
#ax[0].scatter(Xs[0, 0], Xs[0, 1],
112+
# marker='o',
113+
# edgecolor='black',
114+
# color='white',
115+
# label='source data',
116+
# )
117+
#ax[0].scatter(Xs[:ns_perclass, 0], Xs[:ns_perclass, 1],
118+
# marker='o',
119+
# edgecolor='black',
120+
# color='blue',
121+
# )
122+
#ax[0].scatter(Xs[ns_perclass:, 0], Xs[ns_perclass:, 1],
123+
# marker='o',
124+
# edgecolor='black',
125+
# color='red',
126+
# )
127+
#ax[0].set_title('Model: Source\nAcc on source data: {:.2f}\nAcc on target data: {:.2f}'.format(score_src_src, score_src_trgt),
128+
# fontsize=11)
129+
#ax[0].legend()
130+
#
131+
## Data on which to plot target
132+
#x_min, x_max = Xt[:, 0].min() - 1, Xt[:, 0].max() + 1
133+
#y_min, y_max = Xt[:, 1].min() - 1, Xt[:, 1].max() + 1
134+
#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
135+
# np.arange(y_min, y_max, plot_step))
136+
## Plot transfer models
137+
#for i, (method, label, score) in enumerate(zip(methods, labels, scores)):
138+
# clf_transfer = clfs[i]
139+
# Z_transfer = clf_transfer.predict(np.c_[xx.ravel(), yy.ravel()])
140+
# Z_transfer = Z_transfer.reshape(xx.shape)
141+
# ax[i + 1].contourf(xx, yy, Z_transfer, cmap=plt.cm.coolwarm, alpha=0.8)
142+
# ax[i + 1].scatter(Xt[0, 0], Xt[0, 1],
143+
# marker='o',
144+
# edgecolor='black',
145+
# color='white',
146+
# label='target data',
147+
# )
148+
# ax[i + 1].scatter(Xt[:nt_0, 0], Xt[:nt_0, 1],
149+
# marker='o',
150+
# edgecolor='black',
151+
# color='blue',
152+
# )
153+
# ax[i + 1].scatter(Xt[nt_0:, 0], Xt[nt_0:, 1],
154+
# marker='o',
155+
# edgecolor='black',
156+
# color='red',
157+
# )
158+
# ax[i + 1].set_title('Model: {}\nAcc on target data: {:.2f}'.format(label, score),
159+
# fontsize=11)
160+
# ax[i + 1].legend()
161+
#
162+
## fig.savefig('../images/ser_strut.png')
163+
#plt.show()
164+
#

0 commit comments

Comments
 (0)