Skip to content

Commit c107fd3

Browse files
committed
Modify estimate_params
1 parent 5a8458d commit c107fd3

File tree

2 files changed

+88
-71
lines changed

2 files changed

+88
-71
lines changed

bayesml/metatree/_metatree.py

Lines changed: 83 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,10 +1804,8 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
18041804
randomforest = RandomForestRegressor(n_estimators=n_estimators,max_depth=self.c_max_depth,**kwargs)
18051805

18061806
x = np.empty([y.shape[0],self.c_dim_features])
1807-
if self.c_dim_continuous > 0:
1808-
x[:,:self.c_dim_continuous] = x_continuous
1809-
if self.c_dim_categorical > 0:
1810-
x[:,-self.c_dim_categorical:] = x_categorical
1807+
x[:,:self.c_dim_continuous] = x_continuous
1808+
x[:,self.c_dim_continuous:] = x_categorical
18111809

18121810
randomforest.fit(x,y)
18131811

@@ -1828,18 +1826,9 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
18281826
tmp_metatree_list,tmp_metatree_prob_vec = self._marge_metatrees(tmp_metatree_list,tmp_metatree_prob_vec)
18291827

18301828
log_metatree_posteriors = np.log(tmp_metatree_prob_vec)
1831-
if self.c_dim_continuous > 0 and self.c_dim_categorical > 0:
1832-
for i,metatree in enumerate(tmp_metatree_list):
1833-
for j in range(y.shape[0]):
1834-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],x_categorical[j],y[j]))
1835-
elif self.c_dim_continuous > 0:
1836-
for i,metatree in enumerate(tmp_metatree_list):
1837-
for j in range(y.shape[0]):
1838-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],None,y[j]))
1839-
else:
1840-
for i,metatree in enumerate(tmp_metatree_list):
1841-
for j in range(y.shape[0]):
1842-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,None,x_categorical[j],y[j]))
1829+
for i,metatree in enumerate(tmp_metatree_list):
1830+
for j in range(y.shape[0]):
1831+
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],x_categorical[j],y[j]))
18431832
tmp_metatree_prob_vec[:] = np.exp(log_metatree_posteriors - log_metatree_posteriors.max())
18441833
tmp_metatree_prob_vec[:] /= tmp_metatree_prob_vec.sum()
18451834
return tmp_metatree_list,tmp_metatree_prob_vec
@@ -1868,18 +1857,9 @@ def _given_MT(self,x_continuous,x_categorical,y):
18681857
if not self.hn_metatree_list:
18691858
raise(ParameterFormatError("given_MT is supported only when len(self.hn_metatree_list) > 0."))
18701859
log_metatree_posteriors = np.log(self.hn_metatree_prob_vec)
1871-
if self.c_dim_continuous > 0 and self.c_dim_categorical > 0:
1872-
for i,metatree in enumerate(self.hn_metatree_list):
1873-
for j in range(y.shape[0]):
1874-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],x_categorical[j],y[j]))
1875-
elif self.c_dim_continuous > 0:
1876-
for i,metatree in enumerate(self.hn_metatree_list):
1877-
for j in range(y.shape[0]):
1878-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],None,y[j]))
1879-
else:
1880-
for i,metatree in enumerate(self.hn_metatree_list):
1881-
for j in range(y.shape[0]):
1882-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,None,x_categorical[j],y[j]))
1860+
for i,metatree in enumerate(self.hn_metatree_list):
1861+
for j in range(y.shape[0]):
1862+
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],x_categorical[j],y[j]))
18831863
self.hn_metatree_prob_vec[:] = np.exp(log_metatree_posteriors - log_metatree_posteriors.max())
18841864
self.hn_metatree_prob_vec[:] /= self.hn_metatree_prob_vec.sum()
18851865
return self.hn_metatree_list,self.hn_metatree_prob_vec
@@ -1958,7 +1938,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19581938
ParameterFormatError
19591939
)
19601940

1961-
x_categorical = None
1941+
x_categorical = np.empty([y.shape[0],0]) # dummy
19621942

19631943
elif self.c_dim_categorical > 0:
19641944
_check.nonneg_int_vecs(x_categorical,'x_categorical',DataFormatError)
@@ -1981,7 +1961,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19811961
ParameterFormatError
19821962
)
19831963

1984-
x_continuous = None
1964+
x_continuous = np.empty([y.shape[0],0]) # dummy
19851965

