一、什麼是VAE模型
VAE全稱為Variational Autoencoder,是一種生成模型。VAE通過將輸入數據映射到潛在空間中,實現對樣本的壓縮和重構,並且通過引入潛在變量來控制生成數據的分布,從而可以生成新的數據樣本。
VAE模型的主要特點是使用了變分下界來優化模型,從而讓模型在訓練過程中更加穩定,同時可以利用VAE學到的潛在空間進行插值、生成多個樣本等任務。
下面是使用PyTorch實現VAE模型的示例代碼:
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super(VAE, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
# encoder
self.enc_fc1 = nn.Linear(input_size, hidden_size)
self.enc_fc2_mean = nn.Linear(hidden_size, latent_size)
self.enc_fc2_logvar = nn.Linear(hidden_size, latent_size)
# decoder
self.dec_fc1 = nn.Linear(latent_size, hidden_size)
self.dec_fc2 = nn.Linear(hidden_size, input_size)
def encode(self, x):
h = F.relu(self.enc_fc1(x))
mean = self.enc_fc2_mean(h)
logvar = self.enc_fc2_logvar(h)
return mean, logvar
def reparameterize(self, mean, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
def decode(self, z):
h = F.relu(self.dec_fc1(z))
x_hat = torch.sigmoid(self.dec_fc2(h))
return x_hat
def forward(self, x):
mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
x_hat = self.decode(z)
return x_hat, mean, logvar二、VAE模型在圖像生成中的應用
VAE模型在圖像生成中的應用是在潛在空間中生成新的樣本。通常情況下,我們可以使用VAE將輸入圖片編碼成一個低維的向量,然後在潛在空間中隨機採樣,最後將採樣到的向量解碼成新的圖片。
下面是使用VAE模型在MNIST數據集上進行圖片生成的示例代碼:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# load data
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
# train model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = VAE(input_size=784, hidden_size=512, latent_size=20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def loss_function(x_hat, x, mean, logvar):
BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return BCE + KLD
num_epochs = 20
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1).to(device)
optimizer.zero_grad()
x_hat, mean, logvar = model(inputs)
loss = loss_function(x_hat, inputs, mean, logvar)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] loss: %.3f' %
(epoch + 1, running_loss / len(trainset)))
# generate new images
with torch.no_grad():
z = torch.randn(10, 20).to(device)
samples = model.decode(z).cpu()
fig, axs = plt.subplots(1, 10, figsize=(20, 2))
for i in range(10):
axs[i].imshow(samples[i].view(28, 28), cmap='gray')
axs[i].axis('off')
plt.show()三、VAE模型在圖像修復中的應用
VAE模型在圖像修復中的應用是利用VAE學習到的潛在空間對圖片進行修復。可以將待修復圖片編碼成潛在空間中的向量,對缺失的部分進行插值,然後解碼成新的圖片。
下面是使用VAE模型在CelebA數據集上進行圖像修復的示例代碼:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
# load data
transform = transforms.Compose([
transforms.CenterCrop((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = torchvision.datasets.ImageFolder(root='./celeba_train', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True, num_workers=2)
# train model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = VAE(input_size=3*128*128, hidden_size=1024, latent_size=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def loss_function(x_hat, x, mean, logvar):
BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return BCE + KLD
num_epochs = 20
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(dataloader, 0):
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1).to(device)
optimizer.zero_grad()
x_hat, mean, logvar = model(inputs)
loss = loss_function(x_hat, inputs, mean, logvar)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] loss: %.3f' %
(epoch + 1, running_loss / len(dataset)))
# image inpainting
test_img_path = './test.jpg'
img = Image.open(test_img_path)
img = transform(img).unsqueeze(0).to(device)
img_label = torch.zeros_like(img)
img_label[:,:,50:78,60:88] = img[:,:,50:78,60:88]
with torch.no_grad():
z, _, _ = model.encode(img_label.view(1, -1))
z[:, 256:] = 0 # set the second half of z to 0
fixed_img = model.decode(z)
img = img.cpu().squeeze().numpy().transpose(1,2,0)
img_label = img_label.cpu().squeeze().numpy().transpose(1,2,0)
fixed_img = fixed_img.cpu().squeeze().numpy().transpose(1,2,0)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(img)
axs[0].set_title('original image')
axs[0].axis('off')
axs[1].imshow(img_label)
axs[1].set_title('image with mask')
axs[1].axis('off')
axs[2].imshow(fixed_img)
axs[2].set_title('fixed image')
axs[2].axis('off')
plt.show()原創文章,作者:AOJUK,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/331574.html
微信掃一掃
支付寶掃一掃