我有一个resnet 50模型,它输出一个类预测(1,2或3)。根据分类器的输出,我想做另一个预测,根据类预测选择下一个模型。
这是我目前所知道的。
import torch
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.model1 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model1.weight)
self.model2 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model2.weight)
self.model3 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model3.weight)
def forward(self, x):
# Get batch_size
batch_size = x.size(1)
output = torch.zeros(batch_size, 1, device=x.device)
# Loop over every value in batch
for i in range(batch_size):
value = x[:, i]
if value == 1:
output[i] = self.model1(value)
elif value == 2:
output[i] = self.model2(value)
else:
output[i] = self.model3(value)
return output
model = SimpleModel()
output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output
我担心的是,我在循环的每次迭代中只计算一次向前传递,这似乎非常低效。如果我将批处理大小增加到64,会发生什么?前向传递是否会并行计算?
任何想法/想法都将受到赞赏。
1条答案
按热度按时间ymzxtsji1#
您可以按以下方式操作。代码只运行三个模型中的每一个一次,使用掩码作为条件,而不使用任何for循环: