Skip to content

Commit 5f762db

Browse files
committed
Add update and visualize posterior
1 parent 6bd27fa commit 5f762db

File tree

3 files changed

+147
-122
lines changed

3 files changed

+147
-122
lines changed

bayesml/contexttree/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._contexttree import GenModel
2-
# from ._contexttree import LearnModel
2+
from ._contexttree import LearnModel
33

4-
__all__ = ["GenModel"]#, "LearnModel"]
4+
__all__ = ["GenModel", "LearnModel"]

bayesml/contexttree/_contexttree.py

Lines changed: 131 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ def _gen_sample_recursion(self,node,x):
186186
return self._gen_sample_recursion(node.children[x[-node.depth-1]],x)
187187

188188
def _visualize_model_recursion(self,tree_graph,node,node_id,parent_id,sibling_num,p_v):
189-
"""Visualize the stochastic data generative model and generated samples.
190-
191-
"""
189+
"""Visualize the stochastic data generative model and generated samples."""
192190
tmp_id = node_id
193191
tmp_p_v = p_v
194192

@@ -348,7 +346,7 @@ def gen_sample(self,sample_length,initial_values=None):
348346
x[:self.c_d_max] = initial_values
349347

350348
for i in range(self.c_d_max,sample_length+self.c_d_max):
351-
x[i] = self._gen_sample_recursion(self.root,x[i-self.c_d_max:i])
349+
x[i] = self._gen_sample_recursion(self.root,x[:i])
352350

353351
return x[self.c_d_max:]
354352

