假设以下输入:
x = torch.randint(1, 5, size=(2, 3, 3))
print(x.shape)
torch.Size([2, 3, 3])
我想用多个标量执行元素乘法。标量在此Tensor中可用:
weights = torch.tensor([2, 2, 2, 1])
print(weights.shape)
torch.Size([4])
所以,基本上,我需要4个操作:
result_1 = x * weights[0]
result_2 = x * weights[1]
result_3 = x * weights[2]
result_4 = x * weights[3]
打包在一个Tensor中。然而,单纯做
result = x * weights
将不起作用,因为尺寸不适合广播。我目前的解决方案相当丑陋,我认为效率不高:
x = x.unsqueeze(0).repeat_interleave(4, 0)
result = x * weights[:, None, None, None]
我在寻找更好的方法!
1条答案
按热度按时间ivqmmu1c1#
你的解决方案很好!如果你想要其他的东西,试试
.view()
方法。它重塑了Tensor:如果你想使用
.unsqueeze()
方法,那么也许你可以这样写,但我还是更喜欢你的方法(我没有测试这个方法):请让我知道它是如何进行的!