Skip to content

Commit d7b2ea0

Browse files
committed
Modify calc_pred_dist
1 parent c107fd3 commit d7b2ea0

File tree

2 files changed

+104
-43
lines changed

2 files changed

+104
-43
lines changed

bayesml/metatree/_metatree.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -739,17 +739,17 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
739739
by default None.
740740
x_categorical : numpy ndarray, optional
741741
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
742-
by default None. Each element x[i,j] must satisfy
743-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
742+
by default None. Each element x_categorical[i,j] must satisfy
743+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
744744
745745
Returns
746746
-------
747747
x_continuous : numpy ndarray
748748
2 dimensional float array whose size is ``(sample_size,c_dim_continuous)``.
749749
x_categorical : numpy ndarray, optional
750750
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``.
751-
Each element x[i,j] must satisfies
752-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
751+
Each element x_categorical[i,j] must satisfies
752+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
753753
y : numpy ndarray
754754
1 dimensional array whose size is ``sample_size``.
755755
"""
@@ -789,7 +789,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
789789
for i in range(self.c_dim_categorical):
790790
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
791791
raise(DataFormatError(
792-
f"x_categorical[{i}].max() must smaller than "
792+
f"x_categorical[:,{i}].max() must smaller than "
793793
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
794794
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
795795
else:
@@ -825,7 +825,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
825825
for i in range(self.c_dim_categorical):
826826
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
827827
raise(DataFormatError(
828-
f"x_categorical[{i}].max() must smaller than "
828+
f"x_categorical[:,{i}].max() must smaller than "
829829
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
830830
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
831831
else:
@@ -844,9 +844,9 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
844844
)
845845
x_categorical = x_categorical.reshape(-1,self.c_dim_categorical)
846846
for i in range(self.c_dim_categorical):
847-
if x_categorical[i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
847+
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
848848
raise(DataFormatError(
849-
f"x_categorical[{i}].max() must smaller than "
849+
f"x_categorical[:,{i}].max() must smaller than "
850850
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
851851
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
852852

@@ -887,8 +887,8 @@ def save_sample(self,filename,sample_size,x=None):
887887
by default None.
888888
x_categorical : numpy ndarray, optional
889889
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
890-
by default None. Each element x[i,j] must satisfy
891-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
890+
by default None. Each element x_categorical[i,j] must satisfy
891+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
892892
893893
See Also
894894
--------
@@ -1782,8 +1782,8 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
17821782
by default None.
17831783
x_categorical : numpy ndarray, optional
17841784
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
1785-
by default None. Each element x[i,j] must satisfy
1786-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
1785+
by default None. Each element x_categorical[i,j] must satisfy
1786+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
17871787
y : numpy ndarray
17881788
values of objective variable whose dtype may be int or float
17891789
n_estimators : int, optional
@@ -1843,8 +1843,8 @@ def _given_MT(self,x_continuous,x_categorical,y):
18431843
by default None.
18441844
x_categorical : numpy ndarray, optional
18451845
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
1846-
by default None. Each element x[i,j] must satisfy
1847-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
1846+
by default None. Each element x_categorical[i,j] must satisfy
1847+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
18481848
y : numpy ndarray
18491849
values of objective variable whose dtype may be int or float
18501850
@@ -1874,8 +1874,8 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
18741874
by default None.
18751875
x_categorical : numpy ndarray, optional
18761876
2 dimensional int array whose size is ``(sample_size,c_dim_categorical)``,
1877-
by default None. Each element x[i,j] must satisfy
1878-
0 <= x[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
1877+
by default None. Each element x_categorical[i,j] must satisfy
1878+
0 <= x_categorical[i,j] < self.c_num_children_vec[self.c_dim_continuous+i].
18791879
y : numpy ndarray
18801880
values of objective variable whose dtype may be int or float
18811881
alg_type : {'MTRF', 'given_MT'}, optional
@@ -1907,7 +1907,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19071907
for i in range(self.c_dim_categorical):
19081908
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
19091909
raise(DataFormatError(
1910-
f"x_categorical[{i}].max() must smaller than "
1910+
f"x_categorical[:,{i}].max() must smaller than "
19111911
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
19121912
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
19131913
x_categorical.reshape([-1,self.c_dim_categorical])
@@ -1950,7 +1950,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19501950
for i in range(self.c_dim_categorical):
19511951
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
19521952
raise(DataFormatError(
1953-
f"x_categorical[{i}].max() must smaller than "
1953+
f"x_categorical[:,{i}].max() must smaller than "
19541954
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
19551955
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
19561956
x_categorical.reshape([-1,self.c_dim_categorical])
@@ -2292,33 +2292,84 @@ def get_p_params(self):
22922292
return None
22932293

22942294
def _calc_pred_dist_leaf(self,node:_Node,x):
2295-
try:
2296-
node.sub_model.calc_pred_dist(x)
2297-
except:
2298-
node.sub_model.calc_pred_dist()
2295+
try:
2296+
node.sub_model.calc_pred_dist(x)
2297+
except:
2298+
node.sub_model.calc_pred_dist()
22992299

2300-
def _calc_pred_dist_recursion(self,node:_Node,x):
2301-
self._calc_pred_dist_leaf(node,x)
2300+
def _calc_pred_dist_recursion(self,node:_Node,x_continuous,x_categorical):
2301+
self._calc_pred_dist_leaf(node,x_continuous)
23022302
if not node.leaf: # inner node
2303-
self._calc_pred_dist_recursion(node.children[x[node.k]],x)
2303+
if node.k < self.c_dim_continuous:
2304+
for i in range(self.c_num_children_vec[node.k]):
2305+
if node.thresholds[i] < x_continuous[node.k] and x_continuous[node.k] < node.thresholds[i+1]:
2306+
index = i
2307+
break
2308+
else:
2309+
index = x_categorical[node.k-self.c_dim_continuous]
2310+
self._calc_pred_dist_recursion(node.children[index],x_continuous,x_categorical)
23042311

2305-
def calc_pred_dist(self,x_continuous,x_categorical):
2312+
def calc_pred_dist(self,x_continuous=None,x_categorical=None):
23062313
"""Calculate the parameters of the predictive distribution.
23072314
23082315
Parameters
23092316
----------
2310-
x : numpy ndarray
2311-
values of explanatory variables whose dtype is int
2317+
x_continuous : numpy ndarray, optional
2318+
A float vector whose length is ``self.c_dim_continuous``,
2319+
by default None.
2320+
x_categorical : numpy ndarray, optional
2321+
A int vector whose length is ``self.c_dim_categorical``,
2322+
by default None. Each element x_categorical[i] must satisfy
2323+
0 <= x_categorical[i] < self.c_num_children_vec[self.c_dim_continuous+i].
23122324
"""
2313-
pass
2314-
# _check.nonneg_int_vec(x,'x',DataFormatError)
2315-
# if x.shape[0] != self.c_k:
2316-
# raise(DataFormatError(f"x.shape[0] must equal to c_k:{self.c_k}"))
2317-
# if x.max() >= self.c_num_children:
2318-
# raise(DataFormatError(f"x.max() must smaller than c_num_children:{self.c_num_children}"))
2319-
# self._tmp_x[:] = x
2320-
# for root in self.hn_metatree_list:
2321-
# self._calc_pred_dist_recursion(root,self._tmp_x)
2325+
if self.c_dim_continuous > 0 and self.c_dim_categorical > 0:
2326+
_check.float_vec(x_continuous,'x_continuous',DataFormatError)
2327+
_check.shape_consistency(
2328+
x_continuous.shape[0],'x_continuous.shape[0]',
2329+
self.c_dim_continuous,'self.c_dim_continuous',
2330+
ParameterFormatError
2331+
)
2332+
_check.nonneg_int_vec(x_categorical,'x_categorical',DataFormatError)
2333+
_check.shape_consistency(
2334+
x_categorical.shape[0],'x_categorical.shape[0]',
2335+
self.c_dim_categorical,'self.c_dim_categorical',
2336+
ParameterFormatError
2337+
)
2338+
for i in range(self.c_dim_categorical):
2339+
if x_categorical[i] >= self.c_num_children_vec[self.c_dim_continuous+i]:
2340+
raise(DataFormatError(
2341+
f"x_categorical[{i}] must smaller than "
2342+
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
2343+
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
2344+
2345+
elif self.c_dim_continuous > 0:
2346+
_check.float_vec(x_continuous,'x_continuous',DataFormatError)
2347+
_check.shape_consistency(
2348+
x_continuous.shape[0],'x_continuous.shape[0]',
2349+
self.c_dim_continuous,'self.c_dim_continuous',
2350+
ParameterFormatError
2351+
)
2352+
x_categorical = np.empty(0) # dummy
2353+
2354+
elif self.c_dim_categorical > 0:
2355+
_check.nonneg_int_vec(x_categorical,'x_categorical',DataFormatError)
2356+
_check.shape_consistency(
2357+
x_categorical.shape[0],'x_categorical.shape[0]',
2358+
self.c_dim_categorical,'self.c_dim_categorical',
2359+
ParameterFormatError
2360+
)
2361+
for i in range(self.c_dim_categorical):
2362+
if x_categorical[i] >= self.c_num_children_vec[self.c_dim_continuous+i]:
2363+
raise(DataFormatError(
2364+
f"x_categorical[{i}] must smaller than "
2365+
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
2366+
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
2367+
x_continuous = np.empty(0) # dummy
2368+
2369+
self._tmp_x_continuous[:] = x_continuous
2370+
self._tmp_x_categorical[:] = x_categorical
2371+
for root in self.hn_metatree_list:
2372+
self._calc_pred_dist_recursion(root,self._tmp_x_continuous,self._tmp_x_categorical)
23222373

23232374
def _make_prediction_recursion_squared(self,node:_Node):
23242375
if node.leaf == False: # inner node

bayesml/metatree/metatree_test.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,31 @@
55
import numpy as np
66
import copy
77

8+
dim_continuous = 0
9+
dim_categorical = 2
10+
811
gen_model = metatree.GenModel(
9-
c_dim_continuous=0,
10-
c_dim_categorical=2,
12+
c_dim_continuous=dim_continuous,
13+
c_dim_categorical=dim_categorical,
1114
h_g=0.75,
12-
sub_h_params={'h_alpha':0.1,'h_beta':0.1})
15+
SubModel=normal,
16+
)
17+
# sub_h_params={'h_alpha':0.1,'h_beta':0.1})
1318
gen_model.gen_params(threshold_type='random')
1419
gen_model.visualize_model(filename='tree.pdf')
1520

1621
x_continuous,x_categorical,y = gen_model.gen_sample(100)
1722

1823
learn_model = metatree.LearnModel(
19-
c_dim_continuous=0,
20-
c_dim_categorical=2,
24+
c_dim_continuous=dim_continuous,
25+
c_dim_categorical=dim_categorical,
2126
c_num_children_vec=2,
22-
sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
27+
SubModel=normal,
28+
)
29+
# sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
2330
learn_model.update_posterior(x_continuous,x_categorical,y)
31+
learn_model.calc_pred_dist(
32+
np.zeros(dim_continuous,dtype=float),
33+
np.zeros(dim_categorical,dtype=int))
2434
learn_model.visualize_posterior(filename='tree2.pdf')
2535
learn_model.estimate_params(filename='tree3.pdf')

0 commit comments

Comments
 (0)