Numpy:检查3D数组中的任何数组是否在另一个具有重复的较短3D数组中

k7fdbhmy  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(125)

我有一个像这样的Numpy数组:

source = np.array([[[0,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,1],[1,1,0],[1,1,1]]])

字符串
我试着将它与另一个数组进行比较,它的Axis2较短,而Axis3中有重复:

values = np.array([[[0,1,0],[1,0,0],[1,1,1],[1,1,1],[0,1,0]]])


我的目标是有一个布尔数组,只要最长的:

[False, False,True,True,False,False,True]


我试过这些命令:

np.isin(source,values).all(axis=2)


但是它显示了一个七个True的数组。像numpy.in1d()这样的函数似乎是一个很好的选择,但是我没有实现将它用于3D数组。

lstz6jyr

lstz6jyr1#

单程:

np.in1d(np.apply_along_axis(''.join, 2, source.astype(str)), 
        np.apply_along_axis(''.join, 2, values.astype(str)))

array([False, False,  True,  True, False, False,  True])

字符串
另一种方式,虽然可能是内存密集型的:

(source.transpose(1,0,2) == values).all(2).any(1)
array([False, False,  True,  True, False, False,  True])

dfddblmv

dfddblmv2#

解决方案1:将值中的每一行与源中的每一行进行比较

result = (source[:,:,:] == values[:,:,None]).all(axis=-1).any(axis=1)[0]

字符串
解决方案2:转换it(n,1)形状

source_2d = source.squeeze(0)
values_2d = values.squeeze(0)
dtype = [(f'{i}', source_2d.dtype) for i in range(source_2d.shape[1])]
source_struct = source_2d.view(dtype)
values_struct = values_2d.view(dtype)
result = np.isin(source_struct, values_struct)
print(result.squeeze(1).tolist())

相关问题