在Keras中拟合VAE时获得NAN损失

ars1skjm  于 2023-03-23  发布在  其他
关注(0)|答案(1)|浏览(220)

我正在尝试使用Keras在cifar10图像上构建一个变分自动编码器。它在mnist数据上工作得很好。但是使用cifar10,当我调用方法fit时,我的损失(重建损失和KL损失)是NAN,正如你在这里看到的:NAN loss
以下是我的VAE自定义训练步骤:
注:cifar10图像形状=(32,32,3),潜在维度= 2

class VAE(Model):
  
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)

        # encoder and decoder
        self.encoder = encoder
        self.decoder = decoder

        # losses metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
      with tf.GradientTape() as tape:
        # see 4. Encoder
        z_mu, z_sigma, z = self.encoder(data)
        z_decoded = self.decoder(z)

        # compute the losses
        reconstruction_loss = tf.reduce_mean(
                  tf.reduce_sum(
                      keras.losses.binary_crossentropy(data, z_decoded), axis=(1, 2)
                  )
              )
        kl_loss = -(1 + z_sigma - z_mu**2 - tf.exp(z_sigma)) / 2
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # gradients
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # update losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        #  return the final losses
        return {
              "loss": self.total_loss_tracker.result(),
              "reconstruction_loss": self.reconstruction_loss_tracker.result(),
              "kl_loss": self.kl_loss_tracker.result(),
          }

我的编码器:encoder graph
我的解码器:decoder graph
有谁知道吗?

gg0vcinb

gg0vcinb1#

我认为问题出在训练阶跃函数上,这是由于使用了错误的损失函数。将其改为categorical_crossentropy而不是binary_crossentropy将有效。

相关问题