这是GAN的代码。
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
设置发电机网络latent_dim = 100
generator = Sequential()
generator.add(Dense(256 * 7 * 7, input_dim=latent_dim))
generator.add(Reshape((7, 7, 256)))
generator.add(UpSampling2D())
generator.add(Conv2D(128, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(Conv2D(1, kernel_size=3, padding="same"))
generator.add(Activation("tanh"))
设置鉴别器网络鉴别器= Sequential()
discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28, 28, 1), padding="same"))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
定义GAN网络
gan_input = Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = Model(inputs=gan_input, outputs=gan_output)
# Compile the discriminator and GAN networks
discriminator.compile(loss='binary_crossentropy',
optimizer=Adam(0.0002, 0.5),
metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer='adam')
训练GAN
epochs = 10000
batch_size = 512
steps_per_epoch = int(X_train.shape[0] / batch_size)
print(steps_per_epoch)
for epoch in range(epochs):
# Train the discriminator
for step in range(steps_per_epoch):
# Geta batch of real images from the training data
real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
# Generate a batch of fake images using the generator
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
# Train the discriminator on the real and fake images
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# Train the generator
noise = np.random.normal(0, 1, (batch_size, latent_dim))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# Print the progress and save the generated images
print("Epoch {} Discriminator Loss: {} Generator Loss: {}".format(epoch, discriminator_loss[0], generator_loss))
if epoch % 100 == 0:
# Save the generated images
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, latent_dim))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(generated_images[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("gan_mnist_epoch_{}.png".format(epoch))
plt.close()
下面是输出:
Epoch 21 Discriminator Loss: 19.256301978603005 Generator Loss: 0.010846754536032677
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 21ms/step
16/16 [==============================] - 0s 22ms/step
...
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 23ms/step
16/16 [==============================] - ETA: 0s
在这样的“ETA:0s”其他输出停止。代码正常,保存了100个epochs的图像,但没有输出行。可能是jupyter notebook的问题。如何修复?
1条答案
按热度按时间kjthegm61#
解决办法很简单:缓冲区溢出。我们可以用“print(“Epoch {}鉴别器损失:{}发电机损耗:{}".format(epoch,discriminator_loss[0],generator_loss),flush=True)”