计算复数numpy ndarray的abs()**2的最节省内存的方法

new9mtju  于 2024-01-08  发布在  其他
关注(0)|答案(6)|浏览(144)

我正在寻找最节省内存的方法来计算复杂numpy ndarray的绝对平方值

arr = np.empty((250000, 150), dtype='complex128')  # common size

字符串
我还没有找到一个ufunc可以精确地执行np.abs()**2
由于这种大小和类型的数组占用大约半GB,我正在寻找一种主要是内存效率高的方法。
我也希望它是便携式的,所以理想的是一些ufuncs的组合。
到目前为止我的理解是这应该是最好的

result = np.abs(arr)
result **= 2


它将不必要地计算(**0.5)**2,但应该就地计算**2。总的来说,峰值内存需求仅为原始数组大小+结果数组大小,应为1.5 * 原始数组大小,因为结果是真实的。
如果我想摆脱无用的**2调用,我必须这样做

result = arr.real**2
result += arr.imag**2


但如果我没弄错的话,这意味着我必须为**真实的和虚部计算分配内存,因此内存使用的峰值将是2.0 * 原始数组大小。arr.real属性也返回一个非连续数组(但这不是那么重要)。
有什么我遗漏的吗?有更好的方法吗?

  • 编辑1*:很抱歉没有说清楚,我不想覆盖它,所以我不能使用它作为输出。
yvgpqqbh

yvgpqqbh1#

感谢numba最新版本中的numba.vectorize,为任务创建一个numpy通用函数非常容易:

@numba.vectorize([numba.float64(numba.complex128),numba.float32(numba.complex64)])
def abs2(x):
    return x.real**2 + x.imag**2

字符串
在我的机器上,我发现与创建中间数组的纯numpy版本相比,速度提高了三倍:

>>> x = np.random.randn(10000).view('c16')
>>> y = abs2(x)
>>> np.all(y == x.real**2 + x.imag**2)   # exactly equal, being the same operation
True
>>> %timeit np.abs(x)**2
10000 loops, best of 3: 81.4 µs per loop
>>> %timeit x.real**2 + x.imag**2
100000 loops, best of 3: 12.7 µs per loop
>>> %timeit abs2(x)
100000 loops, best of 3: 4.6 µs per loop

nfs0ujit

nfs0ujit2#

编辑:这个解决方案有两倍的最低内存需求,只是稍微快一点。2评论中的讨论是很好的参考。
这里有一个更快的解决方案,结果存储在res中:

import numpy as np
res = arr.conjugate()
np.multiply(arr,res,out=res)

字符串
其中我们利用了复数的绝对值的性质,即abs(z) = sqrt(z*z.conjugate),因此abs(z)**2 = z*z.conjugate

wtlkbnrh

wtlkbnrh3#

如果你的主要目标是节省内存,NumPy的ufuncs采用一个可选的out参数,让你将输出定向到你选择的数组。当你想就地执行操作时,它可能很有用。
如果你对第一个方法做了这个小小的修改,那么你就可以在arr上完全执行操作了:

np.abs(arr, out=arr)
arr **= 2

字符串
一种只使用 * 一点点 * 额外内存的复杂方法是修改arr,计算新的真实的值数组,然后恢复arr
这意味着存储符号的信息(除非你知道你的复数都有正的真实的和虚部)。每个真实的或虚值的符号只需要一个比特,所以这使用了arr的内存(除了你创建的新浮点数组)。

>>> signs_real = np.signbit(arr.real) # store information about the signs
>>> signs_imag = np.signbit(arr.imag)
>>> arr.real **= 2 # square the real and imaginary values
>>> arr.imag **= 2
>>> result = arr.real + arr.imag
>>> arr.real **= 0.5 # positive square roots of real and imaginary values
>>> arr.imag **= 0.5
>>> arr.real[signs_real] *= -1 # restore the signs of the real and imagary values
>>> arr.imag[signs_imag] *= -1


以存储符号位为代价,arr保持不变,result保持我们想要的值。

gojuced7

gojuced74#

arr.realarr.imag只是复杂数组的视图。因此没有分配额外的内存。

t5zmwmid

t5zmwmid5#

如果你不想要sqrt(应该比乘法重得多),那么就不要abs
如果你不想要双内存,那么没有real**2 + imag**2
那么你可以试试这个(使用索引技巧)

N0 = 23
np0 = (np.random.randn(N0) + 1j*np.random.randn(N0)).astype(np.complex128)
ret_ = np.abs(np0)**2
tmp0 = np0.view(np.float64)
ret0 = np.matmul(tmp0.reshape(N0,1,2), tmp0.reshape(N0,2,1)).reshape(N0)
assert np.abs(ret_-ret0).max()<1e-7

字符串
无论如何,我更喜欢numba解决方案

xqkwcwgp

xqkwcwgp6#

如果你想计算一个复杂的numpy数组(即abs()**2)的平方大小,你可以使用numpy.abs()函数来计算绝对值,然后使用numpy.square()函数来计算平方值。下面是一个例子:

import numpy as np

# Create a complex numpy array
complex_array = np.array([1 + 2j, 3 + 4j, 5 + 6j])

# Compute abs()**2
squared_magnitude = np.square(np.abs(complex_array))

print(squared_magnitude)

字符串
这个例子演示了如何有效地计算一个复杂的numpy数组的平方大小。np.abs()函数被应用于逐元素计算绝对值,然后np.square()被用于计算平方值。这种方法既简洁又节省内存。
请记住,np.abs()函数返回一个实值数组,因此后续的np.square()操作不涉及复数。如果您希望将结果保持为复数数组,则可以直接使用np.square(complex_array),因为对复数进行平方可以保留其复数性质。strong text

相关问题