detectron2.layers

class detectron2.layers.FrozenBatchNorm2d(num_features, eps=1e-05)[source]

Bases: torch.nn.Module

BatchNorm2d where the batch statistics and the affine parameters are fixed.

It contains non-trainable buffers called “weight” and “bias”, “running_mean”, “running_var”, initialized to perform identity transformation.

The pre-trained backbone models from Caffe2 only contain “weight” and “bias”, which are computed from the original four parameters of BN. The affine transform x * weight + bias will perform the equivalent computation of (x - running_mean) / sqrt(running_var) * weight + bias. When loading a backbone model from Caffe2, “running_mean” and “running_var” will be left unchanged as identity transformation.

Other pre-trained backbone models may contain all 4 parameters.

The forward is implemented by F.batch_norm(…, training=False).

forward(x)[source]
classmethod convert_frozen_batchnorm(module)[source]

Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.

Parameters

module (torch.nn.Module) –

Returns

If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it.

Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py

classmethod convert_frozenbatchnorm2d_to_batchnorm2d(module: torch.nn.Module)torch.nn.Module[source]

Convert all FrozenBatchNorm2d to BatchNorm2d

Parameters

module (torch.nn.Module) –

Returns

If module is FrozenBatchNorm2d, returns a new module. Otherwise, in-place convert module and return it.

This is needed for quantization:

https://fb.workplace.com/groups/1043663463248667/permalink/1296330057982005/

training: bool
detectron2.layers.get_norm(norm, out_channels)[source]
Parameters

norm (str or callable) – either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module.

Returns

nn.Module or None – the normalization layer

class detectron2.layers.NaiveSyncBatchNorm(*args, stats_mode='', **kwargs)[source]

Bases: torch.nn.BatchNorm2d

In PyTorch<=1.5, nn.SyncBatchNorm has incorrect gradient when the batch size on each worker is different. (e.g., when scale augmentation is used, or when it is applied to mask head).

This is a slower but correct alternative to nn.SyncBatchNorm.

Note

There isn’t a single definition of Sync BatchNorm.

When stats_mode=="", this module computes overall statistics by using statistics of each worker with equal weight. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (N, H, W). This mode does not support inputs with zero batch size.

When stats_mode=="N", this module computes overall statistics by weighting the statistics of each worker by their N. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (H, W). It is slower than stats_mode=="".

Even though the result of this module may not be the true statistics of all samples, it may still be reasonable because it might be preferrable to assign equal weights to all workers, regardless of their (H, W) dimension, instead of putting larger weight on larger images. From preliminary experiments, little difference is found between such a simplified implementation and an accurate computation of overall mean & variance.

forward(input)[source]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
class detectron2.layers.CycleBatchNormList(length: int, bn_class=<class 'torch.nn.BatchNorm2d'>, **kwargs)[source]

Bases: torch.nn.ModuleList

Implement domain-specific BatchNorm by cycling.

When a BatchNorm layer is used for multiple input domains or input features, it might need to maintain a separate test-time statistics for each domain. See Sec 5.2 in Rethinking “Batch” in BatchNorm.

This module implements it by using N separate BN layers and it cycles through them every time a forward() is called.

NOTE: The caller of this module MUST guarantee to always call this module by multiple of N times. Otherwise its test-time statistics will be incorrect.

__init__(length: int, bn_class=<class 'torch.nn.BatchNorm2d'>, **kwargs)[source]
Parameters
  • length – number of BatchNorm layers to cycle.

  • bn_class – the BatchNorm class to use

  • kwargs – arguments of the BatchNorm class, such as num_features.

forward(x)[source]
extra_repr()[source]
training: bool
class detectron2.layers.DeformConv(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]

Bases: torch.nn.Module

__init__(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]

Deformable convolution from Deformable Convolutional Networks.

Arguments are similar to Conv2D. Extra arguments:

Parameters
  • deformable_groups (int) – number of groups used in deformable convolution.

  • norm (nn.Module, optional) – a normalization layer

  • activation (callable(Tensor) -> Tensor) – a callable activation function

forward(x, offset)[source]
extra_repr()[source]
training: bool
class detectron2.layers.ModulatedDeformConv(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]

