问题:
- 分布式进程计算错误并将其与
float
索引一起返回 - 当从不同的等级中收集错误时,这些索引上会发生冲突
- 因此,如果数据集具有100个样本,并且GPU的数量为4,则所得到的索引集的长度将为25,而不是预期的100
- 当我将每个等级的数据(预收集)写入文件时,我可以验证索引是100%不相交的
- 当我将每个等级的数据(后收集)写入文件时,问题消失了
- 注解掉后收集调试数据文件写入,问题返回
注:打印后收集结果也可以“修复”该问题,但对后收集结果进行排序则不能。
因此,将后收集数据写入文件可以解决一些分布式问题。..我被提醒需要flush
流来避免意外的结果,但我在文档中没有看到任何类型的推论。
下面是一个最小的例子,展示了我的代码中发生了什么:
# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
0 + (rank * 100 // world_size),
(rank + 1) * 100 // world_size,
)
# `data` is confirmed to be disjoint across ranks by writing to file here.
# Gather data from all ranks.
if world_size > 1:
all_data = [torch.zeros_like(data) for _ in range(world_size)]
torch.distributed.all_gather(all_data, data)
data = torch.cat(all_data, dim=0)
# By writing "data" to file for debugging, the problem goes away...
# i.e. len(set(data.numpy())) == 100!
# If I comment this out, then my gathered data collides...
# i.e. len(set(data.numpy())) == 100 // world_size
with open("debug_data.pt", "wb") as _file:
torch.save(data, _file)
# I can also simply print the indices and get the same effect...
logger.info(
"Gathered result indices: {}...{}".format(
data[:10, -1], data[-10:, -1]
)
)
# However, sorting the indices doesn't do me any good...
data = data[data[:, -1].argsort(dim=0)]
if rank == 0:
# do_something(data)
1条答案
按热度按时间3df52oht1#
在
all_gather()
调用之后添加torch.distributed.barrier()
调用,以更令人满意的方式解决了这个问题。我没有想到要这样做,因为文档中指出all_gather()
是一个阻塞调用。也许它们的意思是阻塞,如notasync
;与torch.distributed
不同。我认为,将结果记录并写入文件“修复”问题而
sort
没有的原因是,前者不是强制同步的torch操作(因此,不由分布式进程组管理)。