给定Tensor***A***shape(d0,d1,...,dn,dn +1)和排序索引Tensor***I***shape(d0,d1,...,dn),我想使用I中的排序索引重新排序A的索引。
Tensor***A***和***I***的前n维相等,TensorA的第(n +1)维可以是任意大小。
示例
鉴于***A***和***I***:
>>> A.shape
torch.Size([8, 8, 4])
>>> A
tensor([[[5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
[9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
[8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
[4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
[4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
[9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02],
[5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
[7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01]],
...])
>>> I.shape
torch.Size([8, 8])
>>> I
tensor([[2, 7, 4, 6, 1, 3, 0, 5],
...])
重新排序后,A的倒数第二个维度的元素应如下所示:
>>> A
tensor([[[8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
[7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01],
[4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
[5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
[9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
[4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
[5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
[9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02]],
...])
为简单起见,我只包括Tensor***A***和***I***的第一行。
溶液
基于公认的答案,我实现了一个广义版本,它可以对任何数量或维度(d0,d1,...,dn,dn +1,dn +2,...,dn + k)的任何Tensor进行排序,给定排序索引(d0,d1,...,dn)的Tensor。
下面是代码片段:
import torch
from torch import LongTensor, Tensor
def sort_by_indices(values: Tensor, indices: LongTensor) -> Tensor:
num_dims = indices.dim()
new_shape = tuple(indices.shape) + tuple(
1
for _ in range(values.dim() - num_dims)
)
repeats = tuple(
1
for _ in range(num_dims)
) + tuple(values.shape[num_dims:])
repeated_indices = indices.reshape(*new_shape).repeat(*repeats)
return torch.gather(values, num_dims - 1, repeated_indices)
1条答案
按热度按时间c9qzyr3d1#
您可以使用
torch.gather
,但需要如下所示重新调整和tile
索引:输出: