如何在PyTorch中实现Parameter?

dzhpxtsq  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(154)

我正试图了解参数(https://github.com/pytorch/pytorch/blob/329a9a90c0a579a3c67370702454f254109f1c9c/torch/nn/parameter.py)是如何在PyTorch中实现的。
我看到Parameter继承了Tensor,它允许您轻松执行Tensor操作,例如。如果p是一个参数,那么你可以简单地调用p * 2p.tanh(),而不需要检索底层的Tensor。
让我困惑的是Parameter如何存储其私有变量data,因为它似乎指向一个不存在的父Tensor对象。
它只是一个引用self?它是一个围绕子参数类的Tensor Package 器吗?它是一个属性,可以使用一些魔法来返回Tensor数据?

bksxznpy

bksxznpy1#

PyTorch的许多底层实现细节都使用了C++和PyObjects。您可以看到您引用的__new__中的代码返回torch.Tensor._make_subclass(cls, data, requires_grad)。该实现可以在python_variable.cpp中找到,作为THPVariable_make_subclass(https://github.com/pytorch/pytorch/blob/bb 6 b157458 a34 d8 b3499932035381 fdb 12683703/torch/csrc/autograd/python_variable.cpp#L567)。

相关问题