@@ -434,16 +432,18 @@ class _LearnNode():
434432
"""
435433
def __init__(self,
436434
depth,
437-
c_k=2,
438-
h0_g=0.5,
439-
hn_g=0.5,
435+
c_k,
436+
h0_g,
437+
hn_g,
438+
h0_beta_vec,
439+
hn_beta_vec,
440440
):
441441
self.depth = depth
442442
self.children = [None for i in range(c_k)] # child nodes
443443
self.h0_g = h0_g
444444
self.hn_g = hn_g
445-
self.h0_beta_vec = np.ones(c_k) / 2
446-
self.hn_beta_vec = np.ones(c_k) / 2
445+
self.h0_beta_vec = np.copy(h0_beta_vec)
446+
self.hn_beta_vec = np.copy(hn_beta_vec)
447447
self.leaf = False
448448
self.map_leaf = False
449449

@@ -489,8 +489,6 @@ def __init__(
489489
self.c_k = _check.pos_int(c_k,'c_k',ParameterFormatError)
490490

491491
# h0_params
492-
if h0_g is not None:
493-
_check.float_in_closed01(h0_g,'h0_g',ParameterFormatError)
494492
self.h0_g = h0_g
495493
self.h0_beta_vec = np.ones(self.c_k) / 2
496494
self.h0_root = None
@@ -528,7 +526,14 @@ def _set_recursion(self,node:_LearnNode,original_tree_node:_LearnNode):
528526
else:
529527
node.leaf = False
530528
for i in range(self.c_k):
531-
node.children[i] = _GenNode(node.depth+1,self.c_k,self.h0_g,self.hn_g)
529+
node.children[i] = _LearnNode(
530+
node.depth+1,
531+
self.c_k,
532+
self.h0_g,
533+
self.hn_g,
534+
self.h0_beta_vec,
535+
self.hn_beta_vec,
536+
)
532537
self._set_recursion(node.children[i],original_tree_node.children[i])
533538

534539
def set_h0_params(self,
@@ -556,7 +561,6 @@ def set_h0_params(self,
556561
_check.pos_floats(h0_beta_vec,'h0_beta_vec',ParameterFormatError)
557562
self.h0_beta_vec[:] = h0_beta_vec
558563

559-
self.h0_root = _LearnNode(0,self.c_k,self.h0_g,self.hn_g)
560564
if h0_root is not None:
561565
if type(h0_root) is not _LearnNode:
562566
raise(ParameterFormatError(
@@ -605,7 +609,6 @@ def set_hn_params(self,
605609
_check.pos_floats(hn_beta_vec,'hn_beta_vec',ParameterFormatError)
606610
self.hn_beta_vec[:] = hn_beta_vec
607611

608-
self.hn_root = _LearnNode(0,self.c_k,self.hn_g,self.hn_g)
609612
if hn_root is not None:
610613
if type(hn_root) is not _LearnNode:
611614
raise(ParameterFormatError(
@@ -629,114 +632,125 @@ def get_hn_params(self):
629632
"hn_beta_vec":self.hn_beta_vec,
630633
"hn_root":self.hn_root}
631634

632-
def _update_posterior_leaf(self,node,x,y):
633-
try:
634-
node.sub_model.calc_pred_dist(x)
635-
except:
636-
node.sub_model.calc_pred_dist()
637-
pred_dist = node.sub_model.make_prediction(loss='KL') # Futurework: direct method to get marginal likelihood is better
638-
639-
try:
640-
node.sub_model.update_posterior(x,y)
641-
except:
642-
node.sub_model.update_posterior(y)
643-
644-
if type(pred_dist) is np.ndarray:
645-
return pred_dist[y]
646-
try:
647-
return pred_dist.pdf(y)
648-
except:
649-
return pred_dist.pmf(y)
650-
651-
def _update_posterior_recursion(self,node,x,y):
652-
if node.leaf == False: # 内部ノード
653-
tmp1 = self._update_posterior_recursion(node.children[x[node.k]],x,y)
654-
tmp2 = (1 - node.hn_g) * self._update_posterior_leaf(node,x,y) + node.hn_g * tmp1
635+
def _update_posterior_leaf(self,node:_LearnNode,x,i):
636+
tmp = node.hn_beta_vec[x[i]] / node.hn_beta_vec.sum()
637+
node.hn_beta_vec[x[i]] += 1
638+
return tmp
639+
640+
def _update_posterior_recursion(self,node:_LearnNode,x,i):
641+
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
642+
if node.children[x[i-node.depth-1]] is None:
643+
node.children[x[i-node.depth-1]] = _LearnNode(
644+
node.depth+1,
645+
self.c_k,
646+
self.h0_g,
647+
self.hn_g,
648+
self.h0_beta_vec,
649+
self.hn_beta_vec,
650+
)
651+
if node.depth + 1 == self.c_d_max:
652+
node.children[x[i-node.depth-1]].h0_g = 0.0
653+
node.children[x[i-node.depth-1]].hn_g = 0.0
654+
node.children[x[i-node.depth-1]].leaf = True
655+
tmp1 = self._update_posterior_recursion(node.children[x[i-node.depth-1]],x,i)
656+
tmp2 = (1 - node.hn_g) * self._update_posterior_leaf(node,x,i) + node.hn_g * tmp1
655657
node.hn_g = node.hn_g * tmp1 / tmp2
656658
return tmp2
657659
else: # 葉ノード
658-
return self._update_posterior_leaf(node,x,y)
660+
return self._update_posterior_leaf(node,x,i)
659661

660-
def update_posterior(self,x,y,alg_type='MTRF',**kwargs):
661-
"""Update the hyperparameters of the posterior distribution using traning data.
662+
def update_posterior(self,x):
663+
"""Update the hyperparameters using traning data.
662664
663665
Parameters
664666
----------
665667
x : numpy ndarray
666-
values of explanatory variables whose dtype is int
667-
y : numpy ndarray
668-
values of objective variable whose dtype may be int or float
669-
alg_type : {'MTRF', 'given_MT'}, optional
670-
type of algorithm, by default 'MTRF'
671-
**kwargs : dict, optional
672-
optional parameters of algorithms, by default {}
668+
1-dimensional int array
673669
"""
674-
_check.nonneg_int_vecs(x,'x',DataFormatError)
675-
if x.shape[-1] != self.c_k:
676-
raise(DataFormatError(f"x.shape[-1] must equal to c_k:{self.c_k}"))
670+
_check.nonneg_ints(x,'x',DataFormatError)
677671
if x.max() >= self.c_k:
678672
raise(DataFormatError(f"x.max() must smaller than c_k:{self.c_k}"))
679-
680-
if type(y) is np.ndarray:
681-
if x.shape[:-1] != y.shape:
682-
raise(DataFormatError(f"x.shape[:-1] and y.shape must be same."))
683-
elif x.shape[:-1] != ():
684-
raise(DataFormatError(f"If y is a scaler, x.shape[:-1] must be the empty tuple ()."))
685-
686-
x = x.reshape(-1,self.c_k)
687-
y = np.ravel(y)
688-
689-
if alg_type == 'MTRF':
690-
self.hn_metatree_list, self.hn_metatree_prob_vec = self._MTRF(x,y,**kwargs)
691-
elif alg_type == 'given_MT':
692-
self.hn_metatree_list, self.hn_metatree_prob_vec = self._given_MT(x,y)
693-
694-
def _map_recursion_add_nodes(self,node):
673+
x = np.ravel(x)
674+
675+
if self.hn_root is None:
676+
self.hn_root = _LearnNode(
677+
0,
678+
self.c_k,
679+
self.hn_g,
680+
self.hn_g,
681+
self.h0_beta_vec,
682+
self.hn_beta_vec,
683+
)
684+
685+
for i in range(x.shape[0]):
686+
self._update_posterior_recursion(self.hn_root,x,i)
687+
688+
def _map_recursion_add_nodes(self,node:_LearnNode):
695689
if node.depth == self.c_d_max or node.depth == self.c_k: # 葉ノード
690+
node.h0_g = 0.0
696691
node.hn_g = 0.0
697692
node.leaf = True
698693
node.map_leaf = True
699694
else: # 内部ノード
700695
for i in range(self.c_k):
701696
node.children[i] = _LearnNode(depth=node.depth+1,
702697
c_k=self.c_k,
703-
hn_g=self.h0_g,
704-
k=None,
705-
sub_model=self.SubModel(**self.sub_h0_params))
698+
h0_g=self.h0_g,
699+
hn_g=self.hn_g,
700+
h0_beta_vec=self.h0_beta_vec,
701+
hn_beta_vec=self.hn_beta_vec,
702+
)
706703
self._map_recursion_add_nodes(node.children[i])
707704

708-
def _map_recursion(self,node):
709-
if node.leaf:
710-
if node.depth == self.c_d_max or node.depth == self.c_k:
711-
node.map_leaf = True
712-
return 1.0
713-
elif 1.0 - node.hn_g > node.hn_g * self.h0_g ** (self.c_k ** (self.c_d_max - node.depth)-2):
714-
node.map_leaf = True
715-
return 1.0 - node.hn_g
716-
else:
717-
self._map_recursion_add_nodes(node)
718-
return node.hn_g * self.h0_g ** (self.c_k ** (self.c_d_max - node.depth)-2)
705+
def _map_recursion(self,node:_LearnNode):
706+
if node.depth == self.c_d_max:
707+
node.map_leaf = True
708+
return 1.0
719709
else:
720710
tmp1 = 1.0-node.hn_g
721711
tmp_vec = np.empty(self.c_k)
722712
for i in range(self.c_k):
723-
tmp_vec[i] = self._map_recursion(node.children[i])
713+
if node.children[i] is not None:
714+
tmp_vec[i] = self._map_recursion(node.children[i])
715+
else:
716+
node.children[i] = _LearnNode(
717+
node.depth+1,
718+
self.c_k,
719+
self.h0_g,
720+
self.hn_g,
721+
self.h0_beta_vec,
722+
self.hn_beta_vec,
723+
)
724+
if 1.0 - node.h0_g > self.h0_g ** ((self.c_k ** (self.c_d_max - node.depth - 1) - 1)/(self.c_k-1)):
725+
node.children[i].map_leaf = True
726+
tmp_vec[i] = 1.0 - node.hn_g
727+
else:
728+
self._map_recursion_add_nodes(node.children[i])
729+
tmp_vec[i] = self.h0_g ** ((self.c_k ** (self.c_d_max - node.depth) - 1)/(self.c_k-1))
724730
if tmp1 > node.hn_g*tmp_vec.prod():
725731
node.map_leaf = True
726732
return tmp1
727733
else:
728734
node.map_leaf = False
729735
return node.hn_g*tmp_vec.prod()
730736

731-
def _copy_map_tree_recursion(self,copyed_node,original_node):
737+
def _copy_map_tree_recursion(self,copyed_node:_LearnNode,original_node:_LearnNode):
738+
copyed_node.h0_g = original_node.h0_g
732739
copyed_node.hn_g = original_node.hn_g
740+
copyed_node.h0_beta_vec[:] = original_node.h0_beta_vec
741+
copyed_node.hn_beta_vec[:] = original_node.hn_beta_vec
733742
if original_node.map_leaf == False:
734-
copyed_node.k = original_node.k
735743
for i in range(self.c_k):
736-
copyed_node.children[i] = _LearnNode(copyed_node.depth+1,self.c_k)
744+
copyed_node.children[i] = _LearnNode(
745+
copyed_node.depth+1,
746+
self.c_k,
747+
self.h0_g,
748+
self.hn_g,
749+
self.h0_beta_vec,
750+
self.hn_beta_vec,
751+
)
737752
self._copy_map_tree_recursion(copyed_node.children[i],original_node.children[i])
738753
else:
739-
copyed_node.sub_model = copy.deepcopy(original_node.sub_model)
740754
copyed_node.leaf = True
741755

