在Pytorch中利用梯度下降法实现单位单形投影

bq3bfh9z  于 2022-12-23  发布在  其他
关注(0)|答案(1)|浏览(127)
    • bounty将在5天后过期**。回答此问题可获得+100的声誉奖励。Saeed正在寻找来自声誉良好来源的答案:请告诉我问题出在哪里,并修复没有得到概率向量的错误。

在博伊德教授homework solution投影到单位单纯形,他结束了以下方程:

g_of_nu = (1/2)*torch.norm(-relu(-(x-nu)))**2 + nu*(torch.sum(x) -1) - x.size()[0]*nu**2

如果计算nu*,则到单位单形的投影将是y*=relu(x-nu*1)
他的建议是找到g_of_nu,由于g_of_nu是严格凹的,我将它乘以一个负号(f_of_nu),然后使用梯度下降法找到它的全局最小值。

    • 问题**

最后一个向量y*,加起来不等于1,我做错了什么?

    • 复制代码**

x一个一个一个一个x一个一个二个x

    • 功能**
torch.manual_seed(1)
x = torch.randn(10)
nu = torch.linspace(-1, 1, steps=10000)

f = lambda x, nu: -( (1/2)*torch.norm(-relu(-(x-nu)))**2 + nu*(torch.sum(x) -1) - x.size()[0]*nu**2 )

f_value_list = np.asarray( [f(x, i) for i in nu.tolist()] )

i_min = np.argmin(f_value_list)
print(nu[i_min])

fig, ax = plt.subplots()

ax.plot(nu.cpu().detach().numpy(), f_value_list);

这是图中的最小值,它与梯度下降一致。

tensor(0.0665)

35g0bw71

35g0bw711#

误差来自公式的推导:

发件人:

如果你发展出这个表达式

你就会意识到

代替

简而言之,这个错误来自于在制定规范时忘记了1/2因素,一旦你做出了改变,一切都按预期进行:

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
x = torch.randn(10)

x_list = x.tolist()

nu_0 = torch.tensor(0., requires_grad = True)
nu = nu_0
optimizer = torch.optim.SGD([nu], lr=1e-1)

nu_old = torch.tensor(float('inf'))
steps = 100
eps = 1e-6
i = 1
while torch.norm(nu_old-nu) > eps:
  nu_old = nu.clone()
  optimizer.zero_grad()
  f_of_nu = -(0.5*torch.norm(-torch.relu(-(x-nu)))**2 + nu*(torch.sum(x) -1) -0.5*x.size()[0]*nu**2)
  f_of_nu.backward()
  optimizer.step()
  print(f'At step {i+1:2} the function value is {f_of_nu.item(): 1.4f} and nu={nu: 0.4f}' )
  i += 1

y_star = torch.relu((x-nu)).cpu().detach()
print(y_star)
print(list(map(lambda x: round(x, 4), y_star.tolist())))
print(y_star.sum())

输出为:

...
At step 25 the function value is -2.0721 and nu= 0.2328
tensor(0.2328, requires_grad=True)
tensor([0.4285, 0.0341, 0.0000, 0.3885, 0.0000, 0.0000, 0.0000, 0.1489, 0.0000,
        0.0000])
[0.4285, 0.0341, 0.0, 0.3885, 0.0, 0.0, 0.0, 0.1489, 0.0, 0.0]
tensor(1.0000)

相关问题