pytorch 不同大小或排列的形状的同一性

qnakjoqk  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(104)

我有两个PyTorchTensor,一个是3阶,另一个是4阶,有没有办法得到第一个Tensor的阶和形状?例如,在这个引起注意的部分:

q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3, 8)
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3, 8)
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)

k = k.permute(0, 3, 2, 1)
attn = torch.einsum("nchw,nwhu->nchu", q, k)

# Below is what doesn't work. I would like to get it such that hidden_states is a tensor of rank 2, 4, 24

hidden_states = torch.einsum("chw,whu->chu", attn, v)

有没有一个置换/转置可以应用于q、k、v或attn,使我可以乘成(2,4,24)?我还没有找到。
我目前收到此错误:“运行时错误:einsum():等式(3)中下标的数目与操作数0的维数(4)不匹配,并且没有给出省略号”,所以我想知道在这种情况下如何使用省略号,如果这是一个解决方案的话。
任何解释为什么这是或不可能也将是一个例外的答案!

g6ll5ycj

g6ll5ycj1#

看起来你的qk是形状为batch-channel-height-width的四维Tensor(2 × 4 × 3 × 8)。然而,当考虑注意机制时,人们忽略了特征的空间排列,而仅仅将它们视为“特征袋,”即,而不是形状为2x 4x 3x 8x 1 m6n1x和k,您应该具有2x 4x 24:

q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)

attn = torch.einsum("bcn,bcN->bnN", q, k)

# it is customary to convert the raw attn into probabilities using softmax

attn = torch.softmax(attn, dim=-1)
hidden_states = torch.einsum("bnN,bcN->bcn", attn, v)

相关问题