Skip to content

Commit 80e55cd

Browse files
Update README.md
1 parent 27bf05a commit 80e55cd

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

README.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,41 @@
1919

2020
This repository implements an educational project for the Bayesian Multimodeling course. It implements algorithms for sampling from various distributions, using the implicit reparameterization trick.
2121

22-
## Описание
23-
24-
В этом репозитории реализован учебный проект для курса Байесовское мультимоделирование. В нем реализуются алгоритмы сэмплирования из различных распределений, используя implicit reparametriation trick.
22+
## Scope
23+
We plan to implement the following distributions in our library:
24+
- Gaussian normal distribution (*)
25+
- Dirichlet distribution (Beta distributions)(\*)
26+
- Sampling from a mixture of distributions
27+
- Sampling from the Student's t-distribution (**) (\*)
28+
- Sampling from an arbitrary factorized distribution (***)
29+
30+
(\*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison
31+
32+
(\*\*) - this distribution is added as a backup, their inclusion is questionable
33+
34+
(\*\*\*) - this distribution is not very clear in implementation, its inclusion is questionable
35+
36+
## Stack
37+
38+
We plan to inherit from the torch.distribution.Distribution class, so we need to implement all the methods that are present in that class.
39+
40+
## Usage
41+
In this example, we demonstrate the application of our library using a Variational Autoencoder (VAE) model, where the latent layer is modified by a normal distribution.
42+
```
43+
>>> import torch.distributions.implicit as irt
44+
>>> params = Encoder(inputs)
45+
>>> gauss = irt.Normal(*params)
46+
>>> deviated = gauss.rsample()
47+
>>> outputs = Decoder(deviated)
48+
```
49+
In this example, we demonstrate the use of a mixture of distributions using our library.
50+
```
51+
>>> import irt
52+
>>> params = Encoder(inputs)
53+
>>> mix = irt.Mixture([irt.Normal(*params), irt.Dirichlet(*params)])
54+
>>> deviated = mix.rsample()
55+
>>> outputs = Decoder(deviated)
56+
```
2557

2658
## Links
2759
- [LinkReview](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/linkreview.md)

0 commit comments

Comments
 (0)