pytorch torchsummary的summary函数打印其输出两次

mwngjboj  于 2023-02-04  发布在  其他
关注(0)|答案(1)|浏览(237)

所以我一直在尝试实现本文中提出的基于CNN的分类解决方案(https://arxiv.org/pdf/1810.08923.pdf)。下面是我的代码。这是一个相当简单的实现,但我不明白为什么torchsummary会产生这样的结果。我也浏览了他们的GitHub问答,但到目前为止也没有提出这样的问题。

class CNN_Pred2D(nn.Module):
    def __init__(self, n_filters=[8,8,8], debug=True):
        super().__init__()
        self.debug = debug
        
        self.model = nn.Sequential(
            nn.Conv2d(1, n_filters[0], kernel_size=(1,82)),
            nn.ReLU(),
            nn.Conv2d(n_filters[0], n_filters[0], kernel_size=(3,1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),
            
            nn.Conv2d(n_filters[0], n_filters[1], kernel_size=(3,1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),
            
            nn.Flatten(),
            nn.Linear(104,1),
            nn.Sigmoid()
        )

        
    def forward(self, X):
        out = self.model(X)
#         print(out.shape)
        return out

model = CNN_Pred2D().to(device)

summary(model, [(1, 60,82)])

下面是它的输出:

iq0todco

iq0todco1#

您可以尝试使用torchinfo代替torchsummary

from torchinfo import summary

summary(model, (1, 60,82))

相关问题