我正在使用torch summary。我想在打印模型摘要时传递多个参数,其中一个参数只是整数。但是,我得到了一个错误。我遵循this question建议,但它不起作用。
我的网络看起来像
import torch
from torch import nn
from torchsummary import summary
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x,t):
return t * x
我尝试以summary(model,[(3, 64, 64),(1)])
的形式运行summary,但得到的是TypeError: rand() argument after * must be an iterable, not int
。
通过执行summary(model,[(3, 64, 64),(1,)])
“解决”了这个问题,但仍然得到了另一个TypeError: can't multiply sequence by non-int of type 'tuple'
。
那么,如何才能获得模型摘要?
1条答案
按热度按时间hwamh0ep1#
从
(3, 64, 64),(1)
更改为(3, 64, 64),(1,)
有所帮助,因为list的所有元素都应该是Tuple
类型。但是,当它试图向前运行时,它无法执行元组与元组的乘法(因为在任何python代码中都不可能)。您可以从
Tuple
更改为torch.Tensor
summary(model,[torch.Tensor((3, 64, 64)),torch.Tensor((1,))])