CycleGAN演算法

CycleGAN是一種用於圖像轉換的無監督深度學習演算法。與傳統的圖像轉換方法不同,CycleGAN可以在沒有對應訓練集的情況下進行圖像轉換。本文將從演算法原理、模型結構、實現細節、代碼示例等多個方面詳細介紹CycleGAN演算法。

一、演算法原理

CycleGAN演算法基於生成對抗網路(GAN)的思想。GAN包括一個生成器和一個判別器,生成器用於生成假圖片,判別器用於判斷真假。CycleGAN在此基礎上引入了兩個生成器和兩個判別器,用於進行兩個域之間的圖像轉換。

具體地,假設有兩個域X和Y,其中X為照片域,Y為素描域。CycleGAN需要學習兩個映射函數:G:X→Y和F:Y→X。G將X域中的照片轉換為Y域中的素描,F將Y域中的素描轉換為X域中的照片。同時,CycleGAN還需要學習兩個判別器DY和DX,DY用於判斷Y域中的圖片是否為真實素描,DX用於判斷X域中的圖片是否為真實照片。

為了實現這種域之間的圖像轉換,CycleGAN使用了對抗性損失和循環一致性損失。對抗性損失是指生成器G和判別器DY之間的對抗性損失,這意味著G在產生Y域圖像時需要儘可能欺騙DY,而DY需要儘可能地將真實的Y域圖片和G生成的假圖片區分開。循環一致性損失是指,在圖像從X域到Y域的轉換結束後,再反向回到X域時,轉換的結果應該與最初的X域圖像儘可能一致;同理,在圖像從Y域到X域的轉換結束後,再反向回到Y域時轉換的結果應該與最初的Y域圖像儘可能一致。這樣可以避免在進行域之間的圖像轉換時出現不一致的情況。

二、模型結構

CycleGAN的模型結構主要包括兩個生成器和兩個判別器。其中,每個生成器包括一些卷積層、反卷積層和殘差塊;每個判別器包括一些卷積層。

以將照片轉換為素描為例(即從域X到域Y的轉換),生成器G包括從X域圖像到素描圖像的轉換部分和從素描圖像到X域圖像的反轉換部分。G的轉換部分包括5個卷積層和6個殘差塊,反轉換部分包括5個反卷積層和6個殘差塊。

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        conv_block = [  nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.InstanceNorm2d(in_channels, affine=True, track_running_stats=True),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.InstanceNorm2d(in_channels, affine=True, track_running_stats=True) ]
        self.conv_block = nn.Sequential(*conv_block)
        
    def forward(self, x):
        return x + self.conv_block(x)
        
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
          super(Generator, self).__init__()
          # Convolutional layers
          self.conv1 = nn.Conv2d(input_nc, 64, kernel_size=7, stride=1, padding=3, bias=False)
          self.norm1 = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
          self.relu1 = nn.ReLU(inplace=True)

          # Residual blocks
          self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(n_residual_blocks)])

          # Deconvolutional layers
          self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
          self.norm2 = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
          self.relu2 = nn.ReLU(inplace=True)

          self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
          self.norm3 = nn.InstanceNorm2d(32, affine=True, track_running_stats=True)
          self.relu3 = nn.ReLU(inplace=True)

          self.deconv3 = nn.ConvTranspose2d(32, output_nc, kernel_size=7, stride=1, padding=3, bias=False)
          self.tanh = nn.Tanh()

    def forward(self, x):
          # Encoder
          x = self.conv1(x)
          x = self.norm1(x)
          x = self.relu1(x)

          # Residual blocks
          x = self.res_blocks(x)

          # Decoder
          x = self.deconv1(x)
          x = self.norm2(x)
          x = self.relu2(x)

          x = self.deconv2(x)
          x = self.norm3(x)
          x = self.relu3(x)

          x = self.deconv3(x)
          x = self.tanh(x)

          return x

判別器D包括5個卷積層,每個卷積層後跟一個LeakyReLU層(用於允許梯度向後傳播)。所有卷積核大小為4×4,輸入圖像大小為256×256,通道數為3(即RGB圖像)。

class Discriminator(nn.Module):
      def __init__(self, input_nc):
          super(Discriminator, self).__init__()

          # Convolution layers
          self.conv1 = nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1)
          self.leaky_relu1 = nn.LeakyReLU(0.2, inplace=True)

          self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False)
          self.norm1 = nn.InstanceNorm2d(128, affine=True, track_running_stats=True)
          self.leaky_relu2 = nn.LeakyReLU(0.2, inplace=True)

          self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False)
          self.norm2 = nn.InstanceNorm2d(256, affine=True, track_running_stats=True)
          self.leaky_relu3 = nn.LeakyReLU(0.2, inplace=True)

          self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False)
          self.norm3 = nn.InstanceNorm2d(512, affine=True, track_running_stats=True)
          self.leaky_relu4 = nn.LeakyReLU(0.2, inplace=True)

          self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, bias=False)

      def forward(self, x):
          x = self.conv1(x)
          x = self.leaky_relu1(x)

          x = self.conv2(x)
          x = self.norm1(x)
          x = self.leaky_relu2(x)

          x = self.conv3(x)
          x = self.norm2(x)
          x = self.leaky_relu3(x)

          x = self.conv4(x)
          x = self.norm3(x)
          x = self.leaky_relu4(x)

          x = self.conv5(x)

          return x

三、實現細節

在實現CycleGAN時,需要注意以下幾個細節:

1、在訓練過程中,需要分別訓練G、F、DX和DY。在每個訓練輪次中,需要隨機選擇一個batch的X域圖片和一個batch的Y域圖片,分別作為G和F的輸入,生成出G(X)和F(Y),然後進行對抗性訓練。

