python all_gather中的分布式torch数据冲突(将all_gather结果写入文件“修复”问题)

k2arahey  于 2023-04-28  发布在  Python
关注(0)|答案(1)|浏览(387)

问题:

  • 分布式进程计算错误并将其与float索引一起返回
  • 当从不同的等级中收集错误时,这些索引上会发生冲突
  • 因此,如果数据集具有100个样本,并且GPU的数量为4,则所得到的索引集的长度将为25,而不是预期的100
  • 当我将每个等级的数据(预收集)写入文件时,我可以验证索引是100%不相交的
  • 当我将每个等级的数据(后收集)写入文件时,问题消失了
  • 注解掉后收集调试数据文件写入,问题返回

注:打印后收集结果也可以“修复”该问题,但对后收集结果进行排序则不能。
因此,将后收集数据写入文件可以解决一些分布式问题。..我被提醒需要flush流来避免意外的结果,但我在文档中没有看到任何类型的推论。
下面是一个最小的例子,展示了我的代码中发生了什么:

# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
    0 + (rank * 100 // world_size),
    (rank + 1) * 100 // world_size,
)

# `data` is confirmed to be disjoint across ranks by writing to file here.

# Gather data from all ranks.
if world_size > 1:
    all_data = [torch.zeros_like(data) for _ in range(world_size)]
    torch.distributed.all_gather(all_data, data)
    data = torch.cat(all_data, dim=0)

    # By writing "data" to file for debugging, the problem goes away...
    #     i.e. len(set(data.numpy())) == 100!
    # If I comment this out, then my gathered data collides...
    #     i.e. len(set(data.numpy())) == 100 // world_size
    with open("debug_data.pt", "wb") as _file:
        torch.save(data, _file)

    # I can also simply print the indices and get the same effect...
    logger.info(
        "Gathered result indices: {}...{}".format(
            data[:10, -1], data[-10:, -1]
        )
    )

    # However, sorting the indices doesn't do me any good...
    data = data[data[:, -1].argsort(dim=0)]

if rank == 0:
    # do_something(data)
3df52oht

3df52oht1#

all_gather()调用之后添加torch.distributed.barrier()调用,以更令人满意的方式解决了这个问题。我没有想到要这样做,因为文档中指出all_gather()是一个阻塞调用。也许它们的意思是阻塞,如not async;与torch.distributed不同。
我认为,将结果记录并写入文件“修复”问题而sort没有的原因是,前者不是强制同步的torch操作(因此,不由分布式进程组管理)。

相关问题