@@ -109,8 +109,8 @@ class GenModel(base.Generative):
109
109
generated between ``c_ranges[k,0]`` and ``c_ranges[k,1]``.
110
110
By default, [[-3,3],[-3,3],...,[-3,3]].
111
111
SubModel : class, optional
112
- bernoulli, poisson, normal, or exponential,
113
- by default bernoulli
112
+ bernoulli, categorical, poisson, normal, exponential,
113
+ or linearregression, by default bernoulli
114
114
sub_constants : dict, optional
115
115
constants for self.SubModel.GenModel, by default {}
116
116
root : metatree._Node, optional
@@ -202,8 +202,8 @@ def __init__(
202
202
203
203
if SubModel not in MODELS :
204
204
raise (ParameterFormatError (
205
- "SubModel must be bernoulli, "
206
- + "poisson, normal, or exponential ."
205
+ "SubModel must be bernoulli, categorical "
206
+ + "poisson, normal, exponential, or linearregression ."
207
207
))
208
208
self .SubModel = SubModel
209
209
@@ -1068,8 +1068,8 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
1068
1068
generated between ``c_ranges[k,0]`` and ``c_ranges[k,1]``.
1069
1069
By default, [[-3,3],[-3,3],...,[-3,3]].
1070
1070
SubModel : class, optional
1071
- bernoulli, poisson, normal, or exponential,
1072
- by default bernoulli
1071
+ bernoulli, categorical, poisson, normal, exponential,
1072
+ or linearregression, by default bernoulli
1073
1073
sub_constants : dict, optional
1074
1074
constants for self.SubModel.LearnModel, by default {}
1075
1075
h0_k_weight_vec : numpy.ndarray, optional
@@ -1166,8 +1166,8 @@ def __init__(
1166
1166
1167
1167
if SubModel not in MODELS :
1168
1168
raise (ParameterFormatError (
1169
- "SubModel must be bernoulli, "
1170
- + "poisson, normal, or exponential ."
1169
+ "SubModel must be bernoulli, categorical "
1170
+ + "poisson, normal, exponential, or linearregression ."
1171
1171
))
1172
1172
self .SubModel = SubModel
1173
1173
@@ -2403,21 +2403,23 @@ def make_prediction(self,loss="squared"):
2403
2403
tmp_pred_vec [i ] = self ._make_prediction_recursion_squared (metatree )
2404
2404
return self .hn_metatree_prob_vec @ tmp_pred_vec
2405
2405
elif loss == "0-1" :
2406
- if self .SubModel is not bernoulli :
2407
- raise (CriteriaError ("Unsupported loss function! "
2408
- + "\" 0-1\" is supported only when self.SubModel is bernoulli." ))
2409
- tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2410
- for i ,metatree in enumerate (self .hn_metatree_list ):
2411
- tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2412
- return np .argmax (self .hn_metatree_prob_vec @ tmp_pred_dist_vec )
2406
+ if self .SubModel in CLF_MODELS :
2407
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2408
+ for i ,metatree in enumerate (self .hn_metatree_list ):
2409
+ tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2410
+ return np .argmax (self .hn_metatree_prob_vec @ tmp_pred_dist_vec )
2411
+ else :
2412
+ raise (CriteriaError ("Unsupported loss function! \" 0-1\" is supported "
2413
+ + "only when self.SubModel is bernoulli or categorical." ))
2413
2414
elif loss == "KL" :
2414
- if self .SubModel is not bernoulli :
2415
- raise (CriteriaError ("Unsupported loss function! "
2416
- + "\" KL\" is supported only when self.SubModel is bernoulli." ))
2417
- tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2418
- for i ,metatree in enumerate (self .hn_metatree_list ):
2419
- tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2420
- return self .hn_metatree_prob_vec @ tmp_pred_dist_vec
2415
+ if self .SubModel in CLF_MODELS :
2416
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2417
+ for i ,metatree in enumerate (self .hn_metatree_list ):
2418
+ tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2419
+ return self .hn_metatree_prob_vec @ tmp_pred_dist_vec
2420
+ else :
2421
+ raise (CriteriaError ("Unsupported loss function! \" KL\" is supported "
2422
+ + "only when self.SubModel is bernoulli or categorical." ))
2421
2423
else :
2422
2424
raise (CriteriaError ("Unsupported loss function! "
2423
2425
+ "This function supports \" squared\" , \" 0-1\" , and \" KL\" ." ))
0 commit comments