我目前在torch中工作,对一些输入数据实现随机洗牌(在行上,在本例中是第一维)。我是torch的新手,所以我在弄清楚置换是如何工作的方面遇到了一些麻烦。
下面的代码将打乱数据:
if argshuffle then
local perm = torch.randperm(sids:size(1)):long()
print("\n\n\nSize of X and y before")
print(X:view(-1, 1000, 128):size())
print(y:size())
print(sids:size())
print("\nPerm size is: ")
print(perm:size())
X = X:view(-1, 1000, 128)[{{perm},{},{}}]
y = y[{{perm},{}}]
print(sids[{{1}, {}}])
sids = sids[{{perm},{}}]
print(sids[{{1}, {}}])
print(X:size())
print(y:size())
print(sids:size())
os.exit(69)
end
这会打印出来
Size of X and y before
99
1000
128
[torch.LongStorage of size 3]
99
1
[torch.LongStorage of size 2]
99
1
[torch.LongStorage of size 2]
Perm size is:
99
[torch.LongStorage of size 1]
5
[torch.LongStorage of size 1x1]
5
[torch.LongStorage of size 1x1]
99
1000
128
[torch.LongStorage of size 3]
99
1
[torch.LongStorage of size 2]
99
1
[torch.LongStorage of size 2]
从值中,我可以暗示函数没有对数据进行shuffle。我如何才能使它正确shuffle,lua/torch中常见的解决方案是什么?
4条答案
按热度按时间t5fffqht1#
我也遇到了类似的问题。在文档中,没有用于Tensor的shuffle函数(有用于数据集加载器的shuffle函数)。我用
torch.randperm
找到了解决这个问题的方法。我希望它能回答这个问题!
v1l68za42#
例如,如果Tensor的形状为CxNxF(通道、行和特征),则可以沿着第二维进行洗牌,如下所示:
tpgth1q73#
一个直接的解决方案是使用置换矩阵(线性代数中常见的矩阵)。由于您似乎对3d情况感兴趣,我们将不得不首先展平您的3dTensor。因此,这里是我想出的一个示例代码(即用型)
正如您将从一个执行中看到的,它根据
perm
中给定的行索引对Tensordata
重新排序。hfyxw5xn4#
根据您的语法,我假设您使用lua而不是PyTorch来进行torch。torch.Tensor.index是您的函数,它的工作方式如下: