pytorch 如何求取基于导数的损耗的网络参数梯度?

46scxncf  于 2023-03-08  发布在  其他
关注(0)|答案(1)|浏览(93)

我有一个网络y(x),它有一个给定的数据集dy(x),也就是说,我知道y对某个x的导数,但不知道y本身,举一个最小的例子:

import torch

# Define network to predict y(x)
network = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1)
)

# Define dataset dy(x) = x, which corresponds to y = 0.5x^2
x = torch.linspace(0,1,100).reshape(-1,1)
dy = x

# Calculate loss based on derivative of prediction for y
x.requires_grad=True
y_pred = network(x)
dy_pred = torch.autograd.grad(y_pred, x, grad_outputs=torch.ones_like(y_pred), create_graph=True)[0]
loss = torch.mean((dy-dy_pred)**2)

# This throws an error
gradients = torch.autograd.grad(loss, network.parameters())[0]

在最后一行,它给出了错误One of the differentiated Tensors appears to not have been used in the graph,尽管参数已经被明确地用于计算损失。有趣的是,当我使用torch.optim.Adam对损失与loss.backward(),没有错误发生。我如何修复我的错误?注意:定义一个网络来直接预测dy不是我的实际问题的一个选项

gcxthw6b

gcxthw6b1#

我发现问题在于输出神经元的偏差不会导致损失,因此在图中没有使用它,因此,我们必须通过将bias=False传递到最后一层来消除这个偏差参数。

# Define network to predict y(x)
network = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1, bias=False)
)

相关问题