我希望基于整数数组和其他以真实的作为函数输入的数组来高效地执行条件函数求值。我希望找到一种基于JAX的解决方案,它能提供比下面描述的for循环方法更显著的性能改进:
import jax
from jax import vmap;
import jax.numpy as jnp
import jax.random as random
def g_0(x, y, z, u):
return x + y + z + u
def g_1(x, y, z, u):
return x * y * z * u
def g_2(x, y, z, u):
return x - y + z - u
def g_3(x, y, z, u):
return x / y / z / u
g_i = [g_0, g_1, g_2, g_3]
g_i_jit = [jax.jit(func) for func in g_i]
def g_git(i, x, y, z, u):
return g_i_jit[i](x=x, y=y, z=z, u=u)
def g(i, x, y, z, u):
return g_i[i](x=x, y=y, z=z, u=u)
len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))
len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)
len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to g_range-1
total = jnp.zeros((len_u, len_xyz))
for i in range(len_i):
total= total + g_git(i_ar[i], x_ar, y_ar, z_ar, u_ar)
“i_ar”的作用是充当从列表g_i中选择四个函数之一的索引。“i_ar”是整数数组,每个整数表示g_i列表中的一个索引。另一方面,x_ar、y_ar、z_ar和u_ar是真实的数组,它们是i_ar所选函数的输入。
我怀疑i_ar与x_ar、y_ar、z_ar和u_ar之间本质上的差异是很难找到一种JAX方法来更有效地替换上面的for循环的原因。有什么想法吗?如何使用JAX(或其他东西)来替换foor循环,以更有效地获取'total'?
我曾天真地尝试过使用vmap:
g_git_vmap = jax.vmap(g_git)
total = jnp.zeros((len_u, len_xyz))
total = jnp.sum(g_git_vmap(i_ar, x_ar, y_ar, z_ar, u_ar), axis=0)
但是这导致错误消息并且没有结果。
1条答案
按热度按时间lsmd5eda1#
最好的方法可能是使用
lax.switch
,它允许基于索引数组在多个函数之间动态切换。以下是您的原始函数与基于
lax.switch
的方法的比较,在Colab GPU运行时上的计时: