numpy JAXgrad:对矩阵中的特定变量进行推导

wkftcu5l  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(190)

我使用Jax来做矩阵的grad。例如,我有一个函数f(A),其中A是一个像A = \[\[a,b\], \[c,d\]\]的矩阵。我想只做a,c和d的f(A)的梯度(更具体地说,对于下三角形部分)。我怎么做呢?也适用于一般的NxN矩阵,而不仅仅是2x2。
我试着将常规grad转换为下三角形,但我不确定这是否相同,如果输出是正确的。

332nm8kg

332nm8kg1#

JAX没有提供任何方法来获取单个矩阵元素的梯度。有两种方法可以继续;第一,你可以获取整个数组的梯度并提取你感兴趣的元素;例如:

import jax
import jax.numpy as jnp

def f(A):
  return (A ** 2).sum()

A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
df_dA = jax.grad(f)(A)
print(df_dA[0, 0], df_dA[0, 1], df_dA[1, 2])

个字符
或者,您可以将数组的条目拆分为单独的函数参数,然后使用argnums仅针对您感兴趣的那些进行梯度:

def f(a, b, c, d):
  A = jnp.array([[a, b], [c, d]])
  return (A ** 2).sum()

df_da, df_db, df_dc = jax.grad(f, argnums=(0, 1, 2))(1.0, 2.0, 3.0, 4.0)
print(df_da, df_db, df_dc)
2.0 4.0 8.0

的字符串
一般来说,你可能会发现第一种方法在实践中更容易使用,而且效率更高。它确实有一些浪费的计算,但坚持矢量化计算通常会带来净收益,特别是如果你在GPU或TPU等加速器上运行。

相关问题