我尝试使用bincount从数组的一些索引的总和中获取梯度。然而,pytorch没有实现梯度。这可以通过循环和torch.sum实现,但它太慢了。在pytorch中可以有效地做到这一点吗(可能是einsum或index_add)?当然,我们可以循环遍历索引并逐个相加,然而,这将显著增加计算图的大小,并且性能非常低。
import torch
from torch import autograd
import numpy as np
tt = lambda x, grad=True: torch.tensor(x, requires_grad=grad)
inds = tt([1, 5, 7, 1], False).long()
y = tt(np.arange(4) + 0.1).float()
sum_y_section = torch.bincount(inds, y * y, minlength=8)
#sum_y_section = torch.sum(y * y)
grad = autograd.grad(sum_y_section, y, create_graph=True, allow_unused=False)
print("sum_y_section", sum_y_section)
print("grad", grad)
3条答案
按热度按时间fd3cxomn1#
我们可以使用Pytorch V1.11中的一个新特性scatter_reduce。
6l7fqoea2#
我会尝试使用钩子来以自定义的方式操纵渐变
bgibtngc3#
torch.scatter_reduce
在Pytorch 1.13中有一个src
位置参数。这个简单的例子演示了bincount
和scatter_reduce
之间的等价性: