关于Pytorch LSTM代码片段的查询

k4ymrczo  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(118)

在堆栈溢出线程How can i add a Bi-LSTM layer on top of bert model?中,有一行代码:

hidden = torch.cat((lstm_output[:,-1, :256],lstm_output[:,0, 256:]),dim=-1)

有没有人能解释一下为什么最后一个和第一个令牌的连接,而不是任何其他的?这两个令牌会包含什么,他们被选中?

bvn4nwqk

bvn4nwqk1#

在双向模型中,隐藏状态在每一步都被连接起来;因此,该行基本上将正方向上的最后一个隐藏状态的前:256个单元(-1)连接到负方向上的最后一个隐藏状态的最后256:个单元(0)。这样的位置包含输入序列的最“有趣”的 * 摘要 *。
我已经写了a longer and detailed answer,关于如何在PyTorch中为递归模块构造隐藏状态。

相关问题