这是培训函数和开始培训的语句。
def train(d_model, g_model, gan_model, data, target_dir, n_epochs=100, n_batch=16):
# determine the output square shape of the discriminator
n_patch = d_model.output_shape[1]
blue_photo = data[0]
blue_sketch = data[1]
checkpoint_dir = 'Models/checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_opt=opt,
discriminator_opt=opt,
gan_opt=opt,
generator=g_model,
discriminator=d_model,
GAN = gan_model
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
else:
print("Latest checkpoint not found!! Training will begin anew!!")
for i in range(n_epochs):
print(' ========== Epoch', i+1, '========== ')
blue_photo, blue_sketch = shuffle(blue_photo, blue_sketch)
for j in range(int(len(blue_photo)/n_batch)):
start = int(j*n_batch)
end = int(min(len(blue_photo), (j*n_batch)+n_batch))
dataset = [load_images(blue_photo[start:end]), load_images(blue_sketch[start:end])]
# select a batch of real samples
[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
# generate a batch of fake samples
X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
# update discriminator for real samples
d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
# update discriminator for generated samples
d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
d_loss = 0.5 * np.add(d_loss1, d_loss2)
# update the generator
g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
# summarize performance
print('Batch : %d, D Loss : %.3f | G Loss : %.3f' % (j+1, d_loss, g_loss))
#save checkpoints every even epoch
if n_epochs%2 == 0:
ckpt_manager.save()
# summarize model performance
summarize_performance(i, g_model, d_model, dataset, target_dir)
train(d_model, g_model, gan_model, [blue_sketch, blue_photo], 'Models/checkpoint/', n_epochs = 50, n_batch=16)
这个GAN将草图图像转换为逼真的图像。我尝试过不使用检查点进行训练,但它总是在第18个历元时出现ResourceExhausted错误(是否有修复w/o检查点?)。
所以为了避免一次性训练,我设置了检查点。当训练重新开始时,检查点会被保存和恢复。但是即使在24个历元之后,也没有观察到明显的差异。它给我的图像和11个历元之前给我的一样。提前感谢。
1条答案
按热度按时间5lhxktic1#
所提供代码中的问题是,检查点是在训练循环之外保存和恢复的,因此它将始终恢复相同的检查点,从而导致每次都有相同的输出。
要引入变化,您可以将检查点保存和恢复操作移到训练循环内,还可以通过在每个历元开始时重排训练数据来引入数据随机性。
以下是包含这些更改的代码更新版本:
希望能有所帮助。