detectron2.layers package

class detectron2.layers.FrozenBatchNorm2d(num_features, eps=1e-05)[source]

Bases: torch.nn.modules.module.Module

BatchNorm2d where the batch statistics and the affine parameters are fixed.

It contains non-trainable buffers called “weight” and “bias”, “running_mean”, “running_var”, initialized to perform identity transformation.

The pre-trained backbone models from Caffe2 only contain “weight” and “bias”, which are computed from the original four parameters of BN. The affine transform x * weight + bias will perform the equivalent computation of (x - running_mean) / sqrt(running_var) * weight + bias. When loading a backbone model from Caffe2, “running_mean” and “running_var” will be left unchanged as identity transformation.

Other pre-trained backbone models may contain all 4 parameters.

The forward is implemented by F.batch_norm(…, training=False).

forward(x)[source]
classmethod convert_frozen_batchnorm(module)[source]

Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.

Parameters:module (torch.nn.Module) –
Returns:If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it.

Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py

detectron2.layers.get_norm(norm, out_channels)[source]
Parameters:norm (str or callable) –
Returns:nn.Module or None – the normalization layer
class detectron2.layers.NaiveSyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

Bases: detectron2.layers.wrappers.BatchNorm2d

torch.nn.SyncBatchNorm has known unknown bugs. It produces significantly worse AP (and sometimes goes NaN) when the batch size on each worker is quite different (e.g., when scale augmentation is used, or when it is applied to mask head).

Use this implementation before nn.SyncBatchNorm is fixed. It is slower than nn.SyncBatchNorm.

forward(input)[source]
class detectron2.layers.DeformConv(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]

Bases: torch.nn.modules.module.Module

__init__(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]

Deformable convolution.

Arguments are similar to Conv2D. Extra arguments:

Parameters:
  • deformable_groups (int) – number of groups used in deformable convolution.
  • norm (nn.Module, optional) – a normalization layer
  • activation (callable(Tensor) -> Tensor) – a callable activation function
forward(x, offset)[source]
extra_repr()[source]
class detectron2.layers.ModulatedDeformConv(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]

Bases: torch.nn.modules.module.Module

__init__(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]

Modulated deformable convolution.

Arguments are similar to Conv2D. Extra arguments:

Parameters:
  • deformable_groups (int) – number of groups used in deformable convolution.
  • norm (nn.Module, optional) – a normalization layer
  • activation (callable(Tensor) -> Tensor) – a callable activation function
forward(x, offset, mask)[source]
extra_repr()[source]
detectron2.layers.paste_masks_in_image(masks, boxes, image_shape, threshold=0.5)[source]

Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes.

Parameters:
  • masks (tensor) – Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of detected object instances in the image and Hmask, Wmask are the mask width and mask height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
  • boxes (Boxes) – A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond to the same object instance.
  • image_shape (tuple) – height, width
  • threshold (float) – A threshold in [0, 1] for converting the (soft) masks to binary masks.
Returns:

img_masks (Tensor) – A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i.

detectron2.layers.nms(boxes, scores, iou_threshold)[source]

Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).

NMS iteratively removes lower scoring boxes which have an IoU greater than iou_threshold with another (higher scoring) box.

boxes : Tensor[N, 4])
boxes to perform NMS on. They are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
iou_threshold : float
discards all overlapping boxes with IoU < iou_threshold
keep : Tensor
int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores
detectron2.layers.batched_nms(boxes, scores, idxs, iou_threshold)[source]

Same as torchvision.ops.boxes.batched_nms, but safer.

detectron2.layers.batched_nms_rotated(boxes, scores, idxs, iou_threshold)[source]

Performs non-maximum suppression in a batched fashion.

Each index value correspond to a category, and NMS will not be applied between elements of different categories.

Parameters:
  • boxes (Tensor[N, 5]) – boxes where NMS will be performed. They are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
  • scores (Tensor[N]) – scores for each one of the boxes
  • idxs (Tensor[N]) – indices of the categories for each one of the boxes.
  • iou_threshold (float) – discards all overlapping boxes with IoU < iou_threshold
Returns:

Tensor – int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores

detectron2.layers.nms_rotated(boxes, scores, iou_threshold)[source]

Performs non-maximum suppression (NMS) on the rotated boxes according to their intersection-over-union (IoU).

Rotated NMS iteratively removes lower scoring rotated boxes which have an IoU greater than iou_threshold with another (higher scoring) rotated box.

Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they can be representing completely different objects in certain tasks, e.g., OCR.

As for the question of whether rotated-NMS should treat them as faraway boxes even though their IOU is 1, it depends on the application and/or ground truth annotation.

As an extreme example, consider a single character v and the square box around it.

If the angle is 0 degree, the object (text) would be read as ‘v’;

If the angle is 90 degrees, the object (text) would become ‘>’;

If the angle is 180 degrees, the object (text) would become ‘^’;

If the angle is 270/-90 degrees, the object (text) would become ‘<’

