Skip to content

Commit 868a592

Browse files
Merge pull request #59 from yuta-nakahara/release-0.2.2
Release 0.2.2
2 parents 56ccd9d + 9e22b0a commit 868a592

34 files changed

+148
-119
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ Plain text
137137

138138
```
139139
Y. Nakahara, N. Ichijo, K. Shimada, Y. Iikubo,
140-
S. Saito, K. Kazama, T. Matsushima, BayesML Developers, ``BayesML 0.2.1,''
140+
S. Saito, K. Kazama, T. Matsushima, BayesML Developers, ``BayesML 0.2.2,''
141141
[Online] https://github.com/yuta-nakahara/BayesML
142142
```
143143

@@ -148,7 +148,7 @@ BibTeX
148148
author = {Nakahara Yuta and Ichijo Naoki and Shimada Koshi and
149149
Iikubo Yuji and Saito Shota and Kazama Koki and
150150
Matsushima Toshiyasu and {BayesML Developers}},
151-
title = {BayesML 0.2.1},
151+
title = {BayesML 0.2.2},
152152
howpublished = {\url{https://github.com/yuta-nakahara/BayesML}},
153153
year = {2022}
154154
}

README_jp.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ BayesMLへのコントリビューションを考えてくださってありが
134134

135135
```
136136
Y. Nakahara, N. Ichijo, K. Shimada, Y. Iikubo,
137-
S. Saito, K. Kazama, T. Matsushima, BayesML Developers, ``BayesML 0.2.1,''
137+
S. Saito, K. Kazama, T. Matsushima, BayesML Developers, ``BayesML 0.2.2,''
138138
[Online] https://github.com/yuta-nakahara/BayesML
139139
```
140140

@@ -145,7 +145,7 @@ BibTeX
145145
author = {Nakahara Yuta and Ichijo Naoki and Shimada Koshi and
146146
Iikubo Yuji and Saito Shota and Kazama Koki and
147147
Matsushima Toshiyasu and {BayesML Developers}},
148-
title = {BayesML 0.2.1},
148+
title = {BayesML 0.2.2},
149149
howpublished = {\url{https://github.com/yuta-nakahara/BayesML}},
150150
year = {2022}
151151
}

bayesml/contexttree/_contexttree.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# Document Author
44
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
55
import warnings
6-
import copy
7-
import pickle
86
import numpy as np
97
import matplotlib.pyplot as plt
108
from matplotlib.colors import rgb2hex
@@ -202,19 +200,19 @@ def _gen_sample_recursion(self,node,x):
202200
else:
203201
return self._gen_sample_recursion(node.children[x[-node.depth-1]],x)
204202

205-
def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibling_num,p_v):
203+
def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibling_num,p_s):
206204
tmp_id = node_id
207-
tmp_p_v = p_v
205+
tmp_p_s = p_s
208206

209207
# add node information
210-
label_string = f'h_g={node.h_g:.2f}\\lp_v={tmp_p_v:.2f}\\ltheta_vec\\l='
208+
label_string = f'h_g={node.h_g:.2f}\\lp_s={tmp_p_s:.2f}\\ltheta_vec\\l='
211209
if node.leaf:
212210
label_string += f'{np.array2string(node.theta_vec,precision=2,max_line_width=11)}\\l'
213211
else:
214212
label_string += 'None\\l'
215213

216-
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
217-
if tmp_p_v > 0.65:
214+
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_s))}')
215+
if tmp_p_s > 0.65:
218216
tree_graph.node(name=f'{tmp_id}',fontcolor='white')
219217

220218
# add edge information
@@ -223,7 +221,7 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
223221

224222
if node.leaf != True:
225223
for i in range(self.c_k):
226-
node_id = self._visualize_model_recursion(tree_graph,node.children[i],node_id+1,tmp_id,i,tmp_p_v*node.h_g)
224+
node_id = self._visualize_model_recursion(tree_graph,node.children[i],node_id+1,tmp_id,i,tmp_p_s*node.h_g)
227225

228226
return node_id
229227

