python Tensor与多个标量的逐元素乘法

hgncfbus  于 2023-06-04  发布在  Python
关注(0)|答案(1)|浏览(174)

假设以下输入:

x = torch.randint(1, 5, size=(2, 3, 3))
print(x.shape)
torch.Size([2, 3, 3])

我想用多个标量执行元素乘法。标量在此Tensor中可用:

weights = torch.tensor([2, 2, 2, 1])
print(weights.shape)
torch.Size([4])

所以,基本上,我需要4个操作:

result_1 = x * weights[0]
result_2 = x * weights[1]
result_3 = x * weights[2]
result_4 = x * weights[3]

打包在一个Tensor中。然而,单纯做

result = x * weights

将不起作用,因为尺寸不适合广播。我目前的解决方案相当丑陋,我认为效率不高:

x = x.unsqueeze(0).repeat_interleave(4, 0)
result = x * weights[:, None, None, None]

我在寻找更好的方法!

ivqmmu1c

ivqmmu1c1#

你的解决方案很好!如果你想要其他的东西,试试.view()方法。它重塑了Tensor:

x = x.view(1, *x.shape)  # Shape: (1, 2, 3, 3)
weights = weights.view(-1, 1, 1, 1)  # Shape: (4, 1, 1, 1)

如果你想使用.unsqueeze()方法,那么也许你可以这样写,但我还是更喜欢你的方法(我没有测试这个方法):

x = x.unsqueeze(0)  # Shape: (1, 2, 3, 3)
weights = weights.unsqueeze(1).unsqueeze(2).unsqueeze(3)  # Shape: (4, 1, 1, 1)

请让我知道它是如何进行的!

相关问题