pytorch Torch :如何按行重排Tensor?

nwo49xxi  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(366)

我目前在torch中工作,对一些输入数据实现随机洗牌(在行上,在本例中是第一维)。我是torch的新手,所以我在弄清楚置换是如何工作的方面遇到了一些麻烦。
下面的代码将打乱数据:

if argshuffle then 
    local perm = torch.randperm(sids:size(1)):long()
    print("\n\n\nSize of X and y before")
    print(X:view(-1, 1000, 128):size())
    print(y:size())
    print(sids:size())
    print("\nPerm size is: ")
    print(perm:size())
    X = X:view(-1, 1000, 128)[{{perm},{},{}}]
    y = y[{{perm},{}}]
    print(sids[{{1}, {}}])
    sids = sids[{{perm},{}}]
    print(sids[{{1}, {}}])
    print(X:size())
    print(y:size())
    print(sids:size())
    os.exit(69)
end

这会打印出来

Size of X and y before 
99 
1000
128
[torch.LongStorage of size 3]

99 
1
[torch.LongStorage of size 2]

99 
1
[torch.LongStorage of size 2]

Perm size is: 
99 
[torch.LongStorage of size 1]
5
[torch.LongStorage of size 1x1]
5
[torch.LongStorage of size 1x1]

99 
1000
128
[torch.LongStorage of size 3]

99 
1
[torch.LongStorage of size 2]

99 
1
[torch.LongStorage of size 2]

从值中,我可以暗示函数没有对数据进行shuffle。我如何才能使它正确shuffle,lua/torch中常见的解决方案是什么?

t5fffqht

t5fffqht1#

我也遇到了类似的问题。在文档中,没有用于Tensor的shuffle函数(有用于数据集加载器的shuffle函数)。我用torch.randperm找到了解决这个问题的方法。

>>> a=torch.rand(3,5)
>>> print(a)
tensor([[0.4896, 0.3708, 0.2183, 0.8157, 0.7861],
        [0.0845, 0.7596, 0.5231, 0.4861, 0.9237],
        [0.4496, 0.5980, 0.7473, 0.2005, 0.8990]])
>>> # Row shuffling
... 
>>> a=a[torch.randperm(a.size()[0])]
>>> print(a)
tensor([[0.4496, 0.5980, 0.7473, 0.2005, 0.8990],
        [0.0845, 0.7596, 0.5231, 0.4861, 0.9237],
        [0.4896, 0.3708, 0.2183, 0.8157, 0.7861]])
>>> # column shuffling
... 
>>> a=a[:,torch.randperm(a.size()[1])]
>>> print(a)
tensor([[0.2005, 0.7473, 0.5980, 0.8990, 0.4496],
        [0.4861, 0.5231, 0.7596, 0.9237, 0.0845],
        [0.8157, 0.2183, 0.3708, 0.7861, 0.4896]])

我希望它能回答这个问题!

v1l68za4

v1l68za42#

dim = 0
idx = torch.randperm(t.shape[dim])

t_shuffled = t[idx]

例如,如果Tensor的形状为CxNxF(通道、行和特征),则可以沿着第二维进行洗牌,如下所示:

dim=1
idx = torch.randperm(t.shape[dim])

t_shuffled = t[:,idx]
tpgth1q7

tpgth1q73#

一个直接的解决方案是使用置换矩阵(线性代数中常见的矩阵)。由于您似乎对3d情况感兴趣,我们将不得不首先展平您的3dTensor。因此,这里是我想出的一个示例代码(即用型)

data=torch.floor(torch.rand(5,3,2)*100):float()
reordered_data=data:view(5,-1)

perm=torch.randperm(5);
perm_rep=torch.repeatTensor(perm,5,1):transpose(1,2)

indexes=torch.range(1,5);
indexes_rep=torch.repeatTensor(indexes,5,1)

permutation_matrix=indexes_rep:eq(perm_rep):float()
permuted=permutation_matrix*reordered_data

print("perm")
print(perm)
print("before permutation")
print(data)
print("after permutation")
print(permuted:view(5,3,2))

正如您将从一个执行中看到的,它根据perm中给定的行索引对Tensordata重新排序。

hfyxw5xn

hfyxw5xn4#

根据您的语法,我假设您使用lua而不是PyTorch来进行torch。torch.Tensor.index是您的函数,它的工作方式如下:

x = torch.rand(4, 4)
p = torch.randperm(4)
print(x)
print(p)
print(x:index(1,p:long())

相关问题