39
39
# linearregression,
40
40
exponential ,
41
41
}
42
- # GEN_MODELS = {
43
- # bernoulli.GenModel,
44
- # # categorical.GenModel,
45
- # normal.GenModel,
46
- # # multivariate_normal.GenModel,
47
- # # linearregression.GenModel,
48
- # poisson.GenModel,
49
- # exponential.GenModel,
50
- # }
51
- # DISCRETE_GEN_MODELS = {
52
- # bernoulli.GenModel,
53
- # # categorical.GenModel,
54
- # poisson.GenModel,
55
- # }
56
- # CONTINUOUS_GEN_MODELS = {
57
- # normal.GenModel,
58
- # # multivariate_normal.GenModel,
59
- # # linearregression.GenModel,
60
- # exponential.GenModel,
61
- # }
62
42
LEARN_MODELS = {
63
43
bernoulli .LearnModel ,
64
44
# categorical.LearnModel,
@@ -98,18 +78,14 @@ class _Node:
98
78
"""
99
79
def __init__ (self ,
100
80
depth ,
101
- k_candidates ,
102
- c_num_children = 2 ,
103
- h_g = 0.5 ,
104
- k = None ,
105
- sub_model = None
81
+ c_num_children ,
106
82
):
107
83
self .depth = depth
108
84
self .children = [None for i in range (c_num_children )] # child nodes
109
- self .k_candidates = k_candidates
110
- self .h_g = h_g
111
- self .k = k
112
- self .sub_model = sub_model
85
+ self .k_candidates = None
86
+ self .h_g = 0.5
87
+ self .k = None
88
+ self .sub_model = None
113
89
self .leaf = False
114
90
self .map_leaf = False
115
91
@@ -192,7 +168,11 @@ def __init__(
192
168
)
193
169
194
170
# params
195
- self .root = _Node (0 ,list (range (self .c_k )),self .c_num_children ,self .h_g ,0 ,self .SubModel .GenModel (** self .sub_h_params ))
171
+ self .root = _Node (0 ,self .c_num_children )
172
+ self .root .k_candidates = list (range (self .c_k ))
173
+ self .root .h_g = self .h_g
174
+ self .root .k = 0
175
+ self .root .sub_model = self .SubModel .GenModel (** self .sub_h_params )
196
176
self .root .leaf = True
197
177
198
178
self .set_params (root )
@@ -222,7 +202,9 @@ def _gen_params_recursion(self,node:_Node,feature_fix):
222
202
node .leaf = False
223
203
for i in range (self .c_num_children ):
224
204
if feature_fix == False or node .children [i ] is None :
225
- node .children [i ] = _Node (node .depth + 1 ,child_k_candidates ,self .c_num_children ,self .h_g ,None ,None )
205
+ node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
206
+ node .children [i ].k_candidates = child_k_candidates
207
+ node .children [i ].h_g = self .h_g
226
208
else :
227
209
node .children [i ].k_candidates = child_k_candidates
228
210
self ._gen_params_recursion (node .children [i ],feature_fix )
@@ -252,7 +234,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,feature_fix):
252
234
node .leaf = False
253
235
for i in range (self .c_num_children ):
254
236
if feature_fix == False or node .children [i ] is None :
255
- node .children [i ] = _Node (node .depth + 1 ,child_k_candidates ,self .c_num_children ,self .h_g ,None ,None )
237
+ node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
238
+ node .children [i ].k_candidates = child_k_candidates
239
+ node .children [i ].h_g = self .h_g
256
240
else :
257
241
node .children [i ].k_candidates = child_k_candidates
258
242
self ._gen_params_recursion_tree_fix (node .children [i ],feature_fix )
@@ -277,7 +261,8 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
277
261
child_k_candidates = copy .copy (node .k_candidates )
278
262
child_k_candidates .remove (node .k )
279
263
for i in range (self .c_num_children ):
280
- node .children [i ] = _Node (node .depth + 1 ,child_k_candidates ,self .c_num_children ,self .h_g ,None ,None )
264
+ node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
265
+ node .children [i ].k_candidates = child_k_candidates
281
266
self ._set_params_recursion (node .children [i ],original_tree_node .children [i ])
282
267
283
268
def _gen_sample_recursion (self ,node :_Node ,x ):
0 commit comments