Skip to content

Commit 1249653

Browse files
committed
Modify gen_params
1 parent de99d95 commit 1249653

File tree

2 files changed

+172
-68
lines changed

2 files changed

+172
-68
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 159 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,21 @@ class _Node:
7979
def __init__(self,
8080
depth,
8181
c_num_children,
82+
k_candidates=None,
83+
h_g=0.5,
84+
k=None,
85+
sub_model=None,
86+
leaf=False,
87+
map_leaf=False
8288
):
8389
self.depth = depth
8490
self.children = [None for i in range(c_num_children)] # child nodes
85-
self.k_candidates = None
86-
self.h_g = 0.5
87-
self.k = None
88-
self.sub_model = None
89-
self.leaf = False
90-
self.map_leaf = False
91+
self.k_candidates = k_candidates
92+
self.h_g = h_g
93+
self.k = k
94+
self.sub_model = sub_model
95+
self.leaf = leaf
96+
self.map_leaf = map_leaf
9197

9298
class GenModel(base.Generative):
9399
""" The stochastice data generative model and the prior distribution
@@ -168,16 +174,18 @@ def __init__(
168174
)
169175

170176
# params
171-
self.root = _Node(0,self.c_num_children)
172-
self.root.k_candidates = list(range(self.c_k))
173-
self.root.h_g = self.h_g
174-
self.root.k = 0
175-
self.root.sub_model = self.SubModel.GenModel(**self.sub_h_params)
176-
self.root.leaf = True
177+
self.root = _Node(
178+
0,
179+
self.c_num_children,
180+
list(range(self.c_k)),
181+
self.h_g,
182+
sub_model=self.SubModel.GenModel(**self.sub_h_params),
183+
leaf=True
184+
)
177185

178186
self.set_params(root)
179187

180-
def _gen_params_recursion(self,node:_Node,feature_fix):
188+
def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix):
181189
""" generate parameters recursively
182190
183191
Parameters
@@ -187,29 +195,65 @@ def _gen_params_recursion(self,node:_Node,feature_fix):
187195
feature_fix : bool
188196
a bool parameter show the feature is fixed or not
189197
"""
190-
if node.depth == self.c_d_max or node.depth == self.c_k or self.rng.random() > node.h_g: # leaf node
191-
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
192-
node.sub_model.gen_params()
198+
if h_node is None:
193199
if node.depth == self.c_d_max:
194200
node.h_g = 0
195-
node.leaf = True
196-
else: # inner node
197-
if feature_fix == False or node.k is None:
198-
node.k = self.rng.choice(node.k_candidates,
199-
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
200-
child_k_candidates = copy.copy(node.k_candidates)
201-
child_k_candidates.remove(node.k)
202-
node.leaf = False
203-
for i in range(self.c_num_children):
204-
if feature_fix == False or node.children[i] is None:
205-
node.children[i] = _Node(node.depth+1,self.c_num_children)
201+
else:
202+
node.h_g = self.h_g
203+
# node.sub_model.set_h_params(**self.sub_h_params)
204+
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
205+
if node.depth == self.c_d_max or node.depth == self.c_k or self.rng.random() > self.h_g: # leaf node
206+
node.sub_model.gen_params()
207+
node.leaf = True
208+
else: # inner node
209+
if feature_fix == False or node.k is None:
210+
node.k = self.rng.choice(node.k_candidates,
211+
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
212+
child_k_candidates = copy.copy(node.k_candidates)
213+
child_k_candidates.remove(node.k)
214+
node.leaf = False
215+
for i in range(self.c_num_children):
216+
if node.children[i] is None:
217+
node.children[i] = _Node(
218+
node.depth+1,
219+
self.c_num_children,
220+
h_g=self.h_g,
221+
sub_model=self.SubModel.GenModel(**self.sub_h_params),
222+
)
206223
node.children[i].k_candidates = child_k_candidates
207-
node.children[i].h_g = self.h_g
208-
else:
224+
self._gen_params_recursion(node.children[i],None,feature_fix)
225+
else:
226+
if node.depth == self.c_d_max:
227+
node.h_g = 0
228+
else:
229+
node.h_g = h_node.h_g
230+
try:
231+
sub_h_params = h_node.sub_model.get_h_params()
232+
except:
233+
sub_h_params = h_node.sub_model.get_hn_params()
234+
node.sub_model.set_h_params(*sub_h_params.values())
235+
if node.depth == self.c_d_max or node.depth == self.c_k or self.rng.random() > h_node.h_g: # leaf node
236+
node.sub_model.gen_params()
237+
node.leaf = True
238+
else: # inner node
239+
if feature_fix == False or node.k is None:
240+
node.k = self.rng.choice(node.k_candidates,
241+
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
242+
child_k_candidates = copy.copy(node.k_candidates)
243+
child_k_candidates.remove(node.k)
244+
node.leaf = False
245+
for i in range(self.c_num_children):
246+
if node.children[i] is None:
247+
node.children[i] = _Node(
248+
node.depth+1,
249+
self.c_num_children,
250+
h_g=self.h_g,
251+
sub_model=self.SubModel.GenModel(**self.sub_h_params),
252+
)
209253
node.children[i].k_candidates = child_k_candidates
210-
self._gen_params_recursion(node.children[i],feature_fix)
254+
self._gen_params_recursion(node.children[i],h_node.children[i],feature_fix)
211255

212-
def _gen_params_recursion_tree_fix(self,node:_Node,feature_fix):
256+
def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node,feature_fix):
213257
""" generate parameters recursively for fixed tree
214258
215259
Parameters
@@ -219,27 +263,51 @@ def _gen_params_recursion_tree_fix(self,node:_Node,feature_fix):
219263
feature_fix : bool
220264
a bool parameter show the feature is fixed or not
221265
"""
222-
if node.leaf: # leaf node
266+
if h_node is None:
267+
if node.depth == self.c_d_max:
268+
node.h_g = 0
269+
else:
270+
node.h_g = self.h_g
271+
# node.sub_model.set_h_params(**self.sub_h_params)
223272
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
224-
node.sub_model.gen_params()
273+
if node.leaf: # leaf node
274+
node.sub_model.gen_params()
275+
node.leaf = True
276+
else: # inner node
277+
if feature_fix == False or node.k is None:
278+
node.k = self.rng.choice(node.k_candidates,
279+
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
280+
child_k_candidates = copy.copy(node.k_candidates)
281+
child_k_candidates.remove(node.k)
282+
node.leaf = False
283+
for i in range(self.c_num_children):
284+
if node.children[i] is not None:
285+
node.children[i].k_candidates = child_k_candidates
286+
self._gen_params_recursion_tree_fix(node.children[i],None,feature_fix)
287+
else:
225288
if node.depth == self.c_d_max:
226289
node.h_g = 0
227-
node.leaf = True
228-
else: # inner node
229-
if feature_fix == False or node.k is None:
230-
node.k = self.rng.choice(node.k_candidates,
231-
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
232-
child_k_candidates = copy.copy(node.k_candidates)
233-
child_k_candidates.remove(node.k)
234-
node.leaf = False
235-
for i in range(self.c_num_children):
236-
if feature_fix == False or node.children[i] is None:
237-
node.children[i] = _Node(node.depth+1,self.c_num_children)
238-
node.children[i].k_candidates = child_k_candidates
239-
node.children[i].h_g = self.h_g
240-
else:
241-
node.children[i].k_candidates = child_k_candidates
242-
self._gen_params_recursion_tree_fix(node.children[i],feature_fix)
290+
else:
291+
node.h_g = h_node.h_g
292+
try:
293+
sub_h_params = h_node.sub_model.get_h_params()
294+
except:
295+
sub_h_params = h_node.sub_model.get_hn_params()
296+
node.sub_model.set_h_params(*sub_h_params.values())
297+
if node.leaf: # leaf node
298+
node.sub_model.gen_params()
299+
node.leaf = True
300+
else: # inner node
301+
if feature_fix == False or node.k is None:
302+
node.k = self.rng.choice(node.k_candidates,
303+
p=self.h_k_prob_vec[node.k_candidates]/self.h_k_prob_vec[node.k_candidates].sum())
304+
child_k_candidates = copy.copy(node.k_candidates)
305+
child_k_candidates.remove(node.k)
306+
node.leaf = False
307+
for i in range(self.c_num_children):
308+
if node.children[i] is not None:
309+
node.children[i].k_candidates = child_k_candidates
310+
self._gen_params_recursion_tree_fix(node.children[i],h_node,feature_fix)
243311

244312
def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
245313
""" copy parameters from a fixed tree
@@ -260,9 +328,14 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
260328
node.k = original_tree_node.k
261329
child_k_candidates = copy.copy(node.k_candidates)
262330
child_k_candidates.remove(node.k)
331+
node.leaf = False
263332
for i in range(self.c_num_children):
264-
node.children[i] = _Node(node.depth+1,self.c_num_children)
265-
node.children[i].k_candidates = child_k_candidates
333+
node.children[i] = _Node(
334+
node.depth+1,
335+
self.c_num_children,
336+
child_k_candidates,
337+
self.h_g,
338+
)
266339
self._set_params_recursion(node.children[i],original_tree_node.children[i])
267340

