python 将'torch.nn.MultiheadAttention'的标头应用于同一输入

wn9m85ua  于 2023-02-02  发布在  Python
关注(0)|答案(1)|浏览(424)

我的问题肯定有一个简单的答案,但我找不到它。我希望将MultiheadAttention应用于同一个序列,而不复制该序列。我的数据是具有维度(批处理、时间、通道)的时态数据。我将“通道”维度视为嵌入,将时间维度视为序列维度。例如:

N, C, T = 2, 3, 5
n_heads = 7
X = torch.rand(N, T, C)

现在,我想将7个不同的头部作为自我注意力应用于同一个输入X,但据我所知,它需要我复制数据7次:

attn = torch.nn.MultiheadAttention(C * n_heads, n_heads, batch_first=True)
X_ = X.repeat(1, 1, n_heads)
attn(X_, X_, X_)

有没有办法做到这一点,而不复制数据7次?谢谢!

2uluyalo

2uluyalo1#

在Multihead Attention的Pytorch实现中(以及我所知道的所有其他实现中),该类将在提供的QueriesKeysValuesTensor上创建n_heads不同的注意头,而不需要重复输入Tensor。每个 i 头都是使用相同的输入Tensor创建的。
重复X_,虽然不会导致错误,但只会扩大注意力权重所需的Tensor大小,而不会提供任何好处。

相关问题