GAN的損失函數

一、GAN概述

GAN(Generative Adversarial Networks)是一種生成模型,由生成器和判別器兩大部分組成,目的是學習真實數據分布,並且從雜訊中生成與真實數據相似的樣本。

二、GAN的損失函數

GAN的損失函數包含兩個部分:生成器損失和判別器損失。

1. 生成器損失

生成器的任務是生成與真實數據相似的數據樣本,因此其損失可以定義為判別器無法正確辨別生成數據的概率,即:

G_loss = -log(D(G(z)))

其中,G(z)表示雜訊z經過生成器G生成的樣本,D表示判別器,G_loss越小,則生成樣本越接近真實數據。

2. 判別器損失

判別器的任務是辨別真實數據和生成數據,因此其損失可以定義為正確分類真實樣本的概率和正確分類生成樣本的概率的平均數,即:

D_loss = -log(D(x)) -log(1-D(G(z)))

其中,x表示真實數據,D(x)表示判別器將真實數據判為真實數據的概率,D(G(z))表示判別器將生成數據判為真實數據的概率,D_loss越小,則判別器越能夠準確地分辨真實數據和生成數據。

三、GAN損失函數的訓練過程

GAN的訓練過程是博弈過程,即生成器和判別器不斷地相互博弈,訓練流程如下:

1. 初始化參數

生成器和判別器都需要初始化參數,生成器的參數通常以隨機雜訊z作為輸入,輸出與真實數據相似的樣本數據,判別器的參數通常以真實數據或者生成器生成的數據作為輸入,輸出為0(真實數據)或1(生成數據)。

2. 訓練判別器

首先固定生成器的參數,訓練判別器的參數,讓判別器能夠準確地分辨真實數據和生成數據,即最小化判別器損失函數:

min(D_loss)

3. 訓練生成器

接著固定判別器參數,訓練生成器的參數,讓生成器生成與真實數據相似的樣本數據,即最小化生成器損失函數:

min(G_loss)

4. 不斷交替訓練

在訓練過程中,生成器和判別器不斷交替訓練,直到生成器生成的樣本無法被判別器辨別為止。

四、完整代碼示例

import tensorflow as tf
from tensorflow import keras
import numpy as np

# 定義生成器網路
def make_generator_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.LeakyReLU())

    model.add(keras.layers.Dense(512, use_bias=False))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.LeakyReLU())

    model.add(keras.layers.Dense(28*28*1, use_bias=False, activation='tanh'))
    model.add(keras.layers.Reshape((28, 28, 1)))

    return model

# 定義判別器網路
def make_discriminator_model():
    model = keras.Sequential()
    model.add(keras.layers.Flatten(input_shape=(28,28,1)))
    model.add(keras.layers.Dense(512))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.Dense(256))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.Dense(1, activation='sigmoid'))

    return model

# 定義損失函數
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

# 判別器損失函數
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss

    return total_loss

# 生成器損失函數
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# 定義優化器
generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)

# 定義訓練步驟
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# 定義訓練過程
def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)

# 載入數據集
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # 將像素範圍縮放到[-1, 1]之間
BUFFER_SIZE = 60000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 構造生成器和判別器
generator = make_generator_model()
discriminator = make_discriminator_model()

# 定義訓練參數
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16  # 每輪生成的樣本數量

# 開始訓練
train(train_dataset, EPOCHS)

# 生成樣本
noise = tf.random.normal([num_examples_to_generate, noise_dim])
generated_images = generator(noise, training=False)

# 展示生成的樣本
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(4,4))

for i in range(generated_images.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow((generated_images[i, :, :, 0] + 1)/2, cmap='gray')
    plt.axis('off')

plt.show()

原創文章,作者:DLLFS,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/371239.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
DLLFS的頭像DLLFS
上一篇 2025-04-23 00:48
下一篇 2025-04-23 00:48

相關推薦

  • Python中引入上一級目錄中函數

    Python中經常需要調用其他文件夾中的模塊或函數,其中一個常見的操作是引入上一級目錄中的函數。在此,我們將從多個角度詳細解釋如何在Python中引入上一級目錄的函數。 一、加入環…

    編程 2025-04-29
  • Python中capitalize函數的使用

    在Python的字元串操作中,capitalize函數常常被用到,這個函數可以使字元串中的第一個單詞首字母大寫,其餘字母小寫。在本文中,我們將從以下幾個方面對capitalize函…

    編程 2025-04-29
  • Python中set函數的作用

    Python中set函數是一個有用的數據類型,可以被用於許多編程場景中。在這篇文章中,我們將學習Python中set函數的多個方面,從而深入了解這個函數在Python中的用途。 一…

    編程 2025-04-29
  • 三角函數用英語怎麼說

    三角函數,即三角比函數,是指在一個銳角三角形中某一角的對邊、鄰邊之比。在數學中,三角函數包括正弦、餘弦、正切等,它們在數學、物理、工程和計算機等領域都得到了廣泛的應用。 一、正弦函…

    編程 2025-04-29
  • 單片機列印函數

    單片機列印是指通過串口或並口將一些數據列印到終端設備上。在單片機應用中,列印非常重要。正確的列印數據可以讓我們知道單片機運行的狀態,方便我們進行調試;錯誤的列印數據可以幫助我們快速…

    編程 2025-04-29
  • Python3定義函數參數類型

    Python是一門動態類型語言,不需要在定義變數時顯示的指定變數類型,但是Python3中提供了函數參數類型的聲明功能,在函數定義時明確定義參數類型。在函數的形參後面加上冒號(:)…

    編程 2025-04-29
  • Python定義函數判斷奇偶數

    本文將從多個方面詳細闡述Python定義函數判斷奇偶數的方法,並提供完整的代碼示例。 一、初步了解Python函數 在介紹Python如何定義函數判斷奇偶數之前,我們先來了解一下P…

    編程 2025-04-29
  • Python實現計算階乘的函數

    本文將介紹如何使用Python定義函數fact(n),計算n的階乘。 一、什麼是階乘 階乘指從1乘到指定數之間所有整數的乘積。如:5! = 5 * 4 * 3 * 2 * 1 = …

    編程 2025-04-29
  • Python函數名稱相同參數不同:多態

    Python是一門面向對象的編程語言,它強烈支持多態性 一、什麼是多態多態是面向對象三大特性中的一種,它指的是:相同的函數名稱可以有不同的實現方式。也就是說,不同的對象調用同名方法…

    編程 2025-04-29
  • 分段函數Python

    本文將從以下幾個方面詳細闡述Python中的分段函數,包括函數基本定義、調用示例、圖像繪製、函數優化和應用實例。 一、函數基本定義 分段函數又稱為條件函數,指一條直線段或曲線段,由…

    編程 2025-04-29

發表回復

登錄後才能評論