Bases: torch.nn.Module

__init__(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]

Modulated deformable convolution from Deformable ConvNets v2: More Deformable, Better Results.

Arguments are similar to Conv2D. Extra arguments:

Parameters
  • deformable_groups (int) – number of groups used in deformable convolution.

  • norm (nn.Module, optional) – a normalization layer

  • activation (callable(Tensor) -> Tensor) – a callable activation function

forward(x, offset, mask)[source]
extra_repr()[source]
training: bool
detectron2.layers.paste_masks_in_image(masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5)[source]

Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes.

Note

This is a complicated but more accurate implementation. In actual deployment, it is often enough to use a faster but less accurate implementation. See paste_mask_in_image_old() in this file for an alternative implementation.

Parameters
  • masks (tensor) – Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of detected object instances in the image and Hmask, Wmask are the mask width and mask height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].

  • boxes (Boxes or Tensor) – A Boxes of length Bimg or Tensor of shape (Bimg, 4). boxes[i] and masks[i] correspond to the same object instance.

  • image_shape (tuple) – height, width

  • threshold (float) – A threshold in [0, 1] for converting the (soft) masks to binary masks.

Returns

img_masks (Tensor) – A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i.

detectron2.layers.nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float)torch.Tensor[source]

Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).

NMS iteratively removes lower scoring boxes which have an IoU greater than iou_threshold with another (higher scoring) box.

If multiple boxes have the exact same score and satisfy the IoU criterion with respect to a reference box, the selected box is not guaranteed to be the same between CPU and GPU. This is similar to the behavior of argsort in PyTorch when repeated values are present.

Parameters
  • boxes (Tensor[N, 4])) – boxes to perform NMS on. They are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.

  • scores (Tensor[N]) – scores for each one of the boxes

  • iou_threshold (float) – discards all overlapping boxes with IoU > iou_threshold

Returns

keep (Tensor)

int64 tensor with the indices

of the elements that have been kept by NMS, sorted in decreasing order of scores

detectron2.layers.batched_nms(boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float)[source]

Same as torchvision.ops.boxes.batched_nms, but with float().

detectron2.layers.batched_nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float)[source]

Performs non-maximum suppression in a batched fashion.

Each index value correspond to a category, and NMS will not be applied between elements of different categories.

Parameters
  • boxes (Tensor[N, 5]) – boxes where NMS will be performed. They are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format

  • scores (Tensor[N]) – scores for each one of the boxes

  • idxs (Tensor[N]) – indices of the categories for each one of the boxes.

  • iou_threshold (float) – discards all overlapping boxes with IoU < iou_threshold

Returns

Tensor – int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores

detectron2.layers.nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float)[source]

Performs non-maximum suppression (NMS) on the rotated boxes according to their intersection-over-union (IoU).

Rotated NMS iteratively removes lower scoring rotated boxes which have an IoU greater than iou_threshold with another (higher scoring) rotated box.

Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they can be representing completely different objects in certain tasks, e.g., OCR.

As for the question of whether rotated-NMS should treat them as faraway boxes even though their IOU is 1, it depends on the application and/or ground truth annotation.

As an extreme example, consider a single character v and the square box around it.

If the angle is 0 degree, the object (text) would be read as ‘v’;

If the angle is 90 degrees, the object (text) would become ‘>’;

If the angle is 180 degrees, the object (text) would become ‘^’;

If the angle is 270/-90 degrees, the object (text) would become ‘<’

All of these cases have IoU of 1 to each other, and rotated NMS that only uses IoU as criterion would only keep one of them with the highest score - which, practically, still makes sense in most cases because typically only one of theses orientations is the correct one. Also, it does not matter as much if the box is only used to classify the object (instead of transcribing them with a sequential OCR recognition model) later.

On the other hand, when we use IoU to filter proposals that are close to the ground truth during training, we should definitely take the angle into account if we know the ground truth is labeled with the strictly correct orientation (as in, upside-down words are annotated with -180 degrees even though they can be covered with a 0/90/-90 degree box, etc.)

