@@ -107,10 +107,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
107
107
else :
108
108
node .h_g = self .h_g
109
109
node .h_beta_vec [:] = self .h_beta_vec
110
- if node .depth == self .c_d_max or self .rng .random () > self .h_g : # 葉ノード
110
+ if node .depth == self .c_d_max or self .rng .random () > self .h_g : # leaf node
111
111
node .theta_vec [:] = self .rng .dirichlet (self .h_beta_vec )
112
112
node .leaf = True
113
- else : # 内部ノード
113
+ else : # inner node
114
114
node .leaf = False
115
115
for i in range (self .c_k ):
116
116
if node .children [i ] is None :
@@ -122,10 +122,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
122
122
else :
123
123
node .h_g = h_node .h_g
124
124
node .h_beta_vec [:] = h_node .h_beta_vec
125
- if node .depth == self .c_d_max or self .rng .random () > h_node .h_g : # 葉ノード
125
+ if node .depth == self .c_d_max or self .rng .random () > h_node .h_g : # leaf node
126
126
node .theta_vec [:] = self .rng .dirichlet (h_node .h_beta_vec )
127
127
node .leaf = True
128
- else : # 内部ノード
128
+ else : # inner node
129
129
node .leaf = False
130
130
for i in range (self .c_k ):
131
131
if node .children [i ] is None :
@@ -140,9 +140,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
140
140
else :
141
141
node .h_g = self .h_g
142
142
node .h_beta_vec [:] = self .h_beta_vec
143
- if node .leaf : # 葉ノード
143
+ if node .leaf : # leaf node
144
144
node .theta_vec [:] = self .rng .dirichlet (self .h_beta_vec )
145
- else : # 内部ノード
145
+ else : # inner node
146
146
for i in range (self .c_k ):
147
147
if node .children [i ] is not None :
148
148
self ._gen_params_recursion_tree_fix (node .children [i ],None )
@@ -152,9 +152,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
152
152
else :
153
153
node .h_g = h_node .h_g
154
154
node .h_beta_vec [:] = h_node .h_beta_vec
155
- if node .leaf : # 葉ノード
155
+ if node .leaf : # leaf node
156
156
node .theta_vec [:] = self .rng .dirichlet (h_node .h_beta_vec )
157
- else : # 内部ノード
157
+ else : # inner node
158
158
for i in range (self .c_k ):
159
159
if node .children [i ] is not None :
160
160
self ._gen_params_recursion_tree_fix (node .children [i ],h_node .children [i ])
@@ -170,7 +170,7 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
170
170
a object from _Node class
171
171
"""
172
172
node .theta_vec [:] = original_tree_node .theta_vec
173
- if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
173
+ if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
174
174
node .leaf = True
175
175
else :
176
176
node .leaf = False
@@ -201,7 +201,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
201
201
else :
202
202
node .h_g = original_tree_node .h_g
203
203
node .h_beta_vec [:] = original_tree_node .h_beta_vec
204
- if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
204
+ if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
205
205
node .leaf = True
206
206
if node .depth == self .c_d_max :
207
207
node .h_g = 0
@@ -223,7 +223,7 @@ def _gen_sample_recursion(self,node,x):
223
223
x : numpy ndarray
224
224
1 dimensional array whose elements are 0 or 1.
225
225
"""
226
- if node .leaf : # 葉ノード
226
+ if node .leaf : # leaf node
227
227
return self .rng .choice (self .c_k ,p = node .theta_vec )
228
228
else :
229
229
return self ._gen_sample_recursion (node .children [x [- node .depth - 1 ]],x )
@@ -433,15 +433,14 @@ def visualize_model(self,filename=None,format=None,sample_length=10):
433
433
--------
434
434
graphviz.Digraph
435
435
"""
436
- #例外処理
437
436
_check .pos_int (sample_length ,'sample_length' ,DataFormatError )
438
437
439
438
try :
440
439
import graphviz
441
440
tree_graph = graphviz .Digraph (filename = filename ,format = format )
442
441
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
443
442
self ._visualize_model_recursion (tree_graph , self .root , 0 , None , None , 1.0 )
444
- # コンソール上で表示できるようにした方がいいかもしれない.
443
+ # Can we show the image on the console without saving the file?
445
444
tree_graph .view ()
446
445
except ImportError as e :
447
446
print (e )
@@ -532,7 +531,7 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
532
531
else :
533
532
node .h_g = original_tree_node .h_g
534
533
node .h_beta_vec [:] = original_tree_node .h_beta_vec
535
- if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
534
+ if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
536
535
node .leaf = True
537
536
if node .depth == self .c_d_max :
538
537
node .h_g = 0
@@ -565,7 +564,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
565
564
else :
566
565
node .h_g = original_tree_node .h_g
567
566
node .h_beta_vec [:] = original_tree_node .h_beta_vec
568
- if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
567
+ if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
569
568
node .leaf = True
570
569
if node .depth == self .c_d_max :
571
570
node .h_g = 0
@@ -610,7 +609,8 @@ def set_h0_params(self,
610
609
raise (ParameterFormatError (
611
610
"h0_root must be an instance of contexttree._Node"
612
611
))
613
- self .h0_root = _Node (0 ,self .c_k )
612
+ if self .h0_root is None :
613
+ self .h0_root = _Node (0 ,self .c_k )
614
614
self ._set_h0_params_recursion (self .h0_root ,h0_root )
615
615
616
616
self .reset_hn_params ()
@@ -663,7 +663,8 @@ def set_hn_params(self,
663
663
raise (ParameterFormatError (
664
664
"hn_root must be an instance of contexttree._Node"
665
665
))
666
- self .hn_root = _Node (0 ,self .c_k )
666
+ if self .hn_root is None :
667
+ self .hn_root = _Node (0 ,self .c_k )
667
668
self ._set_hn_params_recursion (self .hn_root ,hn_root )
668
669
669
670
self .calc_pred_dist (np .zeros (self .c_d_max ,dtype = int ))
@@ -688,7 +689,7 @@ def _update_posterior_leaf(self,node:_Node,x,i):
688
689
return tmp
689
690
690
691
def _update_posterior_recursion (self ,node :_Node ,x ,i ):
691
- if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # 内部ノード
692
+ if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # inner node
692
693
if node .children [x [i - node .depth - 1 ]] is None :
693
694
node .children [x [i - node .depth - 1 ]] = _Node (
694
695
node .depth + 1 ,
@@ -703,7 +704,7 @@ def _update_posterior_recursion(self,node:_Node,x,i):
703
704
tmp2 = (1 - node .h_g ) * self ._update_posterior_leaf (node ,x ,i ) + node .h_g * tmp1
704
705
node .h_g = node .h_g * tmp1 / tmp2
705
706
return tmp2
706
- else : # 葉ノード
707
+ else : # leaf node
707
708
return self ._update_posterior_leaf (node ,x ,i )
708
709
709
710
def update_posterior (self ,x ):
@@ -728,11 +729,11 @@ def update_posterior(self,x):
728
729
self ._update_posterior_recursion (self .hn_root ,x ,i )
729
730
730
731
def _map_recursion_add_nodes (self ,node :_Node ):
731
- if node .depth == self .c_d_max : # 葉ノード
732
+ if node .depth == self .c_d_max : # leaf node
732
733
node .h_g = 0.0
733
734
node .leaf = True
734
735
node .map_leaf = True
735
- else : # 内部ノード
736
+ else : # inner node
736
737
for i in range (self .c_k ):
737
738
node .children [i ] = _Node (node .depth + 1 ,self .c_k )
738
739
node .children [i ].h_g = self .hn_g
@@ -817,6 +818,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
817
818
tree_graph = graphviz .Digraph (filename = filename ,format = format )
818
819
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
819
820
self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 )
821
+ # Can we show the image on the console without saving the file?
820
822
tree_graph .view ()
821
823
return map_root
822
824
else :
@@ -851,6 +853,38 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
851
853
852
854
return node_id
853
855
856
+ def _visualize_model_recursion_none (self ,tree_graph ,depth ,node_id ,parent_id ,sibling_num ,p_v ):
857
+ tmp_id = node_id
858
+ tmp_p_v = p_v
859
+
860
+ # add node information
861
+ if depth == self .c_d_max :
862
+ label_string = 'hn_g=0\\ l'
863
+ else :
864
+ label_string = f'hn_g={ self .hn_g :.2f} \\ l'
865
+ label_string += f'p_v={ tmp_p_v :.2f} \\ ltheta_vec\\ l='
866
+ label_string += '['
867
+ for i in range (self .c_k ):
868
+ theta_vec_hat = self .hn_beta_vec / self .hn_beta_vec .sum ()
869
+ label_string += f'{ theta_vec_hat [i ]:.2f} '
870
+ if i < self .c_k - 1 :
871
+ label_string += ','
872
+ label_string += ']'
873
+
874
+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
875
+ if tmp_p_v > 0.65 :
876
+ tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
877
+
878
+ # add edge information
879
+ if parent_id is not None :
880
+ tree_graph .edge (f'{ parent_id } ' , f'{ tmp_id } ' , label = f'{ sibling_num } ' )
881
+
882
+ if depth < self .c_d_max :
883
+ for i in range (self .c_k ):
884
+ node_id = self ._visualize_model_recursion_none (tree_graph ,depth + 1 ,node_id + 1 ,tmp_id ,i ,tmp_p_v * self .hn_g )
885
+
886
+ return node_id
887
+
854
888
def visualize_posterior (self ,filename = None ,format = None ):
855
889
"""Visualize the posterior distribution for the parameter.
856
890
@@ -883,8 +917,11 @@ def visualize_posterior(self,filename=None,format=None):
883
917
import graphviz
884
918
tree_graph = graphviz .Digraph (filename = filename ,format = format )
885
919
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
886
- self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 )
887
- # コンソール上で表示できるようにした方がいいかもしれない.
920
+ if self .hn_root is None :
921
+ self ._visualize_model_recursion_none (tree_graph , 0 , 0 , None , None , 1.0 )
922
+ else :
923
+ self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 )
924
+ # Can we show the image on the console without saving the file?
888
925
tree_graph .view ()
889
926
except ImportError as e :
890
927
print (e )
@@ -905,7 +942,7 @@ def _calc_pred_dist_leaf(self,node:_Node):
905
942
return node .h_beta_vec / node .h_beta_vec .sum ()
906
943
907
944
def _calc_pred_dist_recursion (self ,node :_Node ,x ,i ):
908
- if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # 内部ノード
945
+ if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # inner node
909
946
if node .children [x [i - node .depth - 1 ]] is None :
910
947
node .children [x [i - node .depth - 1 ]] = _Node (
911
948
node .depth + 1 ,
@@ -919,7 +956,7 @@ def _calc_pred_dist_recursion(self,node:_Node,x,i):
919
956
tmp1 = self ._calc_pred_dist_recursion (node .children [x [i - node .depth - 1 ]],x ,i )
920
957
tmp2 = (1 - node .h_g ) * self ._calc_pred_dist_leaf (node ) + node .h_g * tmp1
921
958
return tmp2
922
- else : # 葉ノード
959
+ else : # leaf node
923
960
return self ._calc_pred_dist_leaf (node )
924
961
925
962
def calc_pred_dist (self ,x ):
@@ -936,11 +973,9 @@ def calc_pred_dist(self,x):
936
973
i = x .shape [0 ] - 1
937
974
938
975
if self .hn_root is None :
939
- self .hn_root = _Node (0 ,self .c_k )
940
- self .hn_root .h_g = self .hn_g
941
- self .hn_root .h_beta_vec [:] = self .hn_beta_vec
942
-
943
- self .p_theta_vec [:] = self ._calc_pred_dist_recursion (self .hn_root ,x ,i )
976
+ self .p_theta_vec [:] = self .hn_beta_vec / self .hn_beta_vec .sum ()
977
+ else :
978
+ self .p_theta_vec [:] = self ._calc_pred_dist_recursion (self .hn_root ,x ,i )
944
979
945
980
def make_prediction (self ,loss = "KL" ):
946
981
"""Predict a new data point under the given criterion.
@@ -973,7 +1008,7 @@ def _pred_and_update_leaf(self,node:_Node,x,i):
973
1008
return tmp
974
1009
975
1010
def _pred_and_update_recursion (self ,node :_Node ,x ,i ):
976
- if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # 内部ノード
1011
+ if node .depth < self .c_d_max and i - 1 - node .depth >= 0 : # inner node
977
1012
if node .children [x [i - node .depth - 1 ]] is None :
978
1013
node .children [x [i - node .depth - 1 ]] = _Node (
979
1014
node .depth + 1 ,
@@ -988,7 +1023,7 @@ def _pred_and_update_recursion(self,node:_Node,x,i):
988
1023
tmp2 = (1 - node .h_g ) * self ._pred_and_update_leaf (node ,x ,i ) + node .h_g * tmp1
989
1024
node .h_g = node .h_g * tmp1 [x [i ]] / tmp2 [x [i ]]
990
1025
return tmp2
991
- else : # 葉ノード
1026
+ else : # leaf node
992
1027
return self ._pred_and_update_leaf (node ,x ,i )
993
1028
994
1029
def pred_and_update (self ,x ,loss = "KL" ):
0 commit comments