我有一个由args
参数化的函数族
f(x, args)
字符串
并希望确定args
的N = 1000
值的f
在x
上的最小值。我可以访问函数及其导数。我的第一次尝试是循环args
的不同值,并在每次迭代时使用scipy.optimizer,但是它太长了。我相信操作可以通过矢量化来加速。我的下一个尝试是在jax.scipy.optimize.minimize
或jaxopt.ScipyMinimize
中使用jax.vmap
,但我似乎不能为args
传递多个值。
或者,我可以编写自己的向量化优化方法,例如二分法,其中向量化的意思是在数组上进行固定次数的迭代,并且如果其中一个优化问题已经提前达到一定的容错水平,则不会提前停止。
我希望使用一些已经优化,现成的算法,如果一个实现是在this线程是相关的,但args
没有改变。
1条答案
按热度按时间y3bcpkx11#
您可以定义一个函数来找到给定特定
args
的最小值,然后将其 Package 在jax.vmap
中以自动对其进行向量化。例如:字符串