Paddle 多机增量训练load_persistables 出错

swvgeqrz  于 2021-11-30  发布在  Java
关注(0)|答案(9)|浏览(481)

参考官网手册进行多机训练http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/save_load_variables.html
单机的时候使用default的program载入之前保存的模型没有问题,可以继续训练:

main_program = fluid.default_main_program()
         exe = fluid.Executor(place)

         startup_prog = fluid.default_startup_program()
         exe.run(startup_prog)

         fluid.io.load_persistables(exe, 'thirdparty/continue_model-pass-0', startup_prog)

但是用多机的get_pserver_program则报错:

logger.info("run pserver with continue model")
            prog = t.get_pserver_program(current_endpoint)
            startup = t.get_startup_program(current_endpoint, pserver_program=prog)
            exe.run(startup)
            fluid.io.load_persistables(exe, args.continue_model_path, startup)
            exe.run(prog)
            logger.info("pserver starting")

在运行到fluid.io.load_persistables(exe, args.continue_model_path, startup) 的时候报错:

File "train.py", line 269, in <module>
    train()
  File "train.py", line 254, in train
    fluid.io.load_persistables(exe, args.continue_model_path, startup)
  File "/home/work/anaconda2/lib/python2.7/site-packages/paddle/fluid/io.py", line 503, in load_persistables
    filename=filename)
  File "/home/work/anaconda2/lib/python2.7/site-packages/paddle/fluid/io.py", line 377, in load_vars
    filename=filename)
  File "/home/work/anaconda2/lib/python2.7/site-packages/paddle/fluid/io.py", line 387, in load_vars
    new_var = _clone_var_in_block_(load_block, each_var)
  File "/home/work/anaconda2/lib/python2.7/site-packages/paddle/fluid/io.py", line 85, in _clone_var_in_block_
    lod_level=var.lod_level,
  File "/home/work/anaconda2/lib/python2.7/site-packages/paddle/fluid/framework.py", line 418, in lod_level
    return self.desc.lod_level()
paddle.fluid.core.EnforceNotMet: Getting 'lod_level' is not supported by the type of var seg_lr_Factors@GRAD. at [/paddle/paddle/fluid/framework/var_desc.cc:173]
PaddlePaddle Call Stacks: 
0       0x7fa5a5f0eab6p paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 486
1       0x7fa5a5fb3b42p paddle::framework::VarDesc::GetLoDLevel() const + 162
2       0x7fa5a5f705cfp void pybind11::cpp_function::initialize<pybind11::cpp_function::initialize<int, paddle::framework::VarDesc, , pybind11::name, pybind11::is_method, pybind11::sibling>(int (paddle::framework::VarDesc::*)() const, pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::VarDesc const*)#1}, int, paddle::framework::VarDesc const*, pybind11::name, pybind11::is_method, pybind11::sibling>(pybind11::cpp_function::initialize<int, paddle::framework::VarDesc, , pybind11::name, pybind11::is_method, pybind11::sibling>(int (paddle::framework::VarDesc::*)() const, pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::VarDesc const*)#1}&&, int (*)(paddle::framework::VarDesc const*), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) + 143
3       0x7fa5a5f43cb4p pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 2596
4       0x7fa5cb840eecp PyEval_EvalFrameEx + 33468
5       0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
6       0x7fa5cb7cafdap
7       0x7fa5cb7a6773p PyObject_Call + 67
8       0x7fa5cb7a685cp
9       0x7fa5cb7a6952p PyObject_CallFunction + 146
10      0x7fa5cb7e2c94p _PyObject_GenericGetAttrWithDict + 180
11      0x7fa5cb83d9cbp PyEval_EvalFrameEx + 19867
12      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
13      0x7fa5cb83f482p PyEval_EvalFrameEx + 26706
14      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
15      0x7fa5cb83f482p PyEval_EvalFrameEx + 26706
16      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
17      0x7fa5cb83f482p PyEval_EvalFrameEx + 26706
18      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
19      0x7fa5cb83f482p PyEval_EvalFrameEx + 26706
20      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
21      0x7fa5cb83f482p PyEval_EvalFrameEx + 26706
22      0x7fa5cb8424e9p PyEval_EvalCodeEx + 2025
23      0x7fa5cb84270ap PyEval_EvalCode + 26
24      0x7fa5cb85b93dp
25      0x7fa5cb85cab8p PyRun_FileExFlags + 120
26      0x7fa5cb85dcd8p PyRun_SimpleFileExFlags + 232
27      0x7fa5cb86fd3cp Py_Main + 2988
28      0x7fa5caa8e445p __libc_start_main + 245
29      0x560601c1e87fp

为什么用单机的程序可以load 模型继续训练,用pserver的方法不可以??

x33g5p2x

x33g5p2x1#

补充下,seg_lr_Factors 是一个embedding

holgip5t

holgip5t2#

您好,问题已经收到,会尽快排查原因

tkclm6bt

tkclm6bt3#

pserver的startup program和单机的不太一样,我给你找个例子

ddarikpa

ddarikpa4#

seg_lr_Factors@GRAD 这个是梯度,为啥会被load回来? 单机是怎么save的模型,存下来的内容截个图看看?

ymzxtsji

ymzxtsji5#

fluid.io.save_persistables(exe, model_dir, train_program)
就用的这个语句存的

w8biq8rn

w8biq8rn6#

请问这个有下文么

pjngdqdw

pjngdqdw7#

几个问题需要确认一下:

  1. 现在用的paddle fluid 版本是多少
  2. 是否有用到distributed lookup table
yxyvkwin

yxyvkwin8#

load这里有一些问题,在最新的develop分支中已经修复, 这里提供两种方法:

  1. 自定义一个load_persistables的方法:
def _load_persistable_vars(executor, dirname, program):
    def _is_checkpoint_var(var):
        """
        the checkpoint will not save or load all the variables.
        var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

        : param var(Variable)
        """
        if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
                var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
                var.desc.type() == core.VarDesc.VarType.RAW:
            return False
        # @GRAD are named for gradient variables, checkpoint will not save it.
        if "@GRAD" in var.name:
            return False
        # .trainer_ are named for distribute train variables, checkpoint will not save it.
        if ".trainer_" in var.name:
            return False

        # .block is named for distribute train variables, checkpoint will not save it.
        if ".block" in var.name:
            return False

        if "tmp_" in var.name:
            return False

        return var.persistable

    print("Start Load persistable vars from {}, time = {}".format(dirname, time.time()))

    io.load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)

    print("Finish Load persistable vars from {}, time = {}".format(dirname, time.time()))
  1. 用最新的develop分支的代码编译, 可以修复该问题
xytpbqjk

xytpbqjk9#

  1. fluid 1.0
  2. 没有用到

我试试,谢谢

相关问题