为什么当我的函数中有np.power时jax.grad不能给我派生呢?

6mw9ycah  于 2021-07-13  发布在  Java
关注(0)|答案(1)|浏览(444)

我想训练一个简单的线性模型。下面的x和y是我的数据。

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

f是计算所有数据均方误差的函数。

def f(params, x, y):
  return np.mean(np.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()

# initialize parameters

params['w'] = 2.4
params['b'] = 10.
df(params, x, y) # I will do this in a loop (implementing gradient decent part

这给了我一个错误:

FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray

当我离开的时候 np.power 代码起作用。为什么?

vbkedwbf

vbkedwbf1#

jax无法计算 numpy 函数,但它可以计算 jax.numpy 功能。如果你重写你的代码 jax.numpy ,它应该适合您:

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

import jax.numpy as jnp
def f(params, x, y):
  return jnp.mean(jnp.power((params['w'] * x + params['b'])-y , 2))

from jax import grad
df = grad(f)
params = dict()

params['w'] = 2.4
params['b'] = 10.
df(params, x, y)

# {'b': DeviceArray(14.661432, dtype=float32),

# 'w': DeviceArray(7.3792152, dtype=float32)}

你可以在 TracerArrayConversionError 文档页。

相关问题