Source code for fvcore.nn.weight_init

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torch.nn as nn


[docs]def c2_xavier_fill(module: nn.Module) -> None: """ Initialize `module.weight` using the "XavierFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ # Caffe2 implementation of XavierFill in fact # corresponds to kaiming_uniform_ in PyTorch nn.init.kaiming_uniform_(module.weight, a=1) if module.bias is not None: # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, # torch.Tensor]`. nn.init.constant_(module.bias, 0)
[docs]def c2_msra_fill(module: nn.Module) -> None: """ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, # torch.Tensor]`. nn.init.constant_(module.bias, 0)