@@ -2416,13 +2416,17 @@ def make_prediction(self,loss="squared"):
2416
2416
The predicted value under the given loss function.
2417
2417
"""
2418
2418
if loss == "squared" :
2419
- tmp_pred_vec = np .empty (len (self .hn_metatree_list ))
2419
+ if self .SubModel is categorical :
2420
+ tmp_pred_vec = np .empty ([len (self .hn_metatree_list ),self .sub_constants ['c_degree' ]])
2421
+ else :
2422
+ tmp_pred_vec = np .empty (len (self .hn_metatree_list ))
2420
2423
for i ,metatree in enumerate (self .hn_metatree_list ):
2421
2424
tmp_pred_vec [i ] = self ._make_prediction_recursion_squared (metatree )
2422
2425
return self .hn_metatree_prob_vec @ tmp_pred_vec
2423
2426
elif loss == "0-1" :
2424
2427
if self .SubModel in CLF_MODELS :
2425
- tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2428
+ degree = 2 if self .SubModel is bernoulli else self .sub_constants ['c_degree' ]
2429
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),degree ])
2426
2430
for i ,metatree in enumerate (self .hn_metatree_list ):
2427
2431
tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2428
2432
return np .argmax (self .hn_metatree_prob_vec @ tmp_pred_dist_vec )
@@ -2431,7 +2435,8 @@ def make_prediction(self,loss="squared"):
2431
2435
+ "only when self.SubModel is bernoulli or categorical." ))
2432
2436
elif loss == "KL" :
2433
2437
if self .SubModel in CLF_MODELS :
2434
- tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2438
+ degree = 2 if self .SubModel is bernoulli else self .sub_constants ['c_degree' ]
2439
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),degree ])
2435
2440
for i ,metatree in enumerate (self .hn_metatree_list ):
2436
2441
tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2437
2442
return self .hn_metatree_prob_vec @ tmp_pred_dist_vec
0 commit comments