Skip to content

Add files via upload #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 329 additions & 0 deletions code/Demo_version.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "jUA8NsH6G4Td"
},
"source": [
"# Implicit reparametrization trick: demonstration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a6KZPv2ZG4Te"
},
"source": [
"Hear we demonstrate usage of library **torch.distributions.implicit** that we implemented within group of four students at MIPT."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GKRLg4TVG4Tf"
},
"source": [
"### Introduction"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-fRLPlRTG4Tf"
},
"source": [
"The vanilla ELBO estimation in generative models like VAE uses reparametrization trick: a method of sampling random variables with low variance of the gradient of parameter distribution. However, this method is available only for the limited number of distributions. For instance, it cannot be applied to some important continuous standard distributions such as Dirichlet, Beta, Gamma and mixture of distributions.\n",
"\n",
"The implicit reparameterization trick (IRT), being a modification of standart reparametrization trick, is much more expressive and applicable to a wider class of distributions. For a demonstration of its application in VAE, see [vae_experiment.ipynb](https://github.com/intsystems/implicit-reparameterization-trick/blob/main/code/vae_experiment.ipynb).\n",
"\n",
"For more mathematical details, please refer to the **[blog post]** for a comprehensive description of these two reparameterization trick methods.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvTWFGoRG4Tf"
},
"source": [
"### Installation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-uCIbKv_G4Tf"
},
"source": [
"To install library use the following command:\n",
"\n",
"```\n",
"pip install ..\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CxdX6MesG4Tf"
},
"source": [
"### Scope"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tn1U9QG0G4Tf"
},
"source": [
"We implemented several distributions with reparametrized sampling ability using implicit reparametrisation trick:\n",
"\n",
"- Gaussian normal distribution\n",
"\n",
"- Dirichlet distribution (Beta distrbutioin)\n",
"\n",
"- Gamma distrbutioin\n",
"\n",
"- Mixture of distributions\n",
"\n",
"- Student-t distrbutioin\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gpPxugOiG4Tf"
},
"source": [
"### Usage example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LZNqMDGpG4Tf"
},
"source": [
"Here is an example of sampling from the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) distribution:"
]
},
{
"cell_type": "code",
"source": [
"# Import standard libraries\n",
"import sys\n",
"import os\n",
"import torch\n",
"from torch import nn"
],
"metadata": {
"id": "irRTc6dHKSF3"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Import IRT library\n",
"!git clone https://github.com/intsystems/implicit-reparameterization-trick.git /tmp/implicit-reparameterization-trick\n",
"sys.path.append(os.path.abspath('/tmp/implicit-reparameterization-trick/src/irt'))\n",
"import distributions as irt"
],
"metadata": {
"collapsed": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "N7qeViYtIF7A",
"outputId": "b975645b-d35b-4fba-8463-b4e1447a3723"
},
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fatal: destination path '/tmp/implicit-reparameterization-trick' already exists and is not an empty directory.\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 692
},
"id": "XQBX7iMWG4Tg",
"outputId": "6da1b970-6972-48aa-92dd-05729fcd2e33"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([0.2359, 0.8033]) tensor([-0.0937, -0.7642])\n"
]
},
{
"output_type": "execute_result",
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.43.0 (0)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"273pt\" height=\"490pt\"\n viewBox=\"0.00 0.00 273.00 490.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 486)\">\n<title>%3</title>\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-486 269,-486 269,4 -4,4\"/>\n<!-- 136127602221968 -->\n<g id=\"node1\" class=\"node\">\n<title>136127602221968</title>\n<polygon fill=\"#caff70\" stroke=\"black\" points=\"169.5,-31 115.5,-31 115.5,0 169.5,0 169.5,-31\"/>\n<text text-anchor=\"middle\" x=\"142.5\" y=\"-7\" font-family=\"monospace\" font-size=\"10.00\"> ()</text>\n</g>\n<!-- 136127602963744 -->\n<g id=\"node2\" class=\"node\">\n<title>136127602963744</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"190,-86 95,-86 95,-67 190,-67 190,-86\"/>\n<text text-anchor=\"middle\" x=\"142.5\" y=\"-74\" font-family=\"monospace\" font-size=\"10.00\">MeanBackward0</text>\n</g>\n<!-- 136127602963744&#45;&gt;136127602221968 -->\n<g id=\"edge13\" class=\"edge\">\n<title>136127602963744&#45;&gt;136127602221968</title>\n<path fill=\"none\" stroke=\"black\" d=\"M142.5,-66.79C142.5,-60.07 142.5,-50.4 142.5,-41.34\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"146,-41.19 142.5,-31.19 139,-41.19 146,-41.19\"/>\n</g>\n<!-- 136127602956064 -->\n<g id=\"node3\" class=\"node\">\n<title>136127602956064</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"187,-141 98,-141 98,-122 187,-122 187,-141\"/>\n<text text-anchor=\"middle\" x=\"142.5\" y=\"-129\" font-family=\"monospace\" font-size=\"10.00\">PowBackward0</text>\n</g>\n<!-- 136127602956064&#45;&gt;136127602963744 -->\n<g id=\"edge1\" class=\"edge\">\n<title>136127602956064&#45;&gt;136127602963744</title>\n<path fill=\"none\" stroke=\"black\" d=\"M142.5,-121.75C142.5,-114.8 142.5,-104.85 142.5,-96.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"146,-96.09 142.5,-86.09 139,-96.09 146,-96.09\"/>\n</g>\n<!-- 136127602956352 -->\n<g id=\"node4\" class=\"node\">\n<title>136127602956352</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"187,-196 98,-196 98,-177 187,-177 187,-196\"/>\n<text text-anchor=\"middle\" x=\"142.5\" y=\"-184\" font-family=\"monospace\" font-size=\"10.00\">AddBackward0</text>\n</g>\n<!-- 136127602956352&#45;&gt;136127602956064 -->\n<g id=\"edge2\" class=\"edge\">\n<title>136127602956352&#45;&gt;136127602956064</title>\n<path fill=\"none\" stroke=\"black\" d=\"M142.5,-176.75C142.5,-169.8 142.5,-159.85 142.5,-151.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"146,-151.09 142.5,-141.09 139,-151.09 146,-151.09\"/>\n</g>\n<!-- 136127602956304 -->\n<g id=\"node5\" class=\"node\">\n<title>136127602956304</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"134,-251 45,-251 45,-232 134,-232 134,-251\"/>\n<text text-anchor=\"middle\" x=\"89.5\" y=\"-239\" font-family=\"monospace\" font-size=\"10.00\">DivBackward0</text>\n</g>\n<!-- 136127602956304&#45;&gt;136127602956352 -->\n<g id=\"edge3\" class=\"edge\">\n<title>136127602956304&#45;&gt;136127602956352</title>\n<path fill=\"none\" stroke=\"black\" d=\"M98.25,-231.75C105.97,-224.03 117.4,-212.6 126.72,-203.28\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"129.31,-205.64 133.91,-196.09 124.36,-200.69 129.31,-205.64\"/>\n</g>\n<!-- 136127602950352 -->\n<g id=\"node6\" class=\"node\">\n<title>136127602950352</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"149,-306 0,-306 0,-287 149,-287 149,-306\"/>\n<text text-anchor=\"middle\" x=\"74.5\" y=\"-294\" font-family=\"monospace\" font-size=\"10.00\">StandardGammaBackward0</text>\n</g>\n<!-- 136127602950352&#45;&gt;136127602956304 -->\n<g id=\"edge4\" class=\"edge\">\n<title>136127602950352&#45;&gt;136127602956304</title>\n<path fill=\"none\" stroke=\"black\" d=\"M76.98,-286.75C78.96,-279.72 81.82,-269.62 84.31,-260.84\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"87.71,-261.66 87.07,-251.09 80.98,-259.76 87.71,-261.66\"/>\n</g>\n<!-- 136127602958368 -->\n<g id=\"node7\" class=\"node\">\n<title>136127602958368</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"128,-361 21,-361 21,-342 128,-342 128,-361\"/>\n<text text-anchor=\"middle\" x=\"74.5\" y=\"-349\" font-family=\"monospace\" font-size=\"10.00\">ExpandBackward0</text>\n</g>\n<!-- 136127602958368&#45;&gt;136127602950352 -->\n<g id=\"edge5\" class=\"edge\">\n<title>136127602958368&#45;&gt;136127602950352</title>\n<path fill=\"none\" stroke=\"black\" d=\"M74.5,-341.75C74.5,-334.8 74.5,-324.85 74.5,-316.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"78,-316.09 74.5,-306.09 71,-316.09 78,-316.09\"/>\n</g>\n<!-- 136127602958848 -->\n<g id=\"node8\" class=\"node\">\n<title>136127602958848</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"125,-416 24,-416 24,-397 125,-397 125,-416\"/>\n<text text-anchor=\"middle\" x=\"74.5\" y=\"-404\" font-family=\"monospace\" font-size=\"10.00\">AccumulateGrad</text>\n</g>\n<!-- 136127602958848&#45;&gt;136127602958368 -->\n<g id=\"edge6\" class=\"edge\">\n<title>136127602958848&#45;&gt;136127602958368</title>\n<path fill=\"none\" stroke=\"black\" d=\"M74.5,-396.75C74.5,-389.8 74.5,-379.85 74.5,-371.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"78,-371.09 74.5,-361.09 71,-371.09 78,-371.09\"/>\n</g>\n<!-- 136127602052208 -->\n<g id=\"node9\" class=\"node\">\n<title>136127602052208</title>\n<polygon fill=\"lightblue\" stroke=\"black\" points=\"101.5,-482 47.5,-482 47.5,-452 101.5,-452 101.5,-482\"/>\n<text text-anchor=\"middle\" x=\"74.5\" y=\"-470\" font-family=\"monospace\" font-size=\"10.00\">alpha</text>\n<text text-anchor=\"middle\" x=\"74.5\" y=\"-459\" font-family=\"monospace\" font-size=\"10.00\"> (2)</text>\n</g>\n<!-- 136127602052208&#45;&gt;136127602958848 -->\n<g id=\"edge7\" class=\"edge\">\n<title>136127602052208&#45;&gt;136127602958848</title>\n<path fill=\"none\" stroke=\"black\" d=\"M74.5,-451.84C74.5,-444.21 74.5,-434.7 74.5,-426.45\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"78,-426.27 74.5,-416.27 71,-426.27 78,-426.27\"/>\n</g>\n<!-- 136127602959616 -->\n<g id=\"node10\" class=\"node\">\n<title>136127602959616</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"248,-251 159,-251 159,-232 248,-232 248,-251\"/>\n<text text-anchor=\"middle\" x=\"203.5\" y=\"-239\" font-family=\"monospace\" font-size=\"10.00\">SubBackward0</text>\n</g>\n<!-- 136127602959616&#45;&gt;136127602956352 -->\n<g id=\"edge8\" class=\"edge\">\n<title>136127602959616&#45;&gt;136127602956352</title>\n<path fill=\"none\" stroke=\"black\" d=\"M193.7,-231.98C184.68,-224.15 171.09,-212.34 160.18,-202.86\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"162.32,-200.09 152.48,-196.17 157.73,-205.37 162.32,-200.09\"/>\n</g>\n<!-- 136127602957312 -->\n<g id=\"node11\" class=\"node\">\n<title>136127602957312</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"256,-306 167,-306 167,-287 256,-287 256,-306\"/>\n<text text-anchor=\"middle\" x=\"211.5\" y=\"-294\" font-family=\"monospace\" font-size=\"10.00\">DivBackward0</text>\n</g>\n<!-- 136127602957312&#45;&gt;136127602959616 -->\n<g id=\"edge9\" class=\"edge\">\n<title>136127602957312&#45;&gt;136127602959616</title>\n<path fill=\"none\" stroke=\"black\" d=\"M210.18,-286.75C209.13,-279.8 207.63,-269.85 206.31,-261.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"209.75,-260.45 204.8,-251.09 202.83,-261.5 209.75,-260.45\"/>\n</g>\n<!-- 136127602959376 -->\n<g id=\"node12\" class=\"node\">\n<title>136127602959376</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"265,-361 158,-361 158,-342 265,-342 265,-361\"/>\n<text text-anchor=\"middle\" x=\"211.5\" y=\"-349\" font-family=\"monospace\" font-size=\"10.00\">ExpandBackward0</text>\n</g>\n<!-- 136127602959376&#45;&gt;136127602957312 -->\n<g id=\"edge10\" class=\"edge\">\n<title>136127602959376&#45;&gt;136127602957312</title>\n<path fill=\"none\" stroke=\"black\" d=\"M211.5,-341.75C211.5,-334.8 211.5,-324.85 211.5,-316.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"215,-316.09 211.5,-306.09 208,-316.09 215,-316.09\"/>\n</g>\n<!-- 136127602953664 -->\n<g id=\"node13\" class=\"node\">\n<title>136127602953664</title>\n<polygon fill=\"lightgrey\" stroke=\"black\" points=\"262,-416 161,-416 161,-397 262,-397 262,-416\"/>\n<text text-anchor=\"middle\" x=\"211.5\" y=\"-404\" font-family=\"monospace\" font-size=\"10.00\">AccumulateGrad</text>\n</g>\n<!-- 136127602953664&#45;&gt;136127602959376 -->\n<g id=\"edge11\" class=\"edge\">\n<title>136127602953664&#45;&gt;136127602959376</title>\n<path fill=\"none\" stroke=\"black\" d=\"M211.5,-396.75C211.5,-389.8 211.5,-379.85 211.5,-371.13\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"215,-371.09 211.5,-361.09 208,-371.09 215,-371.09\"/>\n</g>\n<!-- 136127601312288 -->\n<g id=\"node14\" class=\"node\">\n<title>136127601312288</title>\n<polygon fill=\"lightblue\" stroke=\"black\" points=\"238.5,-482 184.5,-482 184.5,-452 238.5,-452 238.5,-482\"/>\n<text text-anchor=\"middle\" x=\"211.5\" y=\"-470\" font-family=\"monospace\" font-size=\"10.00\">beta</text>\n<text text-anchor=\"middle\" x=\"211.5\" y=\"-459\" font-family=\"monospace\" font-size=\"10.00\"> (2)</text>\n</g>\n<!-- 136127601312288&#45;&gt;136127602953664 -->\n<g id=\"edge12\" class=\"edge\">\n<title>136127601312288&#45;&gt;136127602953664</title>\n<path fill=\"none\" stroke=\"black\" d=\"M211.5,-451.84C211.5,-444.21 211.5,-434.7 211.5,-426.45\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"215,-426.27 211.5,-416.27 208,-426.27 215,-426.27\"/>\n</g>\n</g>\n</svg>\n",
"text/plain": [
"<graphviz.graphs.Digraph at 0x7bcead6b0040>"
]
},
"metadata": {},
"execution_count": 34
}
],
"source": [
"# Define parameters for distribution\n",
"alpha, beta = torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([2.0, 3.0], requires_grad=True)\n",
"# Define Gamma distribution from IRT library\n",
"z = irt.Gamma(alpha, beta).rsample()\n",
"# Calculate loss\n",
"loss = torch.mean(z ** 2)\n",
"# Backpropogate loss through distribution\n",
"loss.backward()\n",
"\n",
"print(alpha.grad, beta.grad)\n",
"torchviz.make_dot(loss, params = {'alpha': alpha, 'beta': beta})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cwlwbQegG4Tg"
},
"source": [
"Below is a demonstration of using [Beta](https://en.wikipedia.org/wiki/Beta_distribution) distribution to train a simple stochastic model:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "rLb0DCYnG4Tg"
},
"outputs": [],
"source": [
"# Define simple encoder for model\n",
"class SimpleEncoder(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" super(SimpleEncoder, self).__init__()\n",
" self.fc = nn.Linear(input_dim, hidden_dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"# Define simple decoder for model\n",
"class SimpleDecoder(nn.Module):\n",
" def __init__(self, hidden_dim, output_dim):\n",
" super().__init__()\n",
" super(SimpleDecoder, self).__init__()\n",
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"# Define simple model\n",
"class SimpleModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.encoder = SimpleEncoder(20, 10)\n",
" self.decoder = SimpleEncoder(5, 20)\n",
"\n",
" def forward(self, x):\n",
" x = nn.functional.softmax(self.encoder(x), dim=0)\n",
" alpha, beta = x.chunk(2, dim=0)\n",
" y = irt.Beta(alpha, beta).rsample()\n",
" y = self.decoder(y)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "6L0djQVCG4Th"
},
"outputs": [],
"source": [
"model = SimpleModel()\n",
"x = torch.randn(20)\n",
"y = model(x)\n",
"loss = y.mean()\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WlD1oeQOG4Th"
},
"source": [
"### Development"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0KTd4OzEG4Th"
},
"source": [
"If you want to use the code of the IRT library in your project, it is essential to ensure that your implementation successfully passes all the necessary [tests](https://github.com/intsystems/implicit-reparameterization-trick/blob/main/code/run_unittest.py). "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bnTEpaGGG4Th"
},
"source": [
"### Documentation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a4ROycAmG4Th"
},
"source": [
"\n",
"\n",
"For comprehensive information and guidance regarding the features and functionalities of IRT library, you can find the full documentation by following this [link](https://intsystems.github.io/implicit-reparameterization-trick/)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
},
"colab": {
"provenance": []
}
},
"nbformat": 4,
"nbformat_minor": 0
}