我正在尝试使用Keras在cifar10图像上构建一个变分自动编码器。它在mnist数据上工作得很好。但是使用cifar10,当我调用方法fit时,我的损失(重建损失和KL损失)是NAN,正如你在这里看到的:NAN loss
以下是我的VAE自定义训练步骤:
注:cifar10图像形状=(32,32,3),潜在维度= 2
class VAE(Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
# encoder and decoder
self.encoder = encoder
self.decoder = decoder
# losses metrics
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
# see 4. Encoder
z_mu, z_sigma, z = self.encoder(data)
z_decoded = self.decoder(z)
# compute the losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, z_decoded), axis=(1, 2)
)
)
kl_loss = -(1 + z_sigma - z_mu**2 - tf.exp(z_sigma)) / 2
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# gradients
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# update losses
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
# return the final losses
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
我的编码器:encoder graph
我的解码器:decoder graph
有谁知道吗?
1条答案
按热度按时间gg0vcinb1#
我认为问题出在训练阶跃函数上,这是由于使用了错误的损失函数。将其改为categorical_crossentropy而不是binary_crossentropy将有效。