我正在尝试使用列上的逻辑索引来切片PyTorchTensor。我想要索引向量中对应于1值的列。切片和逻辑索引都是可能的,但是它们可以一起使用吗?如果是,如何做到?我的尝试不断抛出无用的错误
TypeError:使用ByteTensor类型的对象索引Tensor。唯一支持的类型是整数、切片、numpy标量和torch。LongTensor或torch。ByteTensor作为唯一参数。
MCVE
期望输出
import torch
C = torch.LongTensor([[1, 3], [4, 6]])
# 1 3
# 4 6
逻辑索引仅对列:
A_log = torch.ByteTensor([1, 0, 1]) # the logical index
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B[:, A_log] # Throws error
如果向量大小相同,则逻辑索引工作:
B_truncated = torch.LongTensor([1, 2, 3])
C = B_truncated[A_log]
我可以通过重复逻辑索引来获得所需的结果,使其具有与我索引的Tensor相同的大小,但之后我还必须重新调整输出。
C = B[A_log.repeat(2, 1)] # [torch.LongTensor of size 4]
C = C.resize_(2, 2)
我还尝试使用索引列表:
A_idx = torch.LongTensor([0, 2]) # the index vector
C = B[:, A_idx] # Throws error
如果我想要索引的连续范围,切片工作:
C = B[:, 1:2]
3条答案
按热度按时间luaexgnf1#
我想这是作为
index_select
函数实现的,可以试试7gcisfzg2#
在PyTorch 1中。5.0中,用作索引的Tensor必须是long、byte或boolTensor。
下面是一个索引作为长的Tensor。
下面是一个布尔(逻辑索引)的Tensor:
qnyhuwrf3#
我尝试了这段代码,并将结果作为注解写在它旁边。