scipy 处理给定到ODE求解器的函数的RHS的多个返回参数

lvmkulzt  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(99)

Python中的ODE求解器将函数的RHS作为参数。这个函数应该签名f(t, y *args)。通常,y可以是一个numpy数组,f将返回一个与y大小相同的数组ydot。然而,语法要求f应该只返回一个东西,并且只返回一个东西:ydot数组(或者float is y也是float)。ODE求解器在内部根据需要多次调用f,以达到收敛。
但是在我的例子中,这个f返回一个元组。不仅仅是ydot,还有更多的数组。现在我可以将这部分移到f之外,并在求解器达到收敛后调用它。但是我改变了行为,因为这些参数并没有在每个时间步计算时更新(这应该发生)。所以我需要这部分代码。但是像scipy.integrate_solve_ivp这样的求解器的语法要求f应该只返回ydot。因此,我没有办法“捕获”额外的参数,并在求解器再次调用它时传递回f
下面是一段代码,给出了大致的想法(注意,这段代码只是我的实际代码的替身,我不能分享,但要点是一样的,我的函数的RHS返回一个元组):

import numpy as np
from scipy.integrate import solve_ivp

# Define the RHS function
def f(t, y, param1, param2):
    # Calculate the derivative dy/dt
    ydot = np.sin(t) * y[0] + np.cos(t) * y[1]  # Example derivative equation

    # Calculate additional parameters
    param1 = some_func_1(ydot, param1)
    param2 = some_func_1(ydot, param2)

    # Return ydot and additional parameters as a tuple
    return ydot, param1, param2

# Define the time span and initial condition
t_span = (0, 5)
y0 = np.random.rand(2)  # Random initial condition

# Solve the ODE using solve_ivp
sol = solve_ivp(f, t_span, y0, args=(param1, param2))

字符串
some_func_1some_func_2是接受param1和param2并对其进行修改的函数。
我试着查看solve_ivpevents参数,但它不起作用。我想是为了别的事。
但这似乎是一个问题,一定有人遇到过,对不对?我可以看到的一个解决方法是使用全局变量,我不返回额外的参数,而是将它们存储在全局列表或全局变量或其他东西中。但我认为全局变量是危险的,因为它们可能会引入bug。所以我在寻找一种更接近的方法。
编辑:
我刚意识到一件事:argssolve_ivp是否可以自己执行功能?因此表达式变为:

sol = solve_ivp(f, t_span, y0, args=(some_func_1(param1), some_func_2(param2))


但即使在这里,通过some_func_1捕获返回的问题仍然存在。

plicqrtu

plicqrtu1#

小心地传递args对象给函数并修改它们是可能的。对象需要是可变的,比如数组或列表。
例如,以expoential_decay为例,让我们添加一种收集所有t值的方法。我将par定义为一个列表,并使用append就地修改它。我可以用数组做些什么,但这是最容易想到的事情。

In [316]: def exponential_decay(t, y): return -0.5 * y
     ...: def f(t,y,par):
     ...:     val = exponential_decay(t,y)
     ...:     par.append(t)
     ...:     return val
     ...: par = []
     ...: sol = integrate.solve_ivp(f, [0, 10], [2, 4, 8], args=(par,))
In [317]: sol
Out[317]: 
  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  1.149e-01  1.264e+00  3.061e+00  4.816e+00
             6.574e+00  8.333e+00  1.000e+01]
        y: [[ 2.000e+00  1.888e+00 ...  3.107e-02  1.351e-02]
            [ 4.000e+00  3.777e+00 ...  6.214e-02  2.702e-02]
            [ 8.000e+00  7.553e+00 ...  1.243e-01  5.403e-02]]
      sol: None
 t_events: None
 y_events: None
     nfev: 44
     njev: 0
      nlu: 0
In [318]: len(par)
Out[318]: 44
In [319]: np.array(par)
Out[319]: 
array([ 0.        ,  0.02      ,  0.02297531,  0.03446296,  0.09190123,
        0.10211248,  0.11487653,  0.11487653,  0.3446296 ,  0.45950614,
        1.03388881,  1.13600129,  1.26364188,  1.26364188,  1.62303707,
        1.80273466,  2.70122262,  2.86095382,  3.06061781,  3.06061781,
        3.41171646,  3.58726578,  4.4650124 ,  4.62105624,  4.81611105,
        4.81611105,  5.16778045,  5.34361515,  6.22278865,  6.37908617,
        6.57445806,  6.57445806,  6.92622442,  7.1021076 ,  7.98152352,
        8.13786412,  8.33328988,  8.33328988,  8.66663191,  8.83330292,
        9.66665798,  9.81480999, 10.        , 10.        ])

字符串

0x6upsns

0x6upsns2#

您的ODE似乎不受param1param2值的影响。因此,我建议将它们从ODE中删除,然后使用从solve_ivp获得的解决方案简单地计算它们。

import numpy as np
from scipy.integrate import solve_ivp

def f(t, y):
    return np.sin(t) * y[0] + np.cos(t) * y[1]

def compute_params(t, y):
    raise NotImplementedError

t_span = (0, 5)
y0 = np.random.default_rng(42).random(2)
sol = solve_ivp(f, t_span, y0)

param1, param2 = compute_params(sol.t, sol.y)

字符串

相关问题