取消pickle保存的pytorch模型抛出AttributeError:尽管内联添加了类定义,但仍无法获取〈module '__main__'上的特性'Net'

h5qlskok  于 2022-12-13  发布在  其他
关注(0)|答案(6)|浏览(234)

我尝试在一个 flask 应用程序中提供一个pytorch模型。这段代码在我之前在jupyter笔记本上运行时是有效的,但现在我在一个虚拟环境中运行它,显然它不能获得属性'Net',即使类定义就在那里。所有其他类似的问题告诉我在同一个脚本中添加保存的模型的类定义。但它仍然不起作用。torch版本是1.0.1(保存的模型和virtualenv都在这里训练)我做错了什么?

import os
import numpy as np
from flask import Flask, request, jsonify 
import requests

import torch
from torch import nn
from torch.nn import functional as F

MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'

r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = torch.load('model.pth')

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():

    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

以下是完整的回溯:

Traceback (most recent call last):
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
    sys.exit(main())
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
    cli.main(args=args, prog_name=name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
    return super(FlaskGroup, self).main(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
    self._load_unlocked()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
    __import__(module_name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
    model = torch.load('model.pth')
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
    return _load(f, map_location, pickle_module)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>

This没有解决我的问题。我不想改变我持久化模型的方式。torch.save()在虚拟环境之外对我来说工作得很好。我不介意将类定义添加到脚本中。尽管如此,我还是想看看是什么导致了错误。

jm81lzqq

jm81lzqq1#

(This是部分答案)
我不认为torch.save(model,'model.pt')在命令提示符下工作,或者当模型从一个以'__main__'运行的脚本保存并从另一个脚本加载时。
原因是torch必须自动加载用于保存文件的模块,并且它从__name__获取模块名称。
现在来看部分:目前还不清楚如何解决这个问题,特别是当您在混合中有virtualenv时。
感谢Jatentaki在这个方向开始对话。

sxpgvts3

sxpgvts32#

首先,我初始化了一个空模型,然后加载了保存的模型,这出于某种原因解决了问题。

lyr7nygr

lyr7nygr3#

我知道我回答这个问题已经晚了。但是我找到了一种方法,可以从另一个包而不是“main”加载模型
在加载模块之前,如果该属性是动态设置的,如下所示,它将起作用。

import __main__
setattr(__main__, "Net", Net)
model = torch.load(os.path.join(parent_dir,"<path to pickle>"), map_location=torch.device("cpu"))

注意:如果“main”是二进制文件,则此黑客攻击无效。

ctrmrzij

ctrmrzij4#

解决这个问题的一个简单方法是,在加载模型之前,需要定义“class Net(nn.Module):“。

3j86kqsm

3j86kqsm5#

简单的解决方案:
1.您只需要创建Net(nn.Module)类的一个示例,如下所示,然后它就可以正常运行了。
1.我也遇到过同样的问题,并通过这些简单的步骤解决了。

import torch
from torch import nn
from torch.nn import functional as F

MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'

r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = Net()#<---------------------------- Extra thing added
model = torch.load('model.pth', , map_location=torch.device('cpu'))#<---- if running on a CPU, else 'cuda'

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():

    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)
0lvr5msh

0lvr5msh6#

这可能不是一个很流行的答案,但是,我发现dill包在使我的代码工作方面是非常一致的。对我来说,我甚至没有试图加载一个模型,我试图解包一个自定义对象,帮助我的东西,但它不能找到它出于某种原因。我不知道为什么,但在我的经验中,dill似乎是一个更好的选择:

# - path to files
    path = Path(path2dataset).expanduser()
    path2file_data_prep = Path(path2file_data_prep).expanduser()
    # - create dag dataprep obj
    print(f'path to data set {path=}')
    dag_prep = SplitDagDataPreparation(path)
    # - save data prep splits object
    print(f'saving to {path2file_data_prep=}')
    torch.save({'data_prep': dag_prep}, path2file_data_prep, pickle_module=dill)
    # - load the data prep splits object to test it loads correctly
    db = torch.load(path2file_data_prep, pickle_module=dill)
    db['data_prep']
    print(db)
    return path2file_data_prep

相关问题