如何在使用自定义损失函数时在PyTorch中执行内存高效的反向传播?

ibps3vxo  于 2023-11-19  发布在  其他
关注(0)|答案(2)|浏览(97)

简介

我正在使用PyTorch进行一个大型深度学习项目,在反向传播过程中面临内存问题。我实现了一个自定义损失函数,我需要知道是否有一种更节省内存的方法来执行反向传播,而不影响自定义损失计算。

验证码

import torch
import torch.nn as nn

class CustomLoss(nn.Module):
    def forward(self, x, y):
        return torch.sum(x * y)

# My neural network
class Net(nn.Module):
    # ...

字符串
我尝试使用PyTorch的内置方法进行反向传播,但它们消耗了大量内存。我希望自定义损失函数可以优化以更好地利用内存。
到底发生了什么?
反向传播期间内存消耗激增,导致我的脚本崩溃。

2ul0zpep

2ul0zpep1#

请考虑:
1.使batch_size更小。
1.使用torch.utils.checkpoint.checkpoint创建检查点(在计算和内存之间进行权衡,请参见here
1.尽可能使用现场操作。

06odsfpq

06odsfpq2#

反向传播过程中的内存消耗是由于需要为每个模型参数存储一个grad参数,加上为优化器状态存储的参数。它与损失函数无关。
您可以通过使用较小的批处理大小和使用梯度累积进行训练来减少内存使用。

相关问题