我试图使用JAX来加速我的行和列比较/选择,我有一个2d数组NxN,每个单元格都是一个数字,我试图从一行中获得最高的4个数字,然后将1放入具有相同索引的不同矩阵C中。这是两个限制:每行和列最多可以有4个1。举一个更简单的例子:
array a = [[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20],
[21,22,23,24,25]]
C:
[[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[1,0,0,0,0]]
我们不能选择最后一行的最后4个单元格,即使这4个单元格的值最高,所以我们只能选择第一个单元格并将其放入C矩阵中
这是我尝试了一个不同的二维数组8x8和正确的C矩阵,仍然是一个例子,真实的的矩阵是4000 x4000,我的代码将需要10+分钟才能完成
a_array = jnp.array([[0,3,4,0,0,12,19,22],
[7,0,0,10,0,0,0,15],
[12,0,0,15,16,19,0,31],
[17,18,0,0,21,23,78,89],
[22,2,78,0,0,1111,12,33],
[123,0,122,10,14,0,50,60],
[10,110,0,1231,0,110,0,61],
[0,17,0,141,0,166,16,0]])
array_len = len(a_array)
c_matrix = jnp.zeros((array_len,array_len))
test_dict = {}
for i in range(array_len):
test_dict[i] = 0
for i in range(array_len):
for j in jnp.flip(jnp.argsort(jnp.array(a_array[i]))[-4:]):
if test_dict[int(j)] < 4:
if a_array[i][int(j)] != 0:
test_dict[int(j)] +=1
c_matrix = c_matrix.at[i,int(j)].set(1)
if test_dict[int(j)] == 4:
a_array = a_array.at[:,int(j)].set(0)
并且C矩阵为:
[[0. 0. 1. 0. 0. 1. 1. 1.]
[1. 0. 0. 1. 0. 0. 0. 1.]
[0. 0. 0. 1. 1. 1. 0. 1.]
[0. 0. 0. 0. 1. 1. 1. 1.]
[1. 0. 1. 0. 0. 1. 1. 0.]
[1. 0. 1. 0. 1. 0. 1. 0.]
[1. 1. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 1. 0. 0. 0. 0.]]
我在这里做的是首先在每列中有一个1的数量的dict跟踪,如果已经有4个1,更新2d数组,该列为全零,因此,当jnp.argsort(...)试图找到最高的4个单元格时,它不会考虑0,我还检查了cell == 0,以摆脱这种边缘情况
jnp.argsort(jnp.array([0,0,0,0,0,0,12,0]))[-4:]
输出:
[4, 5, 7, 6]
提前感谢大家。
1条答案
按热度按时间omjgkv6w1#
在Python中使用JAX、NumPy或类似的数组库编写代码时,一个很好的经验法则是,如果在数组值上编写循环,结果会很慢。相反,您应该尝试用本机向量化操作来表达逻辑。
在这里,您不能对整个操作进行向量化,因为每列数约束意味着每行的输出依赖于所有前面行的输出。在这种情况下,
lax.scan
是一个很好的选择。以下是我如何解决你的问题,并记住这些事情: