如何通过args()在scipy中集成.quad?

bd1hkmkf  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(204)

I was doing a numerical integration where instead of using

scipy.integrate.dblquad

I tried to build the double integral directly using

scipy.integrate.quad

From camz's answer: https://stackoverflow.com/a/30786163/11141816

a=1
b=2

def integrand(theta, t): 
    return sin(t)*sin(theta);

def theta_integral(t):
    # Note scipy will pass args as the second argument
    return integrate.quad(integrand, b-a, t , args=(t))[0]

integrate.quad(theta_integral, b-a,b+a )

However, the part where it was confusing was how the args was passed through the function. For example, in the post args=(t) was pass through to the second argument t of the function integrand automatically, not the first argument theta , where in the scipy document https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.quad.html the only reference was

args: tuple, optional
    Extra arguments to pass to func.

and in fact their arguments had a comma , to be args=(1,) , instead of args=(1) directly.
What if I decide to pass several variables, i.e.

def integrand(theta, t,a,b): 
    return sin(t)*sin(theta)*a*b;

how would I know which args(t,a,b) or args(t,a,b,) were the correct args() for the function?
How to pass args() to integrate.quad in scipy?

z2acfund

z2acfund1#

In [214]: from scipy.integrate import quad

一个简单的例子。我没有使用args,只是展示它们:

In [215]: def foo(x,a,b,c):
     ...:     print(x,a,b,c)
     ...:     return x**2
     ...:     

In [216]: quad(foo, 0, 1, (1,(1,2,3),np.arange(2)))
0.5 1 (1, 2, 3) [0 1]
0.013046735741414128 1 (1, 2, 3) [0 1]
0.9869532642585859 1 (1, 2, 3) [0 1]
0.06746831665550773 1 (1, 2, 3) [0 1]
0.9325316833444923 1 (1, 2, 3) [0 1]
0.16029521585048778 1 (1, 2, 3) [0 1]
0.8397047841495122 1 (1, 2, 3) [0 1]
 ...
0.35280356864926987 1 (1, 2, 3) [0 1]
0.6471964313507301 1 (1, 2, 3) [0 1]
Out[216]: (0.33333333333333337, 3.700743415417189e-15)

这就是args的全部内容。
看一下quad的代码,第一行是

if not isinstance(args, tuple):
    args = (args,)

所以如果你使用args=(t),它将被修改为args=(t,)。记住,在python中(2)2是一样的;这个()函数只对复杂情况下的操作进行分组,否则它们将被忽略。但是(1,)是一个元组,如(1,2)。实际上是逗号定义了元组。
用你内在的功能,加上一个打印:

In [232]: def integrand(theta, t): 
     ...:     print(theta, t)
     ...:     return np.sin(t)*np.sin(theta)
     ...:     

In [233]: integrand(1,2)     # pass 2 numbers 
1 2
Out[233]: 0.7651474012342926

In [234]: integrand(1,np.arange(3))   # is still works if the 2nd is an array:

1 [0 1 2]
Out[234]: array([0.        , 0.70807342, 0.7651474 ])

t是一个数组,np.sin(t)也是一个数组,返回值也是一个数组。但是你没有传递数组。通过quad调用它,theta会变化,t不会:

In [235]: quad(integrand, 0, 1, args=(1,))
0.5 1
0.013046735741414128 1
0.9869532642585859 1
...
0.6471964313507301 1
Out[235]: (0.38682227139505565, 4.2945899214059184e-15)

In [236]: quad(integrand, 0, 1, args=(1,2))
---------------------------------------------------------------------------
...
TypeError: integrand() takes 2 positional arguments but 3 were given

实际的错误是在编译的代码中引发的,所以并不完全透明,但效果与我试图传递3个参数相同:

In [237]: integrand(1,1,2)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [237], in <cell line: 1>()
----> 1 integrand(1,1,2)

TypeError: integrand() takes 2 positional arguments but 3 were given

相关问题