numpy 如何用Numba在三维数组中找到每个单元格的极值?

iyfamqjs  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(78)

我最近写了一个脚本来转换BGR数组的[0,1]浮点数到HSL和回来。我把它放在Code Review上。目前有一个答案,但它不会提高性能。
我已经对我的代码进行了cv2.cvtColor基准测试,发现我的代码效率很低,所以我想用Numba编译代码,使它运行得更快。
我尝试用@nb.njit(cache=True, fastmath=True) Package 所有函数,但这不起作用。
所以我测试了我单独使用过的每个NumPy语法和NumPy函数,发现有两个函数不适用于Numba。
我需要找到每个像素的最大通道(np.max(img, axis=-1))和每个像素的最小通道(np.max(img, axis=-1)),axis参数不适用于Numba。
我试着在谷歌上搜索这个,但我发现的唯一相关的东西是this,但它只实现了np.anynp.all,并且只适用于二维数组,而这里的数组是三维的。
我可以写一个基于for循环的解决方案,但我不会写它,因为它注定是低效的,并且首先违背了使用NumPy和Numba的目的。
最小可重现性示例:

import numba as nb
import numpy as np

@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
    return np.max(arr, axis=-1)

@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
    return np.min(arr, axis=-1)

img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)

例外情况:

In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:

 >>> amax(array(float64, 3d, C), axis=Literal[int](-1))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
    With argument(s): '(array(float64, 3d, C), axis=int64)':
   Rejected as the implementation raised a specific error:
     TypingError: got an unexpected keyword argument 'axis'
  raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)

File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
    return np.max(arr, axis=-1)
    ^

如何解决这一问题?

von4xj4u

von4xj4u1#

不使用np.max(),而是使用循环来实现它是相当简单的:

@nb.njit()
def max_per_cell_nb(arr):
    ret = np.empty(arr.shape[:-1], dtype=arr.dtype)
    n, m = ret.shape
    for i in range(n):
        for j in range(m):
            max_ = arr[i, j, 0]
            max_ = max(max_, arr[i, j, 1])
            max_ = max(max_, arr[i, j, 2])
            ret[i, j] = max_
    return ret

对它进行基准测试,结果证明它比np.max(arr, axis=-1)快16倍。

%timeit max_per_cell_nb(img)
4.88 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit max_per_cell(img)
81 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

在进行基准测试时,我做了以下假设:

  • 图像为1920x1080x3。(换句话说,这是一个大的图像。
  • 图像数组是C顺序而不是Fortran顺序。如果是Fortran顺序,我的方法的速度下降到7ms,np.max()的速度更快,只需要15 ms。参见Check if numpy array is contiguous?了解如何判断数组是C顺序还是Fortran顺序。np.random.random((3, 4, 3))的例子是C连续的。
  • 我将此函数与关闭Numba JIT的np.max(arr, axis=-1)进行比较,因为它无法真正优化对NumPy函数的单个调用。
yrdbyhpb

yrdbyhpb2#

CHW实现

根据@NickODell回答的评论,这里是一个更快的SIMD友好型解决方案,当图像使用 *CWH布局 * 时(根据@NickODell的要求):

import numpy as np
import numba as nb

@nb.vectorize('(float32, float32, float32)')
def vec_max(a, b, c):
    return max(a, b, c)

@nb.njit('(float32[:,:,::1], float32[:,::1])')
def max_per_cell_chw_nb_faster(a, b):
    vec_max(a[0], a[1], a[2], b)

# Benchmark

img1 = np.random.rand(1920, 1080, 3).astype(np.float32)
img2 = np.random.rand(3, 1920, 1080).astype(np.float32)
out = np.empty((1920, 1080), dtype=np.float32)

%timeit out = max_per_cell_nb(img1)
%timeit max_per_cell_chw_nb_faster(img2, out)

以下是我的机器上的结果(i5- 9600 KF + 40 GiB/s RAM在Windows上):

3 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.08 ms ± 6.08 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

这意味着这个实现在32位浮点数上快了3倍。前者是标量,后者使用SIMD指令。前者也会因页面错误而变慢,而后者则不会。SIMD版本如果不是内存受限的话,可能会更快。数据类型越小,SIMD实现越快。使用8位无符号整数,SIMD版本大约快9倍,* 仍然内存受限 *:

1.83 ms ± 9.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
210 µs ± 3.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

HWC实现

请注意,只需将max_per_cell_nb_faster的单行内容替换为以下内容即可生成@NickODell版本,以支持HWC映像:

@nb.njit('(float32[:,:,::1], float32[:,::1])')
def max_per_cell_hwc_nb_faster(a, b):
    vec_max(a[:,:,0], a[:,:,1], a[:,:,2], b)

然而,这个版本比@NickODell的解决方案慢一点(尽管更简单),因为它在内部不使用SIMD指令:

3.35 ms ± 50.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

事实上,LLVM优化器AFAIK还不能对这种访问模式进行向量化(因为这非常困难,而且即使手动小心地完成,它的效率通常也低于使用CWH)。

相关问题