Skip to content

Commit 0dbbac9

Browse files
authored
Gaussian mixture lower bounds (scikit-learn#28559)
1 parent 6d7ff73 commit 0dbbac9

File tree

5 files changed

+19
-0
lines changed

5 files changed

+19
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- Added an attribute `lower_bounds_` in the :class:`mixture.BaseMixture`
2+
class to save the list of lower bounds for each iteration thereby providing
3+
insights into the convergence behavior of mixture models like
4+
:class:`mixture.GaussianMixture`.
5+
By :user:`Manideep Yenugula <myenugula>`

sklearn/mixture/_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def fit_predict(self, X, y=None):
224224
n_init = self.n_init if do_init else 1
225225

226226
max_lower_bound = -np.inf
227+
best_lower_bounds = []
227228
self.converged_ = False
228229

229230
random_state = check_random_state(self.random_state)
@@ -236,6 +237,7 @@ def fit_predict(self, X, y=None):
236237
self._initialize_parameters(X, random_state)
237238

238239
lower_bound = -np.inf if do_init else self.lower_bound_
240+
current_lower_bounds = []
239241

240242
if self.max_iter == 0:
241243
best_params = self._get_parameters()
@@ -248,6 +250,7 @@ def fit_predict(self, X, y=None):
248250
log_prob_norm, log_resp = self._e_step(X)
249251
self._m_step(X, log_resp)
250252
lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)
253+
current_lower_bounds.append(lower_bound)
251254

252255
change = lower_bound - prev_lower_bound
253256
self._print_verbose_msg_iter_end(n_iter, change)
@@ -262,6 +265,7 @@ def fit_predict(self, X, y=None):
262265
max_lower_bound = lower_bound
263266
best_params = self._get_parameters()
264267
best_n_iter = n_iter
268+
best_lower_bounds = current_lower_bounds
265269
self.converged_ = converged
266270

267271
# Should only warn about convergence if max_iter > 0, otherwise
@@ -280,6 +284,7 @@ def fit_predict(self, X, y=None):
280284
self._set_parameters(best_params)
281285
self.n_iter_ = best_n_iter
282286
self.lower_bound_ = max_lower_bound
287+
self.lower_bounds_ = best_lower_bounds
283288

284289
# Always do a final e-step to guarantee that the labels returned by
285290
# fit_predict(X) are always consistent with fit(X).predict(X)

sklearn/mixture/_bayesian_mixture.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ class BayesianGaussianMixture(BaseMixture):
254254
Lower bound value on the model evidence (of the training data) of the
255255
best fit of inference.
256256
257+
lower_bounds_ : array-like of shape (`n_iter_`,)
258+
The list of lower bound values on the model evidence from each iteration
259+
of the best fit of inference.
260+
257261
weight_concentration_prior_ : tuple or float
258262
The dirichlet concentration of each component on the weight
259263
distribution (Dirichlet). The type depends on

sklearn/mixture/_gaussian_mixture.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ class GaussianMixture(BaseMixture):
669669
Lower bound value on the log-likelihood (of the training data with
670670
respect to the model) of the best fit of EM.
671671
672+
lower_bounds_ : array-like of shape (`n_iter_`,)
673+
The list of lower bound values on the log-likelihood from each
674+
iteration of the best fit of EM.
675+
672676
n_features_in_ : int
673677
Number of features seen during :term:`fit`.
674678

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,7 @@ def test_gaussian_mixture_setting_best_params():
12361236
"precisions_cholesky_",
12371237
"n_iter_",
12381238
"lower_bound_",
1239+
"lower_bounds_",
12391240
]:
12401241
assert hasattr(gmm, attr)
12411242

0 commit comments

Comments
 (0)