The way the original dataset is annotated also matters. For example, if the dataset is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, we can estimate a minimum rotated bounding box to this polygon, but there’s no way we can tell the correct angle with 100% confidence (as shown above, there could be 4 different rotated boxes, with angles differed by 90 degrees to each other, covering the exactly same region). In that case we have to just use IoU to determine the box proximity (as many detection benchmarks (even for text) do) unless there’re other assumptions we can make (like width is always larger than height, or the object is not rotated by more than 90 degrees CCW/CW, etc.)

In summary, not considering angles in rotated NMS seems to be a good option for now, but we should be aware of its implications.

Parameters
  • boxes (Tensor[N, 5]) – Rotated boxes to perform NMS on. They are expected to be in (x_center, y_center, width, height, angle_degrees) format.

  • scores (Tensor[N]) – Scores for each one of the rotated boxes

  • iou_threshold (float) – Discards all overlapping rotated boxes with IoU < iou_threshold

Returns

keep (Tensor) – int64 tensor with the indices of the elements that have been kept by Rotated NMS, sorted in decreasing order of scores

detectron2.layers.roi_align(input: torch.Tensor, boxes: torch.Tensor, output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = - 1, aligned: bool = False)torch.Tensor[source]

Performs Region of Interest (RoI) Align operator described in Mask R-CNN

Parameters
  • input (Tensor[N, C, H, W]) – input tensor

  • boxes (Tensor[K, 5] or List[Tensor[L, 4]]) – the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. The coordinate must satisfy 0 <= x1 < x2 and 0 <= y1 < y2. If a single Tensor is passed, then the first column should contain the batch index. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i in a batch

  • output_size (int or Tuple[int, int]) – the size of the output after the cropping is performed, as (height, width)

  • spatial_scale (float) – a scaling factor that maps the input coordinates to the box coordinates. Default: 1.0

  • sampling_ratio (int) – number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. If > 0, then exactly sampling_ratio x sampling_ratio grid points are used. If <= 0, then an adaptive number of grid points are used (computed as ceil(roi_width / pooled_w), and likewise for height). Default: -1

  • aligned (bool) – If False, use the legacy implementation. If True, pixel shift it by -0.5 for align more perfectly about two neighboring pixel indices. This version in Detectron2

Returns

output (Tensor[K, C, output_size[0], output_size[1]])

class detectron2.layers.ROIAlign(output_size, spatial_scale, sampling_ratio, aligned=True)[source]

Bases: torch.nn.Module

__init__(output_size, spatial_scale, sampling_ratio, aligned=True)[source]
Parameters
  • output_size (tuple) – h, w

  • spatial_scale (float) – scale the input boxes by this number

  • sampling_ratio (int) – number of inputs samples to take for each output sample. 0 to take samples densely.

  • aligned (bool) – if False, use the legacy implementation in Detectron. If True, align the results more perfectly.

Note

The meaning of aligned=True:

Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5). But the original roi_align (aligned=False) does not subtract the 0.5 when computing neighboring pixel indices and therefore it uses pixels with a slightly incorrect alignment (relative to our pixel model) when performing bilinear interpolation.

With aligned=True, we first appropriately scale the ROI and then shift it by -0.5 prior to calling roi_align. This produces the correct neighbors; see detectron2/tests/test_roi_align.py for verification.

The difference does not make a difference to the model’s performance if ROIAlign is used together with conv layers.

forward(input, rois)[source]
Parameters
  • input – NCHW images

  • rois – Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.

training: bool
detectron2.layers.roi_align_rotated()
class detectron2.layers.ROIAlignRotated(output_size, spatial_scale, sampling_ratio)[source]

Bases: torch.nn.Module

__init__(output_size, spatial_scale, sampling_ratio)[source]
Parameters
  • output_size (tuple) – h, w

  • spatial_scale (float) – scale the input boxes by this number

  • sampling_ratio (int) – number of inputs samples to take for each output sample. 0 to take samples densely.

Note

ROIAlignRotated supports continuous coordinate by default: Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5).

forward(input, rois)[source]
Parameters
  • input – NCHW images

  • rois – Bx6 boxes. First column is the index into N. The other 5 columns are (x_ctr, y_ctr, width, height, angle_degrees).

