Intro
Generative Adversarial Networks are very popular nowadays and show exciting results. But who worked with GAN knows that the training process is not so simple. At the same time, there are a lot of papers describe techniques for training such neural networks.
I want to write about the approach for training GAN that uses an ensemble of networks. The article is based on “SGAN: An Alternative Training of Generative Adversarial Networks”, CVPR 2018.
SGAN works on:
- faster convergence;
- "mode collapse";
- more realistic generated data.
Let’s briefly take a look at popular paradigms for getting more stable training process and more realistic generated data:
- The vanilla loss function for GAN is based on Kullback-Leibler Divergence that tries to make the distribution of predicted data is looking after real data distribution. But it has a drawback — it’s not symmetric, when real distribution p(x) is close to zero and distribution of predicted data q(x) is non-zero, we start to neglect with the state of q(x). Consequently we get wrong results. As the solution for such kind of problem, we can use Jensen-Shannon Divergence which is symmetric and more smooth. Check this post of Lil’Log for more detail information. But this move to another divergence doesn’t give a much better result as the Earth Mover’s distance. Another more popular name of it is Wasserstein Distance. There are a lot of papers that work and explain the using of Wasserstein Distance in GAN as a loss function (Arjovsky et al. 2017; Gulrajani et al. 2017).
- Unfortunately, WGAN doesn’t provide a stable solution for training such network neither using weight clipping. But another appendix for WGAN is a gradient penalty. In paper Gulrajani et al. 2017 they discuss about replacing weight clipping in WGAN with gradient penalty for getting more stable training process and fast convergence.
- Another way is DCGAN — deep convolutional generative adversarial network.
- Other types of GAN are multi-network GAN methods. The first type of it proposes to train multiple generators versus discriminator, where all generators share their parameters except the last one. To keep the diversity between generators an additional penalty term is added using similarity based function. Another type of multi-network also proposes to use multiple generators versus single discriminator but create a classifier that takes the generator's output and says by whom of the generators the given fake input was generated. The classifier output is used as an additional penalty term for keeping the diversity between generators.
SGAN is looking after the last approach that was discussed above, that there is no only single generator, but the whole idea is completely different. SGAN proposes to create several pairs Discriminator-Generator that treat as local pairs. They exist independently, i.e. they do not share training variables and can be trained not being influenced on each other. At the same time, there is a global pair Discriminator-Generator that updates its parameters using local pairs.
The authors of this paper say the main idea of such kind of multi-network to maintain the statistical independence between the individual pairs; non-sharable parameters are preventing by mode ‘collapse’ and don’t allow to influence on the training process of the global pair when one pair degrades.
Method
N-pairs are set of G = {G_1, …, G_N} — Generators and D = {D_1, …, D_N} — Discriminators that train individually. Global pair is (D_0, G_0) is trained using local pairs.
- Local pairs train independently in standard approach for GAN.
- Discriminator D_0 from global pair is training with each already trained Generator G_i from local pair where i = {1, …, N}.
- As D_0, Generator G_0 from global pair is training with each Discriminator D_i from local pair too, BUT instead of using real D_i, each D_i is copied to another network D_i^msg that called as “messenger discriminator” and G_0 is training with D_i^msg. At every iteration D_i^msg is re-created for training D_0.
The algorithm can be well parallelized.
The loss functions that were tested in the paper are vanilla loss function for GAN (Binary CrossEntropy loss), Wasserstein Distance (WGAN) and WGAN with gradient penalty. Also, they take a look DRAGAN (Deep Regret Analytic GAN) forces the constraint on the gradients of Discriminator only in local regions around real data points, it’s a new gradient penalty scheme.
As network architecture they used Deep Convolution GAN (DCGAN) and another one — for Wasserstein Distance — is 4 fully connected layers.
Discussion
Authors showed that using such kind of network where are N pairs are trained individually with global pair “exhibits higher stability and faster convergence speed”, SGAN uses “supervising” models and prevents an influence of one pair towards all. But there should be a trade-off between the number of local pairs and computation resources.
Inception Score (IS), Fréchet Inception Distance (FID), Entropy are commonly used as metrics for evaluating results in GAN. For more details about metrics you can find in Xu et al. 2018.
Their results on different datasets show that SGAN outperforms existed approaches for GAN:
Mine implementation on PyTorch you can find here. My experiments showed that generated images look realistic after fewer iterations (epochs) than if I use vanilla GAN.
Let’s take a look. I trained 5 local pairs.
MNIST
After 1st epoch (128 batch size) local pairs work worse than global pair:
After 14th epoch:
CelebA
After 1st epoch (128 batch size):
After 2nd epoch:
P.S. I tried to write on Tensorflow, but it was difficult to find the way how to train several networks and make copy trainable variables values of one network to another one and still train these networks independently. After some time of investigations I found the way, in this post I share the instructions for such kind of training.
Useful resource that helps to train GAN: https://github.com/soumith/ganhacks