pytorch “输出通道”大小在nn.conv2d滤波器中的作用有多精确?

z9smfwbn  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(139)

我正在研究pytorch卷积滤波器Conv2d函数,并找到了Conv2d的内核(三、三、三)((意思是输入通道= 3,输出通道=3,内核大小= 3))是[3,3,3,3]的Tensor,如下所示。我不明白这个巨大的Tensor有什么用。我可以很容易地理解conv2d的内核(3,1,3)((tensor [1,3,3,3])),它有3个3 × 3矩阵,每个矩阵可以应用于输入图像的RGB矩阵,然后合并为一个矩阵进行灰度化,所以输出通道是一个简单的通道。
但是对于[3,3,3,3]有什么好处呢????3*3的3个矩阵应用到R G B矩阵并输出它们而不进行合并还不够吗?为什么我们需要额外的2个3,3,3Tensor??
首先,让我们假设conv2d(3,1,3)为3个输入通道(I1,I2,I3)和3个内核(K1,K2,K3)以及一个输出通道(O1)。

(I1*K1 + I2*K2 + I3*K3) /3 = O1.

如果我想要3个输出,计算可能如下所示

I1*K1 = O1     
I2*K2 = O2
I3*K3 = O3

但实际情况比较复杂,因为如果我们需要3个输出通道,那么核的数量是9,Conv2d(3,3,3)代表3个输入(I1,I2,I3),9个核(K11,K12,K13,K21,K22,K23,K31,K32,K33),那么计算对是什么?
下面是conv2d(3,3,3)情形的9个内核。

conv2d(3,3,3)
=>3,3,3,3  kernel
=>tensor([[[[-0.0904, -0.0924, -0.0892],
          [-0.0060, -0.0389, -0.1388],
          [ 0.1636, -0.0933, -0.0295]],

         [[-0.0742,  0.0426, -0.0662],
          [ 0.1625, -0.1485, -0.0169],
          [-0.1122, -0.0875,  0.1021]],

         [[ 0.1214, -0.0896,  0.1304],
          [ 0.0612,  0.0367, -0.0288],
          [-0.1868, -0.1356,  0.0869]]],

        [[[ 0.0792, -0.1562, -0.1878],
          [ 0.0373,  0.1162,  0.1224],
          [-0.1138,  0.0553,  0.1449]],

         [[ 0.1558, -0.1661, -0.0963],
          [ 0.0603, -0.1405,  0.0995],
          [-0.0644, -0.1151,  0.1422]],

         [[ 0.1534,  0.0399, -0.1709],
          [ 0.0765, -0.0665,  0.0119],
          [ 0.0586,  0.1424, -0.1755]]],

        [[[-0.0199, -0.0956, -0.0577],
          [ 0.1312, -0.0273,  0.0615],
          [-0.1037, -0.0247,  0.1915]],

         [[-0.1297,  0.0451,  0.0360],
          [-0.0462,  0.1846, -0.1615],
          [-0.0642,  0.0324,  0.1428]],

         [[ 0.1254, -0.0323,  0.1129],
          [ 0.0482,  0.0839,  0.0227],
          [ 0.0845,  0.1773, -0.1706]]]], requires_grad=True)
8hhllhi2

8hhllhi21#

您可以将Conv2d操作视为如下所示:
您可以在输入Tensor中的(C_in, K_H, K_W)块和相同大小的滤波器之间取矩阵内积,即(C_in,K_H,K_W)来得到输出Tensor的一个条目。(out_channels)输出Tensor中每个空间位置的通道数,您必须相应地创建C_out不同的(C_in, K_H, K_W)滤波器,这导致(C_out, C_in, K_H, K_W)的堆叠滤波器形状。
符号:

  • C_in:通道中
  • K_H:内核高度
  • K_W:内核宽度

相关问题