python Tensorflow中何时需要watch()函数来启用梯度跟踪?

bzzcjhmw  于 2022-12-21  发布在  Python
关注(0)|答案(1)|浏览(183)

我看到一些代码块需要tf.GradientTape().watch()才能工作,有些代码块似乎不需要它也能工作,这让我很困惑。
例如,这段代码需要watch()函数:

with tf.GradientTape() as t:
    # Record the actions performed on tensor x with `watch`
    t.watch(x) 

    # Define y as the sum of the elements in x
    y =  tf.reduce_sum(x)

    # Let z be the square of y
    z = tf.square(y) 

# Get the derivative of z wrt the original input tensor x
dz_dx = t.gradient(z, x)

但是,此块不会:

with tf.GradientTape() as tape:
    logits = model(images, images = True)
    loss_value = loss_object(labels, logits)
    

loss_history.append(loss_value.numpy.mean()
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradient(zip(grads, model.trainable_vaiables))

这两种情况有什么不同?

pu82cl6c

pu82cl6c1#

watch的文档说明如下:
确保此磁带正在跟踪tensor
默认情况下,在磁带环境中访问的任何可训练变量都会被监视。这意味着我们可以通过调用t.gradient(loss, variable)来计算该可训练变量的梯度。检查以下示例:

def grad(model, inputs, targets):
     with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets, training=True)
     return loss_value, tape.gradient(loss_value, model.trainable_variables)

因此在上面的代码中不需要使用tape.watch,但是有时我们需要计算一些不可训练变量的梯度,在这种情况下我们需要使用watch

with tf.GradientTape() as t:
     t.watch(images)
     predictions = cnn_model(images)
     loss = tf.keras.losses.categorical_crossentropy(expected_class_output, predictions)
gradients = t.gradient(loss, images)

在上面的代码中,images是模型的输入,不是可训练变量。我需要计算图像的损失梯度,因此我需要watch它。

相关问题