我正在阅读TorchVision's GoogLeNet的源代码,我发现这些行很奇怪,无法理解。
def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
字符串
我知道ImageNet数据集有mean = [0.485, 0.456, 0.406]
和std = [0.229, 0.224, 0.225]
,它看起来像一些“归一化”,但它显然不是(x - mean) / std
,而更像x * std + mean
。
谁能解释一下这些代码?
1条答案
按热度按时间xqnpmsa81#
这样做是为了匹配TensorFlow对输入图像进行预处理的方式。在将GoogLeNet添加到TorchVision的pull请求中,作者解释说他匹配了TensorFlow完成的处理。下面是在问题中添加规范化的提交。
为TorchVision贡献GoogLeNet的作者写道,
我更新了代码,以匹配TensorFlow权重所需的结构。还添加了用于Inception v3模型的输入规范化。