找到不会导致损失的PyTorch模型参数

9rnv2umw  于 2023-06-23  发布在  其他
关注(0)|答案(2)|浏览(149)

在PyTorch(v1.10)Distibuted DataParallel中,模型中未使用的参数不会导致最终损失,可能会引发RuntimeError(如this other questionthis PyTorch forums thread中所述)。

  • “运行时错误:在开始新的迭代之前,预期已在前一个迭代中完成了减少。此错误表明您的模块具有未用于产生损失的参数。您可以通过将关键字参数find_unused_parameters=True传递给torch.nn.parallel.DistributedDataParallel并确保所有forward函数输出都参与计算损失来启用未使用参数检测。"*

虽然可以检查哪些参数在错误时受到影响(如上所述,或设置env var TORCH_DISTRIBUTED_DEBUG="INFO"),但似乎应该有一种方法来静态检查模型以定位(并可能修剪或禁用梯度)对当前损失目标没有贡献的参数?
因此,给定一个基于torch.nn.Modulemodel,其forward()函数返回一些lossTensor(可能与其他Tensor一起)-我们如何在开始训练之前以编程方式找到所有对loss没有贡献的参数(包括嵌套模块)?

xuo3flqw

xuo3flqw1#

默认情况下,作为某些计算结果的PyTorchTensor记录了它们的历史,即它们的祖先。这是向后传递计算梯度所需的。
我们可以利用这一点,通过遍历整个历史,找到所有对一些新Tensor有贡献的Tensor。
请注意,这适用于始终具有相同架构的静态网络。只要你有条件句,例如。依赖于一些中间值,这是行不通的,我认为在这种情况下,不可能提前找到Tensor。(这与停机问题类似。)

import torch
import torch.nn as nn
# Example of a simple network
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.x = nn.Parameter(torch.tensor([999999.0]))  # not contributing
        self.layers = nn.ModuleList([nn.Sequential(nn.Linear(1, 4), nn.Linear(4, 1)) for _ in range(3)])
    def forward(self, x):
        for m in self.layers: x = m(x) + x
        return x

net = Net()
x = torch.ones((1, 1))
# compute the forward pass to create the computation graph
y = net(x)

# use computation graph to find all contributing tensors
def get_contributing_params(y, top_level=True):
    nf = y.grad_fn.next_functions if top_level else y.next_functions
    for f, _ in nf:
        try:
            yield f.variable
        except AttributeError:
            pass  # node has no tensor
        if f is not None:
            yield from get_contributing_params(f, top_level=False)

contributing_parameters = set(get_contributing_params(y))
all_parameters = set(net.parameters())
non_contributing = all_parameters - contributing_parameters
print(non_contributing)  # returns the [999999.0] tensor
vptzau2j

vptzau2j2#

是的,渐变的旋转不起作用。如果不使用这些层,如何动态删除它们。例如,在逐渐增长的鉴别器中。

相关问题