在Pytorch中,是否有类似于Scipy's的三次样条插值?给定一维输入Tensorx和y,我想通过这些点进行插值,并在xs处对它们求值,以获得ys。此外,我还想使用一个积分函数来找到Ys,即从x[0]到xs的样条插值的积分。
x
y
xs
ys
Ys
x[0]
cigdeys31#
这是我在Pytorch中高效地使用Cubic Hermite Splines并使用autograd支持创建的gist。为了方便起见,我也将代码放在这里。
import torch as T def h_poly_helper(tt): A = T.tensor([ [1, 0, -3, 2], [0, 1, -2, 1], [0, 0, 3, -2], [0, 0, -1, 1] ], dtype=tt[-1].dtype) return [ sum( A[i, j]*tt[j] for j in range(4) ) for i in range(4) ] def h_poly(t): tt = [ None for _ in range(4) ] tt[0] = 1 for i in range(1, 4): tt[i] = tt[i-1]*t return h_poly_helper(tt) def H_poly(t): tt = [ None for _ in range(4) ] tt[0] = t for i in range(1, 4): tt[i] = tt[i-1]*t*i/(i+1) return h_poly_helper(tt) def interp_func(x, y): "Returns integral of interpolating function" if len(y)>1: m = (y[1:] - y[:-1])/(x[1:] - x[:-1]) m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) def f(xs): if len(y)==1: # in the case of 1 point, treat as constant function return y[0] + T.zeros_like(xs) I = T.searchsorted(x[1:], xs) dx = (x[I+1]-x[I]) hh = h_poly((xs-x[I])/dx) return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx return f def interp(x, y, xs): return interp_func(x,y)(xs) def integ_func(x, y): "Returns interpolating function" if len(y)>1: m = (y[1:] - y[:-1])/(x[1:] - x[:-1]) m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) Y = T.zeros_like(y) Y[1:] = (x[1:]-x[:-1])*( (y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*(x[1:]-x[:-1])/12 ) Y = Y.cumsum(0) def f(xs): if len(y)==1: return y[0]*(xs - x[0]) I = T.searchsorted(x[1:], xs) dx = (x[I+1]-x[I]) hh = H_poly((xs-x[I])/dx) return Y[I] + dx*( hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx ) return f def integ(x, y, xs): return integ_func(x,y)(xs) # Example if __name__ == "__main__": import matplotlib.pylab as P # for plotting x = T.linspace(0, 6, 7) y = x.sin() xs = T.linspace(0, 6, 101) ys = interp(x, y, xs) Ys = integ(x, y, xs) P.scatter(x, y, label='Samples', color='purple') P.plot(xs, ys, label='Interpolated curve') P.plot(xs, xs.sin(), '--', label='True Curve') P.plot(xs, Ys, label='Spline Integral') P.plot(xs, 1-xs.cos(), '--', label='True Integral') P.legend() P.show()
3ks5zfa02#
这是对@chausies回答的评论,但太长了,无法发布。只是想发布一个稍微缩小了的版本,主要是为了我自己将来的参考:
import torch def h_poly(t): tt = t[None, :]**torch.arange(4, device=t.device)[:, None] A = torch.tensor([ [1, 0, -3, 2], [0, 1, -2, 1], [0, 0, 3, -2], [0, 0, -1, 1] ], dtype=t.dtype, device=t.device) return A @ tt def interp(x, y, xs): m = (y[1:] - y[:-1]) / (x[1:] - x[:-1]) m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]]) idxs = torch.searchsorted(x[1:], xs) dx = (x[idxs + 1] - x[idxs]) hh = h_poly((xs - x[idxs]) / dx) return hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx
2条答案
按热度按时间cigdeys31#
这是我在Pytorch中高效地使用Cubic Hermite Splines并使用autograd支持创建的gist。
为了方便起见,我也将代码放在这里。
3ks5zfa02#
这是对@chausies回答的评论,但太长了,无法发布。
只是想发布一个稍微缩小了的版本,主要是为了我自己将来的参考: