深度學習中的AI繪畫技術——探究VAE模型

一、什麼是VAE模型

VAE全稱為Variational Autoencoder,是一種生成模型。VAE通過將輸入數據映射到潛在空間中,實現對樣本的壓縮和重構,並且通過引入潛在變量來控制生成數據的分佈,從而可以生成新的數據樣本。

VAE模型的主要特點是使用了變分下界來優化模型,從而讓模型在訓練過程中更加穩定,同時可以利用VAE學到的潛在空間進行插值、生成多個樣本等任務。

下面是使用PyTorch實現VAE模型的示例代碼:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class VAE(nn.Module):
        def __init__(self, input_size, hidden_size, latent_size):
            super(VAE, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.latent_size = latent_size
            
            # encoder
            self.enc_fc1 = nn.Linear(input_size, hidden_size)
            self.enc_fc2_mean = nn.Linear(hidden_size, latent_size)
            self.enc_fc2_logvar = nn.Linear(hidden_size, latent_size)
            
            # decoder
            self.dec_fc1 = nn.Linear(latent_size, hidden_size)
            self.dec_fc2 = nn.Linear(hidden_size, input_size)
            
        def encode(self, x):
            h = F.relu(self.enc_fc1(x))
            mean = self.enc_fc2_mean(h)
            logvar = self.enc_fc2_logvar(h)
            return mean, logvar
        
        def reparameterize(self, mean, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + eps * std
        
        def decode(self, z):
            h = F.relu(self.dec_fc1(z))
            x_hat = torch.sigmoid(self.dec_fc2(h))
            return x_hat
        
        def forward(self, x):
            mean, logvar = self.encode(x)
            z = self.reparameterize(mean, logvar)
            x_hat = self.decode(z)
            return x_hat, mean, logvar

二、VAE模型在圖像生成中的應用

VAE模型在圖像生成中的應用是在潛在空間中生成新的樣本。通常情況下,我們可以使用VAE將輸入圖片編碼成一個低維的向量,然後在潛在空間中隨機採樣,最後將採樣到的向量解碼成新的圖片。

下面是使用VAE模型在MNIST數據集上進行圖片生成的示例代碼:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    
    # load data
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))])
    
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=2)
    
    
    # train model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = VAE(input_size=784, hidden_size=512, latent_size=20).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    def loss_function(x_hat, x, mean, logvar):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        return BCE + KLD
    
    num_epochs = 20
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            inputs = inputs.view(inputs.size(0), -1).to(device)
            optimizer.zero_grad()
            x_hat, mean, logvar = model(inputs)
            loss = loss_function(x_hat, inputs, mean, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    
        print('[%d] loss: %.3f' %
              (epoch + 1, running_loss / len(trainset)))
        
    # generate new images
    with torch.no_grad():
        z = torch.randn(10, 20).to(device)
        samples = model.decode(z).cpu()
    
    fig, axs = plt.subplots(1, 10, figsize=(20, 2))
    for i in range(10):
        axs[i].imshow(samples[i].view(28, 28), cmap='gray')
        axs[i].axis('off')
    
    plt.show()

三、VAE模型在圖像修復中的應用

VAE模型在圖像修復中的應用是利用VAE學習到的潛在空間對圖片進行修復。可以將待修復圖片編碼成潛在空間中的向量,對缺失的部分進行插值,然後解碼成新的圖片。

下面是使用VAE模型在CelebA數據集上進行圖像修復的示例代碼:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    from PIL import Image
    
    # load data
    transform = transforms.Compose([
            transforms.CenterCrop((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    
    dataset = torchvision.datasets.ImageFolder(root='./celeba_train', transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True, num_workers=2)
    
    # train model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = VAE(input_size=3*128*128, hidden_size=1024, latent_size=512).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    def loss_function(x_hat, x, mean, logvar):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        return BCE + KLD
    
    num_epochs = 20
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, _ = data
            inputs = inputs.view(inputs.size(0), -1).to(device)
            optimizer.zero_grad()
            x_hat, mean, logvar = model(inputs)
            loss = loss_function(x_hat, inputs, mean, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    
        print('[%d] loss: %.3f' %
              (epoch + 1, running_loss / len(dataset)))
    
    # image inpainting
    test_img_path = './test.jpg'
    
    img = Image.open(test_img_path)
    img = transform(img).unsqueeze(0).to(device)
    
    img_label = torch.zeros_like(img)
    img_label[:,:,50:78,60:88] = img[:,:,50:78,60:88]
    
    with torch.no_grad():
        z, _, _ = model.encode(img_label.view(1, -1))
        z[:, 256:] = 0 # set the second half of z to 0
        
        fixed_img = model.decode(z)
    
    img = img.cpu().squeeze().numpy().transpose(1,2,0)
    img_label = img_label.cpu().squeeze().numpy().transpose(1,2,0)
    fixed_img = fixed_img.cpu().squeeze().numpy().transpose(1,2,0)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    axs[0].imshow(img)
    axs[0].set_title('original image')
    axs[0].axis('off')
    
    axs[1].imshow(img_label)
    axs[1].set_title('image with mask')
    axs[1].axis('off')
    
    axs[2].imshow(fixed_img)
    axs[2].set_title('fixed image')
    axs[2].axis('off')
    
    plt.show()

原創文章,作者:AOJUK,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/331574.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
AOJUK的頭像AOJUK
上一篇 2025-01-20 14:10
下一篇 2025-01-20 14:10

相關推薦

  • TensorFlow Serving Java:實現開發全功能的模型服務

    TensorFlow Serving Java是作為TensorFlow Serving的Java API,可以輕鬆地將基於TensorFlow模型的服務集成到Java應用程序中。…

    編程 2025-04-29
  • Python訓練模型後如何投入應用

    Python已成為機器學習和深度學習領域中熱門的編程語言之一,在訓練完模型後如何將其投入應用中,是一個重要問題。本文將從多個方面為大家詳細闡述。 一、模型持久化 在應用中使用訓練好…

    編程 2025-04-29
  • Python熱重載技術

    Python熱重載技術是現代編程的關鍵功能之一。它可以幫助我們在程序運行的過程中,更新代碼而無需重新啟動程序。本文將會全方位地介紹Python熱重載的實現方法和應用場景。 一、實現…

    編程 2025-04-29
  • Python實現一元線性回歸模型

    本文將從多個方面詳細闡述Python實現一元線性回歸模型的代碼。如果你對線性回歸模型有一些了解,對Python語言也有所掌握,那麼本文將對你有所幫助。在開始介紹具體代碼前,讓我們先…

    編程 2025-04-29
  • ARIMA模型Python應用用法介紹

    ARIMA(自回歸移動平均模型)是一種時序分析常用的模型,廣泛應用於股票、經濟等領域。本文將從多個方面詳細闡述ARIMA模型的Python實現方式。 一、ARIMA模型是什麼? A…

    編程 2025-04-29
  • 深度查詢宴會的文化起源

    深度查詢宴會,是指通過對一種文化或主題的深度挖掘和探究,為參與者提供一次全方位的、深度體驗式的文化品嘗和交流活動。本文將從多個方面探討深度查詢宴會的文化起源。 一、宴會文化的起源 …

    編程 2025-04-29
  • Python包絡平滑技術解析

    本文將從以下幾個方面對Python包絡平滑技術進行詳細的闡述,包括: 什麼是包絡平滑技術? Python中使用包絡平滑技術的方法有哪些? 包絡平滑技術在具體應用中的實際效果 一、包…

    編程 2025-04-29
  • parent.$.dialog是什麼技術的語法

    parent.$.dialog是一種基於jQuery插件的彈出式對話框技術,它提供了一個方便快捷的方式來創建各種類型和樣式的彈出式對話框。它是對於在網站開發中常見的彈窗、提示框等交…

    編程 2025-04-28
  • 微信小程序重構H5技術方案設計 Github

    本文旨在探討如何在微信小程序中重構H5技術方案,以及如何結合Github進行代碼存儲和版本管理。我們將從以下幾個方面進行討論: 一、小程序與H5技術對比 微信小程序與H5技術都可以…

    編程 2025-04-28
  • VAR模型是用來幹嘛

    VAR(向量自回歸)模型是一種經濟學中的統計模型,用於分析並預測多個變量之間的關係。 一、多變量時間序列分析 VAR模型可以對多個變量的時間序列數據進行分析和建模,通過對變量之間的…

    編程 2025-04-28

發表回復

登錄後才能評論