如何在Pytorch中实现Batchnorm2d?

cedebl8k  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(111)

我尝试使用以下代码实现Batchnorm 2d()层:

class BatchNorm2d(nn.Module):

    def __init__(self, num_features):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.eps = 1e-5
        self.momentum = 0.1
        self.first_run = True

    def forward(self, input):
        # input: [batch_size, num_feature_map, height, width]
        device = input.device
        if self.training:
            mean = torch.mean(input, dim=0, keepdim=True).to(device)  # [1, num_feature, height, width]
            var = torch.var(input, dim=0, unbiased=False, keepdim=True).to(device)  # [1, num_feature, height, width]
            if self.first_run:
                self.weight = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.bias = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.register_buffer('running_mean', torch.zeros(input.shape).to(input.device))
                self.register_buffer('running_var', torch.ones(input.shape).to(input.device))
                self.first_run = False
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            bn_init = (input - mean) / torch.sqrt(var + self.eps)
        else:
            bn_init = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        return self.weight * bn_init + self.bias

字符串
但是经过训练和测试,我发现使用我的层的结果与使用nn.Batchnorm2d()的结果无法比较。它一定有什么问题,我猜问题与初始化forward()中的参数有关。我这样做是因为我不知道如何知道__init__()中输入的形状,也许有更好的方法。我不知道怎么修,请帮帮忙。谢谢!!

nhjlsmyf

nhjlsmyf1#

HERE得到了答案!
所以weight(bias)的形状是(1,num_features,1,1),而不是(1,num_features,width,height)。

e4yzc0pl

e4yzc0pl2#

如果有人在这一点上绊倒了,你实际上不必像上面那样在模型中设置“设备”。在模型之外,你可以只做

device = torch.device('cuda:0')
model = model.to(device)

字符串
不确定这是否比手动设置模块内部的权重和偏差设备更好,但我认为肯定更标准

pokxtpni

pokxtpni3#

Andrej Karpathy的video有一个非常直观的解释。
下面是1D实现的代码片段,来自与视频相关的notebook

class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      xmean = x.mean(0, keepdim=True) # batch mean
      xvar = x.var(0, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

字符串
pytorch的实现在c++中。
然而,这个实现+解释,来自Dive into deep learning网站,正如在批准的答案中提到的,可能会帮助你理解1D和2D情况之间的实现差异。

相关问题