vllm [Misc]:如何修复测试中的正确清理问题

utugiqy6  于 2个月前  发布在  其他
关注(0)|答案(2)|浏览(17)

关于vllm,我们目前在正确清理资源方面遇到了很大的困难,尤其是在分布式推理方面。这使得我们的测试最近受到了影响。
要理解这个问题,我们需要了解pytest的进程模型:对于这个简单的测试:

import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import os
    print((arg, os.getpid()))

pytest 会创建一个进程,依次运行这三个测试。
所以输出是(注意三个测试的进程id相同):

testf.py::test_pass[1] (1, 15068)
PASSED
testf.py::test_pass[2] (2, 15068)
PASSED
testf.py::test_pass[3] (3, 15068)
PASSED

这三个测试共享同一个进程,使得一些低级处理变得困难。

  1. 当某个测试发生段错误时,后面的测试将无法运行,因为进程已经死亡。例如:
import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import os
    print((arg, os.getpid()))

    if arg == 2:
        import ctypes
        func_ptr = ctypes.CFUNCTYPE(ctypes.c_int)(0)
        # Calling the function pointer with an invalid address will cause a segmentation fault
        func_ptr()

这个测试产生的输出是:

Running 3 items in this shard: testf.py::test_pass[1], testf.py::test_pass[2], testf.py::test_pass[3]

testf.py::test_pass[1] (1, 24492)
PASSED
testf.py::test_pass[2] (2, 24492)
Fatal Python error: Segmentation fault

因此,在这种情况下,测试3不会被执行。可以认为这是可以接受的。如果测试2因段错误而失败,我们肯定应该调查原因并修复它。

  1. 当某些测试使进程变脏时,该进程不能再用于后续的测试。例如:
import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100

在这个例子中,每个测试都需要一个干净的进程(没有初始化CUDA),但在测试完成后会使进程变脏。
进程也可能变脏,如果某些对象没有被垃圾回收,且GPU内存没有被释放。
解决方案是为每个测试创建一个新的进程,即使用pytest --forked -s test.py运行测试。这基本上是有效的,但有一个限制:输出无法被捕获。请注意,我添加了print(data)来打印一些内容,但pytest --forked会丢弃输出。这对开发者来说并不友好。
我们探索过的另一种替代方案是,手动为每个测试用例创建一个进程:

import os
arg = int(os.environ['arg'])
def test_pass():
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100
    if arg == 2:
        raise RuntimeError("invalid arg")

并使用环境变量启动每个测试:

arg=1 pytest -v -s test.py
arg=2 pytest -v -s test.py
arg=3 pytest -v -s test.py

例如:
vllm/.buildkite/test-pipeline.yaml
第85行 6a11fdf
| | - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py |
这很繁琐,当我们想要测试多种参数组合时,也无法扩展。
提议的解决方案是手动分叉:

import functools

def fork_new_process_for_each_test(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        import os
        pid = os.fork()
        if pid == 0:
            try:
                f(*args, **kwargs)
            except Exception:
                import traceback
                traceback.print_exc()
                os._exit(1)
            else:
                os._exit(0)
        else:
            _pid, _exitcode = os.waitpid(pid, 0)
            assert _exitcode == 0, f"function {f} failed when called with args {args} and kwargs {kwargs}"
    return wrapper

import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
@fork_new_process_for_each_test
def test_pass(arg):
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100
    if arg == 2:
        raise RuntimeError("invalid arg")

输出是:

================================================================== test session starts ===================================================================
platform linux -- Python 3.9.19, pytest-8.2.2, pluggy-1.5.0 -- /data/youkaichao/miniconda/envs/vllm/bin/python
cachedir: .pytest_cache
rootdir: /data/youkaichao/vllm
configfile: pyproject.toml
plugins: asyncio-0.23.7, shard-0.1.2, anyio-4.4.0, rerunfailures-14.0, forked-1.6.0
asyncio: mode=strict
collected 3 items                                                                                                                                        
Running 3 items in this shard: teste.py::test_pass[1], teste.py::test_pass[2], teste.py::test_pass[3]

teste.py::test_pass[1] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
PASSED
teste.py::test_pass[2] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
Traceback (most recent call last):
  File "/data/youkaichao/vllm/teste.py", line 10, in wrapper
    f(*args, **kwargs)
  File "/data/youkaichao/vllm/teste.py", line 32, in test_pass
    raise RuntimeError("invalid arg")
RuntimeError: invalid arg
FAILED
teste.py::test_pass[3] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
PASSED

======================================================================== FAILURES ========================================================================
______________________________________________________________________ test_pass[2] ______________________________________________________________________

args = (), kwargs = {'arg': 2}, os = <module 'os' from '/data/youkaichao/miniconda/envs/vllm/lib/python3.9/os.py'>, pid = 2986722, _pid = 2986722
_exitcode = 256, @py_assert2 = 0, @py_assert1 = False, @py_format4 = '256 == 0'
@py_format6 = "function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}\n>assert 256 == 0"

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        import os
        pid = os.fork()
        if pid == 0:
            try:
                f(*args, **kwargs)
            except Exception:
                import traceback
                traceback.print_exc()
                os._exit(1)
            else:
                os._exit(0)
        else:
            _pid, _exitcode = os.waitpid(pid, 0)
>           assert _exitcode == 0, f"function {f} failed when called with args {args} and kwargs {kwargs}"
E           AssertionError: function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}
E           assert 256 == 0

teste.py:19: AssertionError
================================================================ short test summary info =================================================================
FAILED teste.py::test_pass[2] - AssertionError: function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}
============================================================== 1 failed, 2 passed in 32.03s ==============================================================

需要注意的是:

  • 每个测试都会得到一个新的干净进程
  • 输出被捕获(打印语句)
  • 测试2失败,但不会阻塞测试3。
  • 即使测试2无法清理资源,因为它是一个新进程,直接调用os._exit(0)退出,也不应该影响测试3。

唯一需要确保的是,当我们进入测试时,进程是干净的。
当然,完美的解决方案是在引入多进程、ray和asyncio时优雅地处理资源,有明确的清理过程。但是这非常困难。没有完美清理资源的预期完成时间。

ycl3bljg

ycl3bljg1#

这个解决方案相当优雅,我喜欢它!

31moq8wy

31moq8wy2#

第一步是合并到#7054。明天我会在主节点监控ci状态,看看pp测试和2个节点测试中是否还有失败。如果没有,我会继续更改剩余的测试。

相关问题