scipy 向量化极小化与求根

rhfm7lfc  于 2023-11-19  发布在  其他
关注(0)|答案(1)|浏览(99)

我有一个由args参数化的函数族

f(x, args)

字符串
并希望确定argsN = 1000值的fx上的最小值。我可以访问函数及其导数。我的第一次尝试是循环args的不同值,并在每次迭代时使用scipy.optimizer,但是它太长了。我相信操作可以通过矢量化来加速。我的下一个尝试是在jax.scipy.optimize.minimizejaxopt.ScipyMinimize中使用jax.vmap,但我似乎不能为args传递多个值。
或者,我可以编写自己的向量化优化方法,例如二分法,其中向量化的意思是在数组上进行固定次数的迭代,并且如果其中一个优化问题已经提前达到一定的容错水平,则不会提前停止。
我希望使用一些已经优化,现成的算法,如果一个实现是在this线程是相关的,但args没有改变。

y3bcpkx1

y3bcpkx11#

您可以定义一个函数来找到给定特定args的最小值,然后将其 Package 在jax.vmap中以自动对其进行向量化。例如:

import jax
import jax.numpy as jnp
from jax.scipy import optimize

def f(x, args):
  a, b = args
  return jnp.sum(a + (x - b) ** 2)

def find_min(a, b):
  x0 = jnp.array([1.0])
  args = (a, b)
  return optimize.minimize(f, x0, (args,), method="BFGS")

a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))

results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())

print(results.success)
# [ True  True  True  True  True  True  True  True  True  True  True  True
#   True  True  True  True  True  True  True  True  True  True  True  True
#   True]

print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
#   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]

字符串

相关问题