如何在NumPy中进行分散和聚集操作?

n1bvdmb6  于 2023-10-19  发布在  其他
关注(0)|答案(8)|浏览(164)

我想在Numpy中实现Tensorflow或PyTorch的分散和聚集操作。

wn9m85ua

wn9m85ua1#

有两个内置的numpy函数可以满足您的要求:

  • 使用np.take_along_axis实现torch.gather
  • 使用np.put_along_axis实现torch.scatter
dnph8jn4

dnph8jn42#

scatter方法的工作量比我预期的要大得多。我在NumPy中没有找到任何现成的函数。我在这里分享它是为了任何可能需要使用NumPy实现它的人的利益。(p.s. self是方法的目的地或输出。

  1. def scatter_numpy(self, dim, index, src):
  2. """
  3. Writes all values from the Tensor src into self at the indices specified in the index Tensor.
  4. :param dim: The axis along which to index
  5. :param index: The indices of elements to scatter
  6. :param src: The source element(s) to scatter
  7. :return: self
  8. """
  9. if index.dtype != np.dtype('int_'):
  10. raise TypeError("The values of index must be integers")
  11. if self.ndim != index.ndim:
  12. raise ValueError("Index should have the same number of dimensions as output")
  13. if dim >= self.ndim or dim < -self.ndim:
  14. raise IndexError("dim is out of range")
  15. if dim < 0:
  16. # Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter
  17. dim = self.ndim + dim
  18. idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
  19. self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
  20. if idx_xsection_shape != self_xsection_shape:
  21. raise ValueError("Except for dimension " + str(dim) +
  22. ", all dimensions of index and output should be the same size")
  23. if (index >= self.shape[dim]).any() or (index < 0).any():
  24. raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)")
  25. def make_slice(arr, dim, i):
  26. slc = [slice(None)] * arr.ndim
  27. slc[dim] = i
  28. return slc
  29. # We use index and dim parameters to create idx
  30. # idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self
  31. idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1),
  32. index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])]
  33. idx = list(np.concatenate(idx, axis=1))
  34. idx.insert(dim, idx.pop())
  35. if not np.isscalar(src):
  36. if index.shape[dim] > src.shape[dim]:
  37. raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ")
  38. src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:]
  39. if idx_xsection_shape != src_xsection_shape:
  40. raise ValueError("Except for dimension " +
  41. str(dim) + ", all dimensions of index and src should be the same size")
  42. # src_idx is a NumPy advanced index for indexing of elements in the src
  43. src_idx = list(idx)
  44. src_idx.pop(dim)
  45. src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape)))
  46. self[idx] = src[src_idx]
  47. else:
  48. self[idx] = src
  49. return self

gather可能有一个更简单的解决方案,但这是我的解决方案:
(here self是从中收集值的ndarray。

  1. def gather_numpy(self, dim, index):
  2. """
  3. Gathers values along an axis specified by dim.
  4. For a 3-D tensor the output is specified by:
  5. out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
  6. out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
  7. out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
  8. :param dim: The axis along which to index
  9. :param index: A tensor of indices of elements to gather
  10. :return: tensor of gathered values
  11. """
  12. idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
  13. self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
  14. if idx_xsection_shape != self_xsection_shape:
  15. raise ValueError("Except for dimension " + str(dim) +
  16. ", all dimensions of index and self should be the same size")
  17. if index.dtype != np.dtype('int_'):
  18. raise TypeError("The values of index must be integers")
  19. data_swaped = np.swapaxes(self, 0, dim)
  20. index_swaped = np.swapaxes(index, 0, dim)
  21. gathered = np.choose(index_swaped, data_swaped)
  22. return np.swapaxes(gathered, 0, dim)
展开查看全部
xesrikrc

xesrikrc3#

scatter_nd操作可以使用*np*'s ufuncs .at函数来实现。
根据TF scatter_nd's文件:
调用tf.scatter_nd(indices, values, shape)与调用tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)完全相同。
因此,您可以使用应用于np.zeros阵列的np.add.at来重现tf.scatter_nd,请参阅下面的MVCE:

  1. import tensorflow as tf
  2. tf.enable_eager_execution() # Remove this line if working in TF2
  3. import numpy as np
  4. def scatter_nd_numpy(indices, updates, shape):
  5. target = np.zeros(shape, dtype=updates.dtype)
  6. indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
  7. updates = updates.ravel()
  8. np.add.at(target, indices, updates)
  9. return target
  10. indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
  11. updates = np.array([[1, 2], [3, 4]])
  12. shape = (2, 3)
  13. scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
  14. scattered_np = scatter_nd_numpy(indices, updates, shape)
  15. assert np.allclose(scattered_tf, scattered_np)

注意:正如@denis所指出的,当某些索引重复时,上述解决方案会有所不同,这可以通过使用计数器并仅获取每个重复索引的最后一个来解决。

展开查看全部
gopyfrb3

gopyfrb34#

对于分散,而不是使用切片赋值,如@DomJack所建议的,通常最好使用np.add.at;因为与切片赋值不同,这在存在重复索引的情况下具有定义良好的行为。

hivapdat

hivapdat5#

refindices是numpy数组:
散点更新:

  1. ref[indices] = updates # tf.scatter_update(ref, indices, updates)
  2. ref[:, indices] = updates # tf.scatter_update(ref, indices, updates, axis=1)
  3. ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2)
  4. ref[..., indices] = updates # tf.scatter_update(ref, indices, updates, axis=-1)

集合:

  1. ref[indices] # tf.gather(ref, indices)
  2. ref[:, indices] # tf.gather(ref, indices, axis=1)
  3. ref[..., indices, :] # tf.gather(ref, indices, axis=-2)
  4. ref[..., indices] # tf.gather(ref, indices, axis=-1)

numpy docs on indexing更多

6pp0gazn

6pp0gazn6#

我做的很像。

  1. def gather(a, dim, index):
  2. expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
  3. return a[expanded_index]
  4. def scatter(a, dim, index, b): # a inplace
  5. expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
  6. a[expanded_index] = b
lnxxn5zx

lnxxn5zx8#

如果您只是想要相同的功能,而不是从头开始实现它,
numpy.insert()是pytorch中scatter_(dim,index,src)操作的一个足够接近的竞争者,但它只处理一维。

相关问题