@@ -19,29 +19,17 @@ class TransferTreeClassifier(BaseAdaptEstimator):
19
19
----------
20
20
estimator : sklearn DecsionTreeClassifier (default=None)
21
21
Source decision tree classifier.
22
-
23
- Xt : numpy array (default=None)
24
- Target input data.
25
-
26
- yt : numpy array (default=None)
27
- Target output data.
28
-
22
+
29
23
algo : str or callable (default="")
30
24
Leaves relabeling if "" or "relab".
31
25
"ser" and "strut" for SER and STRUT algorithms
32
-
33
- (pas la peine de commenter Xt, yt, copy, verbose et random_state)
26
+
34
27
35
28
Attributes
36
29
----------
37
- estimator_ : Same class as estimator
38
- Fitted Estimator.
39
- estimator : sklearn DecsionTreeClassifier
30
+ estimator_ : sklearn DecsionTreeClassifier
40
31
Transferred decision tree classifier using target data.
41
32
42
- source_model:
43
- Source decision tree classifier.
44
-
45
33
parents : numpy array of int.
46
34
47
35
bool_parents_lr : numpy array of {-1,0,1} values.
@@ -108,39 +96,16 @@ def __init__(self,
108
96
copy = self .copy ,
109
97
force_copy = True )
110
98
111
-
112
- # if not hasattr(estimator, "tree_"):
113
- # raise NotFittedError("`estimator` argument has no ``tree_`` attribute, "
114
- # "please call `fit` on `estimator` or use "
115
- # "another estimator.")
99
+
116
100
117
101
self .parents = np .zeros (estimator .tree_ .node_count ,dtype = int )
118
102
self .bool_parents_lr = np .zeros (estimator .tree_ .node_count ,dtype = int )
119
103
self .rules = np .zeros (estimator .tree_ .node_count ,dtype = object )
120
104
self .paths = np .zeros (estimator .tree_ .node_count ,dtype = object )
121
105
self .depths = np .zeros (estimator .tree_ .node_count ,dtype = int )
122
-
123
- self .estimator = estimator
124
- self .source_model = copy .deepcopy (self .estimator )
125
-
126
- self .Xt = Xt
127
- self .yt = yt
128
- self .algo = algo
129
- self .copy = copy
130
- self .verbose = verbose
131
- self .random_state = random_state
132
- self .params = params
133
106
134
107
#Init. meta params
135
108
self ._compute_params ()
136
-
137
- #Target model
138
- if Xt is not None and yt is not None :
139
- self ._relab (Xt ,yt )
140
- self .target_model = self .estimator
141
- self .estimator = copy .deepcopy (self .source_model )
142
- else :
143
- self .target_model = None
144
109
145
110
def fit (self , Xt = None , yt = None , ** fit_params ):
146
111
"""
@@ -164,26 +129,6 @@ def fit(self, Xt=None, yt=None, **fit_params):
164
129
Xt , yt = self ._get_target_data (Xt , yt )
165
130
Xt , yt = check_arrays (Xt , yt )
166
131
set_random_seed (self .random_state )
167
-
168
- #if self.estimator is None:
169
- #Pas d'arbre source
170
-
171
- #if self.estimator.node_count == 0:
172
- #Arbre vide
173
-
174
- #set_random_seed(self.random_state)
175
- #Xt, yt = check_arrays(Xt, yt)
176
-
177
- #self.estimator_ = check_estimator(self.estimator,copy=self.copy,force_copy=True)
178
-
179
- #Tree_ = self.estimator.tree_
180
-
181
- #Target model :
182
- if self .target_model is None :
183
- if Xt is not None and yt is not None :
184
- self ._relab (Xt ,yt )
185
- self .target_model = self .estimator
186
- self .estimator = copy .deepcopy (self .source_model )
187
132
188
133
self ._modify_tree (self .estimator , Xt , yt )
189
134
0 commit comments