我对以下代码片段有一些疑问:
>>> def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
apply()是pytorch.nn包的一部分。你可以在这个包的文档中找到代码。最后一个问题:1.为什么这个代码示例可以工作,尽管当它被赋予apply()时,init_weights()没有添加参数或括号?2.当它作为apply()的参数而没有括号和m时,函数init_weights(m)从哪里获得参数m?
2条答案
按热度按时间ubof19bj1#
我们在
torch.nn.Module.apply(
fn)
的文档中找到您问题的答案:递归地将
fn
应用于每个子模块(由.children()返回)以及self。典型的用法包括初始化模型的参数(另请参见torch-nn-init)。init_weights
在apply
调用之前没有被调用,这是因为没有圆括号,而是将对init_weights
的引用赋予apply
,并且只有在apply
内部之后才调用init_weights
。apply
中的每次调用中都得到它的参数,并且,正如文档所述,由于方法调用net.apply(…)
,它被调用m迭代net
的每个子模块(在本例中)以及net
本身。o4tp2gmn2#
[文档]
def apply(self: T, fn: Callable[['Module'], None]) -> T:
从https://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module.apply阅读有关apply的源代码,它被称为fn(self)last