pytorch 使用nn.Identity进行残差学习的想法是什么?

jc3wubiy  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(213)

所以,我已经读了大约一半的原始ResNet论文,并试图弄清楚如何使我的版本的表格数据。
我读过一些关于PyTorch如何工作的博客文章,我看到大量使用nn.Identity()。现在,这篇论文也经常使用术语 identity mapping。然而,它只是指以元素方式将层堆栈的输入添加到同一堆栈的输出。如果输入和输出维度不同,然后,本文讨论了用零填充输入或使用矩阵W_s将输入投影到不同的维度。
以下是我在一篇博客文章中发现的残差块的摘要:

  1. class ResidualBlock(nn.Module):
  2. def __init__(self, in_channels, out_channels, activation='relu'):
  3. super().__init__()
  4. self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
  5. self.blocks = nn.Identity()
  6. self.shortcut = nn.Identity()
  7. def forward(self, x):
  8. residual = x
  9. if self.should_apply_shortcut: residual = self.shortcut(x)
  10. x = self.blocks(x)
  11. x += residual
  12. return x
  13. @property
  14. def should_apply_shortcut(self):
  15. return self.in_channels != self.out_channels
  16. block1 = ResidualBlock(4, 4)

字符串
我自己对虚拟Tensor的应用:

  1. x = tensor([1, 1, 2, 2])
  2. block1 = ResidualBlock(4, 4)
  3. block2 = ResidualBlock(4, 6)
  4. x = block1(x)
  5. print(x)
  6. x = block2(x)
  7. print(x)
  8. >>> tensor([2, 2, 4, 4])
  9. >>> tensor([4, 4, 8, 8])


最后,x = nn.Identity(x),我不确定它的用途,除了模仿原始论文中的数学术语,但我肯定不是这样的,它有一些隐藏的用途,我只是还没有看到。它会是什么?

EDIT这里是另一个实现剩余学习的例子,这次是在Keras中。它做了我上面建议的事情,只是保留了一个输入的副本,以便添加到输出中:

  1. def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
  2. y = Conv2D(kernel_size=kernel_size,
  3. strides= (1 if not downsample else 2),
  4. filters=filters,
  5. padding="same")(x)
  6. y = relu_bn(y)
  7. y = Conv2D(kernel_size=kernel_size,
  8. strides=1,
  9. filters=filters,
  10. padding="same")(y)
  11. if downsample:
  12. x = Conv2D(kernel_size=1,
  13. strides=2,
  14. filters=filters,
  15. padding="same")(x)
  16. out = Add()([x, y])
  17. out = relu_bn(out)
  18. return out

h22fl7wq

h22fl7wq1#

使用nn.Identity进行残差学习的想法是什么?
没有(几乎没有,见文章的结尾),nn.Identity所做的就是转发给它的输入(基本上是no-op)。
PyTorch repo issue中所示,您在评论中链接了这个想法,由于其他用途,这个想法首先被拒绝,后来合并到PyTorch中(参见理由in this PR)。这个理由**与ResNet块本身无关,参见答案的结尾。

ResNet实现

我能想到的最通用的投影版本是沿着这些行的东西:

  1. class Residual(torch.nn.Module):
  2. def __init__(self, module: torch.nn.Module, projection: torch.nn.Module = None):
  3. super().__init__()
  4. self.module = module
  5. self.projection = projection
  6. def forward(self, inputs):
  7. output = self.module(inputs)
  8. if self.projection is not None:
  9. inputs = self.projection(inputs)
  10. return output + inputs

字符串
你可以将两个堆叠卷积作为module传递,并添加1x1卷积(带填充或步幅或其他东西)作为投影模块。
对于tabular数据,您可以将其用作module(假设您的输入具有50功能):

  1. torch.nn.Sequential(
  2. torch.nn.Linear(50, 50),
  3. torch.nn.ReLU(),
  4. torch.nn.Linear(50, 50),
  5. torch.nn.ReLU(),
  6. torch.nn.Linear(50, 50),
  7. )


基本上,你所要做的就是将input添加到某个模块的输出中,就是这样。

nn.Identity原理

构造神经网络(并在之后阅读它们)可能更容易,例如批处理范数(取自上述PR):

  1. batch_norm = nn.BatchNorm2d
  2. if dont_use_batch_norm:
  3. batch_norm = Identity


现在您可以轻松地将其与nn.Sequential一起使用:

  1. nn.Sequential(
  2. ...
  3. batch_norm(N, momentum=0.05),
  4. ...
  5. )


当打印网络时,它总是有相同数量的子模块(使用BatchNormIdentity),这也使整个事情变得更流畅。
这里提到的另一个用例可能是删除现有神经网络的一部分:

  1. net = tv.models.alexnet(pretrained=True)
  2. # Assume net has two parts
  3. # features and classifier
  4. net.classifier = Identity()


现在,您可以运行net(input),而不是运行net.features(input),这可能也更容易让其他人阅读。

展开查看全部
ruyhziif

ruyhziif2#

nn.Identity()的一个很好的用法是在jit脚本编写过程中。在非常模块化的模型中,脚本将搜索每个if语句,并向前检查所有路径,即使在初始化过程中,if语句设置为false。

  1. class MyModule(nn.Module):
  2. def __init__(self, extra=false):
  3. self.conv = nn.conv2d(3,3)
  4. self.extra = extra
  5. if extra:
  6. self.extra_layer = nn.Conv2d(3, 3)
  7. def forward(self, x):
  8. x = self.conv(x)
  9. if self.extra:
  10. x = self.extra_layer(x)
  11. return x

字符串
此模块不能编写脚本,但您可以执行以下操作

  1. class MyModule(nn.Module):
  2. def __init__(self, extra=false):
  3. self.conv = nn.conv2d(3,3)
  4. self.extra = extra
  5. self.extra_layer = nn.Conv2d(3, 3) if extra else nn.Identity()
  6. def forward(self, x):
  7. x = self.conv(x)
  8. if self.extra:
  9. x = self.extra_layer(x)
  10. return x

展开查看全部

相关问题