pytorch 如何使用1DTensor而不是标量切片2DTensor

mmvthczy  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(111)

通常,您可以像这样对2DTensor进行切片slice = t[:, :k],其中k是整数。有没有可能这样做,但是k是一个一维的整数向量,每一行都有我想得到的项数?
用0或NaN屏蔽这些项也可以。
举例来说:

  1. k = torch.Tensor([1,2,3])
  2. t = torch.Tensor([1,1,1], [2,2,2], [3,3,3])
  3. # perform some operations and the result should be
  4. # 1 - -
  5. # 2 2 -
  6. # 3 3 3

字符串

emeijp43

emeijp431#

我想我是通过使用一个由unsqueeze创建的遮罩来做到这一点的。

  1. k = torch.Tensor([1,3,2])
  2. t = torch.Tensor([[1,1,1], [2,2,2], [3,3,3]])
  3. range_tensor = torch.arange(1, t.size(1)+1)
  4. mask = range_tensor > k.unsqueeze(1)
  5. t[mask] = 0 # Or NaN or whatever
  6. # t is now equal to
  7. # tensor([[1., 0., 0.],
  8. # [2., 2., 2.],
  9. # [3., 3., 0.]])

字符串

相关问题