numpy 第一列滚动1,第二列滚动2,依此类推

t5fffqht  于 2022-12-13  发布在  其他
关注(0)|答案(2)|浏览(119)

我在numpy中有一个数组,我想把第一列滚动1,第二列滚动2,以此类推。
下面是一个例子。

>>> x = np.reshape(np.arange(15), (5, 3))
>>> x
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14]])

我想做的事:

>>> y = roll(x)
>>> y
array([[12, 10,  8],
       [ 0, 13, 11],
       [ 3,  1, 14],
       [ 6,  4,  2],
       [ 9,  7,  5]])

最好的方法是什么?
真实的的数组将是非常大的。我使用cupy,GPU版本的numpy。我会更喜欢解决方案最快的GPU,但当然,任何想法都是受欢迎的。

balp4ylt

balp4ylt1#

您可以使用高级索引:

import numpy as np

x = np.reshape(np.arange(15), (5, 3))

h, w = x.shape

rows, cols = np.arange(h), np.arange(w)
offsets = cols + 1
shifted = np.subtract.outer(rows, offsets) % h

y = x[shifted, cols]

Y:

array([[12, 10,  8],
       [ 0, 13, 11],
       [ 3,  1, 14],
       [ 6,  4,  2],
       [ 9,  7,  5]])
wn9m85ua

wn9m85ua2#

我实现了一个简单的解决方案(roll_for),并将其与@Chrysophylaxs的解决方案(roll_indexing)进行了比较。
结论:roll_indexing对于小阵列更快,但当阵列变大时,差异缩小,最终对于非常大的阵列比roll_for慢。
实施:

import numpy as np

def roll_for(x, shifts=None, axis=-1):
    if shifts is None:
        shifts = np.arange(1, x.shape[axis] + 1)  # OP requirement
    xt = x.swapaxes(axis, 0)  # https://stackoverflow.com/a/31094758/13636407
    yt = np.empty_like(xt)
    for idx, shift in enumerate(shifts):
        yt[idx] = np.roll(xt[idx], shift=shift)
    return yt.swapaxes(0, axis)

def roll_indexing(x):
    h, w = x.shape
    rows, cols = np.arange(h), np.arange(w)
    offsets = cols + 1
    shifted = np.subtract.outer(rows, offsets) % h  # fix
    return x[shifted, cols]

测试项目:

M, N = 5, 3
x = np.arange(M * N).reshape(M, N)
expected = np.array([[12, 10, 8], [0, 13, 11], [3, 1, 14], [6, 4, 2], [9, 7, 5]])

assert np.array_equal(expected, roll_for(x))
assert np.array_equal(expected, roll_indexing(x))

M, N = 100, 200
# roll_indexing did'nt work when M < N before fix
x = np.arange(M * N).reshape(M, N)
assert np.array_equal(roll_for(x), roll_indexing(x))

性能指标评测:

M, N = 100, 100
x = np.arange(M * N).reshape(M, N)
assert np.array_equal(roll_for(x), roll_indexing(x))
%timeit roll_for(x)       # 859 µs ± 2.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit roll_indexing(x)  # 81 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


M, N = 1_000, 1_000
x = np.arange(M * N).reshape(M, N)
assert np.array_equal(roll_for(x), roll_indexing(x))
%timeit roll_for(x)       # 12.7 ms ± 56.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit roll_indexing(x)  # 12.4 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


M, N = 10_000, 10_000
x = np.arange(M * N).reshape(M, N)
assert np.array_equal(roll_for(x), roll_indexing(x))
%timeit roll_for(x)       # 1.3 s ± 6.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit roll_indexing(x)  # 1.61 s ± 4.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

相关问题