一、GAN概述
GAN(Generative Adversarial Networks)是一種生成模型,由生成器和判別器兩大部分組成,目的是學習真實數據分布,並且從噪聲中生成與真實數據相似的樣本。
二、GAN的損失函數
GAN的損失函數包含兩個部分:生成器損失和判別器損失。
1. 生成器損失
生成器的任務是生成與真實數據相似的數據樣本,因此其損失可以定義為判別器無法正確辨別生成數據的概率,即:
G_loss = -log(D(G(z)))
其中,G(z)表示噪聲z經過生成器G生成的樣本,D表示判別器,G_loss越小,則生成樣本越接近真實數據。
2. 判別器損失
判別器的任務是辨別真實數據和生成數據,因此其損失可以定義為正確分類真實樣本的概率和正確分類生成樣本的概率的平均數,即:
D_loss = -log(D(x)) -log(1-D(G(z)))
其中,x表示真實數據,D(x)表示判別器將真實數據判為真實數據的概率,D(G(z))表示判別器將生成數據判為真實數據的概率,D_loss越小,則判別器越能夠準確地分辨真實數據和生成數據。
三、GAN損失函數的訓練過程
GAN的訓練過程是博弈過程,即生成器和判別器不斷地相互博弈,訓練流程如下:
1. 初始化參數
生成器和判別器都需要初始化參數,生成器的參數通常以隨機噪聲z作為輸入,輸出與真實數據相似的樣本數據,判別器的參數通常以真實數據或者生成器生成的數據作為輸入,輸出為0(真實數據)或1(生成數據)。
2. 訓練判別器
首先固定生成器的參數,訓練判別器的參數,讓判別器能夠準確地分辨真實數據和生成數據,即最小化判別器損失函數:
min(D_loss)
3. 訓練生成器
接着固定判別器參數,訓練生成器的參數,讓生成器生成與真實數據相似的樣本數據,即最小化生成器損失函數:
min(G_loss)
4. 不斷交替訓練
在訓練過程中,生成器和判別器不斷交替訓練,直到生成器生成的樣本無法被判別器辨別為止。
四、完整代碼示例
import tensorflow as tf
from tensorflow import keras
import numpy as np
# 定義生成器網絡
def make_generator_model():
model = keras.Sequential()
model.add(keras.layers.Dense(256, input_shape=(100,), use_bias=False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(512, use_bias=False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(28*28*1, use_bias=False, activation='tanh'))
model.add(keras.layers.Reshape((28, 28, 1)))
return model
# 定義判別器網絡
def make_discriminator_model():
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28,28,1)))
model.add(keras.layers.Dense(512))
model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(256))
model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(1, activation='sigmoid'))
return model
# 定義損失函數
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
# 判別器損失函數
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
# 生成器損失函數
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# 定義優化器
generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)
# 定義訓練步驟
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, 100])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# 定義訓練過程
def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
# 加載數據集
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 將像素範圍縮放到[-1, 1]之間
BUFFER_SIZE = 60000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# 構造生成器和判別器
generator = make_generator_model()
discriminator = make_discriminator_model()
# 定義訓練參數
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16 # 每輪生成的樣本數量
# 開始訓練
train(train_dataset, EPOCHS)
# 生成樣本
noise = tf.random.normal([num_examples_to_generate, noise_dim])
generated_images = generator(noise, training=False)
# 展示生成的樣本
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(4,4))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow((generated_images[i, :, :, 0] + 1)/2, cmap='gray')
plt.axis('off')
plt.show()
原創文章,作者:DLLFS,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/371239.html
微信掃一掃
支付寶掃一掃