Variational Autoencoders
- Variational autoencoders (VAEs) are probabilisitic generative models.
- They aim to learn a distribution \(p(x)\) over the data.
- After training, it is possible to draw (generate) samples from this distribution.
- Unfortunately, it is not possible to compute the log-likelihood of a sample \(x^*\) exactly.
- It is possible to define a lower bound on the likelihood.
- The VAE approximates this bound using a Monte Carlo sampling approach.
Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]
Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]
- LVMs model a joint distribution \(p(\bm{x}, \bm{z})\) of the data \(\bm{x}\) and a latent variable \(\bm{z}\).
- They then describe \(p(\bm{x})\) as the marginal distribution of the joint distribution: \[ p(\bm{x}) = \int p(\bm{x}, \bm{z}) d\bm{z} = \int p(\bm{x} | \bm{z}) p(\bm{z}) d\bm{z}. \]
- This is a rather indirect approach to describint \(p(\bm{x})\).
- Useful because expressions for \(p(\bm{x} | \bm{z})\) and \(p(\bm{z})\) are often much simpler than that for \(p(\bm{x})\).
TipExample: Gaussian Mixture Models
- Let’s take a \(1\)D mixture of Gaussians
- \(z\) is discrete: \(p(z)\) is a categorical distribution with probability \(\lambda_n\) for every possible value of \(z\).
- The likelihood \(p(x \mid z = n)\) of the data is normally distributed with mean \(\mu_n\) and variance \(\sigma_n^2\). \[ \begin{aligned} p(z = n) &= \lambda_n \\ p(x \mid z = n) &= \mathcal{N}(x \mid \mu_n, \sigma_n^2). \end{aligned} \]
- The data likelihood is given by marginalizing over the latent variable \(z\) \[ p(x) = \sum_{n=1}^N p(x, z=n) = \sum_{n=1}^N p(x \mid z=n) p(z=n) = \sum_{n=1}^N \lambda_n \mathcal{N}(x \mid \mu_n, \sigma_n^2). \]
- Note that the likelihood and prior are both simple expressions, but the resulting data likelihood is multimodal!
Nonlinear latent variable models
In a nonlinear latent variabl model, both the data \(\bm{x}\) and the latent variable \(\bm{z}\) are continuous and multivariate.
- Both the prior \(p(\bm{z})\) and the likelihood \(p(\bm{x} \mid \bm{z})\) are normally distributed.
- \(p(\bm{z}) = \mathcal{N}(\bm{z} \mid \bm{0}, \bm{I})\)
- \(p(\bm{x} \mid \bm{z}) = \mathcal{N}(\bm{x} \mid \mu(\bm{z}), \Sigma(\bm{z}))\)
- The mean and covariance are given by neural networks, in general
- Many times the covariance is fixed and assumed to be diagonal: \(\Sigma(\bm{z}) = \sigma^2 \bm{I}\) (\(\sigma\) can be learned, too)
- The mean is given by a neural network: \(\mu(\bm{z}) = f(\bm{z}; \bm{\phi})\)
- The data probability \(p(\bm{x} \mid \bm{\phi})\) is found by marginalizing over the latent variable \(\bm{z}\) \[ p(\bm{x} \mid \bm{\phi}) = \int p(\bm{x} \mid \bm{z}, \bm{\phi}) p(\bm{z}) d\bm{z} = \int \mathcal{N}\left(\bm{x} \mid f(\bm{z}; \bm{\phi}), \sigma^2 \bm{I}\right) \mathcal{N}(\bm{z} \mid \bm{0}, \bm{I}) d\bm{z} \]
- This can be viewed as an infinite weighted sum (i.e., an infinite mixture)
- Mixture is of spherical Gaussians with different means
- The weights are \(p(\bm{z})\) and the means are \(f(\bm{z}; \bm{\phi})\)
Generation
A new example \(\bm{x}^*\) is generated by ancestral sampling.
- Draw \(\bm{z}^* \sim p(\bm{z})\)
- Draw \(\bm{x}^* \sim p(\bm{x} \mid \bm{z}^*)\)
- Pass \(\bm{z}^*\) through the decoder network \(f(\bm{z}^*; \bm{\phi})\) to compute the mean of \(p(\bm{x} \mid \bm{z}^*)\)
- Draw \(\bm{x}^*\) from \(\mathcal{N}(\bm{x}^* \mid f(\bm{z}^*; \bm{\phi}), \sigma^2 \bm{I})\)
Training
- We want to maximize the log-likelihood over a training dataset \(\{ \bm{x}_i \}_{i=1}^I\) w.r.t. \(\bm{\phi}\) \[ \bm{\phi}^* = \arg\max_{\bm{\phi}} \sum_{i=1}^I \log p(\bm{x}_i \mid \bm{\phi}) = \arg\max_{\bm{\phi}} \sum_{i=1}^I \log \int p(\bm{x}_i \mid \bm{z}, \bm{\phi}) p(\bm{z}) d\bm{z}, \] where \[ p(\bm{x}_i \mid \bm{\phi}) = \int \mathcal{N}\left(\bm{x}_i \mid f(\bm{z}; \bm{\phi}), \sigma^2 \bm{I}\right) \mathcal{N}(\bm{z} \mid \bm{0}, \bm{I}) d\bm{z}. \]
- Unfortunately, this is intractable to compute directly.
- No closed-form expression for the integral
- No easy way to evaluate it for a particular value of \(\bm{x}\)
- To make progress, we define a lower bound on the log-likelihood.
- Always less than or equal to the log-likelihood for a given value of \(\bm{\phi}\)
- Depends on some other parameters \(\bm{\theta}\)
- We will build a network to compute this lower bound and optimize it.