numpy 在大型2d数组上使用JAX查找最大的n值

8hhllhi2  于 2023-05-29  发布在  其他
关注(0)|答案(1)|浏览(206)

我试图使用JAX来加速我的行和列比较/选择,我有一个2d数组NxN,每个单元格都是一个数字,我试图从一行中获得最高的4个数字,然后将1放入具有相同索引的不同矩阵C中。这是两个限制:每行和列最多可以有4个1。举一个更简单的例子:

array a = [[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20],
[21,22,23,24,25]]
C:
[[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[1,0,0,0,0]]

我们不能选择最后一行的最后4个单元格,即使这4个单元格的值最高,所以我们只能选择第一个单元格并将其放入C矩阵中
这是我尝试了一个不同的二维数组8x8和正确的C矩阵,仍然是一个例子,真实的的矩阵是4000 x4000,我的代码将需要10+分钟才能完成

a_array = jnp.array([[0,3,4,0,0,12,19,22],
            [7,0,0,10,0,0,0,15],
            [12,0,0,15,16,19,0,31],
            [17,18,0,0,21,23,78,89],
            [22,2,78,0,0,1111,12,33],
            [123,0,122,10,14,0,50,60],
            [10,110,0,1231,0,110,0,61],
            [0,17,0,141,0,166,16,0]])

array_len = len(a_array)
c_matrix = jnp.zeros((array_len,array_len))
test_dict = {}
for i in range(array_len):
  test_dict[i] = 0
for i in range(array_len):
  for j in jnp.flip(jnp.argsort(jnp.array(a_array[i]))[-4:]):
    if test_dict[int(j)] < 4:
      if a_array[i][int(j)] != 0:
        test_dict[int(j)] +=1
        c_matrix = c_matrix.at[i,int(j)].set(1)
    if test_dict[int(j)] == 4:
      a_array = a_array.at[:,int(j)].set(0)

并且C矩阵为:

[[0. 0. 1. 0. 0. 1. 1. 1.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1. 1. 0. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1.]
 [1. 0. 1. 0. 0. 1. 1. 0.]
 [1. 0. 1. 0. 1. 0. 1. 0.]
 [1. 1. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0.]]

我在这里做的是首先在每列中有一个1的数量的dict跟踪,如果已经有4个1,更新2d数组,该列为全零,因此,当jnp.argsort(...)试图找到最高的4个单元格时,它不会考虑0,我还检查了cell == 0,以摆脱这种边缘情况

jnp.argsort(jnp.array([0,0,0,0,0,0,12,0]))[-4:]

输出:

[4, 5, 7, 6]

提前感谢大家。

omjgkv6w

omjgkv6w1#

在Python中使用JAX、NumPy或类似的数组库编写代码时,一个很好的经验法则是,如果在数组值上编写循环,结果会很慢。相反,您应该尝试用本机向量化操作来表达逻辑。
在这里,您不能对整个操作进行向量化,因为每列数约束意味着每行的输出依赖于所有前面行的输出。在这种情况下,lax.scan是一个很好的选择。
以下是我如何解决你的问题,并记住这些事情:

import jax

def scan_fun(count, row):
  row = jnp.where(count >= 4, 0, row)
  _, indices = jax.lax.top_k(row, 4)
  c_row = jnp.zeros_like(row).at[indices].set(1)
  c_row = jnp.where(row == 0, 0, c_row)
  count += (c_row > 0)
  return count, c_row

_, c_matrix = jax.lax.scan(scan_fun, jnp.zeros_like(a_array[0]), a_array)
print(c_matrix)
# [[0 0 1 0 0 1 1 1]
#  [1 0 0 1 0 0 0 1]
#  [0 0 0 1 1 1 0 1]
#  [0 0 0 0 1 1 1 1]
#  [1 0 1 0 0 1 1 0]
#  [1 0 1 0 1 0 1 0]
#  [1 1 0 1 0 0 0 0]
#  [0 1 0 1 0 0 0 0]]

相关问题