Skip to content

Commit 99cdd2e

Browse files
committed
Revise set_h_params and make_prediction
1 parent e7347ad commit 99cdd2e

File tree

2 files changed

+97
-25
lines changed

2 files changed

+97
-25
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
423423
if node.children[i] is None:
424424
node.children[i] = _Node(
425425
node.depth+1,
426-
self.c_k,
426+
self.c_num_children,
427427
sub_model=self.SubModel.GenModel(**self.sub_h_params),
428428
)
429429
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
@@ -818,6 +818,22 @@ def __init__(
818818
h0_metatree_prob_vec,
819819
)
820820

821+
def _set_h0_g_recursion(self,node:_Node):
822+
if node.depth == self.c_d_max:
823+
node.h_g = 0
824+
else:
825+
node.h_g = self.h0_g
826+
for i in range(self.c_num_children):
827+
if node.children[i] is not None:
828+
self._set_h0_g_recursion(node.children[i])
829+
830+
def _set_sub_h0_params_recursion(self,node:_Node):
831+
# node.sub_model.set_h0_params(**self.sub_h0_params)
832+
node.sub_model = self.SubModel.LearnModel(**self.sub_h0_params)
833+
for i in range(self.c_num_children):
834+
if node.children[i] is not None:
835+
self._set_sub_h0_params_recursion(node.children[i])
836+
821837
def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
822838
""" copy parameters from a fixed tree
823839
@@ -851,16 +867,36 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
851867
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
852868
node.leaf = True
853869
else:
870+
node.k = original_tree_node.k
871+
child_k_candidates = copy.copy(node.k_candidates)
872+
child_k_candidates.remove(node.k)
854873
node.leaf = False
855874
for i in range(self.c_num_children):
856875
if node.children[i] is None:
857876
node.children[i] = _Node(
858877
node.depth+1,
859-
self.c_k,
878+
self.c_num_children,
879+
child_k_candidates,
860880
sub_model=self.SubModel.LearnModel(**self.sub_h0_params),
861881
)
862882
self._set_h0_params_recursion(node.children[i],original_tree_node.children[i])
863883

884+
def _set_hn_g_recursion(self,node:_Node):
885+
if node.depth == self.c_d_max:
886+
node.h_g = 0
887+
else:
888+
node.h_g = self.hn_g
889+
for i in range(self.c_num_children):
890+
if node.children[i] is not None:
891+
self._set_hn_g_recursion(node.children[i])
892+
893+
def _set_sub_hn_params_recursion(self,node:_Node):
894+
# node.sub_model.set_hn_params(**self.sub_hn_params)
895+
node.sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
896+
for i in range(self.c_num_children):
897+
if node.children[i] is not None:
898+
self._set_sub_hn_params_recursion(node.children[i])
899+
864900
def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
865901
""" copy parameters from a fixed tree
866902
@@ -894,12 +930,16 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
894930
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
895931
node.leaf = True
896932
else:
933+
node.k = original_tree_node.k
934+
child_k_candidates = copy.copy(node.k_candidates)
935+
child_k_candidates.remove(node.k)
897936
node.leaf = False
898937
for i in range(self.c_num_children):
899938
if node.children[i] is None:
900939
node.children[i] = _Node(
901940
node.depth+1,
902-
self.c_k,
941+
self.c_num_children,
942+
child_k_candidates,
903943
sub_model=self.SubModel.LearnModel(**self.sub_hn_params),
904944
)
905945
self._set_hn_params_recursion(node.children[i],original_tree_node.children[i])
@@ -944,14 +984,14 @@ def set_h0_params(self,
944984
self.h0_g = _check.float_in_closed01(h0_g,'h0_g',ParameterFormatError)
945985
if self.h0_metatree_list:
946986
for h0_root in self.h0_metatree_list:
947-
self._set_h0_params_recursion(h0_root,None)
987+
self._set_h0_g_recursion(h0_root)
948988

949989
if sub_h0_params is not None:
950990
self.SubModel.LearnModel(**sub_h0_params)
951991
self.sub_h0_params = copy.deepcopy(sub_h0_params)
952992
if self.h0_metatree_list:
953993
for h0_root in self.h0_metatree_list:
954-
self._set_h0_params_recursion(h0_root,None)
994+
self._set_sub_h0_params_recursion(h0_root)
955995

956996
if h0_metatree_list is not None:
957997
if not isinstance(h0_metatree_list,list):
@@ -964,7 +1004,22 @@ def set_h0_params(self,
9641004
raise(ParameterFormatError(
9651005
"all elements of h0_metatree_list must be instances of metatree._Node or empty"
9661006
))
967-
self.h0_metatree_list = copy.deepcopy(h0_metatree_list)
1007+
diff = len(h0_metatree_list) - len(self.h0_metatree_list)
1008+
if diff < 0:
1009+
del self.h0_metatree_list[diff:]
1010+
elif diff > 0:
1011+
for i in range(diff):
1012+
self.h0_metatree_list.append(
1013+
_Node(
1014+
0,
1015+
self.c_num_children,
1016+
list(range(self.c_k)),
1017+
self.h0_g,
1018+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params),
1019+
)
1020+
)
1021+
for i in range(len(self.h0_metatree_list)):
1022+
self._set_h0_params_recursion(self.h0_metatree_list[i],h0_metatree_list[i])
9681023
if h0_metatree_prob_vec is not None:
9691024
self.h0_metatree_prob_vec = np.copy(
9701025
_check.float_vec_sum_1(
@@ -1063,14 +1118,14 @@ def set_hn_params(self,
10631118
self.hn_g = _check.float_in_closed01(hn_g,'hn_g',ParameterFormatError)
10641119
if self.hn_metatree_list:
10651120
for hn_root in self.hn_metatree_list:
1066-
self._set_hn_params_recursion(hn_root,None)
1121+
self._set_hn_g_recursion(hn_root)
10671122

10681123
if sub_hn_params is not None:
10691124
self.SubModel.LearnModel(**sub_hn_params)
10701125
self.sub_hn_params = copy.deepcopy(sub_hn_params)
10711126
if self.hn_metatree_list:
10721127
for hn_root in self.hn_metatree_list:
1073-
self._set_hn_params_recursion(hn_root,None)
1128+
self._set_sub_hn_params_recursion(hn_root)
10741129

10751130
if hn_metatree_list is not None:
10761131
if not isinstance(hn_metatree_list,list):
@@ -1083,7 +1138,22 @@ def set_hn_params(self,
10831138
raise(ParameterFormatError(
10841139
"all elements of hn_metatree_list must be instances of metatree._Node or empty"
10851140
))
1086-
self.hn_metatree_list = copy.deepcopy(hn_metatree_list)
1141+
diff = len(hn_metatree_list) - len(self.hn_metatree_list)
1142+
if diff < 0:
1143+
del self.hn_metatree_list[diff:]
1144+
elif diff > 0:
1145+
for i in range(diff):
1146+
self.hn_metatree_list.append(
1147+
_Node(
1148+
0,
1149+
self.c_num_children,
1150+
list(range(self.c_k)),
1151+
self.hn_g,
1152+
sub_model=self.SubModel.LearnModel(**self.sub_hn_params),
1153+
)
1154+
)
1155+
for i in range(len(self.hn_metatree_list)):
1156+
self._set_hn_params_recursion(self.hn_metatree_list[i],hn_metatree_list[i])
10871157
if hn_metatree_prob_vec is not None:
10881158
self.hn_metatree_prob_vec = np.copy(
10891159
_check.float_vec_sum_1(
@@ -1610,13 +1680,14 @@ def _make_prediction_leaf_01(self,node:_Node):
16101680
pred_dist = node.sub_model.make_prediction(loss='KL')
16111681
if type(pred_dist) is np.ndarray:
16121682
mode_prob = pred_dist[mode]
1613-
try:
1614-
mode_prob = pred_dist.pdf(mode)
1615-
except:
1683+
else:
16161684
try:
1617-
mode_prob = pred_dist.pmf(mode)
1685+
mode_prob = pred_dist.pdf(mode)
16181686
except:
1619-
mode_prob = None
1687+
try:
1688+
mode_prob = pred_dist.pmf(mode)
1689+
except:
1690+
mode_prob = None
16201691
# elif hasattr(pred_dist,'pdf'):
16211692
# mode_prob = pred_dist.pdf(mode)
16221693
# elif hasattr(pred_dist,'pmf'):

bayesml/metatree/metatree_test.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
from bayesml import metatree
2-
from bayesml import normal
1+
from bayesml.metatree import GenModel
2+
from bayesml.metatree import LearnModel
3+
from bayesml import poisson
4+
from bayesml import bernoulli
35
import numpy as np
46
import copy
57

6-
gen_model = metatree.GenModel(4,3,2,h_g=0.75,SubModel=normal)
8+
gen_model = GenModel(4,3,2,h_g=0.75)
79
gen_model.gen_params()
810
gen_model.visualize_model('tree.pdf')
911
x,y = gen_model.gen_sample(1000)
1012

11-
learn_model = metatree.LearnModel(4,3,2,h0_g=0.75,SubModel=normal)
12-
learn_model.update_posterior(x,y)
13+
learn_model = LearnModel(4,3,2)
14+
learn_model.update_posterior(x,y,n_estimators=1)
1315
learn_model.visualize_posterior('tree2.pdf')
14-
params = learn_model.estimate_params(filename='tree3.pdf')
1516

16-
gen_model2 = metatree.GenModel(4,3,2,h_g=0.1,SubModel=normal)
17-
gen_model2.visualize_model('tree4.pdf')
18-
gen_model2.set_params(params)
19-
# gen_model2.gen_params()
20-
gen_model2.visualize_model('tree5.pdf')
17+
learn_model.calc_pred_dist(np.zeros(4,dtype=int))
18+
print(learn_model.make_prediction(loss='squared'))
19+
20+
learn_model.calc_pred_dist(np.ones(4,dtype=int))
21+
print(learn_model.make_prediction(loss='squared'))

0 commit comments

Comments
 (0)