为什么在PyTorch中有两个不同的标志来禁用梯度计算

bsxbgnwa  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(162)

我是PyTorch的中级学习者,在最近的一些案例中,我看到人们使用torch.inference_mode()而不是著名的torch.no_grad()来验证强化学习中经过训练的代理(RL)实验。我查看了文档,他们有一个表,其中包含两个标记,用于禁用梯度计算。老实说,如果我读了描述,对我来说听起来完全一样。有人想出解释了吗?

vfh0ocws

vfh0ocws1#

所以我已经在网上搜索了几天,我想我得到了我的解释。torch.inference()模式已经被添加为一种更优化的使用PyTorch进行推理的方式(相对于torch.no_grad())。我听了PyTorch podcast,他们解释了为什么存在不同的标志。
1.**Tensor的版本控制:*比方说,你在PyTorch中有一个代码,你用它来训练一个代理。当你在训练的模型上运行torch.no_grad()和推理时,PyTorch仍然有一些功能,比如Tensor的版本计数,它仍然在起作用,每次创建Tensor时都会分配,当你改变特定的Tensor时会增加(版本碰撞)。检查所有Tensor的所有版本需要额外的计算成本,我们不能只是摆脱它们,因为我们必须留意Tensor的变化,要么(直接)到特定的Tensor,要么(间接)到其他Tensor的混叠,这是为了向后计算而保存的。
1.查看Tensor跟踪:Pytorch的Tensor是strided
,这意味着PyTorch在后端使用stride进行索引,如果你想直接访问内存块中的特定元素,就可以使用stride。但是在torch.autograd的情况下,如果你取一个Tensor并创建一个新的view,然后用一个与逆向计算相关的Tensor来改变它通过torch.no_grad,它们记录了一些view元数据,这些元数据是跟踪哪些Tensor需要梯度,哪些不需要梯度所必需的。这也增加了计算资源的额外开销。
因此,torch.autograd检查这些更改,当您切换到torch.inference_mode()(而不是torch.no_grad())时,这些更改不会被跟踪,如果您的代码没有利用上述两点,那么推理模式将工作,并减少代码执行时间。(PyTorch开发团队说,他们在Facebook的生产中部署模型时看到了5-10%的增长。

相关问题