我想训练一个简单的线性模型。下面的x和y是我的数据。
import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)
f是计算所有数据均方误差的函数。
def f(params, x, y):
return np.mean(np.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()
# initialize parameters
params['w'] = 2.4
params['b'] = 10.
df(params, x, y) # I will do this in a loop (implementing gradient decent part
这给了我一个错误:
FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray
当我离开的时候 np.power
代码起作用。为什么?
1条答案
按热度按时间vbkedwbf1#
jax无法计算
numpy
函数,但它可以计算jax.numpy
功能。如果你重写你的代码jax.numpy
,它应该适合您:你可以在
TracerArrayConversionError
文档页。