我正在尝试使用PyTorch的ModuleList并使用批处理来训练它。
如果我正确理解了PyTorch通常的做法,那么使用下面的模板集为给定的模型/网络编写代码,编写一个forward
函数,并且PyTorch只在我们运行模型执行model(..)
时才处理初始批处理维度。
import torch as T
import torch.nn as nn
N = 10 # number of elements in ModuleList
H = 2 # input dimension
B = 5 # batch size
class MyModel(nn.Module):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.list_of_nets = nn.ModuleList([nn.Linear(H, H) for i in range(N)])
def forward(self, i, x):
return self.list_of_nets[i](x)
但是,如果我尝试在批量数据上运行此命令,我会得到TypeError: only integer tensors of a single element can be converted to an index
类型的错误
model = MyModel()
idx = T.randint(0, N, (B,))
x_input = T.rand((B, H))
# both give me the TypeError
model(idx, x_input)
model(idx.reshape(B, 1), x_input)
# this is fine, as expected
model(idx[0], x_input[0])
我检查了我的idx
输入的类型是整数(当我只取第一个数据点时,它确实有效),所以这不是问题的根源。
我哪里做错了?
谢谢!
2条答案
按热度按时间km0tfn4u1#
nn.ModuleList
不是这样工作的。它本质上是nn.Module
的 * 列表 *。它的__getitem__
需要一个整数,就像你在第三条语句中所做的那样。你通常会遇到的是使用列表解析。类似这样:
然后你可以用一个tensor调用forward函数来索引你的模块列表:
9njqaruj2#
ModuleList与批处理没有任何关系。下面是一个如何使用ModuleList的示例:
如果你想批量加载数据,你应该使用DataLoader,然后使用for循环来加载。