keras 我们能否避免像Tensorflow一样在PyTorch的批处理规范中指定“num_features”?

ttisahbt  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(134)

以下是TF中的批次标准:

model = BatchNormalization(momentum=0.15, axis=-1)(model)

以下是Torch中的批次标准:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

您可以看到,还有一个参数:num_features .这很烦人。
假设我不想在torch中使用affine,TF和Torch中的批处理规范应该是相同的。有没有办法避免像Tensorflow一样在PyTorch的批处理规范中指定“num_features”?

4uqofj5v

4uqofj5v1#

如果你真的不喜欢指定这个参数,你可以看看lazy batch norm
否则,您可以指定num_features为任意值(None?),只要affinetrack_running_stats都是False。如果您查看批处理范数函数的基类(可在此链接获得):

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""

    _version = 2
    __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
    num_features: int
    eps: float
    momentum: float
    affine: bool
    track_running_stats: bool
    # WARNING: weight and bias purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

您可以看到,当affine为True时,num_features用于设置self.weightself.bias;当track_running_stats为True时,num_features还用于设置running_meanrunning_std

相关问题