numpy JAX在基于整数数组的条件函数求值中的有效使用

6vl6ewon  于 2023-02-12  发布在  其他
关注(0)|答案(1)|浏览(134)

我希望基于整数数组和其他以真实的作为函数输入的数组来高效地执行条件函数求值。我希望找到一种基于JAX的解决方案,它能提供比下面描述的for循环方法更显著的性能改进:

import jax
from jax import vmap;
import jax.numpy as jnp
import jax.random as random

def g_0(x, y, z, u):
    return x + y + z + u

def g_1(x, y, z, u):
    return x * y * z * u

def g_2(x, y, z, u):
    return x - y + z - u

def g_3(x, y, z, u):
    return x / y / z / u

g_i = [g_0, g_1, g_2, g_3]
g_i_jit = [jax.jit(func) for func in g_i]

def g_git(i, x, y, z, u):
    return g_i_jit[i](x=x, y=y, z=z, u=u)

def g(i, x, y, z, u):
    return g_i[i](x=x, y=y, z=z, u=u)

len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)

len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to g_range-1

total = jnp.zeros((len_u, len_xyz))

for i in range(len_i):
    total= total + g_git(i_ar[i], x_ar, y_ar, z_ar, u_ar)

“i_ar”的作用是充当从列表g_i中选择四个函数之一的索引。“i_ar”是整数数组,每个整数表示g_i列表中的一个索引。另一方面,x_ar、y_ar、z_ar和u_ar是真实的数组,它们是i_ar所选函数的输入。
我怀疑i_ar与x_ar、y_ar、z_ar和u_ar之间本质上的差异是很难找到一种JAX方法来更有效地替换上面的for循环的原因。有什么想法吗?如何使用JAX(或其他东西)来替换foor循环,以更有效地获取'total'?
我曾天真地尝试过使用vmap:

g_git_vmap = jax.vmap(g_git)
total = jnp.zeros((len_u, len_xyz))
total = jnp.sum(g_git_vmap(i_ar, x_ar, y_ar, z_ar, u_ar), axis=0)

但是这导致错误消息并且没有结果。

lsmd5eda

lsmd5eda1#

最好的方法可能是使用lax.switch,它允许基于索引数组在多个函数之间动态切换。
以下是您的原始函数与基于lax.switch的方法的比较,在Colab GPU运行时上的计时:

def f_original(i, x, y, z, u):
  total = jnp.zeros((len(u), len(x)))
  for i in range(len(i)):
    total= total + g_git(i_ar[i], x, y, z, u)
  return total

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g)(i).sum(0)

out1 = f_original(i_ar, x_ar, y_ar, z_ar, u_ar)
out2 = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
np.testing.assert_allclose(out1, out2, rtol=5E-3)

%timeit f_original(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
# 71 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit f_switch(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
# 4.69 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

相关问题