为什么pytorch在m1 pro 10核心上比linux CPU慢?

mspsb9vt  于 2023-08-05  发布在  Linux
关注(0)|答案(1)|浏览(141)

bounty还有6天到期。回答此问题可获得+50声望奖励。piedpiper吸引更多的注意力这个问题:只是想有一些方法来加快我的m1 mac上的事情

我从here运行了以下基准测试。

#!/usr/bin/env python3

import torch

def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)

def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)

# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

字符串
我得到了

mul_sum(x, x):  1065.9 us
bmm(x, x):      134.5 us


在Mac上
和/或

mul_sum(x, x):   52.3 us
bmm(x, x):      120.1 us


在Linux CPU上
我看到了巨大的性能差异,这是预期的吗?
我第一次注意到这种差异是在一个更严肃的程序上,我正试图在这里复制它。

pw136qt2

pw136qt21#

性能差异可归因于明显影响计算执行的几个关键因素:
1.优化进度:PyTorch对Apple Silicon架构的适应仍在完善中。M1 Pro的优化可能不如Linux系统中常见的成熟x86_64 CPU的优化成熟。因此,代码执行效率当前可能倾向于更成熟的架构。
1.架构调整:PyTorch及其依赖的BLAS和LAPACK库针对特定的处理器架构进行了复杂的调整。对于M1 Pro的ARM架构和Linux CPU的典型x86_64架构,这些优化可能不会同样完善。因此,针对x86_64优化的操作可能无法在M1 Pro上高效执行。
1.指令集变体:指令集架构显著地影响各种操作的执行效率。基于ARM的处理器,如M1 Pro,与Linux系统中的传统x86_64处理器相比,部署了不同的指令集。指令集的这种差异可能固有地导致执行不同任务时的不同效率,从而影响所观察到的性能差异。

相关问题