我使用Jax来做矩阵的grad。例如,我有一个函数f(A),其中A是一个像A = \[\[a,b\], \[c,d\]\]的矩阵。我想只做a,c和d的f(A)的梯度(更具体地说,对于下三角形部分)。我怎么做呢?也适用于一般的NxN矩阵,而不仅仅是2x2。我试着将常规grad转换为下三角形,但我不确定这是否相同,如果输出是正确的。
A = \[\[a,b\], \[c,d\]\]
332nm8kg1#
JAX没有提供任何方法来获取单个矩阵元素的梯度。有两种方法可以继续;第一,你可以获取整个数组的梯度并提取你感兴趣的元素;例如:
import jax import jax.numpy as jnp def f(A): return (A ** 2).sum() A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) df_dA = jax.grad(f)(A) print(df_dA[0, 0], df_dA[0, 1], df_dA[1, 2])
个字符或者,您可以将数组的条目拆分为单独的函数参数,然后使用argnums仅针对您感兴趣的那些进行梯度:
argnums
def f(a, b, c, d): A = jnp.array([[a, b], [c, d]]) return (A ** 2).sum() df_da, df_db, df_dc = jax.grad(f, argnums=(0, 1, 2))(1.0, 2.0, 3.0, 4.0) print(df_da, df_db, df_dc)
2.0 4.0 8.0
的字符串一般来说,你可能会发现第一种方法在实践中更容易使用,而且效率更高。它确实有一些浪费的计算,但坚持矢量化计算通常会带来净收益,特别是如果你在GPU或TPU等加速器上运行。
1条答案
按热度按时间332nm8kg1#
JAX没有提供任何方法来获取单个矩阵元素的梯度。有两种方法可以继续;第一,你可以获取整个数组的梯度并提取你感兴趣的元素;例如:
个字符
或者,您可以将数组的条目拆分为单独的函数参数,然后使用
argnums
仅针对您感兴趣的那些进行梯度:的字符串
一般来说,你可能会发现第一种方法在实践中更容易使用,而且效率更高。它确实有一些浪费的计算,但坚持矢量化计算通常会带来净收益,特别是如果你在GPU或TPU等加速器上运行。