Skip to content

Commit 5a8458d

Browse files
committed
Modify _given_MT
1 parent c2a9b00 commit 5a8458d

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

bayesml/metatree/_metatree.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,9 +1825,7 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
18251825
for i in range(n_estimators):
18261826
self._copy_tree_from_sklearn_tree(tmp_metatree_list[i],randomforest.estimators_[i].tree_, 0)
18271827

1828-
print(f'before: {len(tmp_metatree_list)}')
18291828
tmp_metatree_list,tmp_metatree_prob_vec = self._marge_metatrees(tmp_metatree_list,tmp_metatree_prob_vec)
1830-
print(f'after: {len(tmp_metatree_list)}')
18311829

18321830
log_metatree_posteriors = np.log(tmp_metatree_prob_vec)
18331831
if self.c_dim_continuous > 0 and self.c_dim_categorical > 0:
@@ -1846,13 +1844,18 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
18461844
tmp_metatree_prob_vec[:] /= tmp_metatree_prob_vec.sum()
18471845
return tmp_metatree_list,tmp_metatree_prob_vec
18481846

1849-
def _given_MT(self,x,y):
1847+
def _given_MT(self,x_continuous,x_categorical,y):
18501848
"""make metatrees
18511849
18521850
Parameters
18531851
----------
1854-
x : numpy ndarray
1855-
values of explanatory variables whose dtype is int
1852+
x_continuous : numpy ndarray, optional
1853+
2 dimensional float array whose size is ``(sample_size,c_dim_continuous)``,
1854+
by default None.
1855+
x_categorical : numpy ndarray, optional
1856+
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
1857+
by default None. Each element x[i,j] must satisfy
1858+
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
18561859
y : numpy ndarray
18571860
values of objective variable whose dtype may be int or float
18581861
@@ -1865,9 +1868,18 @@ def _given_MT(self,x,y):
18651868
if not self.hn_metatree_list:
18661869
raise(ParameterFormatError("given_MT is supported only when len(self.hn_metatree_list) > 0."))
18671870
log_metatree_posteriors = np.log(self.hn_metatree_prob_vec)
1868-
for i,metatree in enumerate(self.hn_metatree_list):
1869-
for j in range(x.shape[0]):
1870-
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x[j],y[j]))
1871+
if self.c_dim_continuous > 0 and self.c_dim_categorical > 0:
1872+
for i,metatree in enumerate(self.hn_metatree_list):
1873+
for j in range(y.shape[0]):
1874+
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],x_categorical[j],y[j]))
1875+
elif self.c_dim_continuous > 0:
1876+
for i,metatree in enumerate(self.hn_metatree_list):
1877+
for j in range(y.shape[0]):
1878+
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,x_continuous[j],None,y[j]))
1879+
else:
1880+
for i,metatree in enumerate(self.hn_metatree_list):
1881+
for j in range(y.shape[0]):
1882+
log_metatree_posteriors[i] += np.log(self._update_posterior_recursion(metatree,None,x_categorical[j],y[j]))
18711883
self.hn_metatree_prob_vec[:] = np.exp(log_metatree_posteriors - log_metatree_posteriors.max())
18721884
self.hn_metatree_prob_vec[:] /= self.hn_metatree_prob_vec.sum()
18731885
return self.hn_metatree_list,self.hn_metatree_prob_vec

bayesml/metatree/metatree_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
h_g=0.75,
1212
sub_h_params={'h_alpha':0.1,'h_beta':0.1})
1313
gen_model.gen_params(threshold_type='random')
14-
# gen_model.visualize_model(filename='tree.pdf')
14+
gen_model.visualize_model(filename='tree.pdf')
1515

16-
x_continuous,x_categorical,y = gen_model.gen_sample(1000)
16+
x_continuous,x_categorical,y = gen_model.gen_sample(100)
1717

1818
learn_model = metatree.LearnModel(
1919
c_dim_continuous=2,
2020
c_dim_categorical=0,
2121
c_num_children_vec=2,
2222
sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
23-
learn_model.update_posterior(x_continuous,x_categorical,y,n_estimators=1)
23+
learn_model.update_posterior(x_continuous,x_categorical,y)
24+
learn_model.update_posterior(x_continuous,x_categorical,y,alg_type='given_MT')
2425
learn_model.visualize_posterior(filename='tree2.pdf')

0 commit comments

Comments
 (0)