pytorch 将图像分割成小块

ymdaylpp  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(169)

我想实现Vision Transformer模型,在论文中,他们表示他们将输入图像分割成一定分辨率的小块,例如,如果图像64x64和补丁分辨率为16x16,它将被分割成16个小补丁,每个小补丁的分辨率为16x16,因此最终形状为(N,P,P,C),其中N是贴片的数量,P是分辨率,C是通道的数量。
我尝试了什么,所以分裂是矢量化的:

def image_to_patches_fast(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
        
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = image_tensor.view(N,P,P,C)

函数工作,但输出不是预期的,因为当我试图可视化补丁时,有一些错误,补丁可能没有很好地定位,或者我不知道,这里有一个例子:
输入图像:

输出面片的可视化:

可视化补丁的功能:

def show_patches(patches):
    
    N,P = patches.shape[0], patches.shape[1]
       
    nrows, ncols = int(N**0.5),int(N**0.5)
    fig, axes = plt.subplots(nrows = nrows, ncols=ncols)
    for row in range(nrows):

        for col in range(ncols):

            idx = col + (row*nrows)
            
            axes[row][col].imshow(patches[idx,:,:,:])
            axes[row][col].axis("off")

    plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.1,
                    hspace=0.1)
    plt.show()

我尝试了另一个函数来分割图像,但它的速度较慢,因为它使用循环,它的工作原理与预期一样:

def image_to_patches_slow(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
    
    nrows, ncols = int(N**0.5), int(N**0.5)
    
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = torch.zeros((N,P,P,C),dtype = torch.int)
    
    
    for row in range(nrows):
        s_row = row * N
        e_row = (row * N) + N
        for col in range(ncols):

            idx = col + (row*nrows)

            s_col = col*N
            e_col = (col*N) + N
                
            image_patches[idx] = image_tensor[s_row:e_row, s_col:e_col]
    
    return image_patches

它的输出:

因此,任何帮助,因为这个缓慢的版本瓶颈的训练。

w9apscun

w9apscun1#

此方法使用单线整形操作进行修补。这是每个频道的。
如果图像尺寸不能被补丁宽度整除,它将通过剪掉两端来裁剪图像。如果你用torchvision中提供的更智能的东西来代替这种基本的裁剪,那会更好,比如中心裁剪,缩放,或者组合(缩放然后中心裁剪)。
下面的例子为200x200图像分解成50像素的补丁。

import torchvision, torch

img = torchvision.io.read_image('../image.png').permute(1, 2, 0)

H, W, C = img.shape

patch_width = 50
n_rows = H // patch_width
n_cols = W // patch_width

cropped_img = img[:n_rows * patch_width, :n_cols * patch_width, :]

#
# Into patches
# [n_rows, n_cols, patch_width, patch_width, C]
#
patches = torch.empty(n_rows, n_cols, patch_width, patch_width, C)
for chan in range(C):
    patches[..., chan] = (
        cropped_img[..., chan]
        .reshape(n_rows, patch_width, n_cols, patch_width)
        .permute(0, 2, 1, 3)
    )
    
#
#Plot
#
f, axs = plt.subplots(n_rows, n_cols, figsize=(5, 5))

for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        axs[row_idx, col_idx].imshow(patches[row_idx, col_idx, ...] / 255)

for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])
f.subplots_adjust(wspace=0.05, hspace=0.05)
gwbalxhn

gwbalxhn2#

我设法找到了一种方法来并行化分裂跨批图片的过程中使用@some3128回答,这里是增强的解决方案:

def image_to_patches(images, res_patch, H, W, C):
    
    N = images.shape[0]
    patch_width = res_patch
    n_rows = H // patch_width
    n_cols = W // patch_width

    cropped_img = images[:,:n_rows * patch_width, :n_cols * patch_width, :]

    #
    # Into patches
    # [n_rows, n_cols, patch_width, patch_width, C]
    #
    patches = torch.empty(N, n_rows, n_cols, patch_width, patch_width, C).to(int)
    for chan in range(C):
        patches[..., chan] = (
            cropped_img[..., chan]
            .reshape(N, n_rows, patch_width, n_cols, patch_width)
            .permute(0, 1, 3, 2, 4)
        )
       
    return patches.view(N, -1, patch_width, patch_width, C)

现在GPU上的训练要快得多。

相关问题