Skip to content

Commit f0862f7

Browse files
tguillemotogrisel
authored andcommitted
[MRG+1] Bayesian Gaussian Mixture (Integration of GSoC2015 -- second step) (scikit-learn#6651)
* Add the new BayesianGaussianMixture class. Add the test file for the BayesianGaussianMixture. * Add the use of the cholesky decomposition of the precision matrix. * Fix some bugs. * Modification of GaussianMixture class. The purpose here is to prepare the integration of BayesianGaussianMixture. * Fix comments. * Modification of the Docstring. * Add license and author. * Fix pb typo of eq 10.64 and 10.62. * Correct VBGMM bugs. * Fix full version. * Fix the precision normalisation pb. * Fix all cov_type algo for BayesianGaussianMixture. * Optimisation of spherical and diag computation. * Code simplification. * Check the Gaussian Mixture tests are ok. * Add test. * Add new tests for BayesianGaussianMixture and GaussianMixture. * Add the bayesian_gaussian_example and the doc. * Fix comments. * Fix review comments and add license and author. * Fix test compare covar type. * Fix reviews. * Fix tests. * Fix review comments. * Correct reviews. * Fix travis pb. * Fix circleci pb. * Fix review comments. * Fix typo. * Fix comments. Add reg_covar and what's new. * Fix comments. * Fix comments. * [ci skip] Correct legend.
1 parent e2648b1 commit f0862f7

File tree

12 files changed

+1274
-72
lines changed

12 files changed

+1274
-72
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,8 @@ See the :ref:`metrics` section of the user guide for further details.
954954
:template: class.rst
955955

956956
mixture.GaussianMixture
957+
mixture.BayesianGaussianMixture
957958
mixture.DPGMM
958-
mixture.VBGMM
959959

960960

961961
.. _multiclass_ref:

doc/modules/mixture.rst

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -133,40 +133,13 @@ parameters to maximize the likelihood of the data given those
133133
assignments. Repeating this process is guaranteed to always converge
134134
to a local optimum.
135135

136-
.. _vbgmm:
136+
.. _bgmm:
137137

138-
VBGMM: variational Gaussian mixtures
139-
====================================
138+
Bayesian Gaussian Mixture
139+
=========================
140140

141-
The :class:`VBGMM` object implements a variant of the Gaussian mixture
142-
model with :ref:`variational inference <variational_inference>` algorithms.
143-
144-
Pros and cons of class :class:`VBGMM`: variational inference
145-
------------------------------------------------------------
146-
147-
Pros
148-
.....
149-
150-
:Regularization: due to the incorporation of prior information,
151-
variational solutions have less pathological special cases than
152-
expectation-maximization solutions. One can then use full
153-
covariance matrices in high dimensions or in cases where some
154-
components might be centered around a single point without
155-
risking divergence.
156-
157-
Cons
158-
.....
159-
160-
:Bias: to regularize a model one has to add biases. The
161-
variational algorithm will bias all the means towards the origin
162-
(part of the prior information adds a "ghost point" in the origin
163-
to every mixture component) and it will bias the covariances to
164-
be more spherical. It will also, depending on the concentration
165-
parameter, bias the cluster structure either towards uniformity
166-
or towards a rich-get-richer scenario.
167-
168-
:Hyperparameters: this algorithm needs an extra hyperparameter
169-
that might need experimental tuning via cross-validation.
141+
The :class:`BayesianGaussianMixture` object implements a variant of the Gaussian
142+
mixture model with variational inference algorithms.
170143

171144
.. _variational_inference:
172145

@@ -175,7 +148,7 @@ Estimation algorithm: variational inference
175148

176149
Variational inference is an extension of expectation-maximization that
177150
maximizes a lower bound on model evidence (including
178-
priors) instead of data likelihood. The principle behind
151+
priors) instead of data likelihood. The principle behind
179152
variational methods is the same as expectation-maximization (that is
180153
both are iterative algorithms that alternate between finding the
181154
probabilities for each point to be generated by each mixture and
@@ -188,13 +161,54 @@ much so as to render usage unpractical.
188161

189162
Due to its Bayesian nature, the variational algorithm needs more
190163
hyper-parameters than expectation-maximization, the most
191-
important of these being the concentration parameter ``alpha``. Specifying
192-
a high value of alpha leads more often to uniformly-sized mixture
164+
important of these being the concentration parameter ``dirichlet_concentration_prior``. Specifying
165+
a high value of prior of the dirichlet concentration leads more often to uniformly-sized mixture
193166
components, while specifying small (between 0 and 1) values will lead
194167
to some mixture components getting almost all the points while most
195168
mixture components will be centered on just a few of the remaining
196169
points.
197170

171+
.. figure:: ../auto_examples/mixture/images/sphx_glr_plot_bayesian_gaussian_mixture_001.png
172+
:target: ../auto_examples/mixture/plot_bayesian_gaussian_mixture.html
173+
:align: center
174+
:scale: 50%
175+
176+
.. topic:: Examples:
177+
178+
* See :ref:`plot_bayesian_gaussian_mixture.py` for a comparaison of
179+
the results of the ``BayesianGaussianMixture`` for different values
180+
of the parameter ``dirichlet_concentration_prior``.
181+
182+
Pros and cons of variational inference with :class:BayesianGaussianMixture
183+
--------------------------------------------------------------------------
184+
185+
Pros
186+
.....
187+
188+
:Regularization: due to the incorporation of prior information,
189+
variational solutions have less pathological special cases than
190+
expectation-maximization solutions.
191+
192+
:Automatic selection: when `dirichlet_concentration_prior` is small enough and
193+
`n_components` is larger than what is found necessary by the model, the
194+
Variational Bayesian mixture model has a natural tendency to set some mixture
195+
weights values close to zero. This makes it possible to let the model choose a
196+
suitable number of effective components automatically.
197+
198+
Cons
199+
.....
200+
201+
:Bias: to regularize a model one has to add biases. The
202+
variational algorithm will bias all the means towards the origin
203+
(part of the prior information adds a "ghost point" in the origin
204+
to every mixture component) and it will bias the covariances to
205+
be more spherical. It will also, depending on the concentration
206+
parameter, bias the cluster structure either towards uniformity
207+
or towards a rich-get-richer scenario.
208+
209+
:Hyperparameters: this algorithm needs an extra hyperparameter
210+
that might need experimental tuning via cross-validation.
211+
198212
.. _dpgmm:
199213

200214
DPGMM: Infinite Gaussian mixtures

doc/whats_new.rst

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ Model Selection Enhancements and API Changes
6464

6565
- **Parameters ``n_folds`` and ``n_iter`` renamed to ``n_splits``**
6666

67-
Some parameter names have changed:
68-
The ``n_folds`` parameter in :class:`model_selection.KFold`,
69-
:class:`model_selection.LabelKFold`, and
67+
Some parameter names have changed:
68+
The ``n_folds`` parameter in :class:`model_selection.KFold`,
69+
:class:`model_selection.LabelKFold`, and
7070
:class:`model_selection.StratifiedKFold` is now renamed to ``n_splits``.
7171
The ``n_iter`` parameter in :class:`model_selection.ShuffleSplit`,
72-
:class:`model_selection.LabelShuffleSplit`,
73-
and :class:`model_selection.StratifiedShuffleSplit` is now renamed
72+
:class:`model_selection.LabelShuffleSplit`,
73+
and :class:`model_selection.StratifiedShuffleSplit` is now renamed
7474
to ``n_splits``.
7575

7676

@@ -141,8 +141,8 @@ New features
141141
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
142142
Liu`_
143143

144-
- Added new cross-validation splitter
145-
:class:`model_selection.TimeSeriesSplit` to handle time series data.
144+
- Added new cross-validation splitter
145+
:class:`model_selection.TimeSeriesSplit` to handle time series data.
146146
(`#6586
147147
<https://github.com/scikit-learn/scikit-learn/pull/6586>`_) by `YenChen
148148
Lin`_
@@ -402,10 +402,19 @@ API changes summary
402402
- Access to public attributes ``.X_`` and ``.y_`` has been deprecated in
403403
:class:`isotonic.IsotonicRegression`. By `Jonathan Arfa`_.
404404

405+
- The old :class:`VBGMM` is deprecated in favor of the new
406+
:class:`BayesianGaussianMixture`. The new class solves the computational
407+
problems of the old class and computes the Variational Bayesian Gaussian
408+
mixture faster than before.
409+
Ref :ref:`b` for more information.
410+
(`#6651 <https://github.com/scikit-learn/scikit-learn/pull/6651>`_) by
411+
`Wei Xue`_ and `Thierry Guillemot`_.
412+
405413
- The old :class:`GMM` is deprecated in favor of the new
406414
:class:`GaussianMixture`. The new class computes the Gaussian mixture
407415
faster than before and some of computational problems have been solved.
408-
By `Wei Xue`_ and `Thierry Guillemot`_.
416+
(`#6666 <https://github.com/scikit-learn/scikit-learn/pull/6666>`_) by
417+
`Wei Xue`_ and `Thierry Guillemot`_.
409418

410419
- The ``grid_scores_`` attribute of :class:`model_selection.GridSearchCV`
411420
and :class:`model_selection.RandomizedSearchCV` is deprecated in favor of
@@ -415,7 +424,7 @@ API changes summary
415424
`Raghav R V`_.
416425

417426
- The parameters ``n_iter`` or ``n_folds`` in old CV splitters are replaced
418-
by the new parameter ``n_splits`` since it can provide a consistent
427+
by the new parameter ``n_splits`` since it can provide a consistent
419428
and unambiguous interface to represent the number of train-test splits.
420429
(`#7187 <https://github.com/scikit-learn/scikit-learn/pull/7187>`_)
421430
by `YenChen Lin`_.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
======================================================
3+
Bayesian Gaussian Mixture Concentration Prior Analysis
4+
======================================================
5+
6+
Plot the resulting ellipsoids of a mixture of three Gaussians with
7+
variational Bayesian Gaussian Mixture for three different values on the
8+
prior the dirichlet concentration.
9+
10+
For all models, the Variationnal Bayesian Gaussian Mixture adapts its number of
11+
mixture automatically. The parameter `dirichlet_concentration_prior` has a
12+
direct link with the resulting number of components. Specifying a high value of
13+
`dirichlet_concentration_prior` leads more often to uniformly-sized mixture
14+
components, while specifying small (under 0.1) values will lead to some mixture
15+
components getting almost all the points while most mixture components will be
16+
centered on just a few of the remaining points.
17+
"""
18+
# Author: Thierry Guillemot <thierry.guillemot.work@gmail.com>
19+
# License: BSD 3 clause
20+
21+
import numpy as np
22+
import matplotlib as mpl
23+
import matplotlib.pyplot as plt
24+
import matplotlib.gridspec as gridspec
25+
26+
from sklearn.mixture import BayesianGaussianMixture
27+
28+
print(__doc__)
29+
30+
31+
def plot_ellipses(ax, weights, means, covars):
32+
for n in range(means.shape[0]):
33+
v, w = np.linalg.eigh(covars[n][:2, :2])
34+
u = w[0] / np.linalg.norm(w[0])
35+
angle = np.arctan2(u[1], u[0])
36+
angle = 180 * angle / np.pi # convert to degrees
37+
v = 2 * np.sqrt(2) * np.sqrt(v)
38+
ell = mpl.patches.Ellipse(means[n, :2], v[0], v[1], 180 + angle)
39+
ell.set_clip_box(ax.bbox)
40+
ell.set_alpha(weights[n])
41+
ax.add_artist(ell)
42+
43+
44+
def plot_results(ax1, ax2, estimator, dirichlet_concentration_prior, X, y, plot_title=False):
45+
estimator.dirichlet_concentration_prior = dirichlet_concentration_prior
46+
estimator.fit(X)
47+
ax1.set_title("Bayesian Gaussian Mixture for "
48+
r"$dc_0=%.1e$" % dirichlet_concentration_prior)
49+
# ax1.axis('equal')
50+
ax1.scatter(X[:, 0], X[:, 1], s=5, marker='o', color=colors[y], alpha=0.8)
51+
ax1.set_xlim(-2., 2.)
52+
ax1.set_ylim(-3., 3.)
53+
ax1.set_xticks(())
54+
ax1.set_yticks(())
55+
plot_ellipses(ax1, estimator.weights_, estimator.means_,
56+
estimator.covariances_)
57+
58+
ax2.get_xaxis().set_tick_params(direction='out')
59+
ax2.yaxis.grid(True, alpha=0.7)
60+
for k, w in enumerate(estimator.weights_):
61+
ax2.bar(k - .45, w, width=0.9, color='royalblue', zorder=3)
62+
ax2.text(k, w + 0.007, "%.1f%%" % (w * 100.),
63+
horizontalalignment='center')
64+
ax2.set_xlim(-.6, 2 * n_components - .4)
65+
ax2.set_ylim(0., 1.1)
66+
ax2.tick_params(axis='y', which='both', left='off',
67+
right='off', labelleft='off')
68+
ax2.tick_params(axis='x', which='both', top='off')
69+
70+
if plot_title:
71+
ax1.set_ylabel('Estimated Mixtures')
72+
ax2.set_ylabel('Weight of each component')
73+
74+
# Parameters
75+
random_state = 2
76+
n_components, n_features = 3, 2
77+
colors = np.array(['mediumseagreen', 'royalblue', 'r', 'gold',
78+
'orchid', 'indigo', 'darkcyan', 'tomato'])
79+
dirichlet_concentration_prior = np.logspace(-3, 3, 3)
80+
covars = np.array([[[.7, .0], [.0, .1]],
81+
[[.5, .0], [.0, .1]],
82+
[[.5, .0], [.0, .1]]])
83+
samples = np.array([200, 500, 200])
84+
means = np.array([[.0, -.70],
85+
[.0, .0],
86+
[.0, .70]])
87+
88+
89+
# Here we put beta_prior to 0.8 to minimize the influence of the prior for this
90+
# dataset
91+
estimator = BayesianGaussianMixture(n_components=2 * n_components,
92+
init_params='random', max_iter=1500,
93+
mean_precision_prior=.8, tol=1e-9,
94+
random_state=random_state)
95+
96+
# Generate data
97+
rng = np.random.RandomState(random_state)
98+
X = np.vstack([
99+
rng.multivariate_normal(means[j], covars[j], samples[j])
100+
for j in range(n_components)])
101+
y = np.concatenate([j * np.ones(samples[j], dtype=int)
102+
for j in range(n_components)])
103+
104+
# Plot Results
105+
plt.figure(figsize=(4.7 * 3, 8))
106+
plt.subplots_adjust(bottom=.04, top=0.95, hspace=.05, wspace=.05,
107+
left=.03, right=.97)
108+
109+
gs = gridspec.GridSpec(3, len(dirichlet_concentration_prior))
110+
for k, dc in enumerate(dirichlet_concentration_prior):
111+
plot_results(plt.subplot(gs[0:2, k]), plt.subplot(gs[2, k]),
112+
estimator, dc, X, y, plot_title=k == 0)
113+
114+
plt.show()

sklearn/mixture/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .dpgmm import DPGMM, VBGMM
99

1010
from .gaussian_mixture import GaussianMixture
11+
from .bayesian_mixture import BayesianGaussianMixture
1112

1213

1314
__all__ = ['DPGMM',
@@ -17,4 +18,5 @@
1718
'distribute_covar_matrix_to_match_covariance_type',
1819
'log_multivariate_normal_density',
1920
'sample_gaussian',
20-
'GaussianMixture']
21+
'GaussianMixture',
22+
'BayesianGaussianMixture']

sklearn/mixture/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def fit(self, X, y=None):
237237

238238
return self
239239

240-
@abstractmethod
241240
def _e_step(self, X):
242241
"""E step.
243242
@@ -248,12 +247,14 @@ def _e_step(self, X):
248247
Returns
249248
-------
250249
log_prob_norm : array, shape (n_samples,)
251-
log p(X)
250+
Logarithm of the probability of each sample in X.
252251
253252
log_responsibility : array, shape (n_samples, n_components)
254-
logarithm of the responsibilities
253+
Logarithm of the posterior probabilities (or responsibilities) of
254+
the point of each sample in X.
255255
"""
256-
pass
256+
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
257+
return np.mean(log_prob_norm), log_resp
257258

258259
@abstractmethod
259260
def _m_step(self, X, log_resp):
@@ -264,6 +265,8 @@ def _m_step(self, X, log_resp):
264265
X : array-like, shape (n_samples, n_features)
265266
266267
log_resp : array-like, shape (n_samples, n_components)
268+
Logarithm of the posterior probabilities (or responsibilities) of
269+
the point of each sample in X.
267270
"""
268271
pass
269272

0 commit comments

Comments
 (0)