VAE模型詳解

一、概述

Variational Autoencoder(VAE)是一種生成模型,廣泛應用於圖像與文本生成等領域。它可以將數據映射到一個潛在空間中,並通過解碼器從這個潛在空間重新生成出輸入數據。

相較於其他生成模型,VAE採用了貝葉斯推斷的方法,能夠更好地描述數據的不確定性。其目標是使數據在潛在空間中服從一個特定的分布,從而使得通過這個分布採樣的數據與真實數據盡量相似。

VAE包含兩個主要的部分:編碼器和解碼器。編碼器將輸入數據壓縮到潛在空間中,解碼器則從潛在空間中重建出數據。在這個過程中,中間的潛在空間起到了”過渡”的作用,即將輸入數據從原始空間映射到潛在空間,再從潛在空間映射回原始空間。通過對潛在空間的建模,我們可以生成與數據分布相似的新數據。

二、編碼器與解碼器

編碼器和解碼器是VAE的核心組成部分。編碼器將輸入數據x映射到潛在空間z中的一個概率分布,解碼器則從潛在空間z中採樣,並生成與原始數據x相似的新數據。

2.1 編碼器

編碼器的主要目的是將輸入數據x映射到潛在空間z中的一個概率分布,即求解p(z|x)。在VAE中,我們假設p(z|x)是一個高斯分布,其均值和方差可以用x計算得到:


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)
        
    def forward(self, x):
        hidden = F.relu(self.fc1(x))
        mu = self.fc21(hidden)
        logvar = self.fc22(hidden)
        return mu, logvar

在上面的代碼中,Encoder是編碼器的實現,其輸入為x,輸出為潛在空間的均值mu和對數方差logvar。兩個分布之間的KL散度可以用以下公式計算:


def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

上述代碼中kl_loss函數用於計算KL散度。

2.2 解碼器

解碼器的主要目的是從潛在空間z中採樣出一組隨機向量z,並通過解碼器將其映射回到原始數據空間中,即求解p(x|z)。一般的,我們假設p(x|z)是一個高斯分布,其均值與方差可以用z計算得到:


class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, z):
        hidden = F.relu(self.fc1(z))
        output = torch.sigmoid(self.fc2(hidden))
        return output

上述代碼中Decoder是解碼器的實現,其輸入為潛在空間的向量z,輸出為生成的圖像數據output。

三、VAE的訓練

訓練VAE的目標是最小化重建損失和KL散度。重建損失用於度量生成數據與真實數據之間的差異,即使得通過解碼器生成的數據盡量接近真實數據。KL散度用於度量從真實數據分布到潛在空間分布的距離,使得生成的數據與真實數據在潛在空間分布上盡量接近。

3.1 重建損失

重建損失的計算是相對簡單的,即通過解碼器從潛在空間中採樣隨機向量,並計算生成數據與真實數據之間的歐幾里得距離:


def reconstruction_loss(x, x_origin):
    return F.mse_loss(x_origin, x, reduction='sum')

上述代碼中reconstruction_loss函數用於計算重建損失。

3.2 KL散度

KL散度中的μ和logσ都是計算得到的,具體如下:


def loss_function(x, x_origin, mu, logvar):
    BCE = reconstruction_loss(x, x_origin)
    KLD = kl_loss(mu, logvar)
    return BCE + KLD, BCE, KLD

上述代碼中的loss_function函數是整個VAE的損失函數,其輸入為真實數據x和生成數據x_origin,以及潛在空間中的均值mu和對數方差logvar。在訓練時,我們將重建損失和KL散度加權相加,得到整個VAE的損失函數。其中,BCE代表重建損失,KLD代表KL散度。

四、代碼示例

下面是一個完整的VAE模型的代碼實現:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)
        
    def forward(self, x):
        hidden = F.relu(self.fc1(x))
        mu = self.fc21(hidden)
        logvar = self.fc22(hidden)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, z):
        hidden = F.relu(self.fc1(z))
        output = torch.sigmoid(self.fc2(hidden))
        return output

def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

def reconstruction_loss(x, x_origin):
    return F.mse_loss(x_origin, x, reduction='sum')

