Skip to content

Commit c68d531

Browse files
committed
Update _metatree.py
1 parent ddfdae2 commit c68d531

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

bayesml/metatree/_metatree.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class GenModel(base.Generative):
109109
generated between ``c_ranges[k,0]`` and ``c_ranges[k,1]``.
110110
By default, [[-3,3],[-3,3],...,[-3,3]].
111111
SubModel : class, optional
112-
bernoulli, poisson, normal, or exponential,
113-
by default bernoulli
112+
bernoulli, categorical, poisson, normal, exponential,
113+
or linearregression, by default bernoulli
114114
sub_constants : dict, optional
115115
constants for self.SubModel.GenModel, by default {}
116116
root : metatree._Node, optional
@@ -202,8 +202,8 @@ def __init__(
202202

203203
if SubModel not in MODELS:
204204
raise(ParameterFormatError(
205-
"SubModel must be bernoulli, "
206-
+"poisson, normal, or exponential."
205+
"SubModel must be bernoulli, categorical"
206+
+"poisson, normal, exponential, or linearregression."
207207
))
208208
self.SubModel = SubModel
209209

@@ -1068,8 +1068,8 @@ class LearnModel(base.Posterior,base.PredictiveMixin):
10681068
generated between ``c_ranges[k,0]`` and ``c_ranges[k,1]``.
10691069
By default, [[-3,3],[-3,3],...,[-3,3]].
10701070
SubModel : class, optional
1071-
bernoulli, poisson, normal, or exponential,
1072-
by default bernoulli
1071+
bernoulli, categorical, poisson, normal, exponential,
1072+
or linearregression, by default bernoulli
10731073
sub_constants : dict, optional
10741074
constants for self.SubModel.LearnModel, by default {}
10751075
h0_k_weight_vec : numpy.ndarray, optional
@@ -1166,8 +1166,8 @@ def __init__(
11661166

11671167
if SubModel not in MODELS:
11681168
raise(ParameterFormatError(
1169-
"SubModel must be bernoulli, "
1170-
+"poisson, normal, or exponential."
1169+
"SubModel must be bernoulli, categorical"
1170+
+"poisson, normal, exponential, or linearregression."
11711171
))
11721172
self.SubModel = SubModel
11731173

@@ -2403,21 +2403,23 @@ def make_prediction(self,loss="squared"):
24032403
tmp_pred_vec[i] = self._make_prediction_recursion_squared(metatree)
24042404
return self.hn_metatree_prob_vec @ tmp_pred_vec
24052405
elif loss == "0-1":
2406-
if self.SubModel is not bernoulli:
2407-
raise(CriteriaError("Unsupported loss function! "
2408-
+"\"0-1\" is supported only when self.SubModel is bernoulli."))
2409-
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2410-
for i,metatree in enumerate(self.hn_metatree_list):
2411-
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2412-
return np.argmax(self.hn_metatree_prob_vec @ tmp_pred_dist_vec)
2406+
if self.SubModel in CLF_MODELS:
2407+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2408+
for i,metatree in enumerate(self.hn_metatree_list):
2409+
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2410+
return np.argmax(self.hn_metatree_prob_vec @ tmp_pred_dist_vec)
2411+
else:
2412+
raise(CriteriaError("Unsupported loss function! \"0-1\" is supported "
2413+
+"only when self.SubModel is bernoulli or categorical."))
24132414
elif loss == "KL":
2414-
if self.SubModel is not bernoulli:
2415-
raise(CriteriaError("Unsupported loss function! "
2416-
+"\"KL\" is supported only when self.SubModel is bernoulli."))
2417-
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2418-
for i,metatree in enumerate(self.hn_metatree_list):
2419-
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2420-
return self.hn_metatree_prob_vec @ tmp_pred_dist_vec
2415+
if self.SubModel in CLF_MODELS:
2416+
tmp_pred_dist_vec = np.empty([len(self.hn_metatree_list),2])
2417+
for i,metatree in enumerate(self.hn_metatree_list):
2418+
tmp_pred_dist_vec[i] = self._make_prediction_recursion_kl(metatree)
2419+
return self.hn_metatree_prob_vec @ tmp_pred_dist_vec
2420+
else:
2421+
raise(CriteriaError("Unsupported loss function! \"KL\" is supported "
2422+
+"only when self.SubModel is bernoulli or categorical."))
24212423
else:
24222424
raise(CriteriaError("Unsupported loss function! "
24232425
+"This function supports \"squared\", \"0-1\", and \"KL\"."))

0 commit comments

Comments
 (0)