19861966
if alg_type == 'MTRF':
19871967
self.hn_metatree_list, self.hn_metatree_prob_vec = self._MTRF(x_continuous,x_categorical,y,**kwargs)
@@ -1991,37 +1971,57 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19911971
def _map_recursion_add_nodes(self,node:_Node):
19921972
if node.depth == self.c_max_depth or not node.k_candidates: # leaf node
19931973
node.h_g = 0.0
1974+
node.sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
19941975
node.leaf = True
19951976
node.map_leaf = True
19961977
else: # inner node
1997-
node.k = node.k_candidates[0]
1978+
node.k = node.k_candidates[self.hn_k_weight_vec[node.k_candidates].argmax()]
1979+
node.children = [None for i in range(self.c_num_children_vec[node.k])]
1980+
if node.k < self.c_dim_continuous:
1981+
node.thresholds = np.linspace(
1982+
node.ranges[node.k,0],
1983+
node.ranges[node.k,1],
1984+
self.c_num_children_vec[node.k]+1
1985+
)
1986+
else:
1987+
node.thresholds = None
19981988
child_k_candidates = node.k_candidates.copy()
19991989
child_k_candidates.remove(node.k)
2000-
for i in range(self.c_num_children):
1990+
node.leaf = False
1991+
for i in range(self.c_num_children_vec[node.k]):
20011992
node.children[i] = _Node(
20021993
node.depth+1,
2003-
self.c_num_children,
20041994
child_k_candidates,
20051995
self.hn_g,
2006-
sub_model=self.SubModel.LearnModel(**self.sub_h0_params)
1996+
ranges=np.array(node.ranges)
20071997
)
1998+
if node.thresholds is not None:
1999+
node.children[i].ranges[node.k,0] = node.thresholds[i]
2000+
node.children[i].ranges[node.k,1] = node.thresholds[i+1]
20082001
self._map_recursion_add_nodes(node.children[i])
20092002

20102003
def _map_recursion(self,node:_Node):
20112004
if node.leaf:
20122005
if node.depth == self.c_max_depth or not node.k_candidates:
20132006
node.map_leaf = True
20142007
return 1.0
2015-
elif 1.0 - self.hn_g > self.hn_g ** ((self.c_num_children ** (self.c_max_depth - node.depth) - 1)/(self.c_num_children-1)):
2016-
node.map_leaf = True
2017-
return 1.0 - self.hn_g
20182008
else:
2019-
self._map_recursion_add_nodes(node)
2020-
return self.hn_g ** ((self.c_num_children ** (self.c_max_depth - node.depth) - 1)/(self.c_num_children-1))
2009+
sum_nodes = 0
2010+
num_nodes = 1
2011+
rest_num_children_vec = np.sort(self.c_num_children_vec[node.k_candidates])
2012+
for i in range(min(self.c_max_depth-node.depth,len(node.k_candidates))):
2013+
sum_nodes += num_nodes
2014+
num_nodes *= rest_num_children_vec[i]
2015+
if 1.0 - node.h_g > node.h_g * self.hn_g ** (sum_nodes-1):
2016+
node.map_leaf = True
2017+
return 1.0 - node.h_g
2018+
else:
2019+
self._map_recursion_add_nodes(node)
2020+
return node.h_g * self.hn_g ** (sum_nodes-1)
20212021
else:
20222022
tmp1 = 1.0-node.h_g
2023-
tmp_vec = np.empty(self.c_num_children)
2024-
for i in range(self.c_num_children):
2023+
tmp_vec = np.empty(self.c_num_children_vec[node.k])
2024+
for i in range(self.c_num_children_vec[node.k]):
20252025
tmp_vec[i] = self._map_recursion(node.children[i])
20262026
if tmp1 > node.h_g*tmp_vec.prod():
20272027
node.map_leaf = True
@@ -2030,22 +2030,31 @@ def _map_recursion(self,node:_Node):
20302030
node.map_leaf = False
20312031
return node.h_g*tmp_vec.prod()
20322032

