我看到一些代码块需要tf.GradientTape().watch()才能工作,有些代码块似乎不需要它也能工作,这让我很困惑。
例如,这段代码需要watch()函数:
with tf.GradientTape() as t:
# Record the actions performed on tensor x with `watch`
t.watch(x)
# Define y as the sum of the elements in x
y = tf.reduce_sum(x)
# Let z be the square of y
z = tf.square(y)
# Get the derivative of z wrt the original input tensor x
dz_dx = t.gradient(z, x)
但是,此块不会:
with tf.GradientTape() as tape:
logits = model(images, images = True)
loss_value = loss_object(labels, logits)
loss_history.append(loss_value.numpy.mean()
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradient(zip(grads, model.trainable_vaiables))
这两种情况有什么不同?
1条答案
按热度按时间pu82cl6c1#
watch的文档说明如下:
确保此磁带正在跟踪
tensor
。默认情况下,在磁带环境中访问的任何可训练变量都会被监视。这意味着我们可以通过调用
t.gradient(loss, variable)
来计算该可训练变量的梯度。检查以下示例:因此在上面的代码中不需要使用
tape.watch
,但是有时我们需要计算一些不可训练变量的梯度,在这种情况下我们需要使用watch
。在上面的代码中,
images
是模型的输入,不是可训练变量。我需要计算图像的损失梯度,因此我需要watch
它。