我正试图了解参数(https://github.com/pytorch/pytorch/blob/329a9a90c0a579a3c67370702454f254109f1c9c/torch/nn/parameter.py)是如何在PyTorch中实现的。
我看到Parameter继承了Tensor,它允许您轻松执行Tensor操作,例如。如果p
是一个参数,那么你可以简单地调用p * 2
或p.tanh()
,而不需要检索底层的Tensor。
让我困惑的是Parameter如何存储其私有变量data
,因为它似乎指向一个不存在的父Tensor对象。
它只是一个引用self
?它是一个围绕子参数类的Tensor Package 器吗?它是一个属性,可以使用一些魔法来返回Tensor数据?
1条答案
按热度按时间bksxznpy1#
PyTorch的许多底层实现细节都使用了C++和PyObjects。您可以看到您引用的
__new__
中的代码返回torch.Tensor._make_subclass(cls, data, requires_grad)
。该实现可以在python_variable.cpp中找到,作为THPVariable_make_subclass
(https://github.com/pytorch/pytorch/blob/bb 6 b157458 a34 d8 b3499932035381 fdb 12683703/torch/csrc/autograd/python_variable.cpp#L567)。