numpy 如何在M x N数组的行中找到所有值的组合

bhmjp9jg  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(116)

如果题目不是很清楚,我很抱歉。我在构思这个问题时遇到了一些困难。
我有两个N X M维的numpy数组。为了简单起见,假设它们都有shape(2,10)。第一个数组由浮点数组成。例如:

[[0.1,0.02,0.2,0.3,0.013,0.7,0.7,0.11,0.18,0.6],
 [0.23,0.02,0.1,0.1,0.011,0.3,0.4,0.4,0.4,0.5]]

字符串

第二个数组由0和1组成。例如:

[[0,1,0,0,0,0,1,1,0,1],
 [1,0,1,0,0,0,0,1,1,0]]


我正在尝试做以下工作:对于第二个数组中唯一值的给定配置,选择第一个数组中该位置的元素。举个例子,我们在第二个数组中有两行,所以1和0的4种可能配置。即(1,1),(0,0),(1,0),(0,1)。如果我们采用(1,1)的情况(也就是说,第一行和第二行中的元素都等于第二个数组中的“1”),我想找到这些值,并在第一个数组中查找它们的位置。这将从第一个数组返回(0.11,0.4)。
再次道歉,如果这是没有明确沟通。感谢任何反馈。谢谢。

lg40wkob

lg40wkob1#

IIUC,你想在arr1中搜索arr2中列中所有1的位置:

arr1 = np.array(
    [
        [0.1, 0.02, 0.2, 0.3, 0.013, 0.7, 0.7, 0.11, 0.18, 0.6],
        [0.23, 0.02, 0.1, 0.1, 0.011, 0.3, 0.4, 0.4, 0.4, 0.5],
    ]
)

arr2 = np.array([[0, 1, 0, 0, 0, 0, 1, 1, 0, 1], [1, 0, 1, 0, 0, 0, 0, 1, 1, 0]])

out = arr1[:, np.all(arr2, axis=0)]
print(out)

字符串
印刷品:

[[0.11]
 [0.4 ]]


如果你想找到所有的组合:

unique = np.unique(arr2.T, axis=0)

for row in unique:
    print("Combination:")
    print(row)
    print()
    print(arr1[:, np.all(arr2 == row.reshape(arr2.shape[0], -1), axis=0)])
    print("-" * 80)


印刷品:

Combination:
[0 0]

[[0.3   0.013 0.7  ]
 [0.1   0.011 0.3  ]]
--------------------------------------------------------------------------------
Combination:
[0 1]

[[0.1  0.2  0.18]
 [0.23 0.1  0.4 ]]
--------------------------------------------------------------------------------
Combination:
[1 0]

[[0.02 0.7  0.6 ]
 [0.02 0.4  0.5 ]]
--------------------------------------------------------------------------------
Combination:
[1 1]

[[0.11]
 [0.4 ]]
--------------------------------------------------------------------------------

相关问题