我的问题肯定有一个简单的答案,但我找不到它。我希望将MultiheadAttention
应用于同一个序列,而不复制该序列。我的数据是具有维度(批处理、时间、通道)的时态数据。我将“通道”维度视为嵌入,将时间维度视为序列维度。例如:
N, C, T = 2, 3, 5
n_heads = 7
X = torch.rand(N, T, C)
现在,我想将7个不同的头部作为自我注意力应用于同一个输入X
,但据我所知,它需要我复制数据7次:
attn = torch.nn.MultiheadAttention(C * n_heads, n_heads, batch_first=True)
X_ = X.repeat(1, 1, n_heads)
attn(X_, X_, X_)
有没有办法做到这一点,而不复制数据7次?谢谢!
1条答案
按热度按时间2uluyalo1#
在Multihead Attention的Pytorch实现中(以及我所知道的所有其他实现中),该类将在提供的
Queries
、Keys
和Values
Tensor上创建n_heads
不同的注意头,而不需要重复输入Tensor。每个 i 头都是使用相同的输入Tensor创建的。重复
X_
,虽然不会导致错误,但只会扩大注意力权重所需的Tensor大小,而不会提供任何好处。