Pytorch模型摘要- forward函数有多个参数,其中一个参数是整数

y1aodyip  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(175)

我正在使用torch summary。我想在打印模型摘要时传递多个参数,其中一个参数只是整数。但是,我得到了一个错误。我遵循this question建议,但它不起作用。
我的网络看起来像

import torch
from torch import nn
from torchsummary import summary

class SimpleNet(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x,t):
    return t * x

我尝试以summary(model,[(3, 64, 64),(1)])的形式运行summary,但得到的是TypeError: rand() argument after * must be an iterable, not int
通过执行summary(model,[(3, 64, 64),(1,)])“解决”了这个问题,但仍然得到了另一个TypeError: can't multiply sequence by non-int of type 'tuple'
那么,如何才能获得模型摘要?

hwamh0ep

hwamh0ep1#

(3, 64, 64),(1)更改为(3, 64, 64),(1,)有所帮助,因为list的所有元素都应该是Tuple类型。但是,当它试图向前运行时,它无法执行元组与元组的乘法(因为在任何python代码中都不可能)。
您可以从Tuple更改为torch.Tensorsummary(model,[torch.Tensor((3, 64, 64)),torch.Tensor((1,))])

相关问题