Skip to content

Commit a6e6e0d

Browse files
committed
make_prediction and pred_and_update
1 parent df0f3f2 commit a6e6e0d

File tree

1 file changed

+56
-54
lines changed

1 file changed

+56
-54
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,59 +1104,61 @@ def calc_pred_dist(self):
11041104
self.p_lambda_mats_inv[:] = np.linalg.inv(self.p_lambda_mats)
11051105

11061106
def make_prediction(self,loss="squared"):
1107-
pass
1108-
# """Predict a new data point under the given criterion.
1109-
1110-
# Parameters
1111-
# ----------
1112-
# loss : str, optional
1113-
# Loss function underlying the Bayes risk function, by default \"squared\".
1114-
# This function supports \"squared\" and \"0-1\".
1115-
1116-
# Returns
1117-
# -------
1118-
# Predicted_value : {float, numpy.ndarray}
1119-
# The predicted value under the given loss function.
1120-
# """
1121-
# if loss == "squared":
1122-
# return np.sum(self.p_pi_vec[:,np.newaxis] * self.p_mu_vecs, axis=0)
1123-
# elif loss == "0-1":
1124-
# tmp_max = -1.0
1125-
# tmp_argmax = np.empty([self.degree])
1126-
# for k in range(self.num_classes):
1127-
# val = ss_multivariate_t.pdf(x=self.p_mu_vecs[k],
1128-
# loc=self.p_mu_vecs[k],
1129-
# shape=self.p_lambda_mats_inv[k],
1130-
# df=self.p_nus[k])
1131-
# if val * self.p_pi_vec[k] > tmp_max:
1132-
# tmp_argmax[:] = self.p_mu_vecs[k]
1133-
# tmp_max = val * self.p_pi_vec[k]
1134-
# return tmp_argmax
1135-
# else:
1136-
# raise(CriteriaError("Unsupported loss function! "
1137-
# +"This function supports \"squared\", \"0-1\", and \"KL\"."))
1107+
"""Predict a new data point under the given criterion.
1108+
1109+
Parameters
1110+
----------
1111+
loss : str, optional
1112+
Loss function underlying the Bayes risk function, by default \"squared\".
1113+
This function supports \"squared\" and \"0-1\".
1114+
1115+
Returns
1116+
-------
1117+
Predicted_value : {float, numpy.ndarray}
1118+
The predicted value under the given loss function.
1119+
"""
1120+
if loss == "squared":
1121+
return np.sum(self.p_pi_vec[:,np.newaxis] * self.p_mu_vecs, axis=0)
1122+
elif loss == "0-1":
1123+
tmp_max = -1.0
1124+
tmp_argmax = np.empty([self.degree])
1125+
for k in range(self.num_classes):
1126+
val = ss_multivariate_t.pdf(x=self.p_mu_vecs[k],
1127+
loc=self.p_mu_vecs[k],
1128+
shape=self.p_lambda_mats_inv[k],
1129+
df=self.p_nus[k])
1130+
if val * self.p_pi_vec[k] > tmp_max:
1131+
tmp_argmax[:] = self.p_mu_vecs[k]
1132+
tmp_max = val * self.p_pi_vec[k]
1133+
return tmp_argmax
1134+
else:
1135+
raise(CriteriaError(f"loss={loss} is unsupported. "
1136+
+"This function supports \"squared\" and \"0-1\"."))
11381137

11391138
def pred_and_update(self,x,loss="squared"):
1140-
pass
1141-
# """Predict a new data point and update the posterior sequentially.
1142-
1143-
# Parameters
1144-
# ----------
1145-
# x : numpy.ndarray
1146-
# It must be a degree-dimensional vector
1147-
# loss : str, optional
1148-
# Loss function underlying the Bayes risk function, by default \"squared\".
1149-
# This function supports \"squared\", \"0-1\", and \"KL\".
1150-
1151-
# Returns
1152-
# -------
1153-
# Predicted_value : {float, numpy.ndarray}
1154-
# The predicted value under the given loss function.
1155-
# """
1156-
# _check.float_vec(x,'x',DataFormatError)
1157-
# if x.shape != (self.degree,):
1158-
# raise(DataFormatError(f"x must be a 1-dimensional float array whose size is degree: {self.degree}."))
1159-
# self.calc_pred_dist()
1160-
# prediction = self.make_prediction(loss=loss)
1161-
# self.update_posterior(x[np.newaxis,:])
1162-
# return prediction
1139+
"""Predict a new data point and update the posterior sequentially.
1140+
1141+
h0_params will be overwritten by current hn_params
1142+
before updating hn_params by x
1143+
1144+
Parameters
1145+
----------
1146+
x : numpy.ndarray
1147+
It must be a `degree`-dimensional vector
1148+
loss : str, optional
1149+
Loss function underlying the Bayes risk function, by default \"squared\".
1150+
This function supports \"squared\" and \"0-1\".
1151+
1152+
Returns
1153+
-------
1154+
Predicted_value : {float, numpy.ndarray}
1155+
The predicted value under the given loss function.
1156+
"""
1157+
_check.float_vec(x,'x',DataFormatError)
1158+
if x.shape != (self.degree,):
1159+
raise(DataFormatError(f"x must be a 1-dimensional float array whose size is degree: {self.degree}."))
1160+
self.calc_pred_dist()
1161+
prediction = self.make_prediction(loss=loss)
1162+
self.overwrite_h0_params()
1163+
self.update_posterior(x[np.newaxis,:])
1164+
return prediction

0 commit comments

Comments
 (0)