268341
def _gen_sample_recursion(self,node:_Node,x):
@@ -291,7 +364,11 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
291364
tmp_p_v = p_v
292365

293366
# add node information
294-
label_string = f'k={node.k}\\lh_g={node.h_g:.2f}\\lp_v={tmp_p_v:.2f}\\lsub_params={{'
367+
if node.leaf:
368+
label_string = 'k=None\\l'
369+
else:
370+
label_string = f'k={node.k}\\l'
371+
label_string += f'h_g={node.h_g:.2f}\\lp_v={tmp_p_v:.2f}\\lsub_params={{'
295372
if node.leaf:
296373
sub_params = node.sub_model.get_params()
297374
for key,value in sub_params.items():
@@ -332,28 +409,32 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
332409
node.h_g = 0
333410
else:
334411
node.h_g = self.h_g
335-
node.sub_model.set_h_params(**self.sub_h_params)
412+
# node.sub_model.set_h_params(**self.sub_h_params)
413+
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
336414
for i in range(self.c_k):
337415
if node.children[i] is not None:
338416
self._set_h_params_recursion(node.children[i],None)
339417
else:
340-
node.h_g = original_tree_node.h_g
418+
if node.depth == self.c_d_max:
419+
node.h_g = 0
420+
else:
421+
node.h_g = original_tree_node.h_g
341422
try:
342423
sub_h_params = node.sub_model.get_h_params()
343424
except:
344425
sub_h_params = node.sub_model.get_hn_params()
345-
node.sub_model.set_h_params(
346-
*sub_h_params.values()
347-
)
426+
node.sub_model.set_h_params(*sub_h_params.values())
348427
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
349428
node.leaf = True
350-
if node.depth == self.c_d_max:
351-
node.h_g = 0
352429
else:
353430
node.leaf = False
354431
for i in range(self.c_k):
355432
if node.children[i] is None:
356-
node.children[i] = _Node(node.depth+1,self.c_k)
433+
node.children[i] = _Node(
434+
node.depth+1,
435+
self.c_k,
436+
sub_model=self.SubModel.GenModel(**self.sub_h_params),
437+
)
357438
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
358439

