我想使用pytorch来优化一个目标函数,它使用了torch.autograd无法跟踪的操作。我用torch.autograd.Function类的一个自定义forward()封装了这样的操作(如建议的here和here)。由于我知道这样的操作的梯度,我也可以编写backward()。一切看起来像这样:
class Projector(torch.autograd.Function):
# non_torch_var are constant values needed by the operation
@staticmethod
def forward(ctx, vertices, non_torch_var1, non_torch_var2, non_torch_var3):
ctx.save_for_backward(vertices)
vertices2=vertices.detach().cpu().numpy()
ctx.non_torch_var1 = non_torch_var1
ctx.non_torch_var2 = non_torch_var2
ctx.non_torch_var3 = non_torch_var3
out = project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
out = torch.tensor(out, requires_grad=True)
return out
@staticmethod
def backward(ctx, grad_out):
vertices = ctx.saved_tensors[0]
vertices2 = vertices.detach().cpu().numpy()
non_torch_var1 = ctx.non_torch_var1
non_torch_var2 = ctx.non_torch_var2
non_torch_var3 = ctx.non_torch_var3
grad_vertices = grad_project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
grad_vertices = torch.tensor(grad_vertices, requires_grad=True)
return grad_vertices, None, None, None
不过,这种实现似乎不起作用。我使用torchviz包来绘制下面几行代码的内容
import torchviz
out = Projector.apply(*input)
grad_x, = torch.autograd.grad(out.sum(), vertices, create_graph=True)
torchviz.make_dot((grad_x, vertices, out), params={"grad_x": grad_x, "vertices": vertices, "out": out}).render("attached", format="png")
我得到了这个graph,它显示grad_x没有连接到任何东西。
你知道这样的代码出了什么问题吗?
1条答案
按热度按时间rta7y2nd1#
该图正确显示了如何根据
vertices
计算out
(在代码中似乎等于input
)。变量grad_x
正确地显示为断开连接,因为它没有用于计算out
。换句话说,out
不是grad_x
的函数。grad_x
断开并不意味着渐变不流动,也不意味着您的自定义backward
实现不工作。相反,图中存在从vertices
到out
的路径意味着梯度应该流动,即autograd引擎可以计算out
相对于vertices
的梯度。要检查自定义backward
实现的正确性,需要检查grad_x
的值是否正确。简而言之,梯度应该流动,因为存在从
vertices
到out
的路径,并且应该通过检查其值而不是通过查看计算图来验证其正确性。