Strategies for training generative models
If youāve heard about algorithms that generate photorealistic images from scratch ā human faces wearing fake expressions, daytime photos that look like night, cat sketches that are turned into realistic cat images ā youāve heard of GANs.
Generative Adversarial Networks, which I described a couple posts ago while trying to generate real-looking faces from random vectors of numbers, have huge potential. But thereās a catch, of course. Theyāre fiendishly hard to train.
In this post, Iāll write a bit about my experience training them, trying out the latest GAN-training fad diets. Iāll also link to better resources to help you train your GAN.
Depending on your GANās hyperparameters, your GAN might train flawlessly ā more likely, though, is that youāll fall victim to one of GANās well-documented failure modes, and you wonāt learn anything.
You might get stuck with mode collapse. Your generatorās job is to produce outputs that fool the discriminator. It has a little conversation in its head, that goes like this:
Woah! This photo is easy to draw but the discriminator is totally fooled by it. Sucker! My work here is done! Iāll just output this image, every single time, regardless of the random vector Iāve been sent.
ā° Five minutes later š
Fuck! Looks like the discriminator caught on about this image being fake! Gotta make the image a little different and keep pumping it out.
Because the generator doesnāt output diverse samples, the discriminator doesnāt learn much useful, except to flag the specific image (āmodeā) that the generatorās currently outputting as fake. I donāt have a solid mathematical intuition for what exactly triggers the generator to start producing a bunch of very similar images, and Iād like to know.
There are plenty of supposed solutions to mode collapse, and Iāll talk about them in a bit.
Another common failure mode is vanishing gradients. This oneās pretty simple ā the discriminator gets really good, faster than the generator, and is able to mark the generatorās output as fake with super-high probability. If your discriminator is marking fakes as 99.9999% fake, thereās only a tiny sliver of ārealness.ā The generator is trained to increase the ārealnessā of its samples according to the discriminator ā if all we have is a tiny sliver of realness, the gradients are tiny and the training is week. We end up in a sad, sad world where the poor generator doesnāt have the feedback (in the form of gradients) to learn much of anything. Youāll see tiny discriminator loss, and huge generator loss.
In early 2017, three researchers published a paper introducing the Wasserstein GAN, claiming it alleviated all mode-collapse issues and made training GANs significantly more stable.
Most of the paper is spent establishing a theoretical model of why GANs are difficult to train. In GANs, we attempt to train the generator to match the probability distribution of the real data. To accomplish this with gradient descent, we need to minimize some ādistance metricā between the generatorās current output distribution and the real data distribution. The WGAN paper interprets the traditional GAN algorithm as using KL divergence as a distance metric, which isnāt smoothly differentiable at most points, making it really difficult to do successfully.
They suggest a different metric ā āearth mover distanceā ā which is differentiable everywhere, and offers good gradients. The incredible thing about the paper is that changing the GAN algorithm the be equivalent to minimizing earth-mover distance requires just three weird tricks:
rather than outputting classification probabilities (using something like softmax cross-entropy real/fake), the discriminator should output numbers, which can be as large as possible. train this discriminator ā they call it a critic instead ā to return a high positive number for real inputs, and a large negative number for fake inputs.
clip the weights of the gradient after each training iteration ā the paperās authors suggest clipping weights between [-0.01, 0.01]
rather than trying to balance generator/discriminator training, just train the critic to convergence before you start training the generator ā the critic should still give good gradients, even if itās really strong.
I played around with implementing a WGAN to generate synthetic samples based on MNIST digits and CelebA human faces.
Generating MNIST digits was a breeze ā without any hyperparameter tuning, WGAN generated great samples: [šø See the code]
CelebA didnāt work too well. The generator didnāt collapse and give terrible results ā the sample quality clearly improved over time, but training was incredibly slow, far slower than a DCGAN. Hereās what I got before I stopped training:
Another paper promising an improved GAN model is the āLeast Squares Generative Adversarial Networks.āĀ This post by Augustinus Kristiani gives a good overview of the paper, and makes some bold claims ā that itās as stable as WGAN but not nearly as slow, and also generates higher-quality samples.Ā
The basic idea is even simpler than WGANās ā rather than using softmax cross-entropy classification loss in the discriminator, use least-squares predictions instead. (i.e. train the discriminator to output 1 for real samples and -1 for fake samples, and train using L2 loss.) This makes sense because it forces the discriminator to output reasonably-sized numbers like -1 and 1, no matter how āgoodā or āconfidentā it is, which should give better gradients than a discriminator trained to convergence using softmax cross-entropy.
I wasnāt able to achieve particularly miraculous results using LSGAN on CelebA ā I needed a bit of hyperparameter optimization to get anywhere reasonable, and my generator still tended to ācollapseā occasionally and stop outputting good images.
That being said, the recent (and super cool) paper āUnsupervised Image-to-Image Translation Networksā uses least-squares loss instead of the traditional GAN formulation, so itās clearly useful.
Improved WGAN (with Gradient Penalty)
A recent paper finds theoretical issues with one of the tricks WGAN uses to work. WGAN uses āgradient clippingā to enforce a āLipschitz constraintā on the critic parameters (I have no idea what this means). The paper suggests that gradient clipping is a suboptimal way to enforce Lipschitz-ness, and ends up biasing the critic towards simpler models of the true distribution. Instead of clipping gradients, they suggest augmenting the critic loss function to encourage the criticās gradients to be close to 1 with respect to input images.
Implementing gradient clipping isnāt that difficult ā I was able to get an improved WGAN working on MNIST pretty quickly ā š hereās the code for that.
There are lots and lots of people offering strategies for improving GAN training and stability. Many of these work for WGAN and LSGAN as well, but may not be as useful, given those algorithmsā stability promises.
Here are some I find interesting:
Store a āreplay bufferā of previous generator outputs. Occasionally, rather than training the discriminator on the latest generator outputs, train it on some old generator outputs ā it should still know theyāre fake!
Using various different optimizers in the discriminator, rather than Adam (which is usually my go-to) ā apparently, momentum might cause instability
āPrincipledā attempts to balance discriminator and generator strength. Iāve tried keeping track of discriminator accuracy, and stop training the discriminator while its accuracy is < 80%, giving the generator a chance to ācatch up.ā
There are plenty of GAN-training resources that provide more (and better motivated) tricks than these ā here are some of my favorites:
GAN HacksĀ by Soumith Chintala
Improved GANs from OpenAI