def loss_function(x, x_origin, mu, logvar):
    BCE = reconstruction_loss(x, x_origin)
    KLD = kl_loss(mu, logvar)
    return BCE + KLD, BCE, KLD

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, input_size)
        
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, input_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# train
input_size = 784
hidden_size = 256
latent_size = 10
epochs = 10
batch_size = 64
lr = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

model = VAE(input_size, hidden_size, latent_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs+1):
    model.train()
    train_loss = 0
    train_BCE_loss = 0
    train_KLD_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, BCE_loss, KLD_loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        train_BCE_loss += BCE_loss.item()
        train_KLD_loss += KLD_loss.item()
        optimizer.step()
        
    print('Epoch [{}/{}], Loss: {:.4f}, BCE Loss: {:.4f}, KL Divergence: {:.4f}'.format(
                epoch, epochs, train_loss / len(train_loader.dataset), 
                train_BCE_loss / len(train_loader.dataset),
                train_KLD_loss / len(train_loader.dataset)))

原創文章,作者:LZYLY,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/368911.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
LZYLY的頭像LZYLY
上一篇 2025-04-12 13:00
下一篇 2025-04-12 13:00

相關推薦

  • TensorFlow Serving Java:實現開發全功能的模型服務

    TensorFlow Serving Java是作為TensorFlow Serving的Java API,可以輕鬆地將基於TensorFlow模型的服務集成到Java應用程序中。…

    編程 2025-04-29
  • Python訓練模型後如何投入應用

    Python已成為機器學習和深度學習領域中熱門的編程語言之一,在訓練完模型後如何將其投入應用中,是一個重要問題。本文將從多個方面為大家詳細闡述。 一、模型持久化 在應用中使用訓練好…

    編程 2025-04-29
  • ARIMA模型Python應用用法介紹

    ARIMA(自回歸移動平均模型)是一種時序分析常用的模型,廣泛應用於股票、經濟等領域。本文將從多個方面詳細闡述ARIMA模型的Python實現方式。 一、ARIMA模型是什麼? A…

    編程 2025-04-29
  • Python實現一元線性回歸模型

    本文將從多個方面詳細闡述Python實現一元線性回歸模型的代碼。如果你對線性回歸模型有一些了解,對Python語言也有所掌握,那麼本文將對你有所幫助。在開始介紹具體代碼前,讓我們先…

    編程 2025-04-29
  • VAR模型是用來幹嘛

    VAR(向量自回歸)模型是一種經濟學中的統計模型,用於分析並預測多個變數之間的關係。 一、多變數時間序列分析 VAR模型可以對多個變數的時間序列數據進行分析和建模,通過對變數之間的…

    編程 2025-04-28
  • 如何使用Weka下載模型?

    本文主要介紹如何使用Weka工具下載保存本地機器學習模型。 一、在Weka Explorer中下載模型 在Weka Explorer中選擇需要的分類器(Classifier),使用…

    編程 2025-04-28
  • Python實現BP神經網路預測模型

    BP神經網路在許多領域都有著廣泛的應用,如數據挖掘、預測分析等等。而Python的科學計算庫和機器學習庫也提供了很多的方法來實現BP神經網路的構建和使用,本篇文章將詳細介紹在Pyt…

    編程 2025-04-28
  • Python AUC:模型性能評估的重要指標

    Python AUC是一種用於評估建立機器學習模型性能的重要指標。通過計算ROC曲線下的面積,AUC可以很好地衡量模型對正負樣本的區分能力,從而指導模型的調參和選擇。 一、AUC的…

    編程 2025-04-28
  • 量化交易模型的設計與實現

    本文將從多個方面對量化交易模型進行詳細闡述,並給出對應的代碼示例。 一、量化交易模型的概念 量化交易模型是一種通過數學和統計學方法對市場進行分析和預測的手段,可以幫助交易者進行決策…

    編程 2025-04-27
  • Python決定係數0.8模型可行嗎

    Python決定係數0.8模型的可行性,是在機器學習領域被廣泛關注的問題之一。本篇文章將從多個方面對這個問題進行詳細的闡述,並且給出相應的代碼示例。 一、Python決定係數0.8…

    編程 2025-04-27

發表回復

登錄後才能評論