比较多个numpy数组

holgip5t  于 2022-12-23  发布在  其他
关注(0)|答案(6)|浏览(142)

我应该如何比较超过2个numpy数组?

import numpy 
a = numpy.zeros((512,512,3),dtype=numpy.uint8)
b = numpy.zeros((512,512,3),dtype=numpy.uint8)
c = numpy.zeros((512,512,3),dtype=numpy.uint8)
if (a==b==c).all():
     pass

这给予了一个valueError,我对一次比较两个数组不感兴趣。

s1ag04yj

s1ag04yj1#

对于三个数组,你可以检查第一个和第二个数组,然后是第二个和第三个数组的对应元素是否相等,给予两个布尔标量,最后看看这两个标量是否都是True,作为最终的标量输出,就像这样--

np.logical_and( (a==b).all(), (b==c).all() )

对于更多数量的数组,你可以把它们堆叠起来,沿着堆叠的轴得到微分,然后检查这些微分是否所有都等于零。如果是,我们就在所有输入数组中得到相等,否则就不是。实现看起来像这样--

L = [a,b,c]    # List of input arrays
out = (np.diff(np.vstack(L).reshape(len(L),-1),axis=0)==0).all()
vltsax25

vltsax252#

对于三个数组,您实际上应该一次比较两个数组:

if np.array_equal(a, b) and np.array_equal(b, c):
    do_whatever()

对于一个可变数量的数组,假设它们都被合并成一个大数组arrays

if np.all(arrays[:-1] == arrays[1:]):
    do_whatever()
vqlkdk9b

vqlkdk9b3#

为了扩展前面的答案,我将使用itertools中的combinations来构造所有对,然后对每一对进行比较。例如,如果我有三个数组,并希望确认它们都相等,我将用途:

from itertools import combinations

for pair in combinations([a, b, c], 2):
    assert np.array_equal(pair[0], pair[1])
z9smfwbn

z9smfwbn4#

支持不同形状和名称的解决方案

与数组列表的第一个元素比较:

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.arange(4)

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst):
    for arr in lst[1:]:
        if not np.array_equal(lst[0], arr, equal_nan=True):
            return False
    return True

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))
    • 输出**
all_equal(lst_eq)= True
all_equal(lst_neq)= False

表示形状相同且无纳米支撑

将所有元素合并到一个数组中,计算沿新轴的绝对差异,并检查沿新维度的最大元素是否等于0或小于某个阈值,这应该是相当快的。

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.array([0, 1, 3])

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst, threshold = 0):
    arr = np.stack(lst, axis=0)

    return np.max(np.abs(np.diff(arr, axis=0))) <= threshold

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))
    • 输出**
all_equal(lst_eq)= True
all_equal(lst_neq)= False
vqlkdk9b

vqlkdk9b5#

这可能有用。

import numpy

x = np.random.rand(10)
arrays = [x for _ in range(10)]

print(np.allclose(arrays[:-1], arrays[1:]))  # True

arrays.append(np.random.rand(10))

print(np.allclose(arrays[:-1], arrays[1:]))  # False
7ajki6be

7ajki6be6#

一行程序:

arrays = [a, b, c]    
all([np.array_equal(a, b) for a, b in zip(arrays, arrays[1:])])

相关问题