Skip to content

Commit 16a9351

Browse files
Merge pull request #50 from yuta-nakahara/develop-contexttree
Update visualization
2 parents 6c3574c + ae41e62 commit 16a9351

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

bayesml/contexttree/_contexttree.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
107107
else:
108108
node.h_g = self.h_g
109109
node.h_beta_vec[:] = self.h_beta_vec
110-
if node.depth == self.c_d_max or self.rng.random() > self.h_g: # 葉ノード
110+
if node.depth == self.c_d_max or self.rng.random() > self.h_g: # leaf node
111111
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
112112
node.leaf = True
113-
else: # 内部ノード
113+
else: # inner node
114114
node.leaf = False
115115
for i in range(self.c_k):
116116
if node.children[i] is None:
@@ -122,10 +122,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
122122
else:
123123
node.h_g = h_node.h_g
124124
node.h_beta_vec[:] = h_node.h_beta_vec
125-
if node.depth == self.c_d_max or self.rng.random() > h_node.h_g: # 葉ノード
125+
if node.depth == self.c_d_max or self.rng.random() > h_node.h_g: # leaf node
126126
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
127127
node.leaf = True
128-
else: # 内部ノード
128+
else: # inner node
129129
node.leaf = False
130130
for i in range(self.c_k):
131131
if node.children[i] is None:
@@ -140,9 +140,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
140140
else:
141141
node.h_g = self.h_g
142142
node.h_beta_vec[:] = self.h_beta_vec
143-
if node.leaf: # 葉ノード
143+
if node.leaf: # leaf node
144144
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
145-
else: # 内部ノード
145+
else: # inner node
146146
for i in range(self.c_k):
147147
if node.children[i] is not None:
148148
self._gen_params_recursion_tree_fix(node.children[i],None)
@@ -152,9 +152,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
152152
else:
153153
node.h_g = h_node.h_g
154154
node.h_beta_vec[:] = h_node.h_beta_vec
155-
if node.leaf: # 葉ノード
155+
if node.leaf: # leaf node
156156
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
157-
else: # 内部ノード
157+
else: # inner node
158158
for i in range(self.c_k):
159159
if node.children[i] is not None:
160160
self._gen_params_recursion_tree_fix(node.children[i],h_node.children[i])
@@ -170,7 +170,7 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
170170
a object from _Node class
171171
"""
172172
node.theta_vec[:] = original_tree_node.theta_vec
173-
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
173+
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
174174
node.leaf = True
175175
else:
176176
node.leaf = False
@@ -201,7 +201,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
201201
else:
202202
node.h_g = original_tree_node.h_g
203203
node.h_beta_vec[:] = original_tree_node.h_beta_vec
204-
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
204+
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
205205
node.leaf = True
206206
if node.depth == self.c_d_max:
207207
node.h_g = 0
@@ -223,7 +223,7 @@ def _gen_sample_recursion(self,node,x):
223223
x : numpy ndarray
224224
1 dimensional array whose elements are 0 or 1.
225225
"""
226-
if node.leaf: # 葉ノード
226+
if node.leaf: # leaf node
227227
return self.rng.choice(self.c_k,p=node.theta_vec)
228228
else:
229229
return self._gen_sample_recursion(node.children[x[-node.depth-1]],x)
@@ -433,15 +433,14 @@ def visualize_model(self,filename=None,format=None,sample_length=10):
433433
--------
434434
graphviz.Digraph
435435
"""
436-
#例外処理
437436
_check.pos_int(sample_length,'sample_length',DataFormatError)
438437

439438
try:
440439
import graphviz
441440
tree_graph = graphviz.Digraph(filename=filename,format=format)
442441
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
443442
self._visualize_model_recursion(tree_graph, self.root, 0, None, None, 1.0)
444-
# コンソール上で表示できるようにした方がいいかもしれない.
443+
# Can we show the image on the console without saving the file?
445444
tree_graph.view()
446445
except ImportError as e:
447446
print(e)
@@ -532,7 +531,7 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
532531
else:
533532
node.h_g = original_tree_node.h_g
534533
node.h_beta_vec[:] = original_tree_node.h_beta_vec
535-
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
534+
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
536535
node.leaf = True
537536
if node.depth == self.c_d_max:
538537
node.h_g = 0
@@ -565,7 +564,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
565564
else:
566565
node.h_g = original_tree_node.h_g
567566
node.h_beta_vec[:] = original_tree_node.h_beta_vec
568-
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
567+
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
569568
node.leaf = True
570569
if node.depth == self.c_d_max:
571570
node.h_g = 0
@@ -610,7 +609,8 @@ def set_h0_params(self,
610609
raise(ParameterFormatError(
611610
"h0_root must be an instance of contexttree._Node"
612611
))
613-
self.h0_root = _Node(0,self.c_k)
612+
if self.h0_root is None:
613+
self.h0_root = _Node(0,self.c_k)
614614
self._set_h0_params_recursion(self.h0_root,h0_root)
615615

616616
self.reset_hn_params()
@@ -663,7 +663,8 @@ def set_hn_params(self,
663663
raise(ParameterFormatError(
664664
"hn_root must be an instance of contexttree._Node"
665665
))
666-
self.hn_root = _Node(0,self.c_k)
666+
if self.hn_root is None:
667+
self.hn_root = _Node(0,self.c_k)
667668
self._set_hn_params_recursion(self.hn_root,hn_root)
668669

669670
self.calc_pred_dist(np.zeros(self.c_d_max,dtype=int))
@@ -688,7 +689,7 @@ def _update_posterior_leaf(self,node:_Node,x,i):
688689
return tmp
689690

690691
def _update_posterior_recursion(self,node:_Node,x,i):
691-
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
692+
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
692693
if node.children[x[i-node.depth-1]] is None:
693694
node.children[x[i-node.depth-1]] = _Node(
694695
node.depth+1,
@@ -703,7 +704,7 @@ def _update_posterior_recursion(self,node:_Node,x,i):
703704
tmp2 = (1 - node.h_g) * self._update_posterior_leaf(node,x,i) + node.h_g * tmp1
704705
node.h_g = node.h_g * tmp1 / tmp2
705706
return tmp2
706-
else: # 葉ノード
707+
else: # leaf node
707708
return self._update_posterior_leaf(node,x,i)
708709

709710
def update_posterior(self,x):
@@ -728,11 +729,11 @@ def update_posterior(self,x):
728729
self._update_posterior_recursion(self.hn_root,x,i)
729730

730731
def _map_recursion_add_nodes(self,node:_Node):
731-
if node.depth == self.c_d_max: # 葉ノード
732+
if node.depth == self.c_d_max: # leaf node
732733
node.h_g = 0.0
733734
node.leaf = True
734735
node.map_leaf = True
735-
else: # 内部ノード
736+
else: # inner node
736737
for i in range(self.c_k):
737738
node.children[i] = _Node(node.depth+1,self.c_k)
738739
node.children[i].h_g = self.hn_g
@@ -817,6 +818,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
817818
tree_graph = graphviz.Digraph(filename=filename,format=format)
818819
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
819820
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0)
821+
# Can we show the image on the console without saving the file?
820822
tree_graph.view()
821823
return map_root
822824
else:
@@ -851,6 +853,38 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
851853

852854
return node_id
853855

856+
def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibling_num,p_v):
857+
tmp_id = node_id
858+
tmp_p_v = p_v
859+
860+
# add node information
861+
if depth == self.c_d_max:
862+
label_string = 'hn_g=0\\l'
863+
else:
864+
label_string = f'hn_g={self.hn_g:.2f}\\l'
865+
label_string += f'p_v={tmp_p_v:.2f}\\ltheta_vec\\l='
866+
label_string += '['
867+
for i in range(self.c_k):
868+
theta_vec_hat = self.hn_beta_vec / self.hn_beta_vec.sum()
869+
label_string += f'{theta_vec_hat[i]:.2f}'
870+
if i < self.c_k-1:
871+
label_string += ','
872+
label_string += ']'
873+
874+
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
875+
if tmp_p_v > 0.65:
876+
tree_graph.node(name=f'{tmp_id}',fontcolor='white')
877+
878+
# add edge information
879+
if parent_id is not None:
880+
tree_graph.edge(f'{parent_id}', f'{tmp_id}', label=f'{sibling_num}')
881+
882+
if depth < self.c_d_max:
883+
for i in range(self.c_k):
884+
node_id = self._visualize_model_recursion_none(tree_graph,depth+1,node_id+1,tmp_id,i,tmp_p_v*self.hn_g)
885+
886+
return node_id
887+
854888
def visualize_posterior(self,filename=None,format=None):
855889
"""Visualize the posterior distribution for the parameter.
856890
@@ -883,8 +917,11 @@ def visualize_posterior(self,filename=None,format=None):
883917
import graphviz
884918
tree_graph = graphviz.Digraph(filename=filename,format=format)
885919
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
886-
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0)
887-
# コンソール上で表示できるようにした方がいいかもしれない.
920+
if self.hn_root is None:
921+
self._visualize_model_recursion_none(tree_graph, 0, 0, None, None, 1.0)
922+
else:
923+
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0)
924+
# Can we show the image on the console without saving the file?
888925
tree_graph.view()
889926
except ImportError as e:
890927
print(e)
@@ -905,7 +942,7 @@ def _calc_pred_dist_leaf(self,node:_Node):
905942
return node.h_beta_vec / node.h_beta_vec.sum()
906943

