Skip to content

Commit 3e871b4

Browse files
committed
Bug fix of make_prediction when SubModel is categorical
1 parent 868a592 commit 3e871b4

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

bayesml/metatree/_metatree.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,13 +2416,17 @@ def make_prediction(self,loss="squared"):
24162416
The predicted value under the given loss function.
24172417
"""
24182418
if loss == "squared":
2419-
tmp_pred_vec = np.empty(len(self.hn_metatree_list))
2419+
if self.SubModel is categorical:
2420+
tmp_pred_vec = np.empty([len(self.hn_metatree_list),self.sub_constants['c_degree']])
2421+
else:
2422+
tmp_pred_vec = np.empty(len(self.hn_metatree_list))
24202423
for i,metatree in enumerate(self.hn_metatree_list):
24212424
tmp_pred_vec[i] = self._make_prediction_recursion_squared(metatree)
24222425
return self.hn_metatree_prob_vec @ tmp_pred_vec
24232426
elif loss == "0-1":
24242427
if self.SubModel in CLF_MODELS:
2425-
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2428+
degree = 2 if self.SubModel is bernoulli else self.sub_constants['c_degree']
2429+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),degree])
24262430
for i,metatree in enumerate(self.hn_metatree_list):
24272431
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
24282432
return np.argmax(self.hn_metatree_prob_vec @ tmp_pred_dist_vec)
@@ -2431,7 +2435,8 @@ def make_prediction(self,loss="squared"):
24312435
+"only when self.SubModel is bernoulli or categorical."))
24322436
elif loss == "KL":
24332437
if self.SubModel in CLF_MODELS:
2434-
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2438+
degree = 2 if self.SubModel is bernoulli else self.sub_constants['c_degree']
2439+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),degree])
24352440
for i,metatree in enumerate(self.hn_metatree_list):
24362441
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
24372442
return self.hn_metatree_prob_vec @ tmp_pred_dist_vec

0 commit comments

Comments
 (0)