# Train G
optimizer_G.zero_grad()

# GAN loss(對抗性損失)
pred_fake = D_Y(G_X)
loss_GAN = criterion_GAN(pred_fake, target_real)
# Cycle loss(循環一致性損失)
recovered_X = F_Y(G_X)
loss_cycle_XY = criterion_cycle(recovered_X, real_X) * lambda_cycle

# Total loss
loss_G = loss_GAN + loss_cycle_XY
loss_G.backward()

optimizer_G.step()

2、在訓練過程中,需要對生成器的輸出進行歸一化(對於RGB圖片,需要將像素值從[0,255]歸一化到[-1,1])。

# Generate fake X
fake_X = G_Y(real_Y)
# Normalize outputs to be in [-1, 1]
fake_X = (fake_X + 1) / 2.0

3、對於CycleGAN中的循環一致性損失,一般設置一個權重係數lambda_cycle來進行平衡。lambda_cycle的值可以根據具體任務進行調整,不同的任務可能需要不同的lambda_cycle。

四、代碼示例

下面是CycleGAN的一個簡單代碼示例,用於將馬的圖像轉換為斑馬的圖像。

# Loss function
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# Optimizers
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(dataloader):

        # Adversarial ground truths
        valid = torch.Tensor(np.ones((real_A.size(0), *D_output_shape)))
        fake = torch.Tensor(np.zeros((real_A.size(0), *D_output_shape)))

        # Configure input tensor
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = G(real_A)
        pred_fake = D(fake_B)
        loss_GAN_AB = criterion_GAN(pred_fake, valid)

        fake_A = F(real_B)
        pred_fake = D(fake_A)
        loss_GAN_BA = criterion_GAN(pred_fake, valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recovered_A = F(fake_B)
        loss_cycle_A = criterion_cycle(recovered_A, real_A)

        recovered_B = G(fake_A)
        loss_cycle_B = criterion_cycle(recovered_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + lambda_cycle * loss_cycle
        loss_G.backward()

        optimizer_G.step()

        # ----------------------
        #  Train Discriminators
        # ----------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = D(real_B)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = D(fake_B.detach())
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/244622.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-12 13:03
下一篇 2024-12-12 13:03

相關推薦

  • 蝴蝶優化演算法Python版

    蝴蝶優化演算法是一種基於仿生學的優化演算法,模仿自然界中的蝴蝶進行搜索。它可以應用於多個領域的優化問題,包括數學優化、工程問題、機器學習等。本文將從多個方面對蝴蝶優化演算法Python版…

    編程 2025-04-29
  • Python實現爬樓梯演算法

    本文介紹使用Python實現爬樓梯演算法,該演算法用於計算一個人爬n級樓梯有多少種不同的方法。 有一樓梯,小明可以一次走一步、兩步或三步。請問小明爬上第 n 級樓梯有多少種不同的爬樓梯…

    編程 2025-04-29
  • AES加密解密演算法的C語言實現

    AES(Advanced Encryption Standard)是一種對稱加密演算法,可用於對數據進行加密和解密。在本篇文章中,我們將介紹C語言中如何實現AES演算法,並對實現過程進…

    編程 2025-04-29
  • Harris角點檢測演算法原理與實現

    本文將從多個方面對Harris角點檢測演算法進行詳細的闡述,包括演算法原理、實現步驟、代碼實現等。 一、Harris角點檢測演算法原理 Harris角點檢測演算法是一種經典的計算機視覺演算法…

    編程 2025-04-29
  • 數據結構與演算法基礎青島大學PPT解析

    本文將從多個方面對數據結構與演算法基礎青島大學PPT進行詳細的闡述,包括數據類型、集合類型、排序演算法、字元串匹配和動態規劃等內容。通過對這些內容的解析,讀者可以更好地了解數據結構與算…

    編程 2025-04-29
  • 瘦臉演算法 Python 原理與實現

    本文將從多個方面詳細闡述瘦臉演算法 Python 實現的原理和方法,包括該演算法的意義、流程、代碼實現、優化等內容。 一、演算法意義 隨著科技的發展,瘦臉演算法已經成為了人們修圖中不可缺少…

    編程 2025-04-29
  • 神經網路BP演算法原理

    本文將從多個方面對神經網路BP演算法原理進行詳細闡述,並給出完整的代碼示例。 一、BP演算法簡介 BP演算法是一種常用的神經網路訓練演算法,其全稱為反向傳播演算法。BP演算法的基本思想是通過正…

    編程 2025-04-29
  • 粒子群演算法Python的介紹和實現

    本文將介紹粒子群演算法的原理和Python實現方法,將從以下幾個方面進行詳細闡述。 一、粒子群演算法的原理 粒子群演算法(Particle Swarm Optimization, PSO…

    編程 2025-04-29
  • Python回歸演算法算例

    本文將從以下幾個方面對Python回歸演算法算例進行詳細闡述。 一、回歸演算法簡介 回歸演算法是數據分析中的一種重要方法,主要用於預測未來或進行趨勢分析,通過對歷史數據的學習和分析,建立…

    編程 2025-04-28
  • 象棋演算法思路探析

    本文將從多方面探討象棋演算法,包括搜索演算法、啟發式演算法、博弈樹演算法、神經網路演算法等。 一、搜索演算法 搜索演算法是一種常見的求解問題的方法。在象棋中,搜索演算法可以用來尋找最佳棋步。經典的…

    編程 2025-04-28

發表回復

登錄後才能評論