假设我有一个PyTorchTensor,例如:
import torchx = torch.randn([3, 4, 5])
import torch
x = torch.randn([3, 4, 5])
字符串我想得到一个新的Tensor,具有相同的维数,包含从维数1的最终值开始的所有内容。我可以这样做:
x[:, -1:, :]
型但是,如果x有任意数量的维度,我想从特定的维度中获得最终值,那么最好的方法是什么?
x
hmtdttj41#
可以使用index_select:
index_select
torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1))
字符串输出Tensor将包含与输入相同数量的维度。您可以在dim上使用squeeze来消除额外的维度:
dim
squeeze
torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1)).squeeze(dim=dim)
型
select
In [1]: dim = 1In [2]: x = torch.randn([3, 4, 5])In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shapeOut[3]: torch.Size([3, 1, 5])In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shapeOut[4]: torch.Size([3, 5])
In [1]: dim = 1
In [2]: x = torch.randn([3, 4, 5])
In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shape
Out[3]: torch.Size([3, 1, 5])
In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shape
Out[4]: torch.Size([3, 5])
r6vfmomb2#
您可以使用select函数(或Tensor的等效方法),例如,
dim = 1 # the dimension from which to extract the final valuesy = x.select(dim, -1).unsqueeze(dim)
dim = 1 # the dimension from which to extract the final values
y = x.select(dim, -1).unsqueeze(dim)
字符串其中unsqueeze用于保持与原始Tensor相同的维数。
unsqueeze
2条答案
按热度按时间hmtdttj41#
可以使用
index_select
:字符串
输出Tensor将包含与输入相同数量的维度。您可以在
dim
上使用squeeze
来消除额外的维度:型
select
返回输入Tensor的 * 视图 *,但index_select
返回新Tensor。示例:
型
r6vfmomb2#
您可以使用
select
函数(或Tensor的等效方法),例如,字符串
其中
unsqueeze
用于保持与原始Tensor相同的维数。