How to train your GAN
Strategies for training generative models
If youâve heard about algorithms that generate photorealistic images from scratch â human faces wearing fake expressions, daytime photos that look like night, cat sketches that are turned into realistic cat images â youâve heard of GANs.
Generative Adversarial Networks, which I described a couple posts ago while trying to generate real-looking faces from random vectors of numbers, have huge potential. But thereâs a catch, of course. Theyâre fiendishly hard to train.
In this post, Iâll write a bit about my experience training them, trying out the latest GAN-training fad diets. Iâll also link to better resources to help you train your GAN.
Why are GANs so awful?
Depending on your GANâs hyperparameters, your GAN might train flawlessly â more likely, though, is that youâll fall victim to one of GANâs well-documented failure modes, and you wonât learn anything.
You might get stuck with mode collapse. Your generatorâs job is to produce outputs that fool the discriminator. It has a little conversation in its head, that goes like this:
Woah! This photo is easy to draw but the discriminator is totally fooled by it. Sucker! My work here is done! Iâll just output this image, every single time, regardless of the random vector Iâve been sent.
â° Five minutes later đ
Fuck! Looks like the discriminator caught on about this image being fake! Gotta make the image a little different and keep pumping it out.
Because the generator doesnât output diverse samples, the discriminator doesnât learn much useful, except to flag the specific image (âmodeâ) that the generatorâs currently outputting as fake. I donât have a solid mathematical intuition for what exactly triggers the generator to start producing a bunch of very similar images, and Iâd like to know.
There are plenty of supposed solutions to mode collapse, and Iâll talk about them in a bit.
Another common failure mode is vanishing gradients. This oneâs pretty simple â the discriminator gets really good, faster than the generator, and is able to mark the generatorâs output as fake with super-high probability. If your discriminator is marking fakes as 99.9999% fake, thereâs only a tiny sliver of ârealness.â The generator is trained to increase the ârealnessâ of its samples according to the discriminator â if all we have is a tiny sliver of realness, the gradients are tiny and the training is week. We end up in a sad, sad world where the poor generator doesnât have the feedback (in the form of gradients) to learn much of anything. Youâll see tiny discriminator loss, and huge generator loss.
WGAN to the rescue
In early 2017, three researchers published a paper introducing the Wasserstein GAN, claiming it alleviated all mode-collapse issues and made training GANs significantly more stable.
Most of the paper is spent establishing a theoretical model of why GANs are difficult to train. In GANs, we attempt to train the generator to match the probability distribution of the real data. To accomplish this with gradient descent, we need to minimize some âdistance metricâ between the generatorâs current output distribution and the real data distribution. The WGAN paper interprets the traditional GAN algorithm as using KL divergence as a distance metric, which isnât smoothly differentiable at most points, making it really difficult to do successfully.
They suggest a different metric â âearth mover distanceâ â which is differentiable everywhere, and offers good gradients. The incredible thing about the paper is that changing the GAN algorithm the be equivalent to minimizing earth-mover distance requires just three weird tricks:
rather than outputting classification probabilities (using something like softmax cross-entropy real/fake), the discriminator should output numbers, which can be as large as possible. train this discriminator â they call it a critic instead â to return a high positive number for real inputs, and a large negative number for fake inputs.
clip the weights of the gradient after each training iteration â the paperâs authors suggest clipping weights between [-0.01, 0.01]
rather than trying to balance generator/discriminator training, just train the critic to convergence before you start training the generator â the critic should still give good gradients, even if itâs really strong.
I played around with implementing a WGAN to generate synthetic samples based on MNIST digits and CelebA human faces.
Generating MNIST digits was a breeze â without any hyperparameter tuning, WGAN generated great samples: [đ¸ See the code]
CelebA didnât work too well. The generator didnât collapse and give terrible results â the sample quality clearly improved over time, but training was incredibly slow, far slower than a DCGAN. Hereâs what I got before I stopped training:
Least-squares GAN
Another paper promising an improved GAN model is the âLeast Squares Generative Adversarial Networks.â This post by Augustinus Kristiani gives a good overview of the paper, and makes some bold claims â that itâs as stable as WGAN but not nearly as slow, and also generates higher-quality samples.Â
The basic idea is even simpler than WGANâs â rather than using softmax cross-entropy classification loss in the discriminator, use least-squares predictions instead. (i.e. train the discriminator to output 1 for real samples and -1 for fake samples, and train using L2 loss.) This makes sense because it forces the discriminator to output reasonably-sized numbers like -1 and 1, no matter how âgoodâ or âconfidentâ it is, which should give better gradients than a discriminator trained to convergence using softmax cross-entropy.
I wasnât able to achieve particularly miraculous results using LSGAN on CelebA â I needed a bit of hyperparameter optimization to get anywhere reasonable, and my generator still tended to âcollapseâ occasionally and stop outputting good images.
That being said, the recent (and super cool) paper âUnsupervised Image-to-Image Translation Networksâ uses least-squares loss instead of the traditional GAN formulation, so itâs clearly useful.
Improved WGAN (with Gradient Penalty)
A recent paper finds theoretical issues with one of the tricks WGAN uses to work. WGAN uses âgradient clippingâ to enforce a âLipschitz constraintâ on the critic parameters (I have no idea what this means). The paper suggests that gradient clipping is a suboptimal way to enforce Lipschitz-ness, and ends up biasing the critic towards simpler models of the true distribution. Instead of clipping gradients, they suggest augmenting the critic loss function to encourage the criticâs gradients to be close to 1 with respect to input images.
Implementing gradient clipping isnât that difficult â I was able to get an improved WGAN working on MNIST pretty quickly â đ hereâs the code for that.
Other tricks
There are lots and lots of people offering strategies for improving GAN training and stability. Many of these work for WGAN and LSGAN as well, but may not be as useful, given those algorithmsâ stability promises.
Here are some I find interesting:
Store a âreplay bufferâ of previous generator outputs. Occasionally, rather than training the discriminator on the latest generator outputs, train it on some old generator outputs â it should still know theyâre fake!
Using various different optimizers in the discriminator, rather than Adam (which is usually my go-to) â apparently, momentum might cause instability
âPrincipledâ attempts to balance discriminator and generator strength. Iâve tried keeping track of discriminator accuracy, and stop training the discriminator while its accuracy is < 80%, giving the generator a chance to âcatch up.â
There are plenty of GAN-training resources that provide more (and better motivated) tricks than these â here are some of my favorites:
GAN Hacks by Soumith Chintala
Improved GANs from OpenAI












