GAN Tutorial: Connecting BCE Loss to Minimax Game and Understanding Non-Saturating Loss
GAN Tutorial: Connecting BCE Loss to Minimax Game and Understanding Non-Saturating Loss
1. Introduction to GANs
Generative Adversarial Networks (GANs) are composed of two neural networks:
- A Generator (G): learns to generate fake samples $G(z)$ from random noise $z \sim p(z)$
- A Discriminator (D): learns to classify samples as real (from data) or fake (from the generator)
These networks are trained in a two-player minimax game.
2. Binary Cross-Entropy Loss for Discriminator and Generator
Binary Cross-Entropy (BCE)
Given a prediction $\hat{y} \in (0, 1)$ and a true label $y \in {0, 1}$, the BCE loss is: $\text{BCE}(y, \hat{y}) = -y \log(\hat{y}) - (1 - y) \log(1 - \hat{y})$
BCE Loss for Discriminator (D)
D is a binary classifier:
- For real data $x \sim p_{\text{data}}(x)$, label is 1
- For fake data $G(z)$, label is 0
So discriminator loss becomes:
\[\mathcal{L}_D = \mathbb{E}_{x \sim p_{\text{data}}}[-\log D(x)] + \mathbb{E}_{z \sim p(z)}[-\log(1 - D(G(z)))]\]BCE Loss for Generator (G)
Original generator loss (from minimax formulation) $\to$ maximize the loss below:
\[\mathcal{L}_G^{\text{original}} = \mathbb{E}_{z \sim p(z)}[-\log(1 - D(G(z)))]\]But maximizing this loss means maximizing the cross entropy. That means make $D(G(z))\rigtharrow 1$ which is desired. But the problem is if during the training the output loss becomes $0$ then the gradient flow stops and we cannot go beyond $0$. Therefore, we should think about a new way of representing this.
3. Minimax GAN Objective
The original GAN paper (Goodfellow et al. 2014) defines the objective as:
\[\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log (1 - D(G(z)))]\]- $D$ maximizes this expression (equivalent to minimizing BCE loss for real/fake classification)
- $G$ minimizes the same expression (equivalent to minimizing BCE loss with label 0 for fake)
Thus, BCE losses for $D$ and $G$ naturally merge into this min-max framework.
4. Saturation Problem and Non-Saturating Generator Loss
Saturation
If $D(G(z)) \to 0$, then $\log(1 - D(G(z))) \to 0$ and gradient $\nabla_G \mathcal{L}_G \to 0$
This means:
The generator receives no learning signal when it’s weak (bad at fooling $D$). This is called loss saturation.
Non-Saturating Loss (Practical Generator Loss)
Instead of minimizing:
\[\mathbb{E}_{z} [\log(1 - D(G(z)))]\]We minimize:
\[\mathcal{L}_G^{\text{non-saturating}} = -\mathbb{E}_{z} [\log D(G(z))]\]This avoids saturation and gives large gradients when $D(G(z)) \approx 0$.
5. PyTorch Code: Training a GAN on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Generator Network
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, img_dim),
nn.Tanh(), # Output in [-1, 1]
)
def forward(self, z):
return self.model(z)
# Discriminator Network
class Discriminator(nn.Module):
def __init__(self, img_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
# Setup
z_dim = 100
img_dim = 28 * 28
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator(z_dim, img_dim).to(device)
D = Discriminator(img_dim).to(device)
opt_G = optim.Adam(G.parameters(), lr=2e-4)
opt_D = optim.Adam(D.parameters(), lr=2e-4)
criterion = nn.BCELoss()
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # Scale to [-1, 1]
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Training loop
for epoch in range(20):
for real_imgs, _ in dataloader:
real_imgs = real_imgs.view(-1, img_dim).to(device)
batch_size = real_imgs.size(0)
# Labels
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator
z = torch.randn(batch_size, z_dim).to(device)
fake_imgs = G(z)
real_preds = D(real_imgs)
fake_preds = D(fake_imgs.detach()) # detach() avoids computing gradients for G
d_loss = criterion(real_preds, real_labels) + criterion(fake_preds, fake_labels)
opt_D.zero_grad()
d_loss.backward()
opt_D.step()
# Train Generator (non-saturating loss)
# Note: we do NOT use torch.no_grad() here
# Because we want to compute gradients w.r.t. G
# Although D's weights won't be updated, we need gradients to flow through D(G(z))
fake_preds = D(fake_imgs)
g_loss = criterion(fake_preds, real_labels) # Pretend fakes are real
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
print(f"Epoch {epoch+1}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
6. Conclusion
- GANs are trained with BCE loss under a min-max framework.
- The discriminator uses standard BCE classification.
- The generator originally minimized $\log(1 - D(G(z)))$, but this saturates.
- The non-saturating loss $-\log D(G(z))$ is used in practice to ensure strong gradients.
- We do not use
torch.no_grad()
in the generator step because we need gradients to flow fromD(G(z))
back toG
. However, gradients with respect to D’s parameters are still computed — but unused — causing some waste of memory. A more efficient alternative is to freeze D’s parameters usingrequires_grad_(False)
during generator update to save memory.
This tutorial provides both the mathematical reasoning and practical code to understand and train GANs effectively.