Sundries

Generative Adversarial Networks (GANs)

See the code on GitHub

These notes assume you have read [1], and only cover GAN topics that were confusing to me after gaining an initial understanding of GANs. This is not a good place to start learning about GANs. There are many excellent articles introducing GANs on the web. Come back after you've got a grasp on the basics!

Introduction and Definitions

Generative Adversarial Networks (GANs) follow the work of [1], minimizing the adversarial game:

\[ \min\limits_{G} \max\limits_{D} V(D,G) = E_{x \sim p(x)}\log\Big[\sigma \big( D(x) \big) \Big] + E_{z \sim p(z)}\log\Big[ 1- \sigma \big(D(G(z)\big) \Big] \]

where

  • \(D\) and \(G\) are neural neural networks.
  • \(\sigma\) is the sigmoid function (\(\frac{1}{1+e^{-x}}\)).
  • \(p(z)\) is a multivariate normal distribution with mean \(0\) and diagonal covariance (\(p(z) = \mathcal{N}(0,I)\)).
  • $p(x)$ is the data distribution we wish $G$ to model. We have \(\{x_1, ..., x_N\}\) real samples taken from \(p_{data}\).

In [1], the sigmoid function is implicitly included in \(D\), it is included explicitly here to highlight the vanishing gradient problem discussed below.

The networks can be trained via stochastic gradient decent (SGD) by minimizing the loss functions

\[ \begin{aligned} L_{D} &= E_{x \sim p_{data}(x)}-\log\Big[ \sigma \big( D(x) \big) \Big] + E_{z \sim p(z)}-\log\Big[ 1- \sigma \big(D(G(z)\big) \Big] \\ L_{G} &= E_{z \sim p(z)}\log\Big[ 1 - \sigma \big(D(G(z)\big) \Big] \end{aligned} \]

Note the change in the sign of the $log$ functions in $L_D$ relative to the first equation. This is required to use stochastic gradient descent, otherwise stochastic gradient ascent is required (as noted in Algorithm 1 of [1]). Ideally, $L_D$ would be fully minimized before minimizing $\hat{L}_G$, however, the derivative of the sigmoid function ( $\frac{d\sigma(x)}{dx} = \sigma(x)\big(1 - \sigma(x)\big)$) approaches $0$ as $x$ approaches \(-\infty\) or \(\infty\).

Thus, gradients are maximum when \(\sigma(D(x)) = 0.5\), and diminish as $D$ becomes better at classifying samples from \(G(z)\) as fake (\( D(G(x)) << 0\)). As a result, training $D$ to optimality will cause very little information to flow through the gradient to $G$. To counteract this, the weights of $D$ and $G$ are both updated on each mini batch. In practice, this simultaneous training allows \(G\) to stay sufficiently good at tricking \(D\) to avoid diminished gradients. However, you probably noticed that $L_G$ is a little more complicated than just a sigma function: $L_G = \log(1 - \sigma(x))$. Let's see what $\frac{d \log(1 - \sigma(x))}{dx}$ looks like (blue curve):

The gradient of $L_G$ diminishes as $D$ gets better at classifying samples from \(G(z)\) as fake. This can cause training to fail when $D$ can easily distinguish $p(x)$ from \(G(z)\). This is often the case early in training, when samples generated by $G$ are close to just random noise. To help alleviate diminishing gradients early in training, [1] recommends a modified $G$ loss function:

\[ \hat{L}_{G} = E_{z \sim p(z)}-\log\Big[ \sigma \big(D(G(z)\big) \Big] \]

\(\hat{L}_{G}\) provides strong gradients (see red line in plot above) when $D$ is good at classifying samples from \(G(z)\) as fake, but diminishes as $G$ learns to fool $D$. \(\hat{L}_{G}\) and \(L_{G}\) are used interchangeably in this text, but \(\hat{L}_{G}\) is used in the implementation. (NOTE: Some new literature suggests \(\hat{L}_{G}\) is in part the cause of the some of the failure modes discussed below, adding a discussion on that literature is on the TODO list.)

The training algorithm is then:

\[ \begin{aligned} &\text{ for each mini batch } \{x_1,...,x_m\} \text{ do:} \\ &\hspace{20pt} \text{Sample } \{z_1,...,z_m\} \text{ from } p(z) \\ &\hspace{20pt} L_{D} = \frac{1}{m} \sum_{i=1}^m -\log\Big[ \sigma \big( D(x_i) \big) \Big] - \log\Big[ 1- \sigma \big(D(G(z_i)\big) \Big] \\ &\hspace{20pt} w_D = w_D - lr\Delta L_D \\ \\ &\hspace{20pt} \text{Sample } \{z_1,...,z_m\} \text{ from } p(z) \\ &\hspace{20pt} L_{D} = \frac{1}{m} \sum_{i=1}^m -\log\Big[ \sigma \big(D(G(z_i)\big) \Big] \\ &\hspace{20pt} w_G = w_G - lr\Delta L_G \\ \end{aligned} \]

where \(w_G\) and \(w_D\) are the weights of networks \(G\) and \(D\), and $lr$ is the learning rate. [1] suggests that \(D\) could be updated many times per \(G\) update. However, I'm aware very few examples of performing multiple \(D\) updates in practice, and points 11 and 14 of [3] also suggest it isn't worth trying.

This implementation contains 2 versions, differing in their network architecture. The "linear" version uses a 2 layer liner network similar to that used in [1]. The "conv" version uses a convolutional networks similar to that used in [2]. All networks were trained on the MNIST dataset. The results from the convolutional network are discussed below.

Convolutional GAN

[3] Recommends using a GAN with convolutional networks because they "just work". So a deep convolutional GAN (DCGAN) seems like a good place to start. The original DCGAN paper [2] has a nice little box on the third page outlining how to build networks that will be stable durning training. The networks used here follow the guidelines fairly closely, with the following exceptions:

  • 1 fewer layers. One layer is removed to account for the smaller MNIST images (28x28) compared to 32x32 in [2]
  • Less features per layer. Just to reduce training time for experiments, a large network isn't needed for MNIST
  • LeakyReLU in both $G$ and $D$. [3] uses them only in $D$.
  • Use of sigmoid function instead of a Tanh. This was only changed because my implementation of MNIST is in the range [0,1]. A discussion on the (most likely minor) differences Tanh vs sigmoid cause during training are for another day.

Durning training, I tracked

  • loss/G - The loss used to train $G$: \(\hat{L}_{G} = \frac{1}{m} \sum_{i=1}^m -\log\Big[ \sigma \big(D(G(z_i)\big) \Big]\)
  • loss/real - The $p(x)$ half of the loss used to train $D$: \(\frac{1}{m} \sum_{i=1}^m -\log\Big[ \sigma \big( D(x_i) \big) \Big]\)
  • loss/fake - The $G(z)$ half of the loss used to train $D$: \(\frac{1}{m} \sum_{i=1}^m -\log\Big[\sigma \big(D(G(z_i)\big) \Big]\)
  • norms/G: \(||\frac{dG}{d\hat{L}_G}||\)
  • norms/D: \(||\frac{dD}{d\hat{L}_G}||\)

The models were trained with the dimensionality of $p(z)$ equal to $2$ (purple curves) and $7$ (blue curves). Network training appeared fairly stable:

Training curves from DCGAN zDim=2 (orange) and zDim=8 (blue). x axis is number of images trained on.

Training was stopped sooner on the zDim=2 training run after it became clear that while $G$ was learning numbers, it wasn't leaning a static mapping between $p(z)$ and $p(x)$. Instead, it cycled through mapping numbers to different parts of $p(z)$ as training progressed:

A grid of images generated by $G(z)$ where $z$ are points at uniform intervals according to the pdf $\mathcal{N}(0,I)$ between approx. [-3,3]. Images were generated at uniform intervals during training. Left: zDim=2, Right: zDim=7

It may be that there is too little variability in a 2 dimensional $p(z)$ for $G$ to find a mapping from $p(z)$ to $p(x)$ that covers all the variability in $p(x)$. In agreement with this hypothesis, increasing zDim to 7 stabilizes the mapping and allows $G$ to generate all 10 digits from $p(z)$.

Is training stable?

Papers rarely show loss/gradient curves of the networks during training, and I was surprised to find very little information on what the curves of a correctly trained GAN looks like in practice. Ideally, losses from $D$ and $G$ would be perfectly stable for the length for the training and \(\sigma\big(D(G(z))\big) = 0.5\) (and therefore \(L_G = -log(0.5) = 0.3\)) for all z. In the training curves shown above, $D$ is clearly slowly winning out over $G$. The question then becomes, is it worth adjusting hyperparamaters to further stabilize the losses? Without a criteria to test how well the images generated by $G$ match images in \(p(x)\), I do not think there is an easy answer. We do know, however, that at some point $\sigma$ will become so saturated that it's gradient will become 0, despite even the modified $\hat{L}_G$ loss function. To get an idea of what saturating $\sigma$ looks like, I trained $D$ while keeping the weights of $G$ constant.

Training curves from DCGAN in which the weights of $G$ are updated (orange) or held constant (red). zDim=2 for both. x axis is number of images trained on.

As expected, the gradient flowing out of $D$ quickly drops to $0$, but it requires $D$ to achieve very high classification confidence ($L_G \approx > 16$). As long as the gradient flowing out of $D$ does not start to diminish, there should be a signal for $G$ to continue to learn. In the case of both the zD=2 and zD=7 models I trained, neither were close to $D$ having confidence great enough to diminish the gradients. The norm of the gradient flowing out of $D$ was actually increasing in both cases, presumably due to the use of $\hat{L}_G$, which has larger gradients as $D$ gets better (up to some cutoff). I don't think I would spend more time trying to stabilize the losses unless the gradients begin to diminish or the generated images look visibly worse (suggesting the gradients are not providing a useful signal).

HOMEWORK

  • Train networks longer to see how losses and gradients evolve later in training.
  • Recreate the vanishing gradients of $L_{G} = E_{z \sim p(z)}\log\Big[ 1 - \sigma \big(D(G(z)\big) \Big]$ in training.
  • Overview literature exploring the downfalls of GAN training.
  • Train DCGAN with the exact networks used in [2].
  • Discuss results of linear network GANs.

REFERENCES

[1] Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley,
      Sherjil Ozair, Aaron Courville, and Yoshua Bengio. 2014. “Generative
      Adversarial Nets.” In Advances in Neural Information Processing Systems,
      2672–2680.http://papers.nips.cc/paper/5423-generative-adversarial-nets.
[2] Radford, A., Metz, L., and Chintala, S. (2015). Unsupervised representation
      learning with deep convolutional generative adversarial networks. ArXiv
      Prepr. ArXiv151106434.
[3] https://github.com/soumith/ganhacks