pytorch 如何显示梯度下降期间错误的演变?

edqdpe6u  于 2022-11-29  发布在  其他
关注(0)|答案(1)|浏览(145)

我希望能够显示梯度下降过程中错误的演变。
我的代码很简单,我试图根据梯度下降来估计一个线性函数,我重复了几次。
我在PyTorch上工作,在包文档中寻找了一段时间的解决方案后,我没有找到任何非常确定的东西。
我的代码看起来像这样:

x <- df_tensor[,4,drop=FALSE]
y <- df_tensor[,1,drop=FALSE]

my_model <- nn_linear(1,1)

ds <- tensor_dataset(x, y)
dl <- dataloader(ds)

optimiser <- optim_sgd(my_model$parameters, lr = 0.01)
loss <- nnf_mse_loss

for (e in 1:10) {
coro::loop(for (b in dl) {
y_pred <- my_model(b[[1]])

c_loss <- loss(y_pred, b[[2]])
 
optimiser$zero_grad()
c_loss$backward()
optimiser$step()
})
}
1hdlvixo

1hdlvixo1#

我对PyTorch不太清楚,但是optim from base R可以通过trace参数输出步骤的细节。

# example from optim documentation with detailed output

fr <- function(x) {   ## Rosenbrock Banana function
  x1 <- x[1]
  x2 <- x[2]
  100 * (x2 - x1 * x1)^2 + (1 - x1)^2
}

optim(c(-1.2,1), fr, method = "BFGS", control = list(trace = 5, REPORT = 1))
#> initial  value 24.200000 
#> iter   2 value 20.227228
#> iter   3 value 8.606643
#> iter   4 value 3.122992
#> iter   5 value 2.830573
#> iter   6 value 2.634590
#> iter   7 value 2.006884
#> iter   8 value 1.890085
#> iter   9 value 1.520074
#> iter  10 value 1.370040
#> iter  11 value 1.173091
#> iter  12 value 0.916511
#> iter  13 value 0.866076
#> iter  14 value 0.747186
#> iter  15 value 0.607219
#> iter  16 value 0.432657
#> iter  17 value 0.355806
#> iter  18 value 0.279255
#> iter  19 value 0.199113
#> iter  20 value 0.132618
#> iter  21 value 0.114741
#> iter  22 value 0.066574
#> iter  23 value 0.066238
#> iter  24 value 0.047027
#> iter  25 value 0.036728
#> iter  26 value 0.026579
#> iter  27 value 0.013324
#> iter  28 value 0.013253
#> iter  29 value 0.004388
#> iter  30 value 0.001800
#> iter  31 value 0.000365
#> iter  32 value 0.000254
#> iter  33 value 0.000124
#> iter  34 value 0.000021
#> iter  35 value 0.000005
#> iter  36 value 0.000000
#> iter  37 value 0.000000
#> iter  38 value 0.000000
#> iter  38 value 0.000000
#> iter  38 value 0.000000
#> final  value 0.000000 
#> converged
#> $par
#> [1] 0.9998044 0.9996084
#> 
#> $value
#> [1] 3.827383e-08
#> 
#> $counts
#> function gradient 
#>      118       38 
#> 
#> $convergence
#> [1] 0
#> 
#> $message
#> NULL

创建于2022年11月16日,使用reprex v2.0.2

相关问题