pytorch 总是使用torch.tensor或torch.FloatTensor安全吗?还是我需要小心处理Ints?

stszievb  于 2023-03-18  发布在  其他
关注(0)|答案(2)|浏览(182)

https://pytorch.org/docs/stable/tensors.html
我试图理解tensor, FloatTensor, IntTensor之间的区别-我想知道我是否可以一直坚持tensor...或者FloatTensor
我将使用不同Tensor的混合,它们将是:

{integers:labels, floats:continuous, one-hot-encoded:categoricals}

我需要显式地把这些变量设置为不同类型的Tensor吗?,它们都能作为浮点数工作吗?,它们能相互结合工作吗?
这会给我下游带来麻烦吗?
一个一个一个一个一个一个x一个一个二个一个x一个一个三个一个x一个一个x一个一个x一个一个

dw1jzc5e

dw1jzc5e1#

CrossEntropyLoss(或NLLLoss)要求target类型为Long。例如,下面的代码引发RuntimeError

import torch.nn 

criterion = torch.nn.CrossEntropyLoss()

predicted = torch.rand(10, 10, dtype=torch.float)
target = torch.rand(10) #.to(torch.long)
criterion(predicted, target)
# RuntimeError: expected scalar type Long but found Float

您需要取消对转换的注解才能使其工作。我想不出更大的问题,但为什么要麻烦首先将整数转换为浮点型呢?
关于torch.tensortorch.FloatTensor的使用,我更喜欢前者。torch.FloatTensor似乎是遗留构造函数,并且它不接受device作为参数。同样,我不认为这是一个大问题,但使用torch.tensor仍然增加了代码的可读性。

wpx232ag

wpx232ag2#

我已经能够在所有类型的dtype上使用FloatTensor
不能求整数Tensor的平均值,这是有道理的。

a = th.tensor([1,1])
b = th.tensor([2,2])

th.mean(th.stack([a,b]))

RuntimeError: Can only calculate the mean of floating types. Got Long instead.

相关问题