在pytorch中取零函数的导数,给出运行时错误

x6yk4ghg  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(83)

我有以下简单代码:

def f(x):
    return x[:,0] + x[:,1]

def g(x):
    return torch.zeros_like(x[:,0])

def main():    
    x = torch.tensor([[0.3, 0.3],
        [0.6, 0.3],
        [0.3, 0.6],
        [0.6, 0.6]])
    x.requires_grad_()
    grad_myf = autograd.grad(outputs=f(x), inputs=x, grad_outputs=torch.ones_like(f(x)), create_graph=True, retain_graph=True, only_inputs=True)[0]
    print(grad_myf)

字符串
这会输出正确的结果:

tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])


现在我想对g函数求导。g函数只需要返回0,不管x的值是多少。所以它的导数应该是零。所以我写

grad_myg = autograd.grad(outputs=g(x), inputs=x, grad_outputs=torch.ones_like(g(x)), create_graph=True, retain_graph=True, only_inputs=True)[0]
    print(grad_myg)


然后我得到错误消息“RuntimeError:Tensor的元素0不需要grad,也没有grad_fn”。
为什么不起作用?我需要用另一种方式重新定义g吗?比如说

def g(x):
    return 0*x


确实有效但我不知道这是不是最好的办法我定义g的方法是自然的方法。

hm2xizp9

hm2xizp91#

得到RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn是因为Tensortorch.zeros_like(x[:,0])默认有requires_grad=False。如果将其更改为True,则会得到另一个错误:这是因为g的结果是一个全新的Tensor。这个Tensor不在构建的图中(它就像一个孤立的节点),torch无法计算它的梯度。0 * x可以工作,因为torch可以找到它的衍生物。

相关问题