742756
def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
@@ -768,15 +782,25 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
768782
"""
769783

770784
if loss == "0-1":
771-
map_index = 0
772-
map_prob = 0.0
773-
for i,metatree in enumerate(self.hn_metatree_list):
774-
prob = self.hn_metatree_prob_vec[i] * self._map_recursion(metatree)
775-
if prob > map_prob:
776-
map_index = i
777-
map_prob = prob
778-
map_root = _LearnNode(0,self.c_k)
779-
self._copy_map_tree_recursion(map_root,self.hn_metatree_list[map_index])
785+
if self.hn_root is None:
786+
self.hn_root = _LearnNode(
787+
0,
788+
self.c_k,
789+
self.hn_g,
790+
self.hn_g,
791+
self.h0_beta_vec,
792+
self.hn_beta_vec,
793+
)
794+
self._map_recursion(self.hn_root)
795+
map_root = _LearnNode(
796+
0,
797+
self.c_k,
798+
self.h0_g,
799+
self.hn_g,
800+
self.h0_beta_vec,
801+
self.hn_beta_vec,
802+
)
803+
self._copy_map_tree_recursion(map_root,self.hn_root)
780804
if visualize:
781805
import graphviz
782806
tree_graph = graphviz.Digraph(filename=filename,format=format)
@@ -788,27 +812,20 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
788812
raise(CriteriaError("Unsupported loss function! "
789813
+"This function supports only \"0-1\"."))
790814

791-
def _visualize_model_recursion(self,tree_graph,node,node_id,parent_id,sibling_num,p_v):
815+
def _visualize_model_recursion(self,tree_graph,node:_LearnNode,node_id,parent_id,sibling_num,p_v):
792816
tmp_id = node_id
793817
tmp_p_v = p_v
794818

795819
# add node information
796-
label_string = f'k={node.k}\\lhn_g={node.hn_g:.2f}\\lp_v={tmp_p_v:.2f}\\lsub_params={{'
797-
if node.sub_model is not None:
798-
try:
799-
sub_params = node.sub_model.estimate_params(loss='0-1',dict_out=True)
800-
except:
801-
sub_params = node.sub_model.estimate_params(dict_out=True)
820+
label_string = f'hn_g={node.hn_g:.2f}\\lp_v={tmp_p_v:.2f}\\ltheta_vec='
821+
label_string += '['
822+
for i in range(self.c_k):
823+
theta_vec_hat = node.hn_beta_vec / node.hn_beta_vec.sum()
824+
label_string += f'{theta_vec_hat[i]:.2f}'
825+
if i < self.c_k-1:
826+
label_string += ','
827+
label_string += ']'
802828

803-
for key,value in sub_params.items():
804-
try:
805-
label_string += f'\\l{key}:{value:.2f}'
806-
except:
807-
label_string += f'\\l{key}:{value}'
808-
label_string += '}'
809-
else:
810-
label_string += '\\lNone}'
811-
812829
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
813830
if tmp_p_v > 0.65:
814831
tree_graph.node(name=f'{tmp_id}',fontcolor='white')
@@ -817,8 +834,8 @@ def _visualize_model_recursion(self,tree_graph,node,node_id,parent_id,sibling_nu
817834
if parent_id is not None:
818835
tree_graph.edge(f'{parent_id}', f'{tmp_id}', label=f'{sibling_num}')
819836

820-
if node.leaf != True:
821-
for i in range(self.c_k):
837+
for i in range(self.c_k):
838+
if node.children[i] is not None:
822839
node_id = self._visualize_model_recursion(tree_graph,node.children[i],node_id+1,tmp_id,i,tmp_p_v*node.hn_g)
823840

824841
return node_id
@@ -850,13 +867,11 @@ def visualize_posterior(self,filename=None,format=None):
850867
--------
851868
graphbiz.Digraph
852869
"""
853-
MAP_index = np.argmax(self.hn_metatree_prob_vec)
854-
print(f'MAP probability of metatree:{self.hn_metatree_prob_vec[MAP_index]}')
855870
try:
856871
import graphviz
857872
tree_graph = graphviz.Digraph(filename=filename,format=format)
858873
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
859-
self._visualize_model_recursion(tree_graph, self.hn_metatree_list[MAP_index], 0, None, None, 1.0)
874+
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0)
860875
# コンソール上で表示できるようにした方がいいかもしれない.
861876
tree_graph.view()
862877
except ImportError as e:
@@ -896,6 +911,7 @@ def calc_pred_dist(self,x):
896911
x : numpy ndarray
897912
values of explanatory variables whose dtype is int
898913
"""
914+
return
899915
_check.nonneg_int_vec(x,'x',DataFormatError)
900916
if x.shape[0] != self.c_k:
901917
raise(DataFormatError(f"x.shape[0] must equal to c_k:{self.c_k}"))

0 commit comments

Comments
 (0)