Last active
May 5, 2023 00:42
-
-
Save nb312/15d27c93c0fef5db7664142c294d50e4 to your computer and use it in GitHub Desktop.
GAN的一个简单 例子
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import keras | |
from keras.models import Sequential, Model | |
from keras.layers import Dense, Input, LeakyReLU, Dropout | |
from keras.optimizers import Adam | |
# 超参数设置 | |
latent_dim = 100 | |
img_shape = (28, 28, 1) | |
# 构建生成器 | |
def build_generator(): | |
model = Sequential() | |
model.add(Dense(256, input_dim=latent_dim)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(512)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(1024)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(np.prod(img_shape), activation='tanh')) | |
return model | |
# 构建判别器 | |
def build_discriminator(): | |
model = Sequential() | |
model.add(Dense(512, input_dim=np.prod(img_shape))) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(256)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dropout(0.4)) | |
model.add(Dense(1, activation='sigmoid')) | |
return model | |
# 创建生成器和判别器 | |
generator = build_generator() | |
discriminator = build_discriminator() | |
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy']) | |
# 构建组合模型 | |
z = Input(shape=(latent_dim,)) | |
img = generator(z) | |
discriminator.trainable = False | |
valid = discriminator(img) | |
combined = Model(z, valid) | |
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5)) | |
# 准备训练数据(MNIST数据集) | |
from keras.datasets import mnist | |
(X_train, _), (_, _) = mnist.load_data() | |
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 | |
X_train = np.expand_dims(X_train, axis=3) | |
# 训练GAN | |
batch_size = 128 | |
epochs = 10000 | |
for epoch in range(epochs): | |
# 随机选择一个真实图像批次 | |
idx = np.random.randint(0, X_train.shape[0], batch_size) | |
imgs = X_train[idx] | |
# 生成一个假图像批次 | |
noise = np.random.normal(0, 1, (batch_size, latent_dim)) | |
gen_imgs = generator.predict(noise) | |
# 训练判别器 | |
d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1))) | |
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1))) | |
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) | |
# 生成一个假图像批次 | |
noise = np.random.normal(0, 1, (batch_size, latent_dim)) | |
# 训练生成器(让判别器错误地将生成的图像判断为真实图像) | |
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1))) | |
# 输出训练进度信息 | |
print("Epoch %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss)) | |
# 每隔一定间隔保存生成的图像 | |
if epoch % 1000 == 0: | |
save_imgs(epoch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment