Skip to content

Commit a9942ec

Browse files
committed
Reduce arguments of _Node.__init__()
1 parent 7d294fb commit a9942ec

File tree

2 files changed

+19
-34
lines changed

2 files changed

+19
-34
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,6 @@
3939
# linearregression,
4040
exponential,
4141
}
42-
# GEN_MODELS = {
43-
# bernoulli.GenModel,
44-
# # categorical.GenModel,
45-
# normal.GenModel,
46-
# # multivariate_normal.GenModel,
47-
# # linearregression.GenModel,
48-
# poisson.GenModel,
49-
# exponential.GenModel,
50-
# }
51-
# DISCRETE_GEN_MODELS = {
52-
# bernoulli.GenModel,
53-
# # categorical.GenModel,
54-
# poisson.GenModel,
55-
# }
56-
# CONTINUOUS_GEN_MODELS = {
57-
# normal.GenModel,
58-
# # multivariate_normal.GenModel,
59-
# # linearregression.GenModel,
60-
# exponential.GenModel,
61-
# }
6242
LEARN_MODELS = {
6343
bernoulli.LearnModel,
6444
# categorical.LearnModel,
@@ -98,18 +78,14 @@ class _Node:
9878
"""
9979
def __init__(self,
10080
depth,
101-
k_candidates,
102-
c_num_children = 2,
103-
h_g = 0.5,
104-
k = None,
105-
sub_model = None
81+
c_num_children,
10682
):
10783
self.depth = depth
10884
self.children = [None for i in range(c_num_children)] # child nodes
109-
self.k_candidates = k_candidates
110-
self.h_g = h_g
111-
self.k = k
112-
self.sub_model = sub_model
85+
self.k_candidates = None
86+
self.h_g = 0.5
87+
self.k = None
88+
self.sub_model = None
11389
self.leaf = False
11490
self.map_leaf = False
11591

@@ -192,7 +168,11 @@ def __init__(
192168
)
193169

194170
# params
195-
self.root = _Node(0,list(range(self.c_k)),self.c_num_children,self.h_g,0,self.SubModel.GenModel(**self.sub_h_params))
171+
self.root = _Node(0,self.c_num_children)
172+
self.root.k_candidates = list(range(self.c_k))
173+
self.root.h_g = self.h_g
174+
self.root.k = 0
175+
self.root.sub_model = self.SubModel.GenModel(**self.sub_h_params)
196176
self.root.leaf = True
197177

198178
self.set_params(root)
@@ -222,7 +202,9 @@ def _gen_params_recursion(self,node:_Node,feature_fix):
222202
node.leaf = False
223203
for i in range(self.c_num_children):
224204
if feature_fix == False or node.children[i] is None:
225-
node.children[i] = _Node(node.depth+1,child_k_candidates,self.c_num_children,self.h_g,None,None)
205+
node.children[i] = _Node(node.depth+1,self.c_num_children)
206+
node.children[i].k_candidates = child_k_candidates
207+
node.children[i].h_g = self.h_g
226208
else:
227209
node.children[i].k_candidates = child_k_candidates
228210
self._gen_params_recursion(node.children[i],feature_fix)
@@ -252,7 +234,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,feature_fix):
252234
node.leaf = False
253235
for i in range(self.c_num_children):
254236
if feature_fix == False or node.children[i] is None:
255-
node.children[i] = _Node(node.depth+1,child_k_candidates,self.c_num_children,self.h_g,None,None)
237+
node.children[i] = _Node(node.depth+1,self.c_num_children)
238+
node.children[i].k_candidates = child_k_candidates
239+
node.children[i].h_g = self.h_g
256240
else:
257241
node.children[i].k_candidates = child_k_candidates
258242
self._gen_params_recursion_tree_fix(node.children[i],feature_fix)
@@ -277,7 +261,8 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
277261
child_k_candidates = copy.copy(node.k_candidates)
278262
child_k_candidates.remove(node.k)
279263
for i in range(self.c_num_children):
280-
node.children[i] = _Node(node.depth+1,child_k_candidates,self.c_num_children,self.h_g,None,None)
264+
node.children[i] = _Node(node.depth+1,self.c_num_children)
265+
node.children[i].k_candidates = child_k_candidates
281266
self._set_params_recursion(node.children[i],original_tree_node.children[i])
282267

283268
def _gen_sample_recursion(self,node:_Node,x):

bayesml/metatree/metatree_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from bayesml import normal
33
import numpy as np
44

5-
gen_model = metatree.GenModel(2,3,h_g=0.7,SubModel=normal)
5+
gen_model = metatree.GenModel(3,3,h_g=0.7)
66
gen_model.gen_params()
77
gen_model.visualize_model()

0 commit comments

Comments
 (0)