为什么pytorch的nn.ModuleList不支持向量化?下面的代码是否有性能问题?

qyuhtwio  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(210)

我的模型中有很多子网络,对于不同的样本,我想用不同的子网络来做一些计算。经过搜索,我发现了以下方法。为什么pytorch的nn.ModuleList不支持向量化?这种方法是否有性能问题?

import torch
import torch.nn as nn

class SingleVariableNetwork(nn.Module):
    def __init__(self, init_value):
        super(SingleVariableNetwork, self).__init__()
        self.v = torch.tensor([init_value], dtype=torch.int32) 
    def forward(self, x):
        return self.v * x

class IndexedNetwork(nn.Module):
    def __init__(self, networks):
        super(IndexedNetwork, self).__init__()
        self.networks = networks

    def forward(self, x, network_indices):
        outputs = []
        for i in range(x.size(0)):
            network_input = x[i]
            # only integer tensors of a single element can be converted to an index
            network_output = self.networks[network_indices[i]](network_input)
            outputs.append(network_output)
        return torch.stack(outputs)

networks = nn.ModuleList([SingleVariableNetwork(i) for i in range(5)]) 
indexedNetwork = IndexedNetwork(networks)
input = torch.tensor([1, 1, 1, 1, 1])
indices = torch.tensor([3, 0, 2, 1, 4])
result = indexedNetwork(input, indices)
print(result)

字符串
我试着像下面这样写IndexedNetwork类的forward函数:return self.experts[expert_index](expert_input),但是pytorch用消息“只有单个元素的整数Tensor可以转换为索引”引发错误。

baubqpgj

baubqpgj1#

你不能沿着nn.ModuleList广播输入。模块列表只是一个python列表,带有额外的参数跟踪功能。
如果你的子网络都是相同的架构,你应该能够在Tensor操作级别对输入进行向量化。例如,将self.v参数更改为向量(表示子模型),将输入x也更改为向量。

相关问题