检查numpy数组是否改变符号或保持相同的符号

gblwokeq  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(100)

我有一个二维数组。我想检查数组中的列是否从左到右改变符号,或者它们保持相同的符号。如果它们确实改变了符号,我想如果它们变成正数然后变成负数就分配负数,反之亦然分配正数,如果它们变成负数然后变成正数。如果它们只保持正数,那么要分配正数,如果它们只保持负数,我想赋值负数。为了简单起见,假设列不包含零。没有太多的列,但有很多行。如何有效地做到这一点?

qmelpv7a

qmelpv7a1#

您应该逐列进行,以利用numpy的矢量化功能(也称为SIMD)。首先,将整个numpy数组转换为-1/1符号,以仅保留所需信息的必要位。然后在单个操作中计算连续的逐列差,并将其设置为1,0,或-1取决于差异。你可以更聪明一点,减少最后两个步骤在一个单一的。这里是一个片段演示如何做到这一点:

import numpy as np

# Create fake data (ranging from -1 to -1)
A = 2.0 * (np.random.rand(3, 4) - 0.5)

# First step: get the sign, scaled by a factor 0.5
A_sign = 0.5 * np.sign(A)

# Second step: compute the successive col-by-col difference at once
A_sign_diff = np.diff(A_sign, axis=1)

# Third step: Formatting (padding on the left and convert to int8)
# If you want floats, just remove args `casting="unsafe", dtype=np.int8`
R = np.concatenate(
    (np.zeros((len(A), 1)), A_sign_diff), axis=1, 
    casting="unsafe", dtype=np.int8)

个字符
[*]贝内酒店**:如果你需要额外的速度,我建议你看看numba,它提供了一个非常方便的JIT装饰器,与numpy兼容,非常适合这样的小数学代码。期望高达x10的加速。Here是文档!

相关问题