一、概述
Variational Autoencoder(VAE)是一種生成模型,廣泛應用於圖像與文本生成等領域。它可以將數據映射到一個潛在空間中,並通過解碼器從這個潛在空間重新生成出輸入數據。
相較於其他生成模型,VAE採用了貝葉斯推斷的方法,能夠更好地描述數據的不確定性。其目標是使數據在潛在空間中服從一個特定的分布,從而使得通過這個分布採樣的數據與真實數據盡量相似。
VAE包含兩個主要的部分:編碼器和解碼器。編碼器將輸入數據壓縮到潛在空間中,解碼器則從潛在空間中重建出數據。在這個過程中,中間的潛在空間起到了”過渡”的作用,即將輸入數據從原始空間映射到潛在空間,再從潛在空間映射回原始空間。通過對潛在空間的建模,我們可以生成與數據分布相似的新數據。
二、編碼器與解碼器
編碼器和解碼器是VAE的核心組成部分。編碼器將輸入數據x映射到潛在空間z中的一個概率分布,解碼器則從潛在空間z中採樣,並生成與原始數據x相似的新數據。
2.1 編碼器
編碼器的主要目的是將輸入數據x映射到潛在空間z中的一個概率分布,即求解p(z|x)。在VAE中,我們假設p(z|x)是一個高斯分布,其均值和方差可以用x計算得到:
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc21 = nn.Linear(hidden_size, latent_size)
self.fc22 = nn.Linear(hidden_size, latent_size)
def forward(self, x):
hidden = F.relu(self.fc1(x))
mu = self.fc21(hidden)
logvar = self.fc22(hidden)
return mu, logvar
在上面的代碼中,Encoder是編碼器的實現,其輸入為x,輸出為潛在空間的均值mu和對數方差logvar。兩個分布之間的KL散度可以用以下公式計算:
def kl_loss(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
上述代碼中kl_loss函數用於計算KL散度。
2.2 解碼器
解碼器的主要目的是從潛在空間z中採樣出一組隨機向量z,並通過解碼器將其映射回到原始數據空間中,即求解p(x|z)。一般的,我們假設p(x|z)是一個高斯分布,其均值與方差可以用z計算得到:
class Decoder(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, z):
hidden = F.relu(self.fc1(z))
output = torch.sigmoid(self.fc2(hidden))
return output
上述代碼中Decoder是解碼器的實現,其輸入為潛在空間的向量z,輸出為生成的圖像數據output。
三、VAE的訓練
訓練VAE的目標是最小化重建損失和KL散度。重建損失用於度量生成數據與真實數據之間的差異,即使得通過解碼器生成的數據盡量接近真實數據。KL散度用於度量從真實數據分布到潛在空間分布的距離,使得生成的數據與真實數據在潛在空間分布上盡量接近。
3.1 重建損失
重建損失的計算是相對簡單的,即通過解碼器從潛在空間中採樣隨機向量,並計算生成數據與真實數據之間的歐幾里得距離:
def reconstruction_loss(x, x_origin):
return F.mse_loss(x_origin, x, reduction='sum')
上述代碼中reconstruction_loss函數用於計算重建損失。
3.2 KL散度
KL散度中的μ和logσ都是計算得到的,具體如下:
def loss_function(x, x_origin, mu, logvar):
BCE = reconstruction_loss(x, x_origin)
KLD = kl_loss(mu, logvar)
return BCE + KLD, BCE, KLD
上述代碼中的loss_function函數是整個VAE的損失函數,其輸入為真實數據x和生成數據x_origin,以及潛在空間中的均值mu和對數方差logvar。在訓練時,我們將重建損失和KL散度加權相加,得到整個VAE的損失函數。其中,BCE代表重建損失,KLD代表KL散度。
四、代碼示例
下面是一個完整的VAE模型的代碼實現:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc21 = nn.Linear(hidden_size, latent_size)
self.fc22 = nn.Linear(hidden_size, latent_size)
def forward(self, x):
hidden = F.relu(self.fc1(x))
mu = self.fc21(hidden)
logvar = self.fc22(hidden)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, z):
hidden = F.relu(self.fc1(z))
output = torch.sigmoid(self.fc2(hidden))
return output
def kl_loss(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
def reconstruction_loss(x, x_origin):
return F.mse_loss(x_origin, x, reduction='sum')
def loss_function(x, x_origin, mu, logvar):
BCE = reconstruction_loss(x, x_origin)
KLD = kl_loss(mu, logvar)
return BCE + KLD, BCE, KLD
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.encoder = Encoder(input_size, hidden_size, latent_size)
self.decoder = Decoder(latent_size, hidden_size, input_size)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x.view(-1, input_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# train
input_size = 784
hidden_size = 256
latent_size = 10
epochs = 10
batch_size = 64
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
model = VAE(input_size, hidden_size, latent_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs+1):
model.train()
train_loss = 0
train_BCE_loss = 0
train_KLD_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss, BCE_loss, KLD_loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
train_BCE_loss += BCE_loss.item()
train_KLD_loss += KLD_loss.item()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}, BCE Loss: {:.4f}, KL Divergence: {:.4f}'.format(
epoch, epochs, train_loss / len(train_loader.dataset),
train_BCE_loss / len(train_loader.dataset),
train_KLD_loss / len(train_loader.dataset)))
原創文章,作者:LZYLY,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/368911.html