3
3
# Document Author
4
4
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
5
5
import warnings
6
- import copy
7
- import pickle
8
6
import numpy as np
9
7
import matplotlib .pyplot as plt
10
8
from matplotlib .colors import rgb2hex
@@ -202,19 +200,19 @@ def _gen_sample_recursion(self,node,x):
202
200
else :
203
201
return self ._gen_sample_recursion (node .children [x [- node .depth - 1 ]],x )
204
202
205
- def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_v ):
203
+ def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_s ):
206
204
tmp_id = node_id
207
- tmp_p_v = p_v
205
+ tmp_p_s = p_s
208
206
209
207
# add node information
210
- label_string = f'h_g={ node .h_g :.2f} \\ lp_v= { tmp_p_v :.2f} \\ ltheta_vec\\ l='
208
+ label_string = f'h_g={ node .h_g :.2f} \\ lp_s= { tmp_p_s :.2f} \\ ltheta_vec\\ l='
211
209
if node .leaf :
212
210
label_string += f'{ np .array2string (node .theta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
213
211
else :
214
212
label_string += 'None\\ l'
215
213
216
- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
217
- if tmp_p_v > 0.65 :
214
+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
215
+ if tmp_p_s > 0.65 :
218
216
tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
219
217
220
218
# add edge information
@@ -223,7 +221,7 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
223
221
224
222
if node .leaf != True :
225
223
for i in range (self .c_k ):
226
- node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_v * node .h_g )
224
+ node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_s * node .h_g )
227
225
228
226
return node_id
229
227
@@ -307,7 +305,7 @@ def set_params(self,root=None):
307
305
if root is not None :
308
306
if type (root ) is not _Node :
309
307
raise (ParameterFormatError (
310
- "root must be an instance of metatree ._Node"
308
+ "root must be an instance of contexttree ._Node"
311
309
))
312
310
self ._set_params_recursion (self .root ,root )
313
311
return self
@@ -794,7 +792,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
794
792
Loss function underlying the Bayes risk function, by default ``\" 0-1\" ``.
795
793
This function supports only ``\" 0-1\" ``.
796
794
visualize : bool, optional
797
- If ``True``, the estimated metatree will be visualized, by default ``True``.
795
+ If ``True``, the estimated context tree model will be visualized, by default ``True``.
798
796
This visualization requires ``graphviz``.
799
797
filename : str, optional
800
798
Filename for saving the figure, by default ``None``
@@ -804,8 +802,8 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
804
802
805
803
Returns
806
804
-------
807
- map_root : metatree ._Node
808
- The root node of the estimated meta- tree
805
+ map_root : contexttree ._Node
806
+ The root node of the estimated context tree model
809
807
that also contains the estimated parameters in each node.
810
808
811
809
See Also
@@ -824,32 +822,34 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
824
822
import graphviz
825
823
tree_graph = graphviz .Digraph (filename = filename ,format = format )
826
824
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
827
- self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 , True )
825
+ self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 , True , False )
828
826
# Can we show the image on the console without saving the file?
829
827
tree_graph .view ()
830
828
return map_root
831
829
else :
832
830
raise (CriteriaError ("Unsupported loss function! "
833
831
+ "This function supports only \" 0-1\" ." ))
834
832
835
- def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_v ,map_tree ):
833
+ def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_s ,map_tree , h_params ):
836
834
tmp_id = node_id
837
- tmp_p_v = p_v
835
+ tmp_p_s = p_s
838
836
839
837
# add node information
840
- label_string = f'hn_g={ node .h_g :.2f} \\ lp_v= { tmp_p_v :.2f} \\ ltheta_vec \\ l= '
838
+ label_string = f'hn_g={ node .h_g :.2f} \\ lp_s= { tmp_p_s :.2f} \\ l '
841
839
if map_tree and not node .leaf :
842
- label_string += 'None\\ l'
840
+ label_string += 'theta_vec \\ l= None\\ l'
843
841
else :
844
- if np .all (node .h_beta_vec > 1 ):
842
+ if h_params :
843
+ label_string += f'hn_beta_vec\\ l={ np .array2string (node .h_beta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
844
+ elif np .all (node .h_beta_vec > 1 ):
845
845
theta_vec_hat = (node .h_beta_vec - 1 ) / (np .sum (node .h_beta_vec ) - self .c_k )
846
- label_string += f'{ np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
846
+ label_string += f'theta_vec \\ l= { np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
847
847
else :
848
- warnings .warn ("MAP estimate of theta_vec doesn't exist for the current h_beta_vec ." ,ResultWarning )
849
- label_string += 'None\\ l'
848
+ warnings .warn ("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec ." ,ResultWarning )
849
+ label_string += 'theta_vec \\ l= None\\ l'
850
850
851
- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
852
- if tmp_p_v > 0.65 :
851
+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
852
+ if tmp_p_s > 0.65 :
853
853
tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
854
854
855
855
# add edge information
@@ -858,29 +858,31 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
858
858
859
859
for i in range (self .c_k ):
860
860
if node .children [i ] is not None :
861
- node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_v * node .h_g ,map_tree )
861
+ node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_s * node .h_g ,map_tree , h_params )
862
862
863
863
return node_id
864
864
865
- def _visualize_model_recursion_none (self ,tree_graph ,depth ,node_id ,parent_id ,sibling_num ,p_v ):
865
+ def _visualize_model_recursion_none (self ,tree_graph ,depth ,node_id ,parent_id ,sibling_num ,p_s , h_params ):
866
866
tmp_id = node_id
867
- tmp_p_v = p_v
867
+ tmp_p_s = p_s
868
868
869
869
# add node information
870
870
if depth == self .c_d_max :
871
871
label_string = 'hn_g=0.0\\ l'
872
872
else :
873
873
label_string = f'hn_g={ self .hn_g :.2f} \\ l'
874
- label_string += f'p_v={ tmp_p_v :.2f} \\ ltheta_vec\\ l='
875
- if np .all (self .hn_beta_vec > 1 ):
874
+ label_string += f'p_s={ tmp_p_s :.2f} \\ l'
875
+ if h_params :
876
+ label_string += f'hn_beta_vec\\ l={ np .array2string (self .hn_beta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
877
+ elif np .all (self .hn_beta_vec > 1 ):
876
878
theta_vec_hat = (self .hn_beta_vec - 1 ) / (np .sum (self .hn_beta_vec ) - self .c_k )
877
- label_string += f'{ np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
879
+ label_string += f'theta_vec \\ l= { np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
878
880
else :
879
- warnings .warn ("MAP estimate of theta_vec doesn't exist for the current h_beta_vec ." ,ResultWarning )
880
- label_string += 'None\\ l'
881
+ warnings .warn ("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec ." ,ResultWarning )
882
+ label_string += 'theta_vec \\ l= None\\ l'
881
883
882
- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
883
- if tmp_p_v > 0.65 :
884
+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
885
+ if tmp_p_s > 0.65 :
884
886
tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
885
887
886
888
# add edge information
@@ -889,11 +891,11 @@ def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibl
889
891
890
892
if depth < self .c_d_max :
891
893
for i in range (self .c_k ):
892
- node_id = self ._visualize_model_recursion_none (tree_graph ,depth + 1 ,node_id + 1 ,tmp_id ,i ,tmp_p_v * self .hn_g )
894
+ node_id = self ._visualize_model_recursion_none (tree_graph ,depth + 1 ,node_id + 1 ,tmp_id ,i ,tmp_p_s * self .hn_g , h_params )
893
895
894
896
return node_id
895
897
896
- def visualize_posterior (self ,filename = None ,format = None ):
898
+ def visualize_posterior (self ,filename = None ,format = None , h_params = False ):
897
899
"""Visualize the posterior distribution for the parameter.
898
900
899
901
This method requires ``graphviz``.
@@ -904,13 +906,16 @@ def visualize_posterior(self,filename=None,format=None):
904
906
Filename for saving the figure, by default ``None``
905
907
format : str, optional
906
908
Rendering output format (``\" pdf\" ``, ``\" png\" ``, ...).
909
+ h_params : bool, optional
910
+ If ``True``, hyperparameters at each node will be visualized.
911
+ if ``False``, estimated parameters at each node will be visulaized.
907
912
908
913
Examples
909
914
--------
910
915
>>> from bayesml import contexttree
911
916
>>> gen_model = contexttree.GenModel(c_k=2,c_d_max=3,h_g=0.75)
912
917
>>> gen_model.gen_params()
913
- >>> x = gen_model.gen_sample(50 )
918
+ >>> x = gen_model.gen_sample(500 )
914
919
>>> learn_model = contexttree.LearnModel(c_k=2,c_d_max=3,h0_g=0.75)
915
920
>>> learn_model.update_posterior(x)
916
921
>>> learn_model.visualize_posterior()
@@ -926,9 +931,9 @@ def visualize_posterior(self,filename=None,format=None):
926
931
tree_graph = graphviz .Digraph (filename = filename ,format = format )
927
932
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
928
933
if self .hn_root is None :
929
- self ._visualize_model_recursion_none (tree_graph , 0 , 0 , None , None , 1.0 , False )
934
+ self ._visualize_model_recursion_none (tree_graph , 0 , 0 , None , None , 1.0 , h_params )
930
935
else :
931
- self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 , False )
936
+ self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 , False , h_params )
932
937
# Can we show the image on the console without saving the file?
933
938
tree_graph .view ()
934
939
except ImportError as e :
0 commit comments