沿NumPy ndarray轴的可变折射率切片

ego6inou  于 2022-11-10  发布在  其他
关注(0)|答案(2)|浏览(116)

给定形状为(n, m)的ndarray ar,我想沿着长度为k的轴1用k<m提取子序列。在长度为k的子序列的已知起始索引start的情况下,这可以用new_ar = ar[:, start:end](或仅用start:start+k)来解决。
但是,如果我有一个列表start_list和一个长度为nend_list(或者只有start_list,因为子序列的长度无论如何都是已知的),它包含我想要提取的子序列的开始索引(和结束索引),该怎么办?直觉上,我尝试了ar[:, start_list:end_list],但这会抛出TypeError: slice indices must be integers or None or have an __index__ method
如果不使用循环并利用NumPys方法,这个问题**会有什么解决方案?对于我的问题,for循环花了30分钟,但这必须有一个NumPy风格的5ms解决方案,因为它只是索引。

[编辑]:由于使用代码可能会更好地理解问题(谢谢您的提示),我将尝试使其更紧凑,并通过循环显示我为解决问题所做的工作。

我有一个形状为(40450, 200000)的ndarray,代表每个长度为20000040450信号。信号发生了变化,我想让它们对齐。所以我想从每个40450序列中提取长度为190000的子序列。为此,我有一个长度为40450的列表start_list,包含子序列的起始索引(我要提取的每个40450子序列在长度为200000的原始序列中具有不同的起点)。
我可以使用for循环来解决这个问题(ar包含原始序列,start_list包含起始索引):

k = 190000
ar_new = np.zeros((40450, k))
for i in range(ar_new.shape[0]):
    ar_new[i] = ar[i, start_list[i]:start_list[i]+k]

例如,如果start_list[0]0,这意味着我需要ar[0, 0:190000],如果start_list[10000]1337,这意味着我需要ar[10000, 1337:1337+190000]等等。
但对于我的情况,这需要超过30分钟,我相信通过NumPy内置方法/一些切片魔法可以以某种方式解决它。

wmvff8tz

wmvff8tz1#

在经历了一些考验之后

In [14]: a = np.array(range(200000), dtype=float)
    ...: b = np.array(range(200000), dtype=float)
    ...: start, k = 100, 190000

In [15]: %timeit for _ in range(1000): a[:k] = a[s:s+k]
26.4 ms ± 9.04 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [16]: %timeit for _ in range(1000): b[:k] = a[s:s+k]
44.8 ms ± 902 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

我在想①如果你可以没有非对齐的数据,覆盖看起来更快②无论如何,如果进程包含在内存中,我希望我的结果在1?10秒的范围内,而不是30分钟③如果你的问题是交换,覆盖避免分配大约4*4E4*2E5 ⇒ 32E9字节的内存。

mspsb9vt

mspsb9vt2#

我们可以将原始的2维阵列视为3维结构。在as_strided的帮助下,我们可以创建数组的3D视图,其中第一个维度等于原始维度,第二个维度用于迭代子行的可能开始位置,第三个维度用于迭代子行中的值:

from numpy.lib.stride_tricks import as_strided

# test data

n, m = 5, 10
arr = np.arange(n*m).reshape(n, m)
k = 5
start_list = [0, 1, 2, 1, 0]

# main code

n, m = arr.shape
isize = arr.dtype.itemsize
x = 1 + m - k    # a supporting intermediate dimension

assert k < m
assert len(start_list) == n
assert all(0 <= i < x for i in start_list)

# create a view to the original data with modified shape and strides

arr_modified = as_strided(arr, shape=(n,x,k), strides=(m*isize, isize, isize))

# from each row in arr select a k-length part

# starting from the corresponding item in start_list

arr_new = arr_modified[range(n), start_list]

另见:

相关问题