diff --git a/code/demo_v.ipynb b/code/demo_v.ipynb
new file mode 100644
index 0000000..f0e5dc3
--- /dev/null
+++ b/code/demo_v.ipynb
@@ -0,0 +1,363 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jUA8NsH6G4Td"
+ },
+ "source": [
+ "# Implicit reparametrization trick: demonstration"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a6KZPv2ZG4Te"
+ },
+ "source": [
+ "Hear we demonstrate usage of library **torch.distributions.implicit** that we implemented within group of four students at MIPT."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GKRLg4TVG4Tf"
+ },
+ "source": [
+ "### Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-fRLPlRTG4Tf"
+ },
+ "source": [
+ "The vanilla ELBO estimation in generative models like VAE uses reparametrization trick: a method of sampling random variables with low variance of the gradient of parameter distribution. However, this method is available only for the limited number of distributions. For instance, it cannot be applied to some important continuous standard distributions such as Dirichlet, Beta, Gamma and mixture of distributions.\n",
+ "\n",
+ "The implicit reparameterization trick (IRT), being a modification of standart reparametrization trick, is much more expressive and applicable to a wider class of distributions. For a demonstration of its application in VAE, see [vae_experiment.ipynb](https://github.com/intsystems/implicit-reparameterization-trick/blob/main/code/vae_experiment.ipynb).\n",
+ "\n",
+ "Let's define the problem more precisely. Suppose we would like to optimize the expectation $\\mathbb{E}_{q_{\\phi}(z)}[f(z)]$ of some function $f(z)$ w.r.t. the parameters $\\phi$ of the distribution. We will assume that there exists special standardization function $S_{\\phi}(z)$ that eliminates dependence on the distribution's parameters:\n",
+ "\n",
+ "$$\n",
+ "S_{\\phi}(z)=\\varepsilon \\sim q(\\varepsilon),\\quad z = S_{\\phi}^{-1}(\\varepsilon).\n",
+ "$$\n",
+ "\n",
+ "The table below outlines the optimization scheme when applying different reparameterization techniques to solve an optimization problem.\n",
+ "\n",
+ "\n",
+ "\n",
+ "For more mathematical details, please refer to the [blog post](https://habr.com/ru/articles/860686/) for a comprehensive description of these two reparameterization trick methods."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jvTWFGoRG4Tf"
+ },
+ "source": [
+ "### Installation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-uCIbKv_G4Tf"
+ },
+ "source": [
+ "To install IRT library use the following command:\n",
+ "\n",
+ "```\n",
+ "!git clone https://github.com/intsystems/implicit-reparameterization-trick.git\n",
+ "!pip install implicit-reparameterization-trick/src\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CxdX6MesG4Tf"
+ },
+ "source": [
+ "### Scope"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tn1U9QG0G4Tf"
+ },
+ "source": [
+ "We implemented several distributions with reparametrized sampling ability using implicit reparametrisation trick:\n",
+ "\n",
+ "- Gaussian normal distribution\n",
+ "\n",
+ "- Dirichlet distribution (Beta distrbutioin)\n",
+ "\n",
+ "- Gamma distrbutioin\n",
+ "\n",
+ "- Mixture of distributions\n",
+ "\n",
+ "- Student-t distrbutioin\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gpPxugOiG4Tf"
+ },
+ "source": [
+ "### Usage example"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LZNqMDGpG4Tf"
+ },
+ "source": [
+ "Here is an example of sampling from the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) distribution:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Import standard libraries\n",
+ "import sys\n",
+ "import os\n",
+ "import torch\n",
+ "from torch import nn"
+ ],
+ "metadata": {
+ "id": "irRTc6dHKSF3"
+ },
+ "execution_count": 63,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Import IRT library\n",
+ "!git clone https://github.com/intsystems/implicit-reparameterization-trick.git\n",
+ "!pip install implicit-reparameterization-trick/src"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "collapsed": true,
+ "id": "TJNhlmrG7EMJ",
+ "outputId": "8e38045d-5816-4e60-b8bc-8042827c0cd2"
+ },
+ "execution_count": 64,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "fatal: destination path 'implicit-reparameterization-trick' already exists and is not an empty directory.\n",
+ "Processing ./implicit-reparameterization-trick/src\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from irt==0.0.1) (2.5.1+cu121)\n",
+ "Requirement already satisfied: torchvision>=0.16.0 in /usr/local/lib/python3.10/dist-packages (from irt==0.0.1) (0.20.1+cu121)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (3.16.1)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (4.12.2)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (3.1.4)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (2024.10.0)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->irt==0.0.1) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=2.1.0->irt==0.0.1) (1.3.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision>=0.16.0->irt==0.0.1) (1.26.4)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision>=0.16.0->irt==0.0.1) (11.0.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->irt==0.0.1) (3.0.2)\n",
+ "Building wheels for collected packages: irt\n",
+ " Building wheel for irt (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for irt: filename=irt-0.0.1-py3-none-any.whl size=8045 sha256=157a18dcbfe5fa3c389497575c660350cc60cd2e66108d385844a8cc104cf308\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-07cz0fkb/wheels/ba/1a/17/61d086780df8ffccaf66dc930d4af3090f18898673bd63f4ca\n",
+ "Successfully built irt\n",
+ "Installing collected packages: irt\n",
+ " Attempting uninstall: irt\n",
+ " Found existing installation: irt 0.0.1\n",
+ " Uninstalling irt-0.0.1:\n",
+ " Successfully uninstalled irt-0.0.1\n",
+ "Successfully installed irt-0.0.1\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 65,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 692
+ },
+ "id": "XQBX7iMWG4Tg",
+ "outputId": "b452fec2-315a-4a59-fff4-4cce91f6490f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor([0.1060, 0.3866]) tensor([-0.0333, -0.2790])\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 65
+ }
+ ],
+ "source": [
+ "# Define parameters for distribution\n",
+ "alpha, beta = torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([2.0, 3.0], requires_grad=True)\n",
+ "# Define Gamma distribution from IRT library\n",
+ "z = irt.Gamma(alpha, beta).rsample()\n",
+ "# Calculate loss\n",
+ "loss = torch.mean(z ** 2)\n",
+ "# Backpropogate loss through distribution\n",
+ "loss.backward()\n",
+ "\n",
+ "print(alpha.grad, beta.grad)\n",
+ "torchviz.make_dot(loss, params = {'alpha': alpha, 'beta': beta})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cwlwbQegG4Tg"
+ },
+ "source": [
+ "Below is a demonstration of using [Beta](https://en.wikipedia.org/wiki/Beta_distribution) distribution to train a simple stochastic model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {
+ "id": "rLb0DCYnG4Tg"
+ },
+ "outputs": [],
+ "source": [
+ "# Define simple encoder for model\n",
+ "class SimpleEncoder(nn.Module):\n",
+ " def __init__(self, input_dim, hidden_dim):\n",
+ " super().__init__()\n",
+ " super(SimpleEncoder, self).__init__()\n",
+ " self.fc = nn.Linear(input_dim, hidden_dim)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.fc(x)\n",
+ "\n",
+ "# Define simple decoder for model\n",
+ "class SimpleDecoder(nn.Module):\n",
+ " def __init__(self, hidden_dim, output_dim):\n",
+ " super().__init__()\n",
+ " super(SimpleDecoder, self).__init__()\n",
+ " self.fc = nn.Linear(hidden_dim, output_dim)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.fc(x)\n",
+ "\n",
+ "# Define simple model\n",
+ "class SimpleModel(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.encoder = SimpleEncoder(20, 10)\n",
+ " self.decoder = SimpleEncoder(5, 20)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = nn.functional.softmax(self.encoder(x), dim=0)\n",
+ " alpha, beta = x.chunk(2, dim=0)\n",
+ " y = irt.Beta(alpha, beta).rsample()\n",
+ " y = self.decoder(y)\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {
+ "id": "6L0djQVCG4Th"
+ },
+ "outputs": [],
+ "source": [
+ "model = SimpleModel()\n",
+ "x = torch.randn(20)\n",
+ "y = model(x)\n",
+ "loss = y.mean()\n",
+ "loss.backward()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WlD1oeQOG4Th"
+ },
+ "source": [
+ "### Development"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0KTd4OzEG4Th"
+ },
+ "source": [
+ "If you want to use the code of the IRT library in your project, it is essential to ensure that your implementation successfully passes all the necessary [tests](https://github.com/intsystems/implicit-reparameterization-trick/blob/main/code/run_unittest.py). "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bnTEpaGGG4Th"
+ },
+ "source": [
+ "### Documentation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a4ROycAmG4Th"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "For comprehensive information and guidance regarding the features and functionalities of IRT library, you can find the full documentation by following this [link](https://intsystems.github.io/implicit-reparameterization-trick/)."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.15"
+ },
+ "colab": {
+ "provenance": []
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file