All of these cases have IoU of 1 to each other, and rotated NMS that only uses IoU as criterion would only keep one of them with the highest score - which, practically, still makes sense in most cases because typically only one of theses orientations is the correct one. Also, it does not matter as much if the box is only used to classify the object (instead of transcribing them with a sequential OCR recognition model) later.

On the other hand, when we use IoU to filter proposals that are close to the ground truth during training, we should definitely take the angle into account if we know the ground truth is labeled with the strictly correct orientation (as in, upside-down words are annotated with -180 degrees even though they can be covered with a 0/90/-90 degree box, etc.)

The way the original dataset is annotated also matters. For example, if the dataset is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, we can estimate a minimum rotated bounding box to this polygon, but there’s no way we can tell the correct angle with 100% confidence (as shown above, there could be 4 different rotated boxes, with angles differed by 90 degrees to each other, covering the exactly same region). In that case we have to just use IoU to determine the box proximity (as many detection benchmarks (even for text) do) unless there’re other assumptions we can make (like width is always larger than height, or the object is not rotated by more than 90 degrees CCW/CW, etc.)

In summary, not considering angles in rotated NMS seems to be a good option for now, but we should be aware of its implications.

Parameters:
  • boxes (Tensor[N, 5]) – Rotated boxes to perform NMS on. They are expected to be in (x_center, y_center, width, height, angle_degrees) format.
  • scores (Tensor[N]) – Scores for each one of the rotated boxes
  • iou_threshold (float) – Discards all overlapping rotated boxes with IoU < iou_threshold
Returns:

keep (Tensor) – int64 tensor with the indices of the elements that have been kept by Rotated NMS, sorted in decreasing order of scores

detectron2.layers.roi_align()
class detectron2.layers.ROIAlign(output_size, spatial_scale, sampling_ratio, aligned=True)[source]

Bases: torch.nn.modules.module.Module

__init__(output_size, spatial_scale, sampling_ratio, aligned=True)[source]
Parameters:
  • output_size (tuple) – h, w
  • spatial_scale (float) – scale the input boxes by this number
  • sampling_ratio (int) – number of inputs samples to take for each output sample. 0 to take samples densely.
  • aligned (bool) – if False, use the legacy implementation in Detectron. If True, align the results more perfectly.

Note

The meaning of aligned=True:

Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5). But the original roi_align (aligned=False) does not subtract the 0.5 when computing neighboring pixel indices and therefore it uses pixels with a slightly incorrect alignment (relative to our pixel model) when performing bilinear interpolation.

With aligned=True, we first appropriately scale the ROI and then shift it by -0.5 prior to calling roi_align. This produces the correct neighbors; see detectron2/tests/test_roi_align.py for verification.

The difference does not make a difference to the model’s performance if ROIAlign is used together with conv layers.

forward(input, rois)[source]
Parameters:
  • input – NCHW images
  • rois – Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
detectron2.layers.roi_align_rotated()
class detectron2.layers.ROIAlignRotated(output_size, spatial_scale, sampling_ratio)[source]

Bases: torch.nn.modules.module.Module

__init__(output_size, spatial_scale, sampling_ratio)[source]
Parameters:
  • output_size (tuple) – h, w
  • spatial_scale (float) – scale the input boxes by this number
  • sampling_ratio (int) – number of inputs samples to take for each output sample. 0 to take samples densely.

Note

ROIAlignRotated supports continuous coordinate by default: Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5).

forward(input, rois)[source]
Parameters:
  • input – NCHW images
  • rois – Bx6 boxes. First column is the index into N. The other 5 columns are (x_ctr, y_ctr, width, height, angle_degrees).
class detectron2.layers.ShapeSpec[source]

Bases: detectron2.layers.shape_spec._ShapeSpec

A simple structure that contains basic shape specification about a tensor. It is often used as the auxiliary inputs/outputs of models, to obtain the shape inference ability among pytorch modules.

channels
height
width
stride
class detectron2.layers.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

Bases: torch.nn.modules.batchnorm.BatchNorm2d

A wrapper around torch.nn.BatchNorm2d to support zero-size tensor.

forward(x)[source]
class detectron2.layers.Conv2d(*args, **kwargs)[source]

Bases: torch.nn.modules.conv.Conv2d

A wrapper around torch.nn.Conv2d to support zero-size tensor and more features.

__init__(*args, **kwargs)[source]

Extra keyword arguments supported in addition to those in torch.nn.Conv2d:

Parameters:
  • norm (nn.Module, optional) – a normalization layer
  • activation (callable(Tensor) -> Tensor) – a callable activation function

It assumes that norm layer is used before activation.

forward(x)[source]
class detectron2.layers.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros')[source]

Bases: torch.nn.modules.conv.ConvTranspose2d

A wrapper around torch.nn.ConvTranspose2d to support zero-size tensor.

forward(x)[source]
detectron2.layers.cat(tensors, dim=0)[source]

Efficient version of torch.cat that avoids a copy if there is only a single element in a list

detectron2.layers.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)[source]

A wrapper around torch.nn.functional.interpolate() to support zero-size tensor.