如何在pytorch中创建一个高效的条件层?

bsxbgnwa  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(117)

我有一个resnet 50模型,它输出一个类预测(1,2或3)。根据分类器的输出,我想做另一个预测,根据类预测选择下一个模型。
这是我目前所知道的。

import torch

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model1.weight)

        self.model2 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model2.weight)

        self.model3 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model3.weight)

    def forward(self, x):
        
        # Get batch_size
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # Loop over every value in batch
        for i in range(batch_size):
            value = x[:, i]
            if value == 1:
                output[i] = self.model1(value)
            elif value == 2:
                output[i] = self.model2(value)
            else:
                output[i] = self.model3(value)

        return output
model = SimpleModel()

output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output

我担心的是,我在循环的每次迭代中只计算一次向前传递,这似乎非常低效。如果我将批处理大小增加到64,会发生什么?前向传递是否会并行计算?
任何想法/想法都将受到赞赏。

ymzxtsji

ymzxtsji1#

您可以按以下方式操作。代码只运行三个模型中的每一个一次,使用掩码作为条件,而不使用任何for循环:

def forward(self, x):
        
        # Get batch_size
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # Compute one mask for each condition

        value_mask_1 = (x == 1)
        value_mask_2 = (x == 2)
        value_mask_3 = (x == 3)
        
        # Then just run the model on the items selected by each condition's mask.
        # And then assign model's outputs to the corresponding positions in the output variable.

        output[value_mask_1.view_as(output)] = self.model1(x[value_mask_1])
        output[value_mask_2.view_as(output)] = self.model1(x[value_mask_2])
        output[value_mask_3.view_as(output)] = self.model1(x[value_mask_3])

        return output

相关问题