|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +""" |
| 4 | +Created on Thu Mar 3 11:23:55 2022 |
| 5 | +
|
| 6 | +@author: mounir |
| 7 | +""" |
| 8 | + |
| 9 | +import sys |
| 10 | +import copy |
| 11 | +import numpy as np |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from sklearn.tree import DecisionTreeClassifier |
| 14 | +sys.path.insert(0, '../') |
| 15 | +import transfer_tree as TL |
| 16 | + |
| 17 | +methods = [ |
| 18 | + 'ser', |
| 19 | + 'strut', |
| 20 | + 'ser_nr', |
| 21 | +# 'strut_hi' |
| 22 | +] |
| 23 | +labels = [ |
| 24 | + 'SER', |
| 25 | + 'STRUT', |
| 26 | + 'SER$^{*}$', |
| 27 | + # 'STRUT$^{*}$', |
| 28 | + 'STRUT$^{*}$', |
| 29 | +] |
| 30 | + |
| 31 | +np.random.seed(0) |
| 32 | + |
| 33 | +plot_step = 0.01 |
| 34 | +# Generate training source data |
| 35 | +ns = 200 |
| 36 | +ns_perclass = ns // 2 |
| 37 | +mean_1 = (1, 1) |
| 38 | +var_1 = np.diag([1, 1]) |
| 39 | +mean_2 = (3, 3) |
| 40 | +var_2 = np.diag([2, 2]) |
| 41 | +Xs = np.r_[np.random.multivariate_normal(mean_1, var_1, size=ns_perclass), |
| 42 | + np.random.multivariate_normal(mean_2, var_2, size=ns_perclass)] |
| 43 | +ys = np.zeros(ns) |
| 44 | +ys[ns_perclass:] = 1 |
| 45 | +# Generate training target data |
| 46 | +nt = 50 |
| 47 | +# imbalanced |
| 48 | +nt_0 = nt // 10 |
| 49 | +mean_1 = (6, 3) |
| 50 | +var_1 = np.diag([4, 1]) |
| 51 | +mean_2 = (5, 5) |
| 52 | +var_2 = np.diag([1, 3]) |
| 53 | +Xt = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_0), |
| 54 | + np.random.multivariate_normal(mean_2, var_2, size=nt - nt_0)] |
| 55 | +yt = np.zeros(nt) |
| 56 | +yt[nt_0:] = 1 |
| 57 | +# Generate testing target data |
| 58 | +nt_test = 1000 |
| 59 | +nt_test_perclass = nt_test // 2 |
| 60 | +Xt_test = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_test_perclass), |
| 61 | + np.random.multivariate_normal(mean_2, var_2, size=nt_test_perclass)] |
| 62 | +yt_test = np.zeros(nt_test) |
| 63 | +yt_test[nt_test_perclass:] = 1 |
| 64 | + |
| 65 | +# Source classifier |
| 66 | +clf_source = DecisionTreeClassifier(max_depth=None) |
| 67 | +clf_source.fit(Xs, ys) |
| 68 | +score_src_src = clf_source.score(Xs, ys) |
| 69 | +score_src_trgt = clf_source.score(Xt_test, yt_test) |
| 70 | +print('Training score Source model: {:.3f}'.format(score_src_src)) |
| 71 | +print('Testing score Source model: {:.3f}'.format(score_src_trgt)) |
| 72 | +clfs = [] |
| 73 | +scores = [] |
| 74 | +# Transfer with SER |
| 75 | +#clf_transfer = copy.deepcopy(clf_source) |
| 76 | +#transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt) |
| 77 | + |
| 78 | +for method in methods: |
| 79 | + clf_transfer = copy.deepcopy(clf_source) |
| 80 | + if method == 'ser': |
| 81 | + transferred_dt = TL.TransferTreeClassifier(estimator=clf_transfer,Xt=Xt,yt=yt,algo="ser") |
| 82 | + transferred_dt._ser(Xt, yt, node=0, original_ser=True) |
| 83 | + #ser.SER(0, clf_transfer, Xt, yt, original_ser=True) |
| 84 | + if method == 'ser_nr': |
| 85 | + transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True) |
| 86 | + #ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True) |
| 87 | + if method == 'strut': |
| 88 | + transferred_dt._strut(Xt, yt,node=0) |
| 89 | + #if method == 'strut_hi': |
| 90 | + #transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1]) |
| 91 | + #strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1]) |
| 92 | + score = transferred_dt.estimator.score(Xt_test, yt_test) |
| 93 | + #score = clf_transfer.score(Xt_test, yt_test) |
| 94 | + print('Testing score transferred model ({}) : {:.3f}'.format(method, score)) |
| 95 | + clfs.append(transferred_dt.estimator) |
| 96 | + #clfs.append(clf_transfer) |
| 97 | + scores.append(score) |
| 98 | + |
| 99 | +## Plot decision functions |
| 100 | +# |
| 101 | +## Data on which to plot source |
| 102 | +#x_min, x_max = Xs[:, 0].min() - 1, Xs[:, 0].max() + 1 |
| 103 | +#y_min, y_max = Xs[:, 1].min() - 1, Xs[:, 1].max() + 1 |
| 104 | +#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), |
| 105 | +# np.arange(y_min, y_max, plot_step)) |
| 106 | +## Plot source model |
| 107 | +#Z = clf_source.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 108 | +#Z = Z.reshape(xx.shape) |
| 109 | +#fig, ax = plt.subplots(nrows=1, ncols=len(methods) + 1, figsize=(13, 3)) |
| 110 | +#ax[0].contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8) |
| 111 | +#ax[0].scatter(Xs[0, 0], Xs[0, 1], |
| 112 | +# marker='o', |
| 113 | +# edgecolor='black', |
| 114 | +# color='white', |
| 115 | +# label='source data', |
| 116 | +# ) |
| 117 | +#ax[0].scatter(Xs[:ns_perclass, 0], Xs[:ns_perclass, 1], |
| 118 | +# marker='o', |
| 119 | +# edgecolor='black', |
| 120 | +# color='blue', |
| 121 | +# ) |
| 122 | +#ax[0].scatter(Xs[ns_perclass:, 0], Xs[ns_perclass:, 1], |
| 123 | +# marker='o', |
| 124 | +# edgecolor='black', |
| 125 | +# color='red', |
| 126 | +# ) |
| 127 | +#ax[0].set_title('Model: Source\nAcc on source data: {:.2f}\nAcc on target data: {:.2f}'.format(score_src_src, score_src_trgt), |
| 128 | +# fontsize=11) |
| 129 | +#ax[0].legend() |
| 130 | +# |
| 131 | +## Data on which to plot target |
| 132 | +#x_min, x_max = Xt[:, 0].min() - 1, Xt[:, 0].max() + 1 |
| 133 | +#y_min, y_max = Xt[:, 1].min() - 1, Xt[:, 1].max() + 1 |
| 134 | +#xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), |
| 135 | +# np.arange(y_min, y_max, plot_step)) |
| 136 | +## Plot transfer models |
| 137 | +#for i, (method, label, score) in enumerate(zip(methods, labels, scores)): |
| 138 | +# clf_transfer = clfs[i] |
| 139 | +# Z_transfer = clf_transfer.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 140 | +# Z_transfer = Z_transfer.reshape(xx.shape) |
| 141 | +# ax[i + 1].contourf(xx, yy, Z_transfer, cmap=plt.cm.coolwarm, alpha=0.8) |
| 142 | +# ax[i + 1].scatter(Xt[0, 0], Xt[0, 1], |
| 143 | +# marker='o', |
| 144 | +# edgecolor='black', |
| 145 | +# color='white', |
| 146 | +# label='target data', |
| 147 | +# ) |
| 148 | +# ax[i + 1].scatter(Xt[:nt_0, 0], Xt[:nt_0, 1], |
| 149 | +# marker='o', |
| 150 | +# edgecolor='black', |
| 151 | +# color='blue', |
| 152 | +# ) |
| 153 | +# ax[i + 1].scatter(Xt[nt_0:, 0], Xt[nt_0:, 1], |
| 154 | +# marker='o', |
| 155 | +# edgecolor='black', |
| 156 | +# color='red', |
| 157 | +# ) |
| 158 | +# ax[i + 1].set_title('Model: {}\nAcc on target data: {:.2f}'.format(label, score), |
| 159 | +# fontsize=11) |
| 160 | +# ax[i + 1].legend() |
| 161 | +# |
| 162 | +## fig.savefig('../images/ser_strut.png') |
| 163 | +#plt.show() |
| 164 | +# |
0 commit comments