從多個方面詳解WGAN-GP

一、WGAN-GP的簡介

WGAN-GP代表了深度學習中生成對抗網絡(GANs)的一次重要改進。WGAN-GP的全稱是Wasserstein GAN with Gradient Penalty(帶梯度懲罰的Wasserstein GAN),它於2017年由Ishaan Gulrajani等人首次提出。相對於傳統GAN的損失函數(Jensen-Shannon Divergence和KL散度),WGAN-GP採用了Wasserstein距離:WGAN-GP通過最大化生成模型和真實樣本之間的Wasserstein距離與Lipschitz約束而獲得更加穩定的訓練過程。

換句話說,WGAN-GP更容易將GAN的優化問題轉化為用深度網絡去估計兩個分佈之間的Wasserstein距離這一問題。WGAN-GP是一個星形結構,其中生成器網絡與判別器網絡相互作用。令人驚奇的是,WGAN-GP的生成器網絡可以生成更真實的樣本,而判別器網絡可以更好地辨別這些樣本。

二、WGAN-GP的原理解釋

Wasserstein GAN是一種使用Wasserstein距離作為並發學習中損失函數的GAN模型。Wasserstein距離在計算兩個概率分佈之間的距離時,可以比其他距離標準更加準確。

對於比較真實分佈p和生成分佈q,Wasserstein距離定義為:W(p, q)=inf (E[f(x)-f(y)]),其中f是Lipschitz 連續函數,||f||L<=1。事實上,Wasserstein距離比KL散度更適用於GAN模型,因為Wasserstein距離是可微分和連續的,並且在深度學習的訓練過程中可以更好地反映兩個分佈之間的差異。

同時,對於Wasserstein GAN,也需要考慮梯度截斷,確保生成器和判別器網絡的權重在一定的範圍內。為了實現Lipschitz連續性,Wasserstein GAN需要確保W的梯度是有限且權重也有限的,這種限制導致WGAN的梯度消失和模型崩潰問題得到了緩解。

三、WGAN-GP的優點

相比於傳統的GAN,WGAN-GP帶來了以下四個顯著的優點:

1. 避免模式崩潰:傳統GAN經常會出現「模式崩潰」問題,即生成器趨向於生成相同的樣本。WGAN-GP的梯度懲罰機制可有效避免這種情況出現。

2. 更穩定的訓練過程:WGAN-GP使用Wasserstein距離是可微分和連續的,因此其訓練過程更加穩定。

3. 更快的收斂速度:對於某些數據集,WGAN-GP收斂速度比傳統GAN更快。

4. 實現神經元級別的控制:WGAN-GP中的梯度懲罰機制可以提供更加準確的梯度信息,使得我們能夠更加精確地控制生成器的特徵輸出。

四、WGAN-GP的代碼實現

下面給出WGAN-GP的PyTorch實現示例:

# WGAN-GP代碼實現:

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 128) # 輸入層-->中間層
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 28*28) # 中間層-->輸出層

    def forward(self, x):
        x = nn.LeakyReLU(0.2)(self.fc1(x))
        x = nn.LeakyReLU(0.2)(self.fc2(x))
        x = nn.Tanh()(self.fc3(x)) # Tanh函數壓縮至0~1
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(28*28, 256) # 輸入層-->中間層
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1) # 中間層-->輸出層

    def forward(self, x):
        x = nn.LeakyReLU(0.2)(self.fc1(x))
        x = nn.LeakyReLU(0.2)(self.fc2(x))
        x = nn.Sigmoid()(self.fc3(x)) # Sigmoid函數壓縮至0~1
        return x

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(real_data.size()).cuda()

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10

    return gradient_penalty

discriminator = Discriminator().cuda()
generator = Generator().cuda()

batch_size = 64

real_data = torch.Tensor()
fake_data = torch.Tensor()

optimizer_D = RMSprop(discriminator.parameters(), lr=0.00005)
optimizer_G = RMSprop(generator.parameters(), lr=0.00005)

dataset_zh = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(), # Tensor化
    transforms.Normalize((0.1307,), (0.3081,)) # 正則化
]))

# 訓練過程
for epoch in range(100):
    for idx_batch, (real, _) in enumerate(torch.utils.data.DataLoader(dataset_zh,
                                batch_size=batch_size, shuffle=True, num_workers=4)):
        real_data.resize_(real.size()).copy_(real)

        fake = generator(torch.randn(batch_size, 100).cuda())
        fake_data.resize_(fake.size()).copy_(fake)

        critic_loss = nn.ReLU()(1 + discriminator(fake_data).mean() - discriminator(real_data).mean())
        critic_loss.backward(retain_graph=True)

        optimizer_D.step()

        # 判別器的權重限制
        for param in discriminator.parameters():
            param.data.clamp_(-0.01, 0.01)

        gradient_penalty = calc_gradient_penalty(discriminator, real_data, fake_data)

        optimizer_D.zero_grad()
        (0.1 * gradient_penalty + critic_loss).backward()
        optimizer_D.step()

        if idx_batch % 10 == 0:
            generator.zero_grad()
            g_loss = -discriminator(generator(torch.randn(batch_size, 100).cuda())).mean()
            generator.zero_grad()
            g_loss.backward()
            optimizer_G.step()

    print(epoch)

五、總結

WGAN-GP是GAN中的一種非常有用的改進型模型。相比於傳統GAN,它具有更加穩定的訓練過程、更快的收斂速度以及更加精準的生成特徵輸出控制能力。同時,WGAN-GP的代碼實現過程比較簡單,便於初學者在實踐中運用。

原創文章,作者:SAYTL,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/334725.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
SAYTL的頭像SAYTL
上一篇 2025-02-05 13:05
下一篇 2025-02-05 13:05

相關推薦

發表回復

登錄後才能評論