Skip to content

Commit 8881934

Browse files
committed
Add more info to README.md
1 parent 50f4f75 commit 8881934

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ A minimal implementation of Gaussian Mixture Models in Jax
1717

1818
## Installation
1919

20+
`gmmx` can be installed via pip:
21+
2022
```bash
2123
pip install gmmx
2224
```
@@ -29,26 +31,44 @@ from gmmx import GaussianMixtureModelJax, EMFitter
2931
# Create a Gaussian Mixture Model with 16 components and 32 features
3032
gmm = GaussianMixtureModelJax.create(n_components=16, n_features=32)
3133

34+
# Draw samples from the model
3235
n_samples = 10_000
3336
x = gmm.sample(n_samples)
3437

3538
# Fit the model to the data
36-
em_fitter = EMFitter()
37-
gmm_fitted = em_fitter.fit(gmm, x)
39+
em_fitter = EMFitter(tol=1e-3, max_iter=100)
40+
gmm_fitted = em_fitter.fit(x=x, gmm=gmm)
3841
```
3942

43+
## Why Gaussian Mixture models?
44+
45+
What are Gaussian Mixture Models (GMM) useful for in the age of deep learning? GMMs might have come out of fashion for classification tasks, but they still
46+
have a few properties that make them useful in certain scenarios:
47+
48+
- They are universal approximators, meaning that given enough components they can approximate any distribution.
49+
- Their likelihood can be evaluated in closed form, which makes them useful for generative modeling.
50+
- They are rather fast to train and evaluate.
51+
52+
One of these applications is in the context of image reconstruction, where GMMs can be used to model the distribution and pixel correlations of local (patch based)
53+
image features. This can be useful for tasks like image denoising or inpainting. One of these methods I have used them for is [Jolideco](https://github.com/jolideco/jolideco).
54+
Speed up the training of O(10^6) patches was the main motivation for `gmmx`.
55+
4056
## Benchmarks
4157

4258
Here are some results from the benchmarks in the `benchmarks` folder comparing against Scikit-Learn. The benchmarks were run on a 2021 MacBook Pro with an M1 Pro chip.
4359

44-
### Prediction Time
60+
### Prediction
4561

4662
| Time vs. Number of Components | Time vs. Number of Samples | Time vs. Number of Features |
4763
| ------------------------------------------------------------------------------- | ------------------------------------------------------------------------- | --------------------------------------------------------------------------- |
4864
| ![Time vs. Number of Components](docs/_static/time-vs-n-components-predict.png) | ![Time vs. Number of Samples](docs/_static/time-vs-n-samples-predict.png) | ![Time vs. Number of Features](docs/_static/time-vs-n-features-predict.png) |
4965

66+
For prediction the speedup is around 2x for varying number of components and features. For the number of samples the cross-over point is around O(10^4) samples.
67+
5068
### Training Time
5169

5270
| Time vs. Number of Components | Time vs. Number of Samples | Time vs. Number of Features |
5371
| --------------------------------------------------------------------------- | --------------------------------------------------------------------- | ----------------------------------------------------------------------- |
5472
| ![Time vs. Number of Components](docs/_static/time-vs-n-components-fit.png) | ![Time vs. Number of Samples](docs/_static/time-vs-n-samples-fit.png) | ![Time vs. Number of Features](docs/_static/time-vs-n-features-fit.png) |
73+
74+
For training the speedup is around 10x on the same architecture. However there is no guarantee that it will converge to the same solution as Scikit-Learn. But there are some tests in the `tests` folder that compare the results of the two implementations.

0 commit comments

Comments
 (0)