我有一个2d MxN数组A,其中的每一行都是一个索引序列,在末尾用-1填充,例如:
[[ 2 1 -1 -1 -1]
[ 1 4 3 -1 -1]
[ 3 1 0 -1 -1]]
我有另一个浮点值的MxN数组B:
[[ 0.7 0.4 1.5 2.0 4.4 ]
[ 0.8 4.0 0.3 0.11 0.53]
[ 0.6 7.4 0.22 0.71 0.06]]
我想使用A中的索引来过滤B,即对于每一行,只有A中存在的索引保留其值,所有其他位置的值都设置为0.0,即结果如下所示:
[[ 0.0 0.4 1.5 0.0 0.0 ]
[ 0.0 4.0 0.0 0.11 0.53 ]
[ 0.6 7.4 0.0 0.71 0.0]]
在“pure”numpy中有什么好的方法来做这个呢?(我想在pure numpy中做这个,这样我就可以在jax中jit它了。
3条答案
按热度按时间8zzbczxx1#
可以使用broadcasting,但注意它会创建一个
(M, N, N)
形状的大型中间数组(至少在纯numpy中):输出:
jgovgodb2#
Numpy支持花哨的索引,暂时忽略“-1”条目,你可以这样做:
这是因为索引是广播的,列
np.arange(B.shape[0]).reshape(-1, 1)
将A
中给定行的所有元素与B
和result
中对应的行进行匹配。此示例未解决
-1
是有效numpy索引这一事实。当4
(最后一列)不在A
行中时,需要清除该行中与-1
对应的元素:这里,掩码是
[True, False, True]
,表示即使第二行中有-1
,它也包含4
。这种方法相当有效,它只会为掩码创建两个与
A
形状相同的布尔数组。1rhkuytd3#
另一种可能的解决方案:
输出: