我有两个PyTorchTensor,一个是3阶,另一个是4阶,有没有办法得到第一个Tensor的阶和形状?例如,在这个引起注意的部分:
q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3, 8)
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3, 8)
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)
k = k.permute(0, 3, 2, 1)
attn = torch.einsum("nchw,nwhu->nchu", q, k)
# Below is what doesn't work. I would like to get it such that hidden_states is a tensor of rank 2, 4, 24
hidden_states = torch.einsum("chw,whu->chu", attn, v)
有没有一个置换/转置可以应用于q、k、v或attn,使我可以乘成(2,4,24)?我还没有找到。
我目前收到此错误:“运行时错误:einsum():等式(3)中下标的数目与操作数0的维数(4)不匹配,并且没有给出省略号”,所以我想知道在这种情况下如何使用省略号,如果这是一个解决方案的话。
任何解释为什么这是或不可能也将是一个例外的答案!
1条答案
按热度按时间g6ll5ycj1#
看起来你的
q
和k
是形状为batch
-channel
-height
-width
的四维Tensor(2 × 4 × 3 × 8)。然而,当考虑注意机制时,人们忽略了特征的空间排列,而仅仅将它们视为“特征袋,”即,而不是形状为2x 4x 3x 8x 1 m6n1x和k
,您应该具有2x 4x 24: