pytorch 如何实现torch.autograd.函数的自定义前进后退函数?

l5tcr1uw  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(171)

我想使用pytorch来优化一个目标函数,它使用了torch.autograd无法跟踪的操作。我用torch.autograd.Function类的一个自定义forward()封装了这样的操作(如建议的herehere)。由于我知道这样的操作的梯度,我也可以编写backward()。一切看起来像这样:

class Projector(torch.autograd.Function):

    # non_torch_var are constant values needed by the operation
    @staticmethod
    def forward(ctx, vertices, non_torch_var1, non_torch_var2, non_torch_var3):

        ctx.save_for_backward(vertices)
        vertices2=vertices.detach().cpu().numpy()
        ctx.non_torch_var1  = non_torch_var1 
        ctx.non_torch_var2  = non_torch_var2  
        ctx.non_torch_var3  = non_torch_var3 
        out = project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
        out = torch.tensor(out, requires_grad=True)
        return out

    @staticmethod
    def backward(ctx, grad_out):
        vertices  = ctx.saved_tensors[0]
        vertices2 = vertices.detach().cpu().numpy()
        non_torch_var1 = ctx.non_torch_var1
        non_torch_var2 = ctx.non_torch_var2 
        non_torch_var3 = ctx.non_torch_var3

        grad_vertices = grad_project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
        grad_vertices = torch.tensor(grad_vertices, requires_grad=True)
        return grad_vertices, None, None, None

不过,这种实现似乎不起作用。我使用torchviz包来绘制下面几行代码的内容

import torchviz
out = Projector.apply(*input)
grad_x, = torch.autograd.grad(out.sum(), vertices, create_graph=True)
torchviz.make_dot((grad_x, vertices, out), params={"grad_x": grad_x, "vertices": vertices, "out": out}).render("attached", format="png")

我得到了这个graph,它显示grad_x没有连接到任何东西。
你知道这样的代码出了什么问题吗?

rta7y2nd

rta7y2nd1#

该图正确显示了如何根据vertices计算out(在代码中似乎等于input)。变量grad_x正确地显示为断开连接,因为它没有用于计算out。换句话说,out不是grad_x的函数。grad_x断开并不意味着渐变不流动,也不意味着您的自定义backward实现不工作。相反,图中存在从verticesout的路径意味着梯度应该流动,即autograd引擎可以计算out相对于vertices的梯度。要检查自定义backward实现的正确性,需要检查grad_x的值是否正确。
简而言之,梯度应该流动,因为存在从verticesout的路径,并且应该通过检查其值而不是通过查看计算图来验证其正确性。

相关问题