Keras变分自动编码器损失函数

rnmwe5a2  于 2023-04-12  发布在  其他
关注(0)|答案(3)|浏览(197)

我读过Keras关于VAE实现的this blog,其中VAE损失是这样定义的:

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

我看了一下Keras documentation,VAE损失函数是这样定义的:在这个实现中,reconstruction_loss乘以original_dim,这在第一个实现中没有看到!

if args.mse:
        reconstruction_loss = mse(inputs, outputs)
    else:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)

有人能解释一下为什么吗?谢谢!

cotxawn7

cotxawn71#

first_one:CE + mean(kl, axis=-1) = CE + sum(kl, axis=-1) / d
第二个_一个:d * CE + sum(kl, axis=-1)
因此:first_one = second_one / d
注意,第二个返回所有样本的平均损失,但第一个返回所有样本的损失向量。

8i9zcol2

8i9zcol22#

在VAE中,重建损失函数可以表示为:

reconstruction_loss = - log(p ( x | z))

如果解码器输出分布被假设为高斯分布,则损失函数归结为MSE,因为:

reconstruction_loss = - log(p( x | z)) = - log ∏ ( N(x(i), x_out(i), sigma**2) = − ∑ log ( N(x(i), x_out(i), sigma**2) . alpha . ∑ (x(i), x_out(i))**2

相比之下,MSE损失的等式为:

L(x,x_out) = MSE = 1/m ∑ (x(i) - x_out(i)) **2

其中m是输出尺寸。例如,在MNIST中,m =宽×高×通道= 28 × 28 × 1 = 784
因此,

reconstruction_loss = mse(inputs, outputs)

应乘以m(即原始尺寸)以等于VAE公式中的原始重建损失。

vzgqcmou

vzgqcmou3#

在变分自动编码器(VAE)中,损失函数是负证据下限ELBO,它是两项之和:

# simplified formula
VAE_loss = reconstruction_loss + B*KL_loss

KL_loss也被称为regularization_loss。最初,B被设置为1.0,但它可以用作超参数,如beta-VAE(source 1source 2)。
当在图像上训练时,考虑输入Tensor的形状为(batch_size, height, width, channels)。然而,VAE_loss是一个标量值,它是沿着批量大小平均的,你应该对所有其他维度的损失函数求和。也就是说,您应该计算批次中每个训练样本的损失,以获得形状为(batch_size, )的向量,然后取平均值作为VAE_loss
使用均方误差时(MSE)或二进制交叉熵(BCE)计算重建损失,您得到的结果是平均值而不是总和。因此,您应该将结果乘以总维数,例如np.prod(INPUT_DIM),其中INPUT_DIM是输入Tensor的形状。但是,请注意,如果你忘记这样做并将重建损失作为BCE或MSE,你实际上是在VAE损失中为超参数B应用了一个较小的值,所以它可能会起作用。
例如,当您调用tensorflow二进制交叉熵损失函数时,它将计算此总和并除以项目数(check here for a detailed example):

这个公式中的n项将是沿着指定轴求和的项目数,而不是批次大小的数量。但是,您的损失应该是沿着所有维度的总和,对批次中的不同样本进行平均。在计算VAE损失时,您应该注意Tensor的形状。
让我们更详细地看看这一点:
KL_loss或正则化损失测量了潜在或编码变量的分布与假设的先验分布(通常是标准正态分布)之间的差异。您可以使用以下代码计算此ir:

from keras import backend as K

    # z_mean and z_log_var have shape: (batch_size, latent_dim)
    # Regularization loss or KL loss:
    regularization_loss = 1 + z_log_var - K.square(z_mean) -
                          K.exp(z_log_var)
    # After the sum, regularization loss has shape: (batch_size, )
    regularization_loss = -0.5 * K.sum(regularization_loss, axis=-1)

请注意,regularization_loss来自编码器:对于每个输入,编码器正在计算向量z_mean和z_log_var的值。正则化损失度量值与(mean=0,variance=1)的差异。在此损失中,您沿着潜在变量的维度求和:自动编码器中的潜在变量的维数越大,这种损失就越大。
重建损失采用不同的形式,因为它基于输出或预测变量的预期分布。来自original variational autoencoder paper中的附录C:
在变分自动编码器中,神经网络被用作概率编码器和解码器。根据数据和模型的类型,编码器和解码器有许多可能的选择。
您可以使用具有连续变量或二进制变量的变分自动编码器(VAE)。您需要对数据的分布进行一些假设,以便选择重建损失函数。设X为输入变量,并设m为其维度(对于MNIST图像,m = 28*28*1 = 784)。两个常见的假设是:

  • X是连续的:你可以假设输出是正态分布的(每个像素都是独立的),重建损失是L-2范数,即平方和,你可以计算为:m*MSE
  • X是一个二进制变量(例如0/1,根据输出层中的激活函数):你可以假设输出服从伯努利分布,重建损失为m*BCE

对于离散的情况,这段代码将工作:

INPUT_DIM= (28,28,1)
    # Reconstruction loss for binary variables, shape=(batch_size, )
    reconstruction_loss = keras.losses.binary_crossentropy(inputs,
                                                           outputs,
                                                           axis=[1,2,3])
    reconstruction_loss *= K.constant(np.prod(INPUT_DIM))

注意BCE是如何沿着轴1、2和3应用的,而不是沿轴0应用的,轴0是批次中的样本数量。
对于连续的情况,可能的代码是:

# Reconstruction loss for continuous variables, shape=(batch_size, )
    reconstruction_loss = K.mean(K.square(outputs - inputs), axis=[1,2,3])                                 
    reconstruction_loss *= K.constant(np.prod(INPUT_DIM))

请注意,Keras loss MeanSquaredError不接受axis参数,因此我们无法使用它来检索MSE。此外,您可以简单地用途:

reconstruction_loss = K.sum(K.square(outputs - inputs), axis=[1,2,3])

重建损失是所有维度沿着总和,这意味着数据的维度越大,这个总和就越大。也就是说,28 x28图像产生的重建损失比100 x100图像产生的重建损失要小。在实践中,您可能需要调整超参数B的值。
最后,您可以沿着样本求和并取平均值以获得VAE损失:

# Total VAE loss (-ELBO)
    VAE_loss= K.mean(reconstruction_loss +
                        regularization_loss*K.constant(B))

不同损失函数的更详细解释可以在here中找到。

相关问题