keras 当输出通道> 1时,TF depth_to_space与Torch的PixelShuffle不同?

gab6jxml  于 2023-11-19  发布在  其他
关注(0)|答案(2)|浏览(190)

我在尝试将一个torch训练的模型移植到tensorflow(TF)时注意到了一些有趣的事情。当PixelShuffle操作的输出通道大于1时,TF中对应的depth_to_space函数是不等价的(注意:我将TF的输入转换为NHWC,输出转换回NCHW)。我想知道这是预期的行为还是存在误解?
具体地说,

# Torch
torch_out = nn.PixelShuffle(2)(input)

字符串

# TF/Keras
input = np.transpose(input, (0, 2, 3, 1)) // Convert to NHWC
keras_input = keras.layers.Input(shape=input.shape[1:])
keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(input)
...
keras_out = np.transpose(keras_d2s, (0, 3, 1, 2)) // Convert back to NCHW


keras_out != torch_out


下面是一个测试平台:

import numpy as np

import torch
import tensorflow as tf

from torch import nn
from tensorflow import keras

class Shuffle(nn.Module):
    def __init__(self, s, k, ic):
        super(Shuffle, self).__init__()
        self.shuffle = nn.PixelShuffle(s)

    def forward(self, inputs):
        return self.shuffle(inputs)

def main():
    sz = 4

    h = 3
    w = 3
    k = 3
    ic = 8
    s = 2

    input = np.arange(0, ic * h * w, dtype=np.float32).reshape(ic, h, w)
    input = input[np.newaxis]
    torch_input = torch.from_numpy(input)

    shuffle_model = Shuffle(s, k, ic)
    shuffle_out = shuffle_model(torch_input).detach().numpy()
    print('Shuffle out:', shuffle_out.shape)
    print(shuffle_out)

    input = np.transpose(input, (0, 2, 3, 1))

    keras_input = keras.layers.Input(shape=input.shape[1:])
    keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, s))(keras_input)
    keras_model = keras.Model(keras_input, keras_d2s)
    keras_out = keras_model.predict(input)
    
    keras_out = np.transpose(keras_out, (0, 3, 1, 2))
    
    print('Keras out:', keras_out.shape)
    print(keras_out)
    equal = np.allclose(shuffle_out, keras_out)
    print('Equal?', equal)

if __name__ == '__main__':
    main()

ecfsfe2w

ecfsfe2w1#

它们确实是不同的。如果你想让它们匹配,你需要对其中一个输入的通道进行 Shuffle 。或者如果pixelshuffle/depth_to_space层跟随卷积层,你可以对卷积的权重通道进行 Shuffle 。具体来说,如果oc是输出通道的数量,sblock_size,那么你需要使用[i + oc * j for i in range(oc) for j in range(s ** 2)](产生类似于[0,2,4,1,3,5]的值)。

x6yk4ghg

x6yk4ghg2#

PixelShuffle/PixelUnshuffle在数值上不等于depth_to_space/space_to_depth,除非上采样图像确实有一个通道。否则,需要重新排列输出通道以获得匹配的输出,如接受的答案中所述。
这不是一个bug,而只是惯例之间的差异。这可以说与PyTorch传统上是通道优先(即,NCHW是最常见的Tensor布局),而TensorFlow是通道最后(NHWC)的事实有关。
我面临着在Torch中使用depth_to_space/space_to_depth的需求,所以我最终将实现作为PyPi包发布:
https://github.com/lnstadrum/s2d2s

相关问题