907944
def _calc_pred_dist_recursion(self,node:_Node,x,i):
908-
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
945+
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
909946
if node.children[x[i-node.depth-1]] is None:
910947
node.children[x[i-node.depth-1]] = _Node(
911948
node.depth+1,
@@ -919,7 +956,7 @@ def _calc_pred_dist_recursion(self,node:_Node,x,i):
919956
tmp1 = self._calc_pred_dist_recursion(node.children[x[i-node.depth-1]],x,i)
920957
tmp2 = (1 - node.h_g) * self._calc_pred_dist_leaf(node) + node.h_g * tmp1
921958
return tmp2
922-
else: # 葉ノード
959+
else: # leaf node
923960
return self._calc_pred_dist_leaf(node)
924961

925962
def calc_pred_dist(self,x):
@@ -936,11 +973,9 @@ def calc_pred_dist(self,x):
936973
i = x.shape[0] - 1
937974

938975
if self.hn_root is None:
939-
self.hn_root = _Node(0,self.c_k)
940-
self.hn_root.h_g = self.hn_g
941-
self.hn_root.h_beta_vec[:] = self.hn_beta_vec
942-
943-
self.p_theta_vec[:] = self._calc_pred_dist_recursion(self.hn_root,x,i)
976+
self.p_theta_vec[:] = self.hn_beta_vec / self.hn_beta_vec.sum()
977+
else:
978+
self.p_theta_vec[:] = self._calc_pred_dist_recursion(self.hn_root,x,i)
944979

945980
def make_prediction(self,loss="KL"):
946981
"""Predict a new data point under the given criterion.
@@ -973,7 +1008,7 @@ def _pred_and_update_leaf(self,node:_Node,x,i):
9731008
return tmp
9741009

9751010
def _pred_and_update_recursion(self,node:_Node,x,i):
976-
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
1011+
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
9771012
if node.children[x[i-node.depth-1]] is None:
9781013
node.children[x[i-node.depth-1]] = _Node(
9791014
node.depth+1,
@@ -988,7 +1023,7 @@ def _pred_and_update_recursion(self,node:_Node,x,i):
9881023
tmp2 = (1 - node.h_g) * self._pred_and_update_leaf(node,x,i) + node.h_g * tmp1
9891024
node.h_g = node.h_g * tmp1[x[i]] / tmp2[x[i]]
9901025
return tmp2
991-
else: # 葉ノード
1026+
else: # leaf node
9921027
return self._pred_and_update_leaf(node,x,i)
9931028

9941029
def pred_and_update(self,x,loss="KL"):

0 commit comments

Comments
 (0)