pytorch 数据累积usign .backward()函数

t1qtbnec  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(221)

首先,考虑一下Xueyter Notebook中的以下代码块

  1. #Block1
  2. import torch
  1. #Block2
  2. x = torch.tensor([1.], requires_grad=True)
  1. #Block3
  2. y = x**10
  3. y.backward()
  4. print(x.grad)

然后考虑以下代码(将块2和块3组合在一个单元中):

  1. #Block4
  2. x = torch.tensor([1.], requires_grad=True)
  3. y = x**10
  4. y.backward()
  5. print(x.grad)

每次我运行#Block3时,我都会得到打印(x.grad)的累积结果,即10、20、30等。
但是通过运行#Block4,答案总是10
这背后的原因是什么?

ar5n3qh5

ar5n3qh51#

案例一:当你看到不同的渐变值时

这是因为当你调用.grad时,它会迭代地累积值,即
第一次呼叫:

  1. >>> grad = None

第二次通话:

  1. >>> grad = 10 + grad
  2. >>> grad = 20

第三通电话:

  1. >>> grad = 20 + grad
  2. >>> grad = 30

这还在继续。。

案例二:当你看到相同的渐变值时

在第二种情况下,每次运行编程块时都会创建一个新的Tensorx。新的内存将被分配给这些Tensor,grads将被初始化为0。在这种情况下,你总是得到10。
我希望能帮上忙。谢谢你,谢谢

展开查看全部

相关问题