下面是源代码,我用它加载一个.pth
文件并做一个多类图像分类预测。
model = Classifier() # The Model Class.
model.load_state_dict(torch.load('<PTH-FILE-HERE>.pth'))
model = model.to(device)
model.eval()
# prediction function to test images
def predict(img_path):
image = Image.open(img_path)
resize = transforms.Compose(
[ transforms.Resize((256,256)), transforms.ToTensor()])
image = resize(image)
image = image.to(device)
y_result = model(image.unsqueeze(0))
result_idx = y_result.argmax(dim=1)
print(result_idx)
我使用torch.onnx.export
将.pth
文件转换为ONNX文件。
现在,我怎样才能写一个预测脚本类似于上面的一个单独使用ONNX文件,而不使用.pth
文件。?有可能这样做吗?
1条答案
按热度按时间w8f9ii691#
您可以使用ONNX运行时。
您可以按照tutorial了解详细说明。
通常,使用onnx的目的是在不同的框架中加载模型并运行推理,例如PyTorch -〉ONNX -〉TensorRT。