numpy 如何反向索引二维数组

sf6xfgos  于 2023-01-13  发布在  其他
关注(0)|答案(3)|浏览(148)

我有一个2d MxN数组A,其中的每一行都是一个索引序列,在末尾用-1填充,例如:

[[ 2 1 -1 -1 -1]
 [ 1 4  3 -1 -1]
 [ 3 1  0 -1 -1]]

我有另一个浮点值的MxN数组B:

[[ 0.7 0.4 1.5 2.0 4.4 ]
 [ 0.8 4.0  0.3 0.11 0.53]
 [ 0.6 7.4  0.22 0.71 0.06]]

我想使用A中的索引来过滤B,即对于每一行,只有A中存在的索引保留其值,所有其他位置的值都设置为0.0,即结果如下所示:

[[ 0.0 0.4 1.5 0.0 0.0 ]
 [ 0.0 4.0  0.0 0.11 0.53 ]
 [ 0.6 7.4  0.0 0.71 0.0]]

在“pure”numpy中有什么好的方法来做这个呢?(我想在pure numpy中做这个,这样我就可以在jax中jit它了。

8zzbczxx

8zzbczxx1#

可以使用broadcasting,但注意它会创建一个(M, N, N)形状的大型中间数组(至少在纯numpy中):

import numpy as np

A = ...
B = ...

M, N = A.shape

out = np.where(np.any(A[..., None] == np.arange(N), axis=1), B, 0.0)

输出:

array([[0.  , 0.4 , 1.5 , 0.  , 0.  ],
       [0.  , 4.  , 0.  , 0.11, 0.53],
       [0.6 , 7.4 , 0.  , 0.71, 0.  ]])
jgovgodb

jgovgodb2#

Numpy支持花哨的索引,暂时忽略“-1”条目,你可以这样做:

index = (np.arange(B.shape[0]).reshape(-1, 1), A)
result = np.zeros_like(B)
result[index] = B[index]

这是因为索引是广播的,列np.arange(B.shape[0]).reshape(-1, 1)A中给定行的所有元素与Bresult中对应的行进行匹配。
此示例未解决-1是有效numpy索引这一事实。当4(最后一列)不在A行中时,需要清除该行中与-1对应的元素:

mask = (A == -1).any(axis=1) & (A != A.shape[1] - 1).all(axis=1)
result[mask, -1] = 0.0

这里,掩码是[True, False, True],表示即使第二行中有-1,它也包含4
这种方法相当有效,它只会为掩码创建两个与A形状相同的布尔数组。

1rhkuytd

1rhkuytd3#

另一种可能的解决方案:

maxr = np.max(A, axis=1)
A = np.where(A == -1, maxr.reshape(-1,1), A)
mask = np.zeros(np.shape(B), dtype=bool)
np.put_along_axis(mask, A, True, axis=1) 
np.where(mask, B, 0)

输出:

array([[0.  , 0.4 , 1.5 , 0.  , 0.  ],
       [0.  , 4.  , 0.  , 0.11, 0.53],
       [0.6 , 7.4 , 0.  , 0.71, 0.  ]])

相关问题