python 在矩阵乘法之前整形矩阵

kgqe7b3p  于 2023-03-07  发布在  Python
关注(0)|答案(1)|浏览(174)

我有一个使用Pytorch的"Transformer Neural Network"的forward步骤的片段。

其中:

  • 第一个月第一个月,第一个月第二个月,第一个月第三个月,第一个月第四个月
  • self.toprobs(x):具有输入/输出特征(k, num_vocabs)nn.Linear层。
def forward(self, x):
        tokens = self.token_embedding(x)
        b, t, k = tokens.size()

        x = self.transformer_block(tokens)

        x = x.view(b * t, k)
        x = self.toprobs(x)
        x = x.view(b, t, self.num_vocabs)

        output = F.log_softmax(x, dim=2)

        return output

已知:b = 2, t = 2, k = 3, self.num_vocabs = 256

  • xx = self.transformer_block(tokens)之后的输出形状为(2, 2, 3)
  • x重新整形为(b * t, k) -> (4, 3),然后通过self.toprobs(x)我得到了(4, 256),然后再次重新整形回(2, 2, 256)

问题:

  • 为什么x需要重新整形为(b * t, k)?如果我将x的形状保持在(2, 2, 3)并通过self.toprobs(x),我仍然会得到相同的结果并整形(2, 2, 256)
  • 在矩阵乘法步骤的加速或内存使用方面有什么好处吗?
def forward(self, x):
        tokens = self.token_embedding(x)
        b, t, k = tokens.size()

        x = self.transformer_block(tokens)

        # Same result without matrix reshape
        x = self.toprobs(x)

        output = F.log_softmax(x, dim=2)

        return output
2guxujil

2guxujil1#

我认为这只是一个特定的人的实现决策。也许他们觉得这样考虑维度更容易,因为他们可能觉得要转换的特定维度更明确。就我个人而言,我可能更喜欢不更改批处理维度。
视图实际上不会更改Tensor上的基础数据(Tensor的视图共享相同的存储空间)。nn.Linear的行为仅基于最终维度,在代码片段的视图操作中不会更改该维度(链接文档中的H)。根据文档,我怀疑是否存在任何性能差异,对我来说你会看到同样的结果是合理的。

相关问题