2033-
def _copy_map_tree_recursion(self,copyed_node:_Node,original_node:_Node):
2034-
copyed_node.h_g = original_node.h_g
2035-
if original_node.map_leaf == False:
2036-
copyed_node.k = original_node.k
2037-
child_k_candidates = copyed_node.k_candidates.copy()
2038-
child_k_candidates.remove(copyed_node.k)
2039-
for i in range(self.c_num_children):
2040-
copyed_node.children[i] = _Node(
2041-
copyed_node.depth+1,
2042-
self.c_num_children,
2033+
def _copy_map_tree_recursion(self,copied_node:_Node,original_node:_Node):
2034+
copied_node.h_g = original_node.h_g
2035+
if original_node.map_leaf:
2036+
copied_node.sub_model = copy.deepcopy(original_node.sub_model)
2037+
copied_node.leaf = True
2038+
else:
2039+
copied_node.k = original_node.k
2040+
copied_node.children = [None for i in range(self.c_num_children_vec[copied_node.k])]
2041+
if copied_node.k < self.c_dim_continuous:
2042+
copied_node.thresholds = np.array(original_node.thresholds)
2043+
else:
2044+
copied_node.thresholds = None
2045+
child_k_candidates = copied_node.k_candidates.copy()
2046+
child_k_candidates.remove(copied_node.k)
2047+
copied_node.leaf = False
2048+
for i in range(self.c_num_children_vec[copied_node.k]):
2049+
copied_node.children[i] = _Node(
2050+
copied_node.depth+1,
20432051
child_k_candidates,
2052+
ranges=np.array(copied_node.ranges),
20442053
)
2045-
self._copy_map_tree_recursion(copyed_node.children[i],original_node.children[i])
2046-
else:
2047-
copyed_node.sub_model = copy.deepcopy(original_node.sub_model)
2048-
copyed_node.leaf = True
2054+
if copied_node.thresholds is not None:
2055+
copied_node.children[i].ranges[copied_node.k,0] = copied_node.thresholds[i]
2056+
copied_node.children[i].ranges[copied_node.k,1] = copied_node.thresholds[i+1]
2057+
self._copy_map_tree_recursion(copied_node.children[i],original_node.children[i])
20492058

20502059
def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
20512060
"""Estimate the parameter under the given criterion.
@@ -2088,13 +2097,18 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
20882097
if prob > map_prob:
20892098
map_index = i
20902099
map_prob = prob
2091-
map_root = _Node(0,self.c_num_children,list(range(self.c_k)))
2100+
map_root = _Node(
2101+
0,
2102+
self._root_k_candidates,
2103+
self.hn_g,
2104+
ranges=self.c_ranges,
2105+
)
20922106
self._copy_map_tree_recursion(map_root,self.hn_metatree_list[map_index])
20932107
if visualize:
20942108
import graphviz
20952109
tree_graph = graphviz.Digraph(filename=filename,format=format)
20962110
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
2097-
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0)
2111+
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, None, 1.0)
20982112
tree_graph.view()
20992113
return {'root':map_root}
21002114
else:
@@ -2119,17 +2133,20 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,pare
21192133
label_string += f',{node.thresholds[i+1]:.2f}'
21202134
label_string += '}\\l'
21212135
label_string += f'hn_g={node.h_g:.2f}\\lp_v={tmp_p_v:.2f}\\lsub_params={{'
2122-
try:
2123-
sub_params = node.sub_model.estimate_params(loss='0-1',dict_out=True)
2124-
except:
2125-
sub_params = node.sub_model.estimate_params(dict_out=True)
2126-
2127-
for key,value in sub_params.items():
2136+
if node.sub_model is not None:
21282137
try:
2129-
label_string += f'\\l{key}:{value:.2f}'
2138+
sub_params = node.sub_model.estimate_params(loss='0-1',dict_out=True)
21302139
except:
2131-
label_string += f'\\l{key}:{value}'
2132-
label_string += '}'
2140+
sub_params = node.sub_model.estimate_params(dict_out=True)
2141+
2142+
for key,value in sub_params.items():
2143+
try:
2144+
label_string += f'\\l{key}:{value:.2f}'
2145+
except:
2146+
label_string += f'\\l{key}:{value}'
2147+
label_string += '}'
2148+
else:
2149+
label_string += '\\lNone}'
21332150

21342151
tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
21352152
if tmp_p_v > 0.65:

bayesml/metatree/metatree_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import copy
77

88
gen_model = metatree.GenModel(
9-
c_dim_continuous=2,
10-
c_dim_categorical=0,
9+
c_dim_continuous=0,
10+
c_dim_categorical=2,
1111
h_g=0.75,
1212
sub_h_params={'h_alpha':0.1,'h_beta':0.1})
1313
gen_model.gen_params(threshold_type='random')
@@ -16,10 +16,10 @@
1616
x_continuous,x_categorical,y = gen_model.gen_sample(100)
1717

1818
learn_model = metatree.LearnModel(
19-
c_dim_continuous=2,
20-
c_dim_categorical=0,
19+
c_dim_continuous=0,
20+
c_dim_categorical=2,
2121
c_num_children_vec=2,
2222
sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
2323
learn_model.update_posterior(x_continuous,x_categorical,y)
24-
learn_model.update_posterior(x_continuous,x_categorical,y,alg_type='given_MT')
2524
learn_model.visualize_posterior(filename='tree2.pdf')
25+
learn_model.estimate_params(filename='tree3.pdf')

0 commit comments

Comments
 (0)