359440
def set_h_params(self,
@@ -398,7 +479,6 @@ def set_h_params(self,
398479
for h_root in self.h_metatree_list:
399480
self._set_h_params_recursion(h_root,None)
400481

401-
402482
if sub_h_params is not None:
403483
self.SubModel.GenModel(**sub_h_params)
404484
self.sub_h_params = copy.deepcopy(sub_h_params)
@@ -474,7 +554,7 @@ def get_h_params(self):
474554
"h_metatree_list":self.h_metatree_list,
475555
"h_metatree_prob_vec":self.h_metatree_prob_vec}
476556

477-
def gen_params(self,feature_fix=False,tree_fix=False,from_list=False):
557+
def gen_params(self,feature_fix=False,tree_fix=False):
478558
"""Generate the parameter from the prior distribution.
479559
480560
The generated vaule is set at ``self.root``.
@@ -486,13 +566,17 @@ def gen_params(self,feature_fix=False,tree_fix=False,from_list=False):
486566
tree_fix : bool
487567
If ``True``, tree shape will be fixed, by default ``False``.
488568
"""
489-
if from_list == True and len(self.h_metatree_list) > 0:
569+
if self.h_metatree_list:
490570
tmp_root = self.rng.choice(self.h_metatree_list,p=self.h_metatree_prob_vec)
491-
self.set_params(tmp_root)
492-
elif tree_fix:
493-
self._gen_params_recursion_tree_fix(self.root,feature_fix)
571+
if tree_fix:
572+
self._gen_params_recursion_tree_fix(self.root,tmp_root,feature_fix)
573+
else:
574+
self._gen_params_recursion(self.root,tmp_root,feature_fix)
494575
else:
495-
self._gen_params_recursion(self.root,feature_fix)
576+
if tree_fix:
577+
self._gen_params_recursion_tree_fix(self.root,None,feature_fix)
578+
else:
579+
self._gen_params_recursion(self.root,None,feature_fix)
496580

497581
def set_params(self,root=None):
498582
"""Set the parameter of the sthocastic data generative model.
@@ -507,6 +591,14 @@ def set_params(self,root=None):
507591
raise(ParameterFormatError(
508592
"root must be an instance of metatree._Node"
509593
))
594+
self.root = _Node(
595+
0,
596+
self.c_num_children,
597+
list(range(self.c_k)),
598+
self.h_g,
599+
sub_model=self.SubModel.GenModel(**self.sub_h_params),
600+
leaf=True
601+
)
510602
self._set_params_recursion(self.root,root)
511603

512604
def get_params(self):

bayesml/metatree/metatree_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
from bayesml import metatree
22
from bayesml import normal
33
import numpy as np
4+
import copy
45

56
gen_model = metatree.GenModel(3,3,h_g=0.7)
6-
print(gen_model.get_h_params())
7+
gen_model.gen_params()
8+
gen_model.visualize_model(filename='tree.pdf')
9+
params1 = copy.deepcopy(gen_model.get_params())
10+
gen_model.gen_params(feature_fix=True)
11+
gen_model.visualize_model(filename='tree2.pdf')
12+
params2 = copy.deepcopy(gen_model.get_params())
13+
14+
# gen_model2 = metatree.GenModel(3,3)
15+
# gen_model2.visualize_model(filename='tree3.pdf')
16+
# gen_model2.set_h_params(h_metatree_list=[params1['root'],params2['root']])
17+
# gen_model2.gen_params(feature_fix=True)
18+
# gen_model2.visualize_model(filename='tree4.pdf')

0 commit comments

Comments
 (0)