如何获取pytorch的所有层在forward函数中使用?

hec6srdp  于 2023-06-06  发布在  其他
关注(0)|答案(1)|浏览(191)

我想得到pytorch的所有层,还有一个问题PyTorch get all layers of model,所有这些方法都在子对象或named_modules上迭代。
然而,当我试图使用它来获取resnet50的所有层时,我发现在Resnet中的BottleNeck的源代码中,只有一个relu层。但是这个relu层在forward函数中使用了三次。
我找到的所有方法都只能解析一个relu层,这不是我想要的。我期待着一种方法,让所有的层排序的前进顺序。

class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion: int = 4

def __init__(
    self,
    inplanes: int,
    planes: int,
    stride: int = 1,
    downsample: Optional[nn.Module] = None,
    groups: int = 1,
    base_width: int = 64,
    dilation: int = 1,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
    super().__init__()
    if norm_layer is None:
        norm_layer = nn.BatchNorm2d
    width = int(planes * (base_width / 64.0)) * groups
    # Both self.conv2 and self.downsample layers downsample the input when stride != 1
    self.conv1 = conv1x1(inplanes, width)
    self.bn1 = norm_layer(width)
    self.conv2 = conv3x3(width, width, stride, groups, dilation)
    self.bn2 = norm_layer(width)
    self.conv3 = conv1x1(width, planes * self.expansion)
    self.bn3 = norm_layer(planes * self.expansion)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

def forward(self, x: Tensor) -> Tensor:
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out
1yjd4xko

1yjd4xko1#

Relu不是一个层,因为它没有权重。但是对于网络检查的架构的显示:enter link description here

相关问题