In MIWAE_Pytorch_exercises_demo_ProbAI, a base model is implemented and trained on MNAR Data. We will modify this model with an Energy-Based Model (EBM) to handle the MNAR data more effectively.
- 1.1) Train a VAE on MNAR data.
The objective of this hackathon is to modify the VAE using EBM to handle MNAR data.
VAE :
where we defined :
where
As usual in the VAE setting, we will use the reparameterization trick to sample from
where
where
- 1.2) First idea is just to tilt the data output with an EBM. To that end, we consider an EBM of the form:
where
We want to train this model to minimise the log-likelihood of the data under the EBM :
Using Jensen's inequality, we can derive a lower bound for the log-likelihood:
One can obtain the gradient in
The first term is the gradient of the energy function evaluated at the observed data, and the second term is the gradient of the energy function evaluated at samples from the VAE.
- 1.3) Sampling from the resulting model can be done by doing importance sampling with the VAE as proposal distribution
$p_{\phi}(x)$ . To that end, we can sample some samples from the VAE$x_i \sim p_{\phi}(x)$ and reweight them according to the EBM:
$$\tilde{w}i = e^{-E{\theta}(x_i)}$$
Since the partition function
Then we can sample from the weighted samples
- 1.4) Sampling from the resulting model can be done by sampling from the VAE and then use a MCMC chain to update the samples according to the EBM.
The gradient of the full log-likelihood guides the MCMC chain:
wwhere
where
-
1.1) First part is the same, just train the VAE.
-
1.2) Tilt the EBM in latent space :
where
We want to train this model to minimise the log-likelihood of the data :
$$ \begin{align} \mathcal{L}{EBM} & = \log p{\theta}(x) \ & = \log \int_{z} \frac{1}{Z_{\theta}} e^{-E_{\theta}(z)}p_{\phi}(x|z)p(z) \frac{q_{\psi}(z|x)}{q_{\psi}(z|x)}\mathrm{d}z \ & = \log \mathbf{E}{q{\psi}(z|x)}\left[\frac{1}{Z_{\theta}} \frac{e^{-E_{\theta}(z)}p_{\phi}(x|z)p(z)}{q_{\psi}(z|x)}\right] \ (Jensen) & \leq \mathbf{E}{q{\psi}(z|x)}\left[ - E_{\theta}(z) + \log p(z) - \log q_{\psi}(z|x) + \log{p_{\phi}(x|z)} \right] -\log\left(Z_{\theta}\right) \ & = \mathbf{E}{q{\psi}(z|x)}\left[- E_{\theta}(z)+ \log p(z) \right]+ \mathbf{E}{q{\psi}(z|x)} \left[ - \log q_{\psi}(z|x) + \log{p_{\phi}(x|z)} \right] -\log\left(Z_{\theta}\right) \
\end{align}
$$
We can just replace
$$\log(Z_{\theta}) = \int_{z} e^{-E_{\theta}(z)}p(z)\mathrm{dz} = \mathbf{E}{p(z)}\left[ e^{-E{\theta}(z)} \right] $$
$$ \begin{align} \mathcal{L}{EBM} & = \mathbf{E}{q_{\psi}(z|x)}\left[- E_{\theta}(z)+ \log p(z) \right]+ \mathbf{E}{q{\psi}(z|x)} \left[ - \log q_{\psi}(z|x) + \log{p_{\phi}(x|z)} \right] -\log\left(Z_{\theta}\right) \ & = \mathbf{E}{q{\psi}(z|x)}\left[- E_{\theta}(z)+ \log p(z) \right]+ \mathbf{E}{q{\psi}(z|x)} \left[ - \log q_{\psi}(z|x) + \log{p_{\phi}(x|z)} \right] - \mathbb{E}{p(\tilde{z})}\left[-E{\theta}(\tilde{z})\right] \ & = \mathbf{E}{q{\psi}(z|x)}\left[- E_{\theta}(z)\right] - \mathbb{E}{p(\tilde{z})}\left[-E{\theta}(\tilde{z})\right] + \ldots \end{align} $$
-
1.3) Sampling from the resulting model can be done by doing Self-normalized Importance Sampling resampling with the prior distribution
$p(z)$ as proposal. -
1.4) Also with MCMC.
Results with the EBM in data space have been unconvincing so far, better results in the latent space (see 2d-small.ipynb). The notebook is extended into 2d.ipynb, which contains additional results measuring the degree of debiasing as well as a more advanced sampling method (Langevin MCMC). Lower dimensional latent spaces appear to be more easy to work with than higher dimensional ones (see 10d-small.ipynb).
We first compare the label distributions of biased and unbiased datasets to illustrate a setting of class imbalance between training and test scenarios.
Next, we train and encode the biased data onto the 2D latent space to observe how digit classes separate under the trained VAE.
We then sample from the standard normal prior and decode those latent vectors to inspect the raw generative quality based on prior-sampled images.
Next, we fit the EBM and contour the learned EBM energy landscape over the latent plane. Using the EBM, we can adjust prior samples to have a better fit to the (limited) unbiased data.
Decoding the energy-weighted samples, we can see that the model manages to sample more images from the undersampled classes.
We overlay the original latent embeddings with both unweighted and energy-weighted prior samples to visualize how the EBM can help adjust class imbalance for generative models!
Notebook 2d.ipynb contains additional results.