我试图从pytorch optimizer github为我的模型实现新的优化器,但当我使用时,我得到了Assert错误。我不知道我哪里做错了
错误:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-98-44e1eacbb7ce> in <module>()
69
70 p = get_params(OPT_OVER, net, net_input)
---> 71 optimize(OPTIMIZER, p, closure, LR, num_iter)
/content/utils/common_utils.py in optimize(optimizer_type, parameters, closure, LR, num_iter)
230 optimizer.step()
231
--> 232 elif optimizer_type == 'A2GradExp':
233 print('Starting optimization with A2GradExp')
234 optimizer = optim.A2GradExp(parameters(), lr=LR)
AssertionError:
优化器代码:
def optimize(optimizer_type, parameters, closure, LR, num_iter):
"""Runs optimization loop.
Args:
optimizer_type: 'LBFGS' of 'adam'
parameters: list of Tensors to optimize over
closure: function, that returns loss variable
LR: learning rate
num_iter: number of iterations
"""
if optimizer_type == 'LBFGS':
# Do several steps with adam first
optimizer = torch.optim.Adam(parameters, lr=0.001)
for j in range(100):
optimizer.zero_grad()
closure()
optimizer.step()
print('Starting optimization with LBFGS')
def closure2():
optimizer.zero_grad()
return closure()
optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
optimizer.step(closure2)
elif optimizer_type == 'adam':
print('Starting optimization with ADAM')
optimizer = torch.optim.Adam(parameters, lr=LR)
for j in range(num_iter):
optimizer.zero_grad()
closure()
optimizer.step()
elif optimizer_type == 'A2GradExp':
print('Starting optimization with A2GradExp')
optimizer = optim.A2GradExp(parameters, lr=LR)
for j in range(num_iter):
optimizer.zero_grad()
closure()
optimizer.step()
else:
assert False
我尝试过其他优化器,我的代码运行得非常好。但是当我使用这个链接中的一个优化器时,我得到了这个错误。我应该做哪些更改来停止此错误。
暂无答案!
目前还没有任何答案,快来回答吧!