Source code for detectron2.layers.batch_norm

# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.distributed as dist
from fvcore.nn.distributed import differentiable_all_reduce
from torch import nn
from torch.nn import functional as F

from detectron2.utils import comm, env

from .wrappers import BatchNorm2d


[docs]class FrozenBatchNorm2d(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. It contains non-trainable buffers called "weight" and "bias", "running_mean", "running_var", initialized to perform identity transformation. The pre-trained backbone models from Caffe2 only contain "weight" and "bias", which are computed from the original four parameters of BN. The affine transform `x * weight + bias` will perform the equivalent computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. When loading a backbone model from Caffe2, "running_mean" and "running_var" will be left unchanged as identity transformation. Other pre-trained backbone models may contain all 4 parameters. The forward is implemented by `F.batch_norm(..., training=False)`. """ _version = 3 def __init__(self, num_features, eps=1e-5): super().__init__() self.num_features = num_features self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features) - eps)
[docs] def forward(self, x): if x.requires_grad: # When gradients are needed, F.batch_norm will use extra memory # because its backward op computes gradients for weight/bias as well. scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) out_dtype = x.dtype # may be half return x * scale.to(out_dtype) + bias.to(out_dtype) else: # When gradients are not needed, F.batch_norm is a single fused op # and provide more optimization opportunities. return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps, )
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: # No running_mean/var in early versions # This will silent the warnings if prefix + "running_mean" not in state_dict: state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) if prefix + "running_var" not in state_dict: state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def __repr__(self): return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
[docs] @classmethod def convert_frozen_batchnorm(cls, module): """ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res
[docs]def get_norm(norm, out_channels): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if norm is None: return None if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm2d, # Fixed in https://github.com/pytorch/pytorch/pull/36382 "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), # for debugging: "nnSyncBN": nn.SyncBatchNorm, "naiveSyncBN": NaiveSyncBatchNorm, # expose stats_mode N as an option to caller, required for zero-len inputs "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), }[norm] return norm(out_channels)
[docs]class NaiveSyncBatchNorm(BatchNorm2d): """ In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient when the batch size on each worker is different. (e.g., when scale augmentation is used, or when it is applied to mask head). This is a slower but correct alternative to `nn.SyncBatchNorm`. Note: There isn't a single definition of Sync BatchNorm. When ``stats_mode==""``, this module computes overall statistics by using statistics of each worker with equal weight. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (N, H, W). This mode does not support inputs with zero batch size. When ``stats_mode=="N"``, this module computes overall statistics by weighting the statistics of each worker by their ``N``. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (H, W). It is slower than ``stats_mode==""``. Even though the result of this module may not be the true statistics of all samples, it may still be reasonable because it might be preferrable to assign equal weights to all workers, regardless of their (H, W) dimension, instead of putting larger weight on larger images. From preliminary experiments, little difference is found between such a simplified implementation and an accurate computation of overall mean & variance. """ def __init__(self, *args, stats_mode="", **kwargs): super().__init__(*args, **kwargs) assert stats_mode in ["", "N"] self._stats_mode = stats_mode
[docs] def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] half_input = input.dtype == torch.float16 if half_input: # fp16 does not have good enough numerics for the reduction here input = input.float() mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) if self._stats_mode == "": assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' vec = torch.cat([mean, meansqr], dim=0) vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) momentum = self.momentum else: if B == 0: vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) vec = vec + input.sum() # make sure there is gradient w.r.t input else: vec = torch.cat( [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 ) vec = differentiable_all_reduce(vec * B) total_batch = vec[-1].detach() momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) self.running_mean += momentum * (mean.detach() - self.running_mean) self.running_var += momentum * (var.detach() - self.running_var) ret = input * scale + bias if half_input: ret = ret.half() return ret
[docs]class CycleBatchNormList(nn.ModuleList): """ Implement domain-specific BatchNorm by cycling. When a BatchNorm layer is used for multiple input domains or input features, it might need to maintain a separate test-time statistics for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`. This module implements it by using N separate BN layers and it cycles through them every time a forward() is called. NOTE: The caller of this module MUST guarantee to always call this module by multiple of N times. Otherwise its test-time statistics will be incorrect. """
[docs] def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs): """ Args: length: number of BatchNorm layers to cycle. bn_class: the BatchNorm class to use kwargs: arguments of the BatchNorm class, such as num_features. """ self._affine = kwargs.pop("affine", True) super().__init__([bn_class(**kwargs, affine=False) for k in range(length)]) if self._affine: # shared affine, domain-specific BN channels = self[0].num_features self.weight = nn.Parameter(torch.ones(channels)) self.bias = nn.Parameter(torch.zeros(channels)) self._pos = 0
[docs] def forward(self, x): ret = self[self._pos](x) self._pos = (self._pos + 1) % len(self) if self._affine: w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) return ret * w + b else: return ret
[docs] def extra_repr(self): return f"affine={self._affine}"