assertionerror-在尝试实现新的优化器时

8xiog9wr  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(256)

我试图从pytorch optimizer github为我的模型实现新的优化器,但当我使用时,我得到了Assert错误。我不知道我哪里做错了
错误:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-98-44e1eacbb7ce> in <module>()
     69 
     70 p = get_params(OPT_OVER, net, net_input)
---> 71 optimize(OPTIMIZER, p, closure, LR, num_iter)

/content/utils/common_utils.py in optimize(optimizer_type, parameters, closure, LR, num_iter)
    230             optimizer.step()
    231 
--> 232     elif optimizer_type == 'A2GradExp':
    233         print('Starting optimization with A2GradExp')
    234         optimizer = optim.A2GradExp(parameters(), lr=LR)

AssertionError:


优化器代码:

def optimize(optimizer_type, parameters, closure, LR, num_iter):
    """Runs optimization loop.

    Args:
        optimizer_type: 'LBFGS' of 'adam'
        parameters: list of Tensors to optimize over
        closure: function, that returns loss variable
        LR: learning rate
        num_iter: number of iterations 
    """
    if optimizer_type == 'LBFGS':
        # Do several steps with adam first
        optimizer = torch.optim.Adam(parameters, lr=0.001)
        for j in range(100):
            optimizer.zero_grad()
            closure()
            optimizer.step()

        print('Starting optimization with LBFGS')        
        def closure2():
            optimizer.zero_grad()
            return closure()
        optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
        optimizer.step(closure2)

    elif optimizer_type == 'adam':
        print('Starting optimization with ADAM')
        optimizer = torch.optim.Adam(parameters, lr=LR)

        for j in range(num_iter):
            optimizer.zero_grad()
            closure()
            optimizer.step()

    elif optimizer_type == 'A2GradExp':
        print('Starting optimization with A2GradExp')
        optimizer = optim.A2GradExp(parameters, lr=LR)

        for j in range(num_iter):
            optimizer.zero_grad()
            closure()
            optimizer.step()
    else:
        assert False

我尝试过其他优化器,我的代码运行得非常好。但是当我使用这个链接中的一个优化器时,我得到了这个错误。我应该做哪些更改来停止此错误。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题