python 如何通过nn.ModuleList运行成批数据

k4aesqcs  于 2023-04-04  发布在  Python
关注(0)|答案(2)|浏览(141)

我正在尝试使用PyTorch的ModuleList并使用批处理来训练它。
如果我正确理解了PyTorch通常的做法,那么使用下面的模板集为给定的模型/网络编写代码,编写一个forward函数,并且PyTorch只在我们运行模型执行model(..)时才处理初始批处理维度。

import torch as T
import torch.nn as nn

N = 10 # number of elements in ModuleList
H = 2  # input dimension
B = 5  # batch size

class MyModel(nn.Module):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)

        self.list_of_nets = nn.ModuleList([nn.Linear(H, H) for i in range(N)])

    def forward(self, i, x):
        return self.list_of_nets[i](x)

但是,如果我尝试在批量数据上运行此命令,我会得到TypeError: only integer tensors of a single element can be converted to an index类型的错误

model = MyModel()
idx = T.randint(0, N, (B,))
x_input = T.rand((B, H))

# both give me the TypeError
model(idx, x_input)
model(idx.reshape(B, 1), x_input)

# this is fine, as expected
model(idx[0], x_input[0])

我检查了我的idx输入的类型是整数(当我只取第一个数据点时,它确实有效),所以这不是问题的根源。
我哪里做错了?
谢谢!

km0tfn4u

km0tfn4u1#

nn.ModuleList不是这样工作的。它本质上是nn.Module的 * 列表 *。它的__getitem__需要一个整数,就像你在第三条语句中所做的那样。
你通常会遇到的是使用列表解析。类似这样:

def forward(self, i, x):
    out = torch.stack([self.list_of_nets[_i](x) for _i in i])
    return out

然后你可以用一个tensor调用forward函数来索引你的模块列表:

>>> model(idx, x_input) # shaped (B,B,H)
9njqaruj

9njqaruj2#

ModuleList与批处理没有任何关系。下面是一个如何使用ModuleList的示例:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

如果你想批量加载数据,你应该使用DataLoader,然后使用for循环来加载。

相关问题