training: bool
class detectron2.layers.ShapeSpec(channels: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, stride: Optional[int] = None)[source]

Bases: object

A simple structure that contains basic shape specification about a tensor. It is often used as the auxiliary inputs/outputs of models, to complement the lack of shape inference ability among pytorch modules.

channels: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
stride: Optional[int] = None
class detectron2.layers.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

Bases: torch.nn.modules.batchnorm._BatchNorm

Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the input size). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

If track_running_stats is set to False, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.

Note

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.

Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.

Parameters
  • num_features\(C\) from an expected input of size \((N, C, H, W)\)

  • eps – a value added to the denominator for numerical stability. Default: 1e-5

  • momentum – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

  • affine – a boolean value that when set to True, this module has learnable affine parameters. Default: True

  • track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: True

Shape:
  • Input: \((N, C, H, W)\)

  • Output: \((N, C, H, W)\) (same shape as input)

Examples:

>>> # With Learnable Parameters
>>> m = nn.BatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm2d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
class detectron2.layers.Conv2d(*args, **kwargs)[source]

Bases: torch.nn.Conv2d

A wrapper around torch.nn.Conv2d to support empty inputs and more features.

__init__(*args, **kwargs)[source]

Extra keyword arguments supported in addition to those in torch.nn.Conv2d:

Parameters
  • norm (nn.Module, optional) – a normalization layer

  • activation (callable(Tensor) -> Tensor) – a callable activation function

It assumes that norm layer is used before activation.

forward(x)[source]
bias: Optional[torch.Tensor]
out_channels: int
kernel_size: Tuple[int, ]
stride: Tuple[int, ]
padding: Tuple[int, ]
dilation: Tuple[int, ]
transposed: bool
output_padding: Tuple[int, ]
groups: int
padding_mode: str
weight: torch.Tensor
class detectron2.layers.ConvTranspose2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = 'zeros')

Bases: torch.nn.modules.conv._ConvTransposeNd

Applies a 2D transposed convolution operator over an input image composed of several input planes.

This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation).

This module supports TensorFloat32.

  • stride controls the stride for the cross-correlation.

  • padding controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. See note below for details.

  • output_padding controls the additional size added to one side of the output shape. See note below for details.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.

  • groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out\_channels}}{\text{in\_channels}}\)).

The parameters kernel_size, stride, padding, output_padding can either be:

  • a single int – in which case the same value is used for the height and width dimensions

  • a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension

Note

The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a Conv2d and a ConvTranspose2d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when stride > 1, Conv2d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.

Note

In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting torch.backends.cudnn.deterministic = True. See /notes/randomness for more information.

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – dilation * (kernel_size - 1) - padding zero-padding will be added to both sides of each dimension in the input. Default: 0

  • output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

  • dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1

Shape:
  • Input: \((N, C_{in}, H_{in}, W_{in})\)

  • Output: \((N, C_{out}, H_{out}, W_{out})\) where

\[H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1\]
\[W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1\]
weight

the learnable weights of the module of shape \((\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},\) \(\text{kernel\_size[0]}, \text{kernel\_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}\)

Type

Tensor

bias

the learnable bias of the module of shape (out_channels) If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}\)

Type

Tensor

Examples:

>>> # With square kernels and equal stride
>>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> input = torch.randn(1, 16, 12, 12)
>>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
forward(input: torch.Tensor, output_size: Optional[List[int]] = None)torch.Tensor
bias: Optional[torch.Tensor]
out_channels: int
kernel_size: Tuple[int, ]
stride: Tuple[int, ]
padding: Tuple[int, ]
dilation: Tuple[int, ]
transposed: bool
output_padding: Tuple[int, ]
groups: int
padding_mode: str
weight: torch.Tensor
detectron2.layers.cat(tensors: List[torch.Tensor], dim: int = 0)[source]

Efficient version of torch.cat that avoids a copy if there is only a single element in a list

detectron2.layers.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)[source]

Down/up samples the input to either the given size or the given scale_factor

The algorithm used for interpolation is determined by mode.

Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape.

The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width.

The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area

