python 使用检查点进行GAN训练,但模型根本无法学习

46scxncf  于 2023-03-11  发布在  Python
关注(0)|答案(1)|浏览(144)

这是培训函数和开始培训的语句。

def train(d_model, g_model, gan_model, data, target_dir, n_epochs=100, n_batch=16):
    # determine the output square shape of the discriminator
    n_patch = d_model.output_shape[1]
    
    blue_photo = data[0]
    blue_sketch = data[1]
    
    checkpoint_dir = 'Models/checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_opt=opt,
                                  discriminator_opt=opt,
                                  gan_opt=opt,
                                  generator=g_model,
                                  discriminator=d_model,
                                  GAN = gan_model
                                  )

    ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
    if ckpt_manager.latest_checkpoint:
      checkpoint.restore(ckpt_manager.latest_checkpoint)
      print ('Latest checkpoint restored!!')
    else:
        print("Latest checkpoint not found!! Training will begin anew!!")
    
    for i in range(n_epochs):
        print(' ========== Epoch', i+1, '========== ')
        
        blue_photo, blue_sketch = shuffle(blue_photo, blue_sketch)

        for j in range(int(len(blue_photo)/n_batch)):
            
            start = int(j*n_batch)
            end = int(min(len(blue_photo), (j*n_batch)+n_batch))
            
            dataset = [load_images(blue_photo[start:end]), load_images(blue_sketch[start:end])]

            # select a batch of real samples
            [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
            
            # generate a batch of fake samples
            X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
            
            # update discriminator for real samples
            d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
            
            # update discriminator for generated samples
            d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
            
            d_loss = 0.5 * np.add(d_loss1, d_loss2)
            
            # update the generator
            g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
            
            # summarize performance
            print('Batch : %d, D Loss : %.3f | G Loss : %.3f' % (j+1, d_loss, g_loss))
        
        #save checkpoints every even epoch    
        if n_epochs%2 == 0:
            ckpt_manager.save()
        
        # summarize model performance
        summarize_performance(i, g_model, d_model, dataset, target_dir)

train(d_model, g_model, gan_model, [blue_sketch, blue_photo], 'Models/checkpoint/', n_epochs = 50, n_batch=16)

这个GAN将草图图像转换为逼真的图像。我尝试过不使用检查点进行训练,但它总是在第18个历元时出现ResourceExhausted错误(是否有修复w/o检查点?)。
所以为了避免一次性训练,我设置了检查点。当训练重新开始时,检查点会被保存和恢复。但是即使在24个历元之后,也没有观察到明显的差异。它给我的图像和11个历元之前给我的一样。提前感谢。

5lhxktic

5lhxktic1#

所提供代码中的问题是,检查点是在训练循环之外保存和恢复的,因此它将始终恢复相同的检查点,从而导致每次都有相同的输出。
要引入变化,您可以将检查点保存和恢复操作移到训练循环内,还可以通过在每个历元开始时重排训练数据来引入数据随机性。
以下是包含这些更改的代码更新版本:

def train(d_model, g_model, gan_model, data, target_dir, n_epochs=100, n_batch=16):                                                                   
# determine the output square shape of the discriminator
n_patch = d_model.output_shape[1]

blue_photo = data[0]
blue_sketch = data[1]

checkpoint_dir = 'Models/checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_opt=opt,
                              discriminator_opt=opt,
                              gan_opt=opt,
                              generator=g_model,
                              discriminator=d_model,
                              GAN = gan_model
                              )

ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)

for i in range(n_epochs):
    print(' ========== Epoch', i+1, '========== ')
    
    # shuffle the data
    blue_photo, blue_sketch = shuffle(blue_photo, blue_sketch)

    for j in range(int(len(blue_photo)/n_batch)):
        
        start = int(j*n_batch)
        end = int(min(len(blue_photo), (j*n_batch)+n_batch))
        
        dataset = [load_images(blue_photo[start:end]), load_images(blue_sketch[start:end])]

        # select a batch of real samples
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
        
        # generate a batch of fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        
        # update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        
        # update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        
        d_loss = 0.5 * np.add(d_loss1, d_loss2)
        
        # update the generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
        
        # summarize performance
        print('Batch : %d, D Loss : %.3f | G Loss : %.3f' % (j+1, d_loss, g_loss))
    
    #save checkpoints every even epoch    
    if (i+1) % 2 == 0:
        ckpt_manager.save()
    
    # summarize model performance
    summarize_performance(i, g_model, d_model, dataset, target_dir)
    
    # restore from checkpoint
    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)
        print ('Latest checkpoint restored!!')
    else:
        print("Latest checkpoint not found!! Training will begin anew!!")

希望能有所帮助。

相关问题