从任意尺寸的PyTorchTensor的特定尺寸/轴获取最终值

os8fio9y  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(154)

假设我有一个PyTorchTensor,例如:

  1. import torch
  2. x = torch.randn([3, 4, 5])

字符串
我想得到一个新的Tensor,具有相同的维数,包含从维数1的最终值开始的所有内容。我可以这样做:

  1. x[:, -1:, :]


但是,如果x有任意数量的维度,我想从特定的维度中获得最终值,那么最好的方法是什么?

hmtdttj4

hmtdttj41#

可以使用index_select

  1. torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1))

字符串
输出Tensor将包含与输入相同数量的维度。您可以在dim上使用squeeze来消除额外的维度:

  1. torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1)).squeeze(dim=dim)

  • 注意:* 虽然select返回输入Tensor的 * 视图 *,但index_select返回新Tensor。
    示例:
  1. In [1]: dim = 1
  2. In [2]: x = torch.randn([3, 4, 5])
  3. In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shape
  4. Out[3]: torch.Size([3, 1, 5])
  5. In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shape
  6. Out[4]: torch.Size([3, 5])

展开查看全部
r6vfmomb

r6vfmomb2#

您可以使用select函数(或Tensor的等效方法),例如,

  1. dim = 1 # the dimension from which to extract the final values
  2. y = x.select(dim, -1).unsqueeze(dim)

字符串
其中unsqueeze用于保持与原始Tensor相同的维数。

相关问题