Skip to content

Commit cde7eba

Browse files
committed
Modify h_params_recursion
1 parent 0d83bda commit cde7eba

File tree

2 files changed

+35
-30
lines changed

2 files changed

+35
-30
lines changed

bayesml/contexttree/_contexttree.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,13 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
190190
a object from _Node class
191191
"""
192192
if original_tree_node is None:
193-
node.h_g = self.h_g
194-
node.h_beta_vec[:] = self.h_beta_vec
195-
if node.depth == self.c_d_max: # 葉ノード
196-
node.leaf = True
197-
if node.depth == self.c_d_max:
198-
node.h_g = 0
193+
if node.depth == self.c_d_max:
194+
node.h_g = 0
199195
else:
200-
node.leaf = False
201-
for i in range(self.c_k):
202-
if node.children[i] is None:
203-
node.children[i] = _Node(node.depth+1,self.c_k)
196+
node.h_g = self.h_g
197+
node.h_beta_vec[:] = self.h_beta_vec
198+
for i in range(self.c_k):
199+
if node.children[i] is not None:
204200
self._set_h_params_recursion(node.children[i],None)
205201
else:
206202
node.h_g = original_tree_node.h_g
@@ -525,17 +521,13 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
525521
a object from _Node class
526522
"""
527523
if original_tree_node is None:
528-
node.h_g = self.h0_g
529-
node.h_beta_vec[:] = self.h0_beta_vec
530-
if node.depth == self.c_d_max: # 葉ノード
531-
node.leaf = True
532-
if node.depth == self.c_d_max:
533-
node.h_g = 0
524+
if node.depth == self.c_d_max:
525+
node.h_g = 0
534526
else:
535-
node.leaf = False
536-
for i in range(self.c_k):
537-
if node.children[i] is None:
538-
node.children[i] = _Node(node.depth+1,self.c_k)
527+
node.h_g = self.h0_g
528+
node.h_beta_vec[:] = self.h0_beta_vec
529+
for i in range(self.c_k):
530+
if node.children[i] is not None:
539531
self._set_h0_params_recursion(node.children[i],None)
540532
else:
541533
node.h_g = original_tree_node.h_g
@@ -562,17 +554,13 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
562554
a object from _Node class
563555
"""
564556
if original_tree_node is None:
565-
node.h_g = self.hn_g
566-
node.h_beta_vec[:] = self.hn_beta_vec
567-
if node.depth == self.c_d_max: # 葉ノード
568-
node.leaf = True
569-
if node.depth == self.c_d_max:
570-
node.h_g = 0
557+
if node.depth == self.c_d_max:
558+
node.h_g = 0
571559
else:
572-
node.leaf = False
573-
for i in range(self.c_k):
574-
if node.children[i] is None:
575-
node.children[i] = _Node(node.depth+1,self.c_k)
560+
node.h_g = self.hn_g
561+
node.h_beta_vec[:] = self.hn_beta_vec
562+
for i in range(self.c_k):
563+
if node.children[i] is not None:
576564
self._set_hn_params_recursion(node.children[i],None)
577565
else:
578566
node.h_g = original_tree_node.h_g
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from bayesml.contexttree import GenModel
2+
from bayesml.contexttree import LearnModel
3+
import numpy as np
4+
5+
gen_model = GenModel(2,3,h_g=0.7,h_beta_vec=1.0)
6+
gen_model.gen_params()
7+
gen_model.visualize_model(filename='tree1.pdf')
8+
9+
x = gen_model.gen_sample(1000)
10+
11+
learn_model = LearnModel(2,4)
12+
learn_model.visualize_posterior(filename='tree2.pdf')
13+
learn_model.reset_hn_params()
14+
learn_model.visualize_posterior(filename='tree3.pdf')
15+
learn_model.update_posterior(x)
16+
learn_model.visualize_posterior(filename='tree4.pdf')
17+
learn_model.estimate_params()

0 commit comments

Comments
 (0)