tensorflow Shape=()打印时包含多个值的向量

eqqqjvef  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(161)

我使用Tensorflow来开发VAE,所以我使用的成本是ELBO(证据下界)。为了将误差应用到梯度,在最后一步中必须使用reduce_mean(),以便成本函数返回标量。

def vae_cost(x_true, model, analytic_kl=False, kl_weight=4):
  x_true = tf.cast(x_true, tf.float32)
  z_sample, mu, sd = model.encode(x_true)
  x_recons_logits = model.decoder(z_sample)
  # compute mean squared error
  recons_error = tf.cast(
      tf.reduce_mean((x_true - x_recons_logits) ** 2, axis=[1, 2, 3]),
      tf.float32)
  # compute reverse KL divergence, either analytically 
  # or through MC approximation with one sample
  if analytic_kl:
    kl_divergence = -0.5 * tf.math.reduce_sum(
        1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
        axis=1) # shape=(batch_size,)
  else:
    log_pz = normal_log_pdf(z_sample, 0., 1.) # shape=(batch_size,)
    logqz_x = normal_log_pdf(z_sample, mu, tf.math.square(sd))
    kl_divergence = logqz_x - log_pz

  elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
  return -elbo

(Note:这是我从here中提取的代码,几乎没有修改)
模型训练完美;从这个意义上说没有问题。我有问题的是 * 打印 * 错误的事实。我对tensorflow的内部工作原理知之甚少,但我知道你不能使用python的内置函数print(),因为如果我没弄错的话,它会打印计算图。因此,tf.print()似乎是解决方案。但控制台中显示的不是单个值,而是:

2.72147369
2.37455082
3.83512926
2.00962853
2.3469491
3.15436459
2.25914431
2.40686131
2.98925924
2.75991917
1.94956458
3.1419673
2.06783676
2.53439474
2.18458319
2.31454301
1.79345393
1.81354737
2.27693963
1.60603094
2.71092319
1.90332329
2.64296
1.94370067
2.07476187
2.32125258

然后,如果我使用python的print()

<tf.Tensor 'Neg:0' shape=() dtype=float32>

如果向量有shape=(),那么如何用tf.print()得到这么多的值呢?我是否真的混淆了这个函数是如何工作的?在这种情况下,我实际上如何打印错误呢?如果你能解释一下"Neg:0"的含义,我将不胜感激。提前感谢。

ruyhziif

ruyhziif1#

tf.print()的输出是一个值列表,每个值对应于输入Tensor中的一个元素。在本例中,输入Tensor是一个shape=()的向量,因此输出是一个值列表,每个值对应于输入向量中的一个元素。
下面是修改后的代码:

def vae_cost(x_true, model, analytic_kl=False, kl_weight=4):
  x_true = tf.cast(x_true, tf.float32)
  z_sample, mu, sd = model.encode(x_true)
  x_recons_logits = model.decoder(z_sample)
  # compute mean squared error
  recons_error = tf.cast(
      tf.reduce_mean((x_true - x_recons_logits) ** 2, axis=[1, 2, 3]),
      tf.float32)
  # compute reverse KL divergence, either analytically 
  # or through MC approximation with one sample
  if analytic_kl:
    kl_divergence = -0.5 * tf.math.reduce_sum(
        1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
        axis=1) # shape=(batch_size,)
  else:
    log_pz = normal_log_pdf(z_sample, 0., 1.) # shape=(batch_size,)
    logqz_x = normal_log_pdf(z_sample, mu, tf.math.square(sd))
    kl_divergence = logqz_x - log_pz

  elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
  return -elbo, kl_divergence, recons_error

相关问题