GAN состоит из двух частей: генератора и дискриминатора. Две сети конкурируют друг с другом за постоянное улучшение своих возможностей, в конечном итоге генерируя реалистичные данные.
Процесс обучения GAN похож на игру: генератор пытается сделать дискриминатор неспособным различать настоящие и поддельные данные, а дискриминатор изо всех сил старается правильно отличать реальные данные от сгенерированных. Цель GAN — сделать сгенерированные генератором выборки всё ближе и ближе к реальному распределению выборок и в конечном итоге добиться того эффекта, при котором сгенерированные данные практически неотличимы от реальных данных.
Генератор GAN принимает на вход случайный шум, поэтому генерируемые данные каждый раз разные. Шум обычно отбирается из простого распределения, такого как стандартное нормальное распределение или равномерное распределение:
Целью выборки случайного шума является введение разнообразия, которое позволяет генератору генерировать различные типы выборок во время обучения, тем самым изучая более подробную информацию о распределении выборки.
noise = np.random.normal(0, 1, (batch_size, noise_dim))
Генератор GGG представляет собой нейронную сеть, которая получает вектор шума zzz и посредством серии нелинейных преобразований генерирует выборки, аналогичные реальному распределению данных. Задача генератора — сгенерировать максимально реалистичные образцы, чтобы обмануть дискриминатор. Выходные данные генератора должны быть очень близки к реальным данным по форме, характеристикам и распределению.
генераторизвходитьимеет низкую размерностьизслучайный шум,И этовыходЭто многомерныйиз Генерировать данные (например, изображение или аудио). На ранних этапах обучения образцы генераторвыхода могут быть такими же, как реальные. данные сильно различаются, но по мере обучения генератор учится улавливать реальные функции данных и генерировать реалистичные поддельные образцы.
генераториз Основная цель –Максимизируйте частоту ошибок дискриминатора,То есть снижается способность дискриминатора различать истину и ложь за счет создания более реалистичных образцов.
generated_samples = generator.predict(noise)
Задача дискриминатора DDD — классифицировать входные данные и определить, является ли это реальной выборкой или сгенерированной выборкой. Он принимает два типа ввода:
Дискриминатор выводит значение вероятности D(x)D(x)D(x), которое представляет вероятность того, что выборка получена из реальных данных. В идеале дискриминатор может точно различать эти два типа выборок:
Функция потерь дискриминатора обычно использует двоичную кросс-энтропийную потерю, рассчитываемую отдельно для реальных данных и сгенерированных данных. Целью оптимизации дискриминатора является максимизация точности классификации, то есть правильная идентификация реальных выборок и правильное обнаружение поддельных выборок, генерируемых генератором.
real_loss = discriminator.train_on_batch(real_data, real_labels)
fake_loss = discriminator.train_on_batch(generated_samples, fake_labels)
функция потерь генератора
Цель генератора — заставить дискриминатор думать, что сгенерированные им данные реальны, поэтому он выполняет обратное распространение ошибки, чтобы минимизировать Генератор. данныеизпотеря。функция потерь генератора Разработан какМаксимизируйте вероятность ошибки дискриминатора。поэтому,Потери генератора определяются как:
LG=−log(D(G(z)))L_G = - \log(D(G(z)))LG=−log(D(G(z)))
в D(G(z))D(G(z))D(G(z)) представляет прогнозируемое значение дискриминатора для поддельной выборки, сгенерированной генератором. генератор надеется, что дискриминатор поверит, что эти фейковые образцы настоящие,поэтомуоно пытаетсяминимизироватьэто значение。
Функция потерь дискриминатора
Задача дискриминатора — отличить реальные данные от сгенерированных, поэтому его функция потерь состоит из двух частей:
Конечная функция потерь дискриминатора представляет собой взвешенную сумму этих двух частей потерь:
LD=−(log(D(x))+log(1−D(G(z))))L_D = - \left( \log(D(x)) + \log(1 - D(G(z))) \right)LD=−(log(D(x))+log(1−D(G(z))))
Процесс оптимизации
GANизтренироватьсяиспользоватьАлгоритм обратного распространениягенератор Масса обновлений и дискриминаторов. Процесс обучения обычно делится на два этапа:
Процесс обучения GAN представляет собой процесс попеременного обновления, а генератор и дискриминатор постоянно совершенствуются посредством этого конфронтационного обучения. В идеале обучение продолжается до тех пор, пока данные, сгенерированные генератором, не станут неразличимы дискриминатором.
# Обновить дискриминатор
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(real_samples, real_labels)
d_loss_fake = discriminator.train_on_batch(generated_samples, fake_labels)
# генератор обновлений
discriminator.trainable = False
g_loss = gan.train_on_batch(noise, real_labels)
В ГАН спортивный процесссередина,Баланс между генератором и дискриминаторомэто ключевой вопрос。тренироватьсяиз Идеальный результат – этогенераторгенерироватьиз Образцы становятся более реалистичными,дискриминатор Не могу сказатьреальные данныеи Генерировать данные. Однако в ходе реального обучения часто встречаются следующие проблемы:
свернуть Режим — распространенная проблема при обучении GAN, что означает, что генератор начинает концентрироваться на генерации определенного типа данных, игнорируя при этом другие закономерности в распределении данных. Несмотря на то, что результат генератора выглядит аутентично, ему не хватает разнообразия, чтобы охватить реальные данныеизвесь дистрибутив。Чтобы решить эту проблему,Исследователи предложили множество способов улучшить,Если используетсяпакетная регуляризацияИли используйте большегенератор Архитектура。
Обучение GAN очень чувствительно к настройкам параметров. Неправильная настройка скорости обучения, сложности модели и веса функции потерь генератора и дискриминатора может привести к нестабильному обучению или даже к сбою. Общие решения включают использование WGAN (Wasserstein GAN) для уменьшения нестабильности обучения и создание более сбалансированной конкуренции между генератором и дискриминатором за счет соответствующей настройки гиперпараметров.
Либо дискриминатор слишком силен, либо генератор слишком слаб, что приводит к сбою обучения. Если дискриминатор слишком мощный, он быстро отличит реальные данные от сгенерированных, оставляя генератору мало шансов на обучение. В это время баланс обучения можно улучшить, ограничив количество шагов обновления дискриминатора или скорректировав структуру модели.
Благодаря широкому применению и углубленному исследованию GAN было предложено множество улучшенных версий, устраняющих его ограничения, такие как:
Эти варианты решают различные проблемы обучения GAN, еще больше расширяя возможности и эффекты GAN в практических приложениях.
Ниже приведен простой пример кода GAN с использованием платформы TensorFlow и Keras на Python, показывающий, как обучить GAN генерировать изображения рукописных цифр (на основе набора данных MNIST).
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
# Загрузка набора данных MNIST
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train,axis=-1)
# создаватьгенератор
def build_generator():
model = tf.keras.Sequential()
model.add(layers.Dense(256, input_dim=100))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(1024))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(28 * 28 * 1, activation='tanh'))
model.add(layers.Reshape((28, 28,1)))
return model
# создаватьдискриминатор
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1, activation='sigmoid'))
return model
# Определить модель GAN
def build_gan(generator, discriminator):
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.trainable = False
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
return gan
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
# Обучение ГАН
def train_gan(epochs,batch_size=128):
for epoch in range(epochs):
# тренироватьсядискриминатор noise = np.random.normal(0, 1, (batch_size, 100))
generated_images = generator.predict(noise)
real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
labels_real = np.ones((batch_size, 1))
labels_fake = np.zeros((batch_size, 1))
d_loss_real = discriminator.train_on_batch(real_images, labels_real)
d_loss_fake = discriminator.train_on_batch(generated_images, labels_fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# обучающий генератор
noise = np.random.normal(0, 1, (batch_size, 100))
labels = np.ones((batch_size, 1))
g_loss = gan.train_on_batch(noise, labels)
if epoch % 100 == 0:
print(f"Epoch {epoch}, D loss: {d_loss[0]}, G loss: {g_loss}")
# Начать обучение
train_gan(epochs=10000)
Генеративно-состязательные сети (GAN) открыли совершенно новую область машинного обучения и особенно хороши в создании высококачественных изображений, видео и других форм данных. Благодаря состязательному обучению двух нейронных сетей GAN способна генерировать поддельные данные, которые практически неотличимы от реальных данных. Несмотря на то, что в процессе обучения существуют проблемы, благодаря постоянным улучшениям, таким как WGAN, условный GAN и т. д., потенциал GAN был проверен во многих областях. Ожидается, что в будущем GAN будет играть более важную роль в более практических приложениях, от генерации изображений до творческих областей искусственного интеллекта, и принесет нам больше сюрпризов.