如何使用scipy.integrate.solve_ivp()在交互仿真中求解常微分方程?

dfty9e19  于 2023-01-13  发布在  其他
关注(0)|答案(1)|浏览(216)

我已经使用低级API scipy.integrate.RK45()实现了一个匀速圆周运动的简单模拟,如下所示。

import numpy as np
import scipy.integrate
import matplotlib.pyplot as plt

r = np.array([1, 0], 'float')
v = np.array([0, 1], 'float')
dt = 0.1

def motion_eq(t, y):
    r, v = y[0:2], y[2:4]
    return np.hstack([v, -r])
motion_solver = scipy.integrate.RK45(motion_eq, 0, np.hstack([r, v]),
    t_bound = np.inf, first_step = dt, max_step = dt)

particle, *_ = plt.plot(*r.T, 'o')
plt.gca().set_aspect(1)
plt.xlim([-2, 2])
plt.ylim([-2, 2])
def update():
    motion_solver.step()
    r = motion_solver.y[0:2]
    particle.set_data(*r.T)
    plt.draw()
timer = plt.gcf().canvas.new_timer(interval = 50)
timer.add_callback(update)
timer.start()

plt.show()

我一开始尝试了高级API scipy.integrate.solve_ivp(),但它似乎没有提供一个接口来创建包含系统状态的示例,并迭代地获取系统状态(我称之为交互式模拟,因为您可以暂停、更改系统状态和恢复,尽管示例代码中没有实现)。
这在solve_ivp()中可行吗?如果不可行,我在RK45中做得对吗?特别是在指定t_boundfirst_stepmax_step选项时?我可以在Internet上找到大量关于求解给定时间间隔的资源,但我找不到这样求解的资源。

oymdgrw7

oymdgrw71#

虽然我不是数值分析方面的Maven,但我在回答自己的问题,因为我是在分析solve_ivp()RK45实现并试验API之后得出结论的。
首先,高级API solve_ivp()没有提供接口来创建包含系统状态的示例并迭代地获取系统状态,因此,我应该这样做,即使这样做效率很低。

dt = 0.1
t = 0
r = np.array([1, 0], 'float')
v = np.array([0, 1], 'float')

def motion_eq(t, y):
    r, v = y[0:2], y[2:4]
    return np.hstack([v, -r])

...
def update():
    global t, r, v
    sol = scipy.integrate.solve_ivp(motion_eq, (t, t + dt), np.hstack([r, v]),
        t_eval = (t + dt,), method = 'RK45')
    t = sol.t[0]
    r, v = sol.y[0:2, 0], sol.y[2:4, 0]
    ...
...

其次,低级API RK45提供了创建包含系统状态的示例的接口(RK45()初始化器)并迭代地获得系统的状态(step()方法),但是它不提供接口来控制每次迭代的时间步长,但是如果期望的时间步长是恒定的并且足够小,你可以通过RK45()初始化器的first_stepmax_step参数在一定程度上控制时间步长,就像问题中的例子一样。
这可以从这个和这个看出,在_step_impl()结束时,self.h_abs比调用前的值大,但在下一次调用_step_impl()开始时,它被限制为self.max_step,但正如Lutz Lehmann所说,如果估计误差(error_norm)较大,则下一次调用时self.h_abs将小于self.max_step,所以在这种情况下,需要像下面这样循环插值(插值部分由Lutz Lehmann完成)。

...
dt = 0.1
t = 0
...
motion_solver = scipy.integrate.RK45(motion_eq, 0, np.hstack([r, v]),
    t_bound = np.inf, first_step = dt, max_step = dt)
...
def update():
    global t
    t += dt
    while motion_solver.t < t:
        motion_solver.step()
    sol = motion_solver.dense_output()
    y = sol(t) # interpolation
    r, v = y[0:2], y[2:4]
    ...
...

相关问题