@@ -307,7 +305,7 @@ def set_params(self,root=None):
307305
if root is not None:
308306
if type(root) is not _Node:
309307
raise(ParameterFormatError(
310-
"root must be an instance of metatree._Node"
308+
"root must be an instance of contexttree._Node"
311309
))
312310
self._set_params_recursion(self.root,root)
313311
return self
@@ -794,7 +792,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
794792
Loss function underlying the Bayes risk function, by default ``\"0-1\"``.
795793
This function supports only ``\"0-1\"``.
796794
visualize : bool, optional
797-
If ``True``, the estimated metatree will be visualized, by default ``True``.
795+
If ``True``, the estimated context tree model will be visualized, by default ``True``.
798796
This visualization requires ``graphviz``.
799797
filename : str, optional
800798
Filename for saving the figure, by default ``None``
@@ -804,8 +802,8 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
804802
805803
Returns
806804
-------
807-
map_root : metatree._Node
808-
The root node of the estimated meta-tree
805+
map_root : contexttree._Node
806+
The root node of the estimated context tree model
809807
that also contains the estimated parameters in each node.
810808
811809
See Also
@@ -824,32 +822,34 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
824822
import graphviz
825823
tree_graph = graphviz.Digraph(filename=filename,format=format)
826824
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
827-
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0, True)
825+
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0, True, False)
828826
# Can we show the image on the console without saving the file?
829827
tree_graph.view()
830828
return map_root
831829
else:
832830
raise(CriteriaError("Unsupported loss function! "
833831
+"This function supports only \"0-1\"."))
834832

835-
def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibling_num,p_v,map_tree):
833+
def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibling_num,p_s,map_tree,h_params):
836834
tmp_id = node_id
837-
tmp_p_v = p_v
835+
tmp_p_s = p_s
838836

839837
# add node information
840-
label_string = f'hn_g={node.h_g:.2f}\\lp_v={tmp_p_v:.2f}\\ltheta_vec\\l='
838+
label_string = f'hn_g={node.h_g:.2f}\\lp_s={tmp_p_s:.2f}\\l'
841839
if map_tree and not node.leaf:
842-
label_string += 'None\\l'
840+
label_string += 'theta_vec\\l=None\\l'
843841
else:
844-
if np.all(node.h_beta_vec > 1):
842+
if h_params:
843+
label_string += f'hn_beta_vec\\l={np.array2string(node.h_beta_vec,precision=2,max_line_width=11)}\\l'
844+
elif np.all(node.h_beta_vec > 1):
845845
theta_vec_hat = (node.h_beta_vec - 1) / (np.sum(node.h_beta_vec) - self.c_k)
846-
label_string += f'{np.array2string(theta_vec_hat,precision=2,max_line_width=11)}\\l'
846+
label_string += f'theta_vec\\l={np.array2string(theta_vec_hat,precision=2,max_line_width=11)}\\l'
847847
else:
848-
warnings.warn("MAP estimate of theta_vec doesn't exist for the current h_beta_vec.",ResultWarning)
849-
label_string += 'None\\l'
848+
warnings.warn("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec.",ResultWarning)
849+
label_string += 'theta_vec\\l=None\\l'
850850

851-
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
852-
if tmp_p_v > 0.65:
851+
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_s))}')
852+
if tmp_p_s > 0.65:
853853
tree_graph.node(name=f'{tmp_id}',fontcolor='white')
854854

855855
# add edge information
@@ -858,29 +858,31 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
858858

859859
for i in range(self.c_k):
860860
if node.children[i] is not None:
861-
node_id = self._visualize_model_recursion(tree_graph,node.children[i],node_id+1,tmp_id,i,tmp_p_v*node.h_g,map_tree)
861+
node_id = self._visualize_model_recursion(tree_graph,node.children[i],node_id+1,tmp_id,i,tmp_p_s*node.h_g,map_tree,h_params)
862862

863863
return node_id
864864

865-
def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibling_num,p_v):
865+
def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibling_num,p_s,h_params):
866866
tmp_id = node_id
867-
tmp_p_v = p_v
867+
tmp_p_s = p_s
868868

869869
# add node information
870870
if depth == self.c_d_max:
871871
label_string = 'hn_g=0.0\\l'
872872
else:
873873
label_string = f'hn_g={self.hn_g:.2f}\\l'
874-
label_string += f'p_v={tmp_p_v:.2f}\\ltheta_vec\\l='
875-
if np.all(self.hn_beta_vec > 1):
874+
label_string += f'p_s={tmp_p_s:.2f}\\l'
875+
if h_params:
876+
label_string += f'hn_beta_vec\\l={np.array2string(self.hn_beta_vec,precision=2,max_line_width=11)}\\l'
877+
elif np.all(self.hn_beta_vec > 1):
876878
theta_vec_hat = (self.hn_beta_vec - 1) / (np.sum(self.hn_beta_vec) - self.c_k)
877-
label_string += f'{np.array2string(theta_vec_hat,precision=2,max_line_width=11)}\\l'
879+
label_string += f'theta_vec\\l={np.array2string(theta_vec_hat,precision=2,max_line_width=11)}\\l'
878880
else:
879-
warnings.warn("MAP estimate of theta_vec doesn't exist for the current h_beta_vec.",ResultWarning)
880-
label_string += 'None\\l'
881+
warnings.warn("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec.",ResultWarning)
882+
label_string += 'theta_vec\\l=None\\l'
881883

