一旦函数值低于阈值,立即终止scipy minimize

ttygqcqt  于 2022-11-09  发布在  其他
关注(0)|答案(3)|浏览(179)

我有一个函数,其中的一些参数将返回一个概率。我如何设置scipyminimize,使其在找到一些参数时立即终止,这些参数将返回一个低于某个阈值的概率(即使它是一个“大”概率,如0.1左右)?
多谢了!

ttvkxqim

ttvkxqim1#

第一个回答是:这取决于您使用的底层求解器。大多数时候,SciPy只是围绕其他语言中的高效实现进行 Package (例如Fortran中的SLSQP)。
trust-constr则不是这样,它是用Python实现的,它允许回调返回True来停止优化过程。更多细节请参见the callback argument of scipy.optimize.minimize的文档。
对于其他求解器,实现所需结果的最直接方法是实现自己的异常,类似于Andrew Nelson。您将无法获得求解器的内部状态,但Python脚本可以继续运行,并且在每个候选点只对函数求值一次。
下面是一个使用Nelder-Mead单纯形下坡算法的可重现示例:

from scipy.optimize import minimize
from numpy import inf

class Trigger(Exception):
    pass

class ObjectiveFunctionWrapper:

    def __init__(self, fun, fun_tol=None):
        self.fun = fun
        self.best_x = None
        self.best_f = inf
        self.fun_tol = fun_tol or -inf
        self.number_of_f_evals = 0

    def __call__(self, x):
        _f = self.fun(x)

        self.number_of_f_evals += 1

        if _f < self.best_f:
            self.best_x, self.best_f = x, _f

        return _f

    def stop(self, *args):
        if self.best_f < self.fun_tol:
            raise Trigger

if __name__ == "__main__":

    def f(x):
        return sum([xi**2 for xi in x])

    fun_tol = 1e-4
    f_wrapped = ObjectiveFunctionWrapper(f, fun_tol)

    try:
        minimize(
            f_wrapped,
            [10] * 5,  # problem dimension is 5, x0 is [1, ..., 1],
            method="Nelder-Mead",
            callback=f_wrapped.stop
        )
    except Trigger:
        print(f"Found f value below tolerance of {fun_tol}\
            in {f_wrapped.number_of_f_evals} f-evals:\
            \nx = {f_wrapped.best_x}\
            \nf(x) = {f_wrapped.best_f}")
    except Exception as e:  # catch other errors
        raise e

输出量:

Found f value below tolerance of 0.0001            in 239 f-evals:            
x = [ 0.00335493  0.00823628 -0.00356564 -0.00126547  0.00158183]            
f(x) = 9.590933918640515e-05
p8h8hvxi

p8h8hvxi2#

您可以使用回呼参数来最小化。这是一个函数,在每次最小化迭代时都会被呼叫。您可以使用这个参数来检查函数的值,如果它低于临界值,则终止最小化。

hivapdat

hivapdat3#

这有点混乱,但我会使用类似下面的内容来 Package 目标函数:

import numpy as np
class fun_tracker:
    def __init__(self, fun, fatol=None):
        self.fatol = fatol or -np.inf
        self.fun = fun
        self.bestx = None
        self.bestval = np.inf
        self.val = None

    def __call__(self, x, *args):
        self.val = self.fun(np.asarray(x), *args)
        if self.val < self.bestval:
            self.bestx, self.bestval = x, self.val

        if self.val < self.fatol:
            raise StopIteration
        else:
            return self.val

def quad(x):
    return np.sum(x**2)

相关问题