Skip to content

Commit 2bffa9f

Browse files
committed
Modify make_prediction
1 parent d7b2ea0 commit 2bffa9f

File tree

2 files changed

+63
-70
lines changed

2 files changed

+63
-70
lines changed

bayesml/metatree/_metatree.py

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,21 +1896,21 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
18961896
self.c_dim_continuous,'self.c_dim_continuous',
18971897
ParameterFormatError
18981898
)
1899-
x_continuous.reshape([-1,self.c_dim_continuous])
1899+
x_continuous = x_continuous.reshape([-1,self.c_dim_continuous])
19001900

19011901
_check.nonneg_int_vecs(x_categorical,'x_categorical',DataFormatError)
19021902
_check.shape_consistency(
19031903
x_categorical.shape[-1],'x_categorical.shape[-1]',
19041904
self.c_dim_categorical,'self.c_dim_categorical',
19051905
ParameterFormatError
19061906
)
1907+
x_categorical = x_categorical.reshape([-1,self.c_dim_categorical])
19071908
for i in range(self.c_dim_categorical):
19081909
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
19091910
raise(DataFormatError(
19101911
f"x_categorical[:,{i}].max() must smaller than "
19111912
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
19121913
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
1913-
x_categorical.reshape([-1,self.c_dim_categorical])
19141914

19151915
_check.shape_consistency(
19161916
x_continuous.shape[0],'x_continuous.shape[0]',
@@ -1930,7 +1930,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19301930
self.c_dim_continuous,'self.c_dim_continuous',
19311931
ParameterFormatError
19321932
)
1933-
x_continuous.reshape([-1,self.c_dim_continuous])
1933+
x_continuous = x_continuous.reshape([-1,self.c_dim_continuous])
19341934

19351935
_check.shape_consistency(
19361936
x_continuous.shape[0],'x_continuous.shape[0]',
@@ -1947,13 +1947,13 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19471947
self.c_dim_categorical,'self.c_dim_categorical',
19481948
ParameterFormatError
19491949
)
1950+
x_categorical = x_categorical.reshape([-1,self.c_dim_categorical])
19501951
for i in range(self.c_dim_categorical):
19511952
if x_categorical[:,i].max() >= self.c_num_children_vec[self.c_dim_continuous+i]:
19521953
raise(DataFormatError(
19531954
f"x_categorical[:,{i}].max() must smaller than "
19541955
+f"self.c_num_children_vec[{self.c_dim_continuous+i}]: "
19551956
+f"{self.c_num_children_vec[self.c_dim_continuous+i]}"))
1956-
x_categorical.reshape([-1,self.c_dim_categorical])
19571957

19581958
_check.shape_consistency(
19591959
x_categorical.shape[0],'x_categorical.shape[0]',
@@ -2372,52 +2372,41 @@ def calc_pred_dist(self,x_continuous=None,x_categorical=None):
23722372
self._calc_pred_dist_recursion(root,self._tmp_x_continuous,self._tmp_x_categorical)
23732373

23742374
def _make_prediction_recursion_squared(self,node:_Node):
2375-
if node.leaf == False: # inner node
2376-
return ((1 - node.h_g) * node.sub_model.make_prediction(loss='squared')
2377-
+ node.h_g * self._make_prediction_recursion_squared(node.children[self._tmp_x[node.k]]))
2378-
else: # leaf node
2375+
if node.leaf: # leaf node
23792376
return node.sub_model.make_prediction(loss='squared')
2377+
else: # inner node
2378+
if node.k < self.c_dim_continuous:
2379+
for i in range(self.c_num_children_vec[node.k]):
2380+
if node.thresholds[i] < self._tmp_x_continuous[node.k] and self._tmp_x_continuous[node.k] < node.thresholds[i+1]:
2381+
index = i
2382+
break
2383+
else:
2384+
index = self._tmp_x_categorical[node.k-self.c_dim_continuous]
2385+
return ((1 - node.h_g) * node.sub_model.make_prediction(loss='squared')
2386+
+ node.h_g * self._make_prediction_recursion_squared(node.children[index]))
23802387

2381-
def _make_prediction_leaf_01(self,node:_Node):
2382-
mode = node.sub_model.make_prediction(loss='0-1')
2383-
pred_dist = node.sub_model.make_prediction(loss='KL')
2384-
if type(pred_dist) is np.ndarray:
2385-
mode_prob = pred_dist[mode]
2386-
else:
2387-
try:
2388-
mode_prob = pred_dist.pdf(mode)
2389-
except:
2390-
try:
2391-
mode_prob = pred_dist.pmf(mode)
2392-
except:
2393-
mode_prob = None
2394-
# elif hasattr(pred_dist,'pdf'):
2395-
# mode_prob = pred_dist.pdf(mode)
2396-
# elif hasattr(pred_dist,'pmf'):
2397-
# mode_prob = pred_dist.pmf(mode)
2398-
# else:
2399-
# mode_prob = None
2400-
return mode, mode_prob
2401-
2402-
def _make_prediction_recursion_01(self,node:_Node):
2403-
if node.leaf == False: # inner node
2404-
mode1,mode_prob1 = self._make_prediction_leaf_01(node)
2405-
mode2,mode_prob2 = self._make_prediction_recursion_01(node.children[self._tmp_x[node.k]])
2406-
if (1 - node.h_g) * mode_prob1 > node.h_g * mode_prob2:
2407-
return mode1,mode_prob1
2388+
def _make_prediction_recursion_kl(self,node:_Node):
2389+
if node.leaf: # leaf node
2390+
return node.sub_model.make_prediction(loss='KL')
2391+
else: # inner node
2392+
if node.k < self.c_dim_continuous:
2393+
for i in range(self.c_num_children_vec[node.k]):
2394+
if node.thresholds[i] < self._tmp_x_continuous[node.k] and self._tmp_x_continuous[node.k] < node.thresholds[i+1]:
2395+
index = i
2396+
break
24082397
else:
2409-
return mode2,mode_prob2
2410-
else: # leaf node
2411-
return self._make_prediction_leaf_01(node)
2398+
index = self._tmp_x_categorical[node.k-self.c_dim_continuous]
2399+
return ((1 - node.h_g) * node.sub_model.make_prediction(loss='KL')
2400+
+ node.h_g * self._make_prediction_recursion_kl(node.children[index]))
24122401

2413-
def make_prediction(self,loss="0-1"):
2402+
def make_prediction(self,loss="squared"):
24142403
"""Predict a new data point under the given criterion.
24152404
24162405
Parameters
24172406
----------
24182407
loss : str, optional
2419-
Loss function underlying the Bayes risk function, by default \"0-1\".
2420-
This function supports \"squared\", \"0-1\".
2408+
Loss function underlying the Bayes risk function, by default \"squared\".
2409+
This function supports \"squared\", \"0-1\", and \"KL\".
24212410
24222411
Returns
24232412
-------
@@ -2430,39 +2419,49 @@ def make_prediction(self,loss="0-1"):
24302419
tmp_pred_vec[i] = self._make_prediction_recursion_squared(metatree)
24312420
return self.hn_metatree_prob_vec @ tmp_pred_vec
24322421
elif loss == "0-1":
2433-
tmp_mode = np.empty(len(self.hn_metatree_list))
2434-
tmp_mode_prob_vec = np.empty(len(self.hn_metatree_list))
2422+
if self.SubModel is not bernoulli:
2423+
raise(CriteriaError("Unsupported loss function! "
2424+
+"\"0-1\" is supported only when self.SubModel is bernoulli."))
2425+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2426+
for i,metatree in enumerate(self.hn_metatree_list):
2427+
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2428+
return np.argmax(self.hn_metatree_prob_vec @ tmp_pred_dist_vec)
2429+
elif loss == "KL":
2430+
if self.SubModel is not bernoulli:
2431+
raise(CriteriaError("Unsupported loss function! "
2432+
+"\"KL\" is supported only when self.SubModel is bernoulli."))
2433+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
24352434
for i,metatree in enumerate(self.hn_metatree_list):
2436-
tmp_mode[i],tmp_mode_prob_vec[i] = self._make_prediction_recursion_01(metatree)
2437-
return tmp_mode[np.argmax(self.hn_metatree_prob_vec * tmp_mode_prob_vec)]
2435+
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2436+
return self.hn_metatree_prob_vec @ tmp_pred_dist_vec
24382437
else:
24392438
raise(CriteriaError("Unsupported loss function! "
2440-
+"This function supports \"squared\" and \"0-1\"."))
2439+
+"This function supports \"squared\", \"0-1\", and \"KL\"."))
24412440

2442-
def pred_and_update(self,x,y,loss="0-1"):
2441+
def pred_and_update(self,x_continuous=None,x_categorical=None,y=None,loss="squared"):
24432442
"""Predict a new data point and update the posterior sequentially.
24442443
24452444
Parameters
24462445
----------
2447-
x : numpy.ndarray
2448-
It must be a degree-dimensional vector
2446+
x_continuous : numpy ndarray, optional
2447+
A float vector whose length is ``self.c_dim_continuous``,
2448+
by default None.
2449+
x_categorical : numpy ndarray, optional
2450+
A int vector whose length is ``self.c_dim_categorical``,
2451+
by default None. Each element x_categorical[i] must satisfy
2452+
0 <= x_categorical[i] < self.c_num_children_vec[self.c_dim_continuous+i].
24492453
y : numpy ndarray
24502454
values of objective variable whose dtype may be int or float
24512455
loss : str, optional
2452-
Loss function underlying the Bayes risk function, by default \"0-1\".
2456+
Loss function underlying the Bayes risk function, by default \"squared\".
24532457
This function supports \"squared\", \"0-1\", and \"KL\".
24542458
24552459
Returns
24562460
-------
24572461
predicted_value : {float, numpy.ndarray}
24582462
The predicted value under the given loss function.
24592463
"""
2460-
_check.nonneg_int_vec(x,'x',DataFormatError)
2461-
if x.shape[-1] != self.c_k:
2462-
raise(DataFormatError(f"x.shape[-1] must equal to c_k:{self.c_k}"))
2463-
if x.max() >= self.c_num_children:
2464-
raise(DataFormatError(f"x.max() must smaller than c_num_children:{self.c_num_children}"))
2465-
self.calc_pred_dist(x)
2464+
self.calc_pred_dist(x_continuous,x_categorical)
24662465
prediction = self.make_prediction(loss=loss)
2467-
self.update_posterior(x,y,alg_type='given_MT')
2466+
self.update_posterior(x_continuous,x_categorical,y,alg_type='given_MT')
24682467
return prediction

bayesml/metatree/metatree_test.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,18 @@
1212
c_dim_continuous=dim_continuous,
1313
c_dim_categorical=dim_categorical,
1414
h_g=0.75,
15-
SubModel=normal,
16-
)
17-
# sub_h_params={'h_alpha':0.1,'h_beta':0.1})
18-
gen_model.gen_params(threshold_type='random')
15+
sub_h_params={'h_alpha':0.1,'h_beta':0.1})
16+
gen_model.gen_params(threshold_type='even')
1917
gen_model.visualize_model(filename='tree.pdf')
2018

21-
x_continuous,x_categorical,y = gen_model.gen_sample(100)
19+
x_continuous,x_categorical,y = gen_model.gen_sample(200)
20+
x_continuous_test,x_categorical_test,y_test = gen_model.gen_sample(10)
2221

2322
learn_model = metatree.LearnModel(
2423
c_dim_continuous=dim_continuous,
2524
c_dim_categorical=dim_categorical,
2625
c_num_children_vec=2,
27-
SubModel=normal,
28-
)
29-
# sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
26+
sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
3027
learn_model.update_posterior(x_continuous,x_categorical,y)
31-
learn_model.calc_pred_dist(
32-
np.zeros(dim_continuous,dtype=float),
33-
np.zeros(dim_categorical,dtype=int))
34-
learn_model.visualize_posterior(filename='tree2.pdf')
35-
learn_model.estimate_params(filename='tree3.pdf')
28+
for i in range(10):
29+
print(learn_model.pred_and_update(x_continuous_test[i],x_categorical_test[i],y_test[i],loss='0-1'))

0 commit comments

Comments
 (0)