882-
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
883-
if tmp_p_v > 0.65:
884+
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_s))}')
885+
if tmp_p_s > 0.65:
884886
tree_graph.node(name=f'{tmp_id}',fontcolor='white')
885887

886888
# add edge information
@@ -889,11 +891,11 @@ def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibl
889891

890892
if depth < self.c_d_max:
891893
for i in range(self.c_k):
892-
node_id = self._visualize_model_recursion_none(tree_graph,depth+1,node_id+1,tmp_id,i,tmp_p_v*self.hn_g)
894+
node_id = self._visualize_model_recursion_none(tree_graph,depth+1,node_id+1,tmp_id,i,tmp_p_s*self.hn_g,h_params)
893895

894896
return node_id
895897

896-
def visualize_posterior(self,filename=None,format=None):
898+
def visualize_posterior(self,filename=None,format=None,h_params=False):
897899
"""Visualize the posterior distribution for the parameter.
898900
899901
This method requires ``graphviz``.
@@ -904,13 +906,16 @@ def visualize_posterior(self,filename=None,format=None):
904906
Filename for saving the figure, by default ``None``
905907
format : str, optional
906908
Rendering output format (``\"pdf\"``, ``\"png\"``, ...).
909+
h_params : bool, optional
910+
If ``True``, hyperparameters at each node will be visualized.
911+
if ``False``, estimated parameters at each node will be visulaized.
907912
908913
Examples
909914
--------
910915
>>> from bayesml import contexttree
911916
>>> gen_model = contexttree.GenModel(c_k=2,c_d_max=3,h_g=0.75)
912917
>>> gen_model.gen_params()
913-
>>> x = gen_model.gen_sample(50)
918+
>>> x = gen_model.gen_sample(500)
914919
>>> learn_model = contexttree.LearnModel(c_k=2,c_d_max=3,h0_g=0.75)
915920
>>> learn_model.update_posterior(x)
916921
>>> learn_model.visualize_posterior()
@@ -926,9 +931,9 @@ def visualize_posterior(self,filename=None,format=None):
926931
tree_graph = graphviz.Digraph(filename=filename,format=format)
927932
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
928933
if self.hn_root is None:
929-
self._visualize_model_recursion_none(tree_graph, 0, 0, None, None, 1.0, False)
934+
self._visualize_model_recursion_none(tree_graph, 0, 0, None, None, 1.0, h_params)
930935
else:
931-
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0, False)
936+
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0, False, h_params)
932937
# Can we show the image on the console without saving the file?
933938
tree_graph.view()
934939
except ImportError as e:

bayesml/metatree/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* :math:`\mathcal{S}` : a set of :math:`s`
1515
* :math:`\mathcal{I}(T)` : a set of inner nodes of :math:`T`
1616
* :math:`\mathcal{L}(T)` : a set of leaf nodes of :math:`T`
17-
* :math:`\boldsymbol{k}=(k_s)_{s \in \mathcal{S}}` : feature assignmet vector where :math:`k_s \in \{1, 2,\ldots,p+q\}`. If :math:`k_s \leq p`, the node :math:`s` has a threshold.
17+
* :math:`\boldsymbol{k}=(k_s)_{s \in \mathcal{S}}` : feature assignment vector where :math:`k_s \in \{1, 2,\ldots,p+q\}`. If :math:`k_s \leq p`, the node :math:`s` has a threshold.
1818
* :math:`\boldsymbol{\theta}=(\theta_s)_{s \in \mathcal{S}}` : a set of parameter
1919
* :math:`s(\boldsymbol{x}) \in \mathcal{L}(T)` : a leaf node of :math:`T` corresponding to :math:`\boldsymbol{x}`, which is determined according to :math:`\boldsymbol{k}` and the thresholds.
2020
@@ -80,7 +80,7 @@
8080
.. math::
8181
p(\boldsymbol{k}_b | \boldsymbol{x}^n, y^n)\propto \prod_{i=1}^n \tilde{q}_{s_{\lambda}}(y_{i}|\boldsymbol{x}_{i},\boldsymbol{x}^{i-1}, y^{i-1}, M_{T_b, \boldsymbol{k}_b}),
8282
83-
where :math:`s_{\lambda}` is the root node of :math:`M_{T, \boldsymbol{k}_b}`.
83+
where :math:`s_{\lambda}` is the root node of :math:`M_{T_b, \boldsymbol{k}_b}`.
8484
8585
The predictive distribution is as follows:
8686

0 commit comments

Comments
 (0)