pytorch 为什么TorchVision的GoogLeNet有这种奇怪的“正常化”?

mrphzbgm  于 2024-01-09  发布在  Go
关注(0)|答案(1)|浏览(196)

我正在阅读TorchVision's GoogLeNet的源代码,我发现这些行很奇怪,无法理解。

def _transform_input(self, x: Tensor) -> Tensor:
    if self.transform_input:
        x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
        x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
        x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
        x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    return x

字符串
我知道ImageNet数据集有mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225],它看起来像一些“归一化”,但它显然不是(x - mean) / std,而更像x * std + mean
谁能解释一下这些代码?

xqnpmsa8

xqnpmsa81#

这样做是为了匹配TensorFlow对输入图像进行预处理的方式。在将GoogLeNet添加到TorchVision的pull请求中,作者解释说他匹配了TensorFlow完成的处理。下面是在问题中添加规范化的提交。
为TorchVision贡献GoogLeNet的作者写道,
我更新了代码,以匹配TensorFlow权重所需的结构。还添加了用于Inception v3模型的输入规范化。

相关问题