python 逐层转发剩余块,结果错误

5lhxktic  于 2023-02-21  发布在  Python
关注(0)|答案(1)|浏览(98)

我编写了一个带有残余连接的pytorch代码,如下所示

all_module = []
for i in range(3):
  layer = nn.Sequential(
      nn.Conv1d(n_hidden_channels, n_hidden_channels),
      nn.LeakyReLU(),
      nn.Conv1d(n_hidden_channels, n_hidden_channels),
      nn.LeakyReLU()
      )
  all_module.append(layer)
module_list = nn.ModuleList(all_module)

# method 1
for layer in module_list:
  x = x + layer(x)
print(x)

# method 2
for layer in module_list:
  y = torch.clone(x)
  for m in layer:
    y = m(y)
  x = x + y
print(x)

为什么方法1和方法2的输出不同?
不知道为什么会这样。

qybjjes1

qybjjes11#

两个方法做的是相同的事情,如果你顺序运行两个方法,第一个方法将更新x;因此,当你运行第二个方法时,你会得到不同的结果。2如果你在第一个方法之前复制x,你会看到两个方法会产生相同的结果。

相关问题