使用浮点索引对numpy数组进行采样(类似于pytorch grid_sample)

mwngjboj  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(174)

有没有一些方法可以对一个带有浮点数索引的numpy数组进行采样,使用双线性插值来获得中间值?例如,给定1D数组:

arr=np.array([0,1])

字符串
我希望arr[0.5]返回0.5,因为该索引位于0和1之间。对于2D示例:

arr=np.array([[0,1],[2,3]])


arr[0.5, 0.5]应该返回1.5。在pytorch中,这个功能是由torch.nn.grid_sample提供的,我想在我的应用程序中比较一下在numpy中这样做的性能。

kpbpu008

kpbpu0081#

不知道是否可以用纯麻木来实现。就我个人而言,我使用Opencv remap函数作为pytorch grid_sample的替代方法。它有一个python绑定并支持numpy数组。
请参阅有关重新Map的OpenCV文档
编辑:Scipy interp看起来也不错。
Scipy interp2d

km0tfn4u

km0tfn4u2#

import numpy as np
import torch

def grid_sample(tensor, grid):
    """Given an input and a flow-field grid, computes the output using input
    values and pixel locations from grid.

    Args:
        tensor: (N, C, H_in, W_in) tensor
        grid: (N, H_out, W_out, 2) tensor in the range of [-1, 1]

    Returns:
        (N, C, H_out, W_out) tensor
    
    See `torch.nn.functional.grid_sample`.
    """
    b, c, h, w = tensor.shape
    b_, h_out, w_out, w_ = grid.shape
    assert(b == b_ and w_ == 2)
    out = []
    for (t, g) in zip(tensor, grid):
        x_ = 0.5 * (w - 1) * (g[..., 0].reshape(-1) + 1)
        y_ = 0.5 * (h - 1) * (g[..., 1].reshape(-1) + 1)
        ix = np.floor(x_).astype(np.int32).clip(0, w - 2)
        iy = np.floor(y_).astype(np.int32).clip(0, h - 2)
        dx = x_ - ix
        dy = y_ - iy
        out.append( (1 - dx) * (1 - dy) * t[..., iy, ix]
            + dx * (1 - dy) * t[..., iy, ix + 1]
            + (1 - dx) * dy * t[..., iy + 1, ix]
            + dx * dy * t[..., iy + 1, ix + 1])
    return np.concatenate(out, axis=0).reshape(b, c, h_out, w_out)

if __name__ == "__main__":
    tensor = torch.randn((3, 32, 64, 64))
    grid = torch.distributions.Uniform(-1, 1).sample((3, 7, 8, 2))
    out = torch.nn.functional.grid_sample(tensor, grid, mode='bilinear',
                                          align_corners=True)

    out_np = grid_sample(tensor.numpy(), grid.numpy())

    diff = np.abs(out_np - out.numpy())
    
    print(np.max(diff))
    print(np.linalg.norm(diff))

字符串
最大差异约为1 e-7。

相关问题