VAEGAN全方位解析

一、VAEGAN代码

VAEGAN(Variational Autoencoder with Generative Adversarial Networks)是基于VAE和GAN的生成式模型。VAEGAN的一大特点是利用了VAE的编码器和解码器以及GAN的判别器,同时使用了生成器和判别器之间的对抗损失函数以及VAE的重构损失。

import torch.nn.functional as F
import torch.nn as nn
import torch

class VAEGAN(nn.Module):
    def __init__(self,latent_dim:int):
        super(VAEGAN,self).__init__()
        self.latent_dim = latent_dim
        self.conv1 = nn.Conv2d(1,32,3,padding=1)
        self.conv2 = nn.Conv2d(32,64,3,padding=1)
        self.fc1 = nn.Linear(64*7*7,128)
        self.fc21 = nn.Linear(128,self.latent_dim)
        self.fc22 = nn.Linear(128,self.latent_dim)
        self.fc3 = nn.Linear(self.latent_dim,128)
        self.fc4 = nn.Linear(128,64*7*7)
        self.deconv1 = nn.ConvTranspose2d(64,32,3,2,1,1)
        self.deconv2 = nn.ConvTranspose2d(32,1,3,2,1,1)
        self.ld1 = nn.Linear(self.latent_dim,32)
        self.ld2 = nn.Linear(32,1)
        self.gconv1 = nn.ConvTranspose2d(self.latent_dim,32,3,2,1,1)
        self.gconv2 = nn.ConvTranspose2d(32,1,3,2,1,1)

    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1, 64*7*7)
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1),self.fc22(h1)

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        epsilon = torch.randn_like(std)
        return mu + epsilon*std

    def decode(self, z):
        x = F.relu(self.fc3(z))
        x = F.relu(self.fc4(x))
        x = x.view(-1,64,7,7)
        x = F.relu(self.deconv1(x))
        x = torch.sigmoid(self.deconv2(x))
        return x

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decode(z)
        return x_hat,mu,log_var

    def d_loss(self,real,fake):
        criterion = nn.BCELoss()
        valid = torch.Tensor(real.size(0), 1).fill_(1.0).cuda()
        fake = torch.Tensor(fake.size(0), 1).fill_(0.0).cuda()
        real_loss = criterion(real, valid)
        fake_loss = criterion(fake, fake)
        return real_loss + fake_loss

    def g_loss_vae(self,x_recon,x,mu,log_var):
        x_recon = x_recon.view(x_recon.size(0), -1)
        x = x.view(x.size(0), -1)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kld_loss

    def g_loss_gan(self, fake):
        valid = torch.Tensor(fake.size(0), 1).fill_(1.0).cuda()
        loss = nn.BCELoss()
        return loss(fake, valid)

    def g_loss_total(self,g_loss_vae,g_loss_gan):
        return g_loss_vae + g_loss_gan
        

二、VAEGAN生成时间序列

VAEGAN可以对图像进行生成,但也可以用于生成时间序列数据。通过VAEGAN进行无监督学习可以得到良好的时间序列数据分布,从而可以使用此模型进行生成。

以下是VAEGAN生成时间序列的示例代码:

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

class VAEGAN(nn.Module):
    def __init__(self):
        super(VAEGAN, self).__init__()
    
    def forward(self, x, y=None):
        ...
    
class VAEGANTimeSeriesGenerator:
    def __init__(self, model_path, latent=20, num_hidden=100, num_samples=50):
        self.model = VAEGAN(latent, num_hidden, num_samples)
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        self.model.eval()

    def generate(self, num_samples):
        with torch.no_grad():
            z = torch.randn(num_samples, self.model.latent_dim)
            fake_ts = self.model.decode(z).numpy()
        return fake_ts

if __name__ == '__main__':
    model_path = 'vaegan_ts.pt'
    vaegan_generator = VAEGANTimeSeriesGenerator(model_path)

    fake_ts = vaegan_generator.generate(num_samples=10)
    print(fake_ts)

三、VAEGAN损失

VAEGAN的损失函数由VAE的重构损失与KL散度以及GAN的对抗损失共同构成,其中VAE重构损失是衡量生成器生成的图像与原始图像的相似程度,KL散度是衡量生成器生成数据与原始数据分布的差异;GAN的对抗损失函数是衡量判别器将真实数据判为真实数据的概率和虚假数据(生成器生成的数据)判别为真实数据的概率的差别,以此来让生成器生成更好的数据。

    def d_loss(self,real,fake):
        criterion = nn.BCELoss()
        valid = torch.Tensor(real.size(0), 1).fill_(1.0).cuda()
        fake = torch.Tensor(fake.size(0), 1).fill_(0.0).cuda()
        real_loss = criterion(real, valid)
        fake_loss = criterion(fake, fake)
        return real_loss + fake_loss

    def g_loss_vae(self,x_recon,x,mu,log_var):
        x_recon = x_recon.view(x_recon.size(0), -1)
        x = x.view(x.size(0), -1)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kld_loss

    def g_loss_gan(self, fake):
        valid = torch.Tensor(fake.size(0), 1).fill_(1.0).cuda()
        loss = nn.BCELoss()
        return loss(fake, valid)

    def g_loss_total(self,g_loss_vae,g_loss_gan):
        return g_loss_vae + g_loss_gan

四、VAEGAN KL散度

VAEGAN中的KL散度是由VAE的编码器和解码器计算的,其中KL散度衡量的是生成器生成数据与原始数据分布之间的差异,进而通过高斯分布的均值和协方差计算KL散度。

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        epsilon = torch.randn_like(std)
        return mu + epsilon*std

    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1, 64*7*7)
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1),self.fc22(h1)

    def g_loss_vae(self,x_recon,x,mu,log_var):
        x_recon = x_recon.view(x_recon.size(0), -1)
        x = x.view(x.size(0), -1)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kld_loss

五、VAEGAN实现代码

以下的实现代码中使用的数据集是MNIST数据集,对于其他的数据集也可以使用VAEGAN进行训练、生成数据。在此只给出简单的实现代码。

import torch
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from vaegan import *

train_dataset = dset.MNIST(root='./data/dset', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

vaegan = VAEGAN(latent_dim=20).cuda()
optimizer = optim.Adam(vaegan.parameters(), lr=0.0005)
epochs = 10

for epoch in range(epochs):
    for i, (real_inputs, real_labels) in enumerate(train_loader):
        real_inputs = real_inputs.cuda()
        optimizer.zero_grad()
        # discriminator loss
        noise = torch.randn(real_inputs.size(0), vaegan.latent_dim).cuda()
        fake_inputs = vaegan.decode(noise)
        real_output = vaegan.discriminate(real_inputs)
        fake_output = vaegan.discriminate(fake_inputs.detach())
        d_loss = vaegan.d_loss(real_output, fake_output)
        d_loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        # generator loss
        noise = torch.randn(real_inputs.size(0), vaegan.latent_dim).cuda()
        fake_inputs = vaegan.decode(noise)
        fake_output = vaegan.discriminate(fake_inputs)
        x_recon, mu, log_var = vaegan(real_inputs)
        g_loss_vae = vaegan.g_loss_vae(x_recon, real_inputs, mu, log_var)
        g_loss_gan = vaegan.g_loss_gan(fake_output)
        g_loss_total = vaegan.g_loss_total(g_loss_vae, g_loss_gan)
        g_loss_total.backward()
        optimizer.step()
        if i % 50 == 0 and epoch % 2 == 0:
            print("epoch: [%d/%d], batch: [%d/%d], d_loss: %.4f, g_loss_vae: %.4f, g_loss_gan: %.4f" % (
                epoch + 1, epochs, i + 1, len(train_loader), d_loss.item(), g_loss_vae.item(), g_loss_gan.item()))

    torch.save(vaegan.state_dict(), 'vaegan.pt')

六、VAEGAN Python代码

VAEGAN可以使用Python进行实现,以下是一个使用Keras实现的VAEGAN的示例代码:

from keras import backend as K
from keras.models import Model
from keras.layers import Input, Lambda, Concatenate, Dense, Reshape, Flatten, BatchNormalization, Activation, Conv2D, Conv2DTranspose
from keras.losses import binary_crossentropy, mean_squared_error
from keras.datasets import mnist
from keras.optimizers import Adam

import numpy as np

def clip_images(X_train, clip_size):
"""
Reduce the number of colors in the image to a fixed number
"""
X_train = np.clip(X_train, -clip_size, clip_size)
X_train = (X_train + clip_size) / (clip_size * 2)
return X_train.reshape(-1, 28, 28, 1)

def sampler(args):
z_mean, z_log_var = args
eps = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.)
return z_mean + K.exp(z_log_var / 2) * eps

latent_dim = 20

inputs = Input(shape=(28, 28, 1))

h = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding="same")(inputs)
h = BatchNormalization()(h)
h = Activation("relu")(h)

h = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)
h = BatchNormalization()(h)
h = Activation("relu")(h)

h = Flatten()(h)
h = Dense(128)(h)
h = BatchNormalization()(h)
h = Activation("relu")(h)

z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

z = Lambda(sampler)([z_mean, z_log_var])

h = Dense(64 * 7 * 7)(z)
h = BatchNormalization()(h)
h = Activation("relu")(h)
h = Reshape((7, 7, 64))(h)

h = Conv2DTranspose(64, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)
h = BatchNormalization()(h)
h = Activation("relu")(h)

h = Conv2DTranspose(32, kernel_size=(3, 3), strides=(2, 2), padding="same")(

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/197457.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2024-12-03 16:33
下一篇 2024-12-03 20:04

发表回复

登录后才能评论