keras GAN奇怪输出ETA

q5iwbnjs  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(98)

这是GAN的代码。

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

设置发电机网络latent_dim = 100

generator = Sequential()

generator.add(Dense(256 * 7 * 7, input_dim=latent_dim))
generator.add(Reshape((7, 7, 256)))
generator.add(UpSampling2D())
generator.add(Conv2D(128, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(Conv2D(1, kernel_size=3, padding="same"))
generator.add(Activation("tanh"))

设置鉴别器网络鉴别器= Sequential()

discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28, 28, 1), padding="same"))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

定义GAN网络

gan_input = Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(inputs=gan_input, outputs=gan_output)

# Compile the discriminator and GAN networks
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(0.0002, 0.5),
                      metrics=['accuracy'])

gan.compile(loss='binary_crossentropy', optimizer='adam')

训练GAN

epochs = 10000
batch_size = 512
steps_per_epoch = int(X_train.shape[0] / batch_size)
print(steps_per_epoch)

for epoch in range(epochs):
    # Train the discriminator
    for step in range(steps_per_epoch):
        # Geta batch of real images from the training data
        real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]

        # Generate a batch of fake images using the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_images = generator.predict(noise)

        # Train the discriminator on the real and fake images
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

    # Print the progress and save the generated images
    print("Epoch {} Discriminator Loss: {} Generator Loss: {}".format(epoch, discriminator_loss[0], generator_loss))

    if epoch % 100 == 0:
        # Save the generated images
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, latent_dim))
        generated_images = generator.predict(noise)
        generated_images = 0.5 * generated_images + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(generated_images[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("gan_mnist_epoch_{}.png".format(epoch))
        plt.close()

下面是输出:

Epoch 21 Discriminator Loss: 19.256301978603005 Generator Loss: 0.010846754536032677
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 21ms/step
16/16 [==============================] - 0s 22ms/step
...
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 23ms/step
16/16 [==============================] - ETA: 0s

在这样的“ETA:0s”其他输出停止。代码正常,保存了100个epochs的图像,但没有输出行。可能是jupyter notebook的问题。如何修复?

kjthegm6

kjthegm61#

解决办法很简单:缓冲区溢出。我们可以用“print(“Epoch {}鉴别器损失:{}发电机损耗:{}".format(epoch,discriminator_loss[0],generator_loss),flush=True)”

相关问题