Parameters
  • input (Tensor) – the input tensor

  • size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]) – output spatial size.

  • scale_factor (float or Tuple[float]) – multiplier for spatial size. Has to match input size if it is a tuple.

  • mode (str) – algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'. Default: 'nearest'

  • align_corners (bool, optional) – Geometrically, we consider the pixels of the input and output as squares rather than points. If set to True, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. If set to False, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation independent of input size when scale_factor is kept the same. This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: False

  • recompute_scale_factor (bool, optional) – recompute the scale_factor for use in the interpolation calculation. When scale_factor is passed as a parameter, it is used to compute the output_size. If recompute_scale_factor is False or not specified, the passed-in scale_factor will be used in the interpolation computation. Otherwise, a new scale_factor will be computed based on the output and input sizes for use in the interpolation computation (i.e. the computation will be identical to if the computed output_size were passed-in explicitly). Note that when scale_factor is floating-point, the recomputed scale_factor may differ from the one passed in due to rounding and precision issues.

Note

With mode='bicubic', it’s possible to cause overshoot, in other words it can produce negative values or values greater than 255 for images. Explicitly call result.clamp(min=0, max=255) if you want to reduce the overshoot when displaying the image.

Warning

With align_corners = True, the linearly interpolating modes (linear, bilinear, and trilinear) don’t proportionally align the output and input pixels, and thus the output values can depend on the input size. This was the default behavior for these modes up to version 0.3.1. Since then, the default behavior is align_corners = False. See Upsample for concrete examples on how this affects the outputs.

Warning

When scale_factor is specified, if recompute_scale_factor=True, scale_factor is used to compute the output_size which will then be used to infer new scales for the interpolation. The default behavior for recompute_scale_factor changed to False in 1.6.0, and scale_factor is used in the interpolation calculation.

Note

This operation may produce nondeterministic gradients when given tensors on a CUDA device. See /notes/randomness for more information.

class detectron2.layers.Linear(in_features: int, out_features: int, bias: bool = True)

Bases: torch.nn.Module

Applies a linear transformation to the incoming data: \(y = xA^T + b\)

This module supports TensorFloat32.

Parameters
  • in_features – size of each input sample

  • out_features – size of each output sample

  • bias – If set to False, the layer will not learn an additive bias. Default: True

Shape:
  • Input: \((N, *, H_{in})\) where \(*\) means any number of additional dimensions and \(H_{in} = \text{in\_features}\)

  • Output: \((N, *, H_{out})\) where all but the last dimension are the same shape as the input and \(H_{out} = \text{out\_features}\).

weight

the learnable weights of the module of shape \((\text{out\_features}, \text{in\_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in\_features}}\)

bias

the learnable bias of the module of shape \((\text{out\_features})\). If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{in\_features}}\)

Examples:

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
extra_repr()str
forward(input: torch.Tensor)torch.Tensor
reset_parameters()None
in_features: int
out_features: int
weight: torch.Tensor
detectron2.layers.nonzero_tuple(x)[source]

A ‘as_tuple=True’ version of torch.nonzero to support torchscript. because of https://github.com/pytorch/pytorch/issues/38718

detectron2.layers.cross_entropy(input, target, *, reduction='mean', **kwargs)

Same as loss_func, but returns 0 (instead of nan) for empty inputs.

detectron2.layers.empty_input_loss_func_wrapper(loss_func)[source]
detectron2.layers.shapes_to_tensor(x: List[int], device: Optional[torch.device] = None)torch.Tensor[source]

Turn a list of integer scalars or integer Tensor scalars into a vector, in a way that’s both traceable and scriptable.

In tracing, x should be a list of scalar Tensor, so the output can trace to the inputs. In scripting or eager, x should be a list of int.

detectron2.layers.move_device_like(src: torch.Tensor, dst: torch.Tensor)torch.Tensor[source]

Tracing friendly way to cast tensor to another tensor’s device. Device will be treated as constant during tracing, scripting the casting process as whole can workaround this issue.

class detectron2.layers.CNNBlockBase(in_channels, out_channels, stride)[source]

Bases: torch.nn.Module

A CNN block is assumed to have input channels, output channels and a stride. The input and output of forward() method must be NCHW tensors. The method can perform arbitrary computation but must match the given channels and stride specification.

Attribute:

in_channels (int): out_channels (int): stride (int):

__init__(in_channels, out_channels, stride)[source]

The __init__ method of any subclass should also contain these arguments.

Parameters
  • in_channels (int) –

  • out_channels (int) –

  • stride (int) –

freeze()[source]

Make this block not trainable. This method sets all parameters to requires_grad=False, and convert all BatchNorm layers to FrozenBatchNorm

Returns

the block itself

training: bool
class detectron2.layers.DepthwiseSeparableConv2d(in_channels, out_channels, kernel_size=3, padding=1, dilation=1, *, norm1=None, activation1=None, norm2=None, activation2=None)[source]

Bases: torch.nn.Module

A kxk depthwise convolution + a 1x1 convolution.

In Xception: Deep Learning with Depthwise Separable Convolutions, norm & activation are applied on the second conv. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications uses norm & activation on both convs.

__init__(in_channels, out_channels, kernel_size=3, padding=1, dilation=1, *, norm1=None, activation1=None, norm2=None, activation2=None)[source]
Parameters
  • norm1 (str or callable) – normalization for the two conv layers.

  • norm2 (str or callable) – normalization for the two conv layers.

  • activation1 (callable(Tensor) -> Tensor) – activation function for the two conv layers.

  • activation2 (callable(Tensor) -> Tensor) – activation function for the two conv layers.

forward(x)[source]
training: bool
class detectron2.layers.ASPP(in_channels, out_channels, dilations, *, norm, activation, pool_kernel_size=None, dropout: float = 0.0, use_depthwise_separable_conv=False)[source]

Bases: torch.nn.Module

Atrous Spatial Pyramid Pooling (ASPP).

__init__(in_channels, out_channels, dilations, *, norm, activation, pool_kernel_size=None, dropout: float = 0.0, use_depthwise_separable_conv=False)[source]
Parameters
  • in_channels (int) – number of input channels for ASPP.

  • out_channels (int) – number of output channels.

  • dilations (list) – a list of 3 dilations in ASPP.

  • norm (str or callable) – normalization for all conv layers. See layers.get_norm() for supported format. norm is applied to all conv layers except the conv following global average pooling.

  • activation (callable) – activation function.

  • pool_kernel_size (tuple, list) – the average pooling size (kh, kw) for image pooling layer in ASPP. If set to None, it always performs global average pooling. If not None, it must be divisible by the shape of inputs in forward(). It is recommended to use a fixed input feature size in training, and set this option to match this size, so that it performs global average pooling in training, and the size of the pooling window stays consistent in inference.

  • dropout (float) – apply dropout on the output of ASPP. It is used in the official DeepLab implementation with a rate of 0.1: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa

  • use_depthwise_separable_conv (bool) – use DepthwiseSeparableConv2d for 3x3 convs in ASPP, proposed in Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.

forward(x)[source]
training: bool
detectron2.layers.ciou_loss(boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = 'none', eps: float = 1e-07)torch.Tensor[source]

Complete Intersection over Union Loss (Zhaohui Zheng et. al) https://arxiv.org/abs/1911.08287 :param boxes1: box locations in XYXY format, shape (N, 4) or (4,). :type boxes1: Tensor :param boxes2: box locations in XYXY format, shape (N, 4) or (4,). :type boxes2: Tensor :param reduction: ‘none’ | ‘mean’ | ‘sum’

‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.

Parameters

eps (float) – small number to prevent division by zero

detectron2.layers.diou_loss(boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = 'none', eps: float = 1e-07)torch.Tensor[source]

Distance Intersection over Union Loss (Zhaohui Zheng et. al) https://arxiv.org/abs/1911.08287 :param boxes1: box locations in XYXY format, shape (N, 4) or (4,). :type boxes1: Tensor :param boxes2: box locations in XYXY format, shape (N, 4) or (4,). :type boxes2: Tensor :param reduction: ‘none’ | ‘mean’ | ‘sum’

‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.

Parameters

eps (float) – small number to prevent division by zero