Skip to content

Commit e7347ad

Browse files
committed
Revise set_params and estimate_params
1 parent c06e23f commit e7347ad

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,17 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
300300
a object form GenNode class
301301
"""
302302
if original_tree_node.leaf: # leaf node
303-
node.sub_model = copy.deepcopy(original_tree_node.sub_model)
303+
try:
304+
sub_params = original_tree_node.sub_model.get_params()
305+
node.sub_model.set_params(**sub_params)
306+
except:
307+
try:
308+
sub_params = original_tree_node.sub_model.estimate_params(loss='0-1',dict_out=True)
309+
node.sub_model.set_params(**sub_params)
310+
except:
311+
sub_params = original_tree_node.sub_model.estimate_params(dict_out=True)
312+
node.sub_model.set_params(**sub_params)
313+
304314
if node.depth == self.c_d_max:
305315
node.h_g = 0
306316
node.leaf = True
@@ -315,6 +325,7 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
315325
self.c_num_children,
316326
child_k_candidates,
317327
self.h_g,
328+
sub_model=self.SubModel.GenModel(**self.sub_h_params)
318329
)
319330
self._set_params_recursion(node.children[i],original_tree_node.children[i])
320331

@@ -1416,7 +1427,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
14161427
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
14171428
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0)
14181429
tree_graph.view()
1419-
return map_root
1430+
return {'root':map_root}
14201431
else:
14211432
raise(CriteriaError("Unsupported loss function! "
14221433
+"This function supports only \"0-1\"."))
@@ -1537,7 +1548,7 @@ def visualize_posterior(self,filename=None,format=None):
15371548
self._visualize_model_recursion_none(tree_graph, 0, list(range(self.c_k)), 0, None, None, 1.0)
15381549
else:
15391550
MAP_index = np.argmax(self.hn_metatree_prob_vec)
1540-
print(f'MAP probability of metatree:{self.hn_metatree_prob_vec[MAP_index]}')
1551+
print(f'Approximate MAP probability of metatree:{self.hn_metatree_prob_vec[MAP_index]}')
15411552
self._visualize_model_recursion(tree_graph, self.hn_metatree_list[MAP_index], 0, None, None, 1.0)
15421553
# Can we show the image on the console without saving the file?
15431554
tree_graph.view()

bayesml/metatree/metatree_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
import numpy as np
44
import copy
55

6-
gen_model = metatree.GenModel(4,3,2,h_g=0.75)
6+
gen_model = metatree.GenModel(4,3,2,h_g=0.75,SubModel=normal)
77
gen_model.gen_params()
88
gen_model.visualize_model('tree.pdf')
99
x,y = gen_model.gen_sample(1000)
1010

11-
learn_model = metatree.LearnModel(4,3,2,h0_g=0.75)
12-
learn_model.update_posterior(x,y,n_estimators=1)
11+
learn_model = metatree.LearnModel(4,3,2,h0_g=0.75,SubModel=normal)
12+
learn_model.update_posterior(x,y)
1313
learn_model.visualize_posterior('tree2.pdf')
14+
params = learn_model.estimate_params(filename='tree3.pdf')
1415

15-
gen_model2 = metatree.GenModel(4,3,2,h_g=0.1)
16-
gen_model2.visualize_model('tree3.pdf')
17-
gen_model2.set_h_params(*learn_model.get_hn_params().values())
18-
gen_model2.gen_params()
16+
gen_model2 = metatree.GenModel(4,3,2,h_g=0.1,SubModel=normal)
1917
gen_model2.visualize_model('tree4.pdf')
18+
gen_model2.set_params(params)
19+
# gen_model2.gen_params()
20+
gen_model2.visualize_model('tree5.pdf')

0 commit comments

Comments
 (0)