我尝试在一个 flask 应用程序中提供一个pytorch模型。这段代码在我之前在jupyter笔记本上运行时是有效的,但现在我在一个虚拟环境中运行它,显然它不能获得属性'Net',即使类定义就在那里。所有其他类似的问题告诉我在同一个脚本中添加保存的模型的类定义。但它仍然不起作用。torch版本是1.0.1(保存的模型和virtualenv都在这里训练)我做错了什么?
import os
import numpy as np
from flask import Flask, request, jsonify
import requests
import torch
from torch import nn
from torch.nn import functional as F
MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'
r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=-1)
model = torch.load('model.pth')
app = Flask(__name__)
@app.route("/")
def hello():
return "Binary classification example\n"
@app.route('/predict', methods=['GET'])
def predict():
x_data = request.args['x_data']
x_data = x_data.split()
x_data = list(map(float, x_data))
sample = np.array(x_data)
sample_tensor = torch.from_numpy(sample).float()
out = model(sample_tensor)
_, predicted = torch.max(out.data, -1)
if predicted.item() == 0:
pred_class = "Has no liver damage - ", predicted.item()
elif predicted.item() == 1:
pred_class = "Has liver damage - ", predicted.item()
return jsonify(pred_class)
以下是完整的回溯:
Traceback (most recent call last):
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
sys.exit(main())
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
cli.main(args=args, prog_name=name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
return super(FlaskGroup, self).main(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
rv = self.invoke(ctx)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
return ctx.invoke(f, obj, *args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
self._load_unlocked()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
self._app = rv = self.loader()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
app = locate_app(self, import_name, name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
__import__(module_name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
model = torch.load('model.pth')
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
return _load(f, map_location, pickle_module)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>
This没有解决我的问题。我不想改变我持久化模型的方式。torch.save()在虚拟环境之外对我来说工作得很好。我不介意将类定义添加到脚本中。尽管如此,我还是想看看是什么导致了错误。
6条答案
按热度按时间jm81lzqq1#
(This是部分答案)
我不认为
torch.save(model,'model.pt')
在命令提示符下工作,或者当模型从一个以'__main__'
运行的脚本保存并从另一个脚本加载时。原因是torch必须自动加载用于保存文件的模块,并且它从
__name__
获取模块名称。现在来看部分:目前还不清楚如何解决这个问题,特别是当您在混合中有virtualenv时。
感谢Jatentaki在这个方向开始对话。
sxpgvts32#
首先,我初始化了一个空模型,然后加载了保存的模型,这出于某种原因解决了问题。
lyr7nygr3#
我知道我回答这个问题已经晚了。但是我找到了一种方法,可以从另一个包而不是“main”加载模型
在加载模块之前,如果该属性是动态设置的,如下所示,它将起作用。
注意:如果“main”是二进制文件,则此黑客攻击无效。
ctrmrzij4#
解决这个问题的一个简单方法是,在加载模型之前,需要定义“class Net(nn.Module):“。
3j86kqsm5#
简单的解决方案:
1.您只需要创建
Net(nn.Module)
类的一个示例,如下所示,然后它就可以正常运行了。1.我也遇到过同样的问题,并通过这些简单的步骤解决了。
0lvr5msh6#
这可能不是一个很流行的答案,但是,我发现
dill
包在使我的代码工作方面是非常一致的。对我来说,我甚至没有试图加载一个模型,我试图解包一个自定义对象,帮助我的东西,但它不能找到它出于某种原因。我不知道为什么,但在我的经验中,dill似乎是一个更好的选择: