带梯度的Pytorch计数器

xlpyo6sf  于 2023-03-23  发布在  其他
关注(0)|答案(3)|浏览(185)

我尝试使用bincount从数组的一些索引的总和中获取梯度。然而,pytorch没有实现梯度。这可以通过循环和torch.sum实现,但它太慢了。在pytorch中可以有效地做到这一点吗(可能是einsum或index_add)?当然,我们可以循环遍历索引并逐个相加,然而,这将显著增加计算图的大小,并且性能非常低。

import torch
from torch import autograd
import numpy as np
tt = lambda x, grad=True: torch.tensor(x, requires_grad=grad)    
inds = tt([1, 5, 7, 1], False).long()
y = tt(np.arange(4) + 0.1).float()
sum_y_section = torch.bincount(inds, y * y, minlength=8)
#sum_y_section = torch.sum(y * y)
grad = autograd.grad(sum_y_section, y, create_graph=True, allow_unused=False)
print("sum_y_section", sum_y_section)
print("grad", grad)
fd3cxomn

fd3cxomn1#

我们可以使用Pytorch V1.11中的一个新特性scatter_reduce。

bincount = lambda inds, arr: torch.scatter_reduce(arr, 0, inds, reduce="sum")
6l7fqoea

6l7fqoea2#

我会尝试使用钩子来以自定义的方式操纵渐变

bgibtngc

bgibtngc3#

torch.scatter_reduce在Pytorch 1.13中有一个src位置参数。这个简单的例子演示了bincountscatter_reduce之间的等价性:

num_bins = 8
bins = torch.zeros(num_bins)

#generate indices and weights
num_indices = 100
indices = torch.randint(num_bins, size=(num_indices,))
weights = torch.rand(num_indices)

# Counting Indices

# with torch.bincount
counts1 = torch.bincount(indices, minlength=num_bins)

# with torch.scatter_reduce
counts2 = bins.scatter_reduce(0, indices, torch.ones(num_indices), reduce = 'sum')
print(counts1)
print(counts2)

# Binning Weights

# with torch.bincount
binned_wts1 =  indices.bincount(weights, minlength=num_bins)

# with torch.scatter_reduce
binned_wts2 = bins.scatter_reduce(0, indices, weights, reduce='sum')

print(binned_wts1)
print(binned_wts2)

相关问题