清除当前python代码中pytorch使用的所有GPU内存,而无需退出python

vjhs03f7  于 2023-11-19  发布在  Python
关注(0)|答案(2)|浏览(237)

我运行的是一个修改过的第三方代码,它使用了pytorch和GPU。我通过改变模型参数多次运行同一个模型,我在python中这样做,即我有一个 Package 器python文件,它用不同的模型参数调用模型。但是我在运行第二个或第三个模型时得到了out-of-memory错误。也就是说,模型可以正常运行一次,没有任何内存问题。因此,如果我在运行第一个模型后结束代码,然后重新启动第二个模型,代码工作正常。但是,如果我在python中链接模型,我会遇到out-of-memory问题。
我怀疑第三方代码中有一些内存泄漏。在谷歌上,我发现了两个建议。一个是调用torch.cuda.empty_cache(),另一个是使用del tensor_name显式删除Tensor。然而,empty_cache()命令并不能帮助释放整个内存,而第三个-party代码中的Tensor太多,我无法单独删除所有Tensor。有没有办法在python代码中清除当前python程序使用的整个GPU内存本身?

f0brbegy

f0brbegy1#

在没有实际阅读代码的情况下,很难确定是什么导致了内存问题。但是,大多数情况下,empty_cache()无法完成清理是因为某些进程仍在运行。因此,请尝试在empty_cache()之后添加此内容

import gc
gc.collect()

字符串

jaxagkaj

jaxagkaj2#

垃圾收集器和del直接在模型和训练数据上工作,当使用循环中的模型时,我很少工作。通常,每次迭代都会创建一个新模型,而不会从内存中清除前一个模型,因此整个循环需要(model_size + training data) * n的内存容量,其中 n 是迭代次数。这在使用联邦学习工具(如Flower)或使用k-fold cross validation.如果你想使用一个多处理的方法,它应该总是工作,以清除子进程使用的GPU内存,这将工作:

import torch
import torch.multiprocessing as mp
from torch.multiprocessing import Pool, Manager

class nn_model(torch.nn.Module)
   # define class for model

def train_fn(model, train_ds, val_ds, queue)
    # do training stuff
    history = # save training metrics into history dictionary

    queue.put(model, history)

if __name__ == "__main__":
    train_ds = # load train dataset
    val_ds = # load val dataset
    test_ds = # load test dataset

    model = nn_model(*args)

    pool = Pool(1)  # or replace 1 with however many child processes you want
    queue = Manager().Queue()

    # assume training loop of three iterations
    [pool.apply_async(train_fn, args=(model, train_ds, val_ds, queue)) for _ in range(3)]
    pool.close()
    pool.join()

    # this next line will give a variable queue_results containing a tuple of (model, history)
    queue_results = [queue.get() for _ in range(queue.qsize())]

    # assuming queue_results contains only one element
    model, history = queue_results[0]
    
    for i, j in test_ds:
        with torch.no_grad():
            model = model.eval()
            preds = model(i)

字符串
如果你不想使用Pool,而想显式地杀死子进程,你可以使用(而不是Pool):

from torch.multiprocessing import Process

if __name__ == "__main__":

    ...

    # inside "_main__" replacing Pool, assuming three iterations
    procs = [mp.Process(target=train_fn, args=(model, train_ds, val_ds, queue)) for _ in range(3)]
    for p in procs:
        p.start()
        p.join()
    
    queue_results = [queue.get() for _ in range(queue.qsize())]

    for p in procs:
        try:
            p.close()
        except ValueError as e:
            print(f"Couldn't close process: {e}")
        del p

    # assuming queue_results contains only one element
    model, history = queue_results[0]
    
    for i, j in test_ds:
        with torch.no_grad():
            model = model.eval()
            preds = model(i)


我认为使用Pool比Process更方便。

相关问题