detectron2.modeling

detectron2.modeling.build_anchor_generator(cfg, input_shape)

Built an anchor generator from cfg.MODEL.ANCHOR_GENERATOR.NAME.

class detectron2.modeling.FPN(bottom_up, in_features, out_channels, norm='', top_block=None, fuse_type='sum')

Bases: detectron2.modeling.Backbone

This module implements Feature Pyramid Networks for Object Detection. It creates pyramid features built on top of some input feature maps.

__init__(bottom_up, in_features, out_channels, norm='', top_block=None, fuse_type='sum')
Parameters
  • bottom_up (Backbone) – module representing the bottom up subnetwork. Must be a subclass of Backbone. The multi-scale feature maps generated by the bottom up network, and listed in in_features, are used to generate FPN levels.

  • in_features (list[str]) – names of the input feature maps coming from the backbone to which FPN is attached. For example, if the backbone produces [“res2”, “res3”, “res4”], any contiguous sublist of these may be used; order must be from high to low resolution.

  • out_channels (int) – number of channels in the output feature maps.

  • norm (str) – the normalization to use.

  • top_block (nn.Module or None) – if provided, an extra operation will be performed on the output of the last (smallest resolution) FPN output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute “num_levels”, meaning the number of extra FPN levels added by this block, and “in_feature”, which is a string representing its input feature (e.g., p5).

  • fuse_type (str) – types for fusing the top down features and the lateral ones. It can be “sum” (default), which sums up element-wise; or “avg”, which takes the element-wise mean of the two.

forward(x)
Parameters

input (dict[str->Tensor]) – mapping feature map name (e.g., “res5”) to feature map tensor for each feature level in high to low resolution order.

Returns

dict[str->Tensor] – mapping from feature map name to FPN feature map tensor in high to low resolution order. Returned feature names follow the FPN paper convention: “p<stage>”, where stage has stride = 2 ** stage e.g., [“p2”, “p3”, …, “p6”].

output_shape()
property size_divisibility
class detectron2.modeling.Backbone

Bases: torch.nn.Module

Abstract base class for network backbones.

__init__()

The __init__ method of any subclass can specify its own set of arguments.

abstract forward()

Subclasses must override this method, but adhere to the same return type.

Returns

dict[str->Tensor] – mapping from feature name (e.g., “res2”) to tensor

output_shape()
Returns

dict[str->ShapeSpec]

property size_divisibility

Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the “bottom up” and “top down” paths. Set to 0 if no specific input size divisibility is required.

training: bool
class detectron2.modeling.ResNet(stem, stages, num_classes=None, out_features=None, freeze_at=0)

Bases: detectron2.modeling.Backbone

Implement Deep Residual Learning for Image Recognition.

__init__(stem, stages, num_classes=None, out_features=None, freeze_at=0)
Parameters
  • stem (nn.Module) – a stem module

  • stages (list[list[CNNBlockBase]]) – several (typically 4) stages, each contains multiple CNNBlockBase.

  • num_classes (None or int) – if None, will not perform classification. Otherwise, will create a linear layer.

  • out_features (list[str]) – name of the layers whose outputs should be returned in forward. Can be anything in “stem”, “linear”, or “res2” … If None, will return the output of the last layer.

  • freeze_at (int) – The number of stages at the beginning to freeze. see freeze() for detailed explanation.

forward(x)
Parameters

x – Tensor of shape (N,C,H,W). H, W must be a multiple of self.size_divisibility.

Returns

dict[str->Tensor] – names and the corresponding features

freeze(freeze_at=0)

Freeze the first several stages of the ResNet. Commonly used in fine-tuning.

Layers that produce the same feature map spatial size are defined as one “stage” by Feature Pyramid Networks for Object Detection.

Parameters

freeze_at (int) – number of stages to freeze. 1 means freezing the stem. 2 means freezing the stem and one residual stage, etc.

Returns

nn.Module – this ResNet itself

static make_default_stages(depth, block_class=None, **kwargs)[source]

Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). If it doesn’t create the ResNet variant you need, please use make_stage() instead for fine-grained customization.

Parameters
  • depth (int) – depth of ResNet

  • block_class (type) – the CNN block class. Has to accept bottleneck_channels argument for depth > 50. By default it is BasicBlock or BottleneckBlock, based on the depth.

  • kwargs – other arguments to pass to make_stage. Should not contain stride and channels, as they are predefined for each depth.

Returns

list[list[CNNBlockBase]]

modules in all stages; see arguments of

ResNet.__init__.

static make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs)[source]

Create a list of blocks of the same type that forms one ResNet stage.

Parameters
  • block_class (type) – a subclass of CNNBlockBase that’s used to create all blocks in this stage. A module of this type must not change spatial resolution of inputs unless its stride != 1.

  • num_blocks (int) – number of blocks in this stage

  • in_channels (int) – input channels of the entire stage.

  • out_channels (int) – output channels of every block in the stage.

  • kwargs – other arguments passed to the constructor of block_class. If the argument name is “xx_per_block”, the argument is a list of values to be passed to each block in the stage. Otherwise, the same argument is passed to every block in the stage.

Returns

list[CNNBlockBase] – a list of block module.

Examples:

stage = ResNet.make_stage(
    BottleneckBlock, 3, in_channels=16, out_channels=64,
    bottleneck_channels=16, num_groups=1,
    stride_per_block=[2, 1, 1],
    dilations_per_block=[1, 1, 2]
)

Usually, layers that produce the same feature map spatial size are defined as one “stage” (in Feature Pyramid Networks for Object Detection). Under such definition, stride_per_block[1:] should all be 1.

output_shape()
training: bool
detectron2.modeling.build_backbone(cfg, input_shape=None)

Build a backbone from cfg.MODEL.BACKBONE.NAME.

Returns

an instance of Backbone

detectron2.modeling.build_resnet_backbone(cfg, input_shape)

Create a ResNet instance from config.

Returns

ResNet – a ResNet instance.

class detectron2.modeling.GeneralizedRCNN(*args, **kwargs)

Bases: torch.nn.Module

Generalized R-CNN. Any models that contains the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction

__init__(*, backbone: detectron2.modeling.Backbone, proposal_generator: torch.nn.Module, roi_heads: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], input_format: Optional[str] = None, vis_period: int = 0)
Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • proposal_generator – a module that generates proposals using backbone features

  • roi_heads – a ROI head that performs per-region computation

  • pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

  • pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

  • input_format – describe the meaning of channels of input. Needed by visualization

  • vis_period – the period to run visualization. Set to 0 to disable.

property device
forward(batched_inputs: List[Dict[str, torch.Tensor]])
Parameters

batched_inputs

a list, batched outputs of DatasetMapper . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:

  • image: Tensor, image in (C, H, W) format.

  • instances (optional): groundtruth Instances

  • proposals (optional): Instances, precomputed proposals.

Other information that’s included in the original dicts, such as:

  • ”height”, “width” (int): the output resolution of the model, used in inference. See postprocess() for details.

Returns

list[dict] – Each dict is the output for one input image. The dict contains one key “instances” whose value is a Instances. The Instances object has the following keys: “pred_boxes”, “pred_classes”, “scores”, “pred_masks”, “pred_keypoints”

classmethod from_config(cfg)[source]
inference(batched_inputs: List[Dict[str, torch.Tensor]], detected_instances: Optional[List[detectron2.structures.Instances]] = None, do_postprocess: bool = True)

Run inference on the given inputs.

Parameters
  • batched_inputs (list[dict]) – same as in forward()

  • detected_instances (None or list[Instances]) – if not None, it contains an Instances object per image. The Instances object contains “pred_boxes” and “pred_classes” which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs.

  • do_postprocess (bool) – whether to apply post-processing on the outputs.

Returns

When do_postprocess=True, same as in forward(). Otherwise, a list[Instances] containing raw network outputs.

preprocess_image(batched_inputs: List[Dict[str, torch.Tensor]])

Normalize, pad and batch the input images.

visualize_training(batched_inputs, proposals)

A function used to visualize images and proposals. It shows ground truth bounding boxes on the original image and up to 20 top-scoring predicted object proposals on the original image. Users can implement different visualization functions for different models.

Parameters
  • batched_inputs (list) – a list that contains input to the model.

  • proposals (list) – a list that contains predicted proposals. Both batched_inputs and proposals should have the same length.

training: bool
class detectron2.modeling.PanopticFPN(*args, **kwargs)

Bases: detectron2.modeling.GeneralizedRCNN

Implement the paper Panoptic Feature Pyramid Networks.

__init__(*, sem_seg_head: torch.nn.Module, combine_overlap_thresh: float = 0.5, combine_stuff_area_thresh: float = 4096, combine_instances_score_thresh: float = 0.5, **kwargs)

NOTE: this interface is experimental.

Parameters
  • sem_seg_head – a module for the semantic segmentation head.

  • combine_overlap_thresh – combine masks into one instances if they have enough overlap

  • combine_stuff_area_thresh – ignore stuff areas smaller than this threshold

  • combine_instances_score_thresh – ignore instances whose score is smaller than this threshold

Other arguments are the same as GeneralizedRCNN.

forward(batched_inputs)
Parameters

batched_inputs

a list, batched outputs of DatasetMapper. Each item in the list contains the inputs for one image.

For now, each item in the list is a dict that contains:

  • ”image”: Tensor, image in (C, H, W) format.

  • ”instances”: Instances

  • ”sem_seg”: semantic segmentation ground truth.

  • Other information that’s included in the original dicts, such as: “height”, “width” (int): the output resolution of the model, used in inference. See postprocess() for details.

Returns

list[dict] – each dict has the results for one image. The dict contains the following keys:

classmethod from_config(cfg)[source]
inference(batched_inputs: List[Dict[str, torch.Tensor]], do_postprocess: bool = True)

Run inference on the given inputs.

Parameters
  • batched_inputs (list[dict]) – same as in forward()

  • do_postprocess (bool) – whether to apply post-processing on the outputs.

Returns

When do_postprocess=True, see docs in forward(). Otherwise, returns a (list[Instances], list[Tensor]) that contains the raw detector outputs, and raw semantic segmentation outputs.

training: bool
class detectron2.modeling.ProposalNetwork(*args, **kwargs)

Bases: torch.nn.Module

A meta architecture that only predicts object proposals.

__init__(*, backbone: detectron2.modeling.Backbone, proposal_generator: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float])
Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • proposal_generator – a module that generates proposals using backbone features

  • pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

  • pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

property device
forward(batched_inputs)

:param Same as in GeneralizedRCNN.forward:

Returns

list[dict] – Each dict is the output for one input image. The dict contains one key “proposals” whose value is a Instances with keys “proposal_boxes” and “objectness_logits”.

classmethod from_config(cfg)[source]
training: bool
class detectron2.modeling.RetinaNet(*args, **kwargs)

Bases: torch.nn.Module

Implement RetinaNet in Focal Loss for Dense Object Detection.

__init__(*, backbone: detectron2.modeling.Backbone, head: torch.nn.Module, head_in_features, anchor_generator, box2box_transform, anchor_matcher, num_classes, focal_loss_alpha=0.25, focal_loss_gamma=2.0, smooth_l1_beta=0.0, box_reg_loss_type='smooth_l1', test_score_thresh=0.05, test_topk_candidates=1000, test_nms_thresh=0.5, max_detections_per_image=100, pixel_mean, pixel_std, vis_period=0, input_format='BGR')

NOTE: this interface is experimental.

Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • head (nn.Module) – a module that predicts logits and regression deltas for each level from a list of per-level features

  • head_in_features (Tuple[str]) – Names of the input feature maps to be used in head

  • anchor_generator (nn.Module) – a module that creates anchors from a list of features. Usually an instance of AnchorGenerator

  • box2box_transform (Box2BoxTransform) – defines the transform from anchors boxes to instance boxes

  • anchor_matcher (Matcher) – label the anchors by matching them with ground truth.

  • num_classes (int) – number of classes. Used to label background proposals.

  • Loss parameters (#) –

  • focal_loss_alpha (float) – focal_loss_alpha

  • focal_loss_gamma (float) – focal_loss_gamma

  • smooth_l1_beta (float) – smooth_l1_beta

  • box_reg_loss_type (str) – Options are “smooth_l1”, “giou”

  • Inference parameters (#) –

  • test_score_thresh (float) – Inference cls score threshold, only anchors with score > INFERENCE_TH are considered for inference (to improve speed)

  • test_topk_candidates (int) – Select topk candidates before NMS

  • test_nms_thresh (float) – Overlap threshold used for non-maximum suppression (suppress boxes with IoU >= this threshold)

  • max_detections_per_image (int) – Maximum number of detections to return per image during inference (100 is based on the limit established for the COCO dataset).

  • Input parameters (#) –

  • pixel_mean (Tuple[float]) – Values to be used for image normalization (BGR order). To train on images of different number of channels, set different mean & std. Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]

  • pixel_std (Tuple[float]) – When using pre-trained models in Detectron1 or any MSRA models, std has been absorbed into its conv1 weights, so the std needs to be set 1. Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)

  • vis_period (int) – The period (in terms of steps) for minibatch visualization at train time. Set to 0 to disable.

  • input_format (str) – Whether the model needs RGB, YUV, HSV etc.

property device
forward(batched_inputs: List[Dict[str, torch.Tensor]])
Parameters

batched_inputs

a list, batched outputs of DatasetMapper . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:

  • image: Tensor, image in (C, H, W) format.

  • instances: Instances

Other information that’s included in the original dicts, such as:

  • ”height”, “width” (int): the output resolution of the model, used in inference. See postprocess() for details.

Returns

In training, dict[str, Tensor] – mapping from a named loss to a tensor storing the loss. Used during training only. In inference, the standard output format, described in Use Models.

classmethod from_config(cfg)[source]
inference(anchors: List[detectron2.structures.Boxes], pred_logits: List[torch.Tensor], pred_anchor_deltas: List[torch.Tensor], image_sizes: List[Tuple[int, int]])
Parameters
  • anchors (list[Boxes]) – A list of #feature level Boxes. The Boxes contain anchors of this image on the specific feature level.

  • pred_logits – list[Tensor], one per level. Each has shape (N, Hi * Wi * Ai, K or 4)

  • pred_anchor_deltas – list[Tensor], one per level. Each has shape (N, Hi * Wi * Ai, K or 4)

  • image_sizes (List[(h, w)]) – the input image sizes

Returns

results (List[Instances]) – a list of #images elements.

inference_single_image(anchors: List[detectron2.structures.Boxes], box_cls: List[torch.Tensor], box_delta: List[torch.Tensor], image_size: Tuple[int, int])

Single-image inference. Return bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS).

Parameters
  • anchors (list[Boxes]) – list of #feature levels. Each entry contains a Boxes object, which contains all the anchors in that feature level.

  • box_cls (list[Tensor]) – list of #feature levels. Each entry contains tensor of size (H x W x A, K)

  • box_delta (list[Tensor]) – Same shape as ‘box_cls’ except that K becomes 4.

  • image_size (tuple(H, W)) – a tuple of the image height and width.

Returns

Same as inference, but for only one image.

label_anchors(anchors, gt_instances)
Parameters
  • anchors (list[Boxes]) – A list of #feature level Boxes. The Boxes contains anchors of this image on the specific feature level.

  • gt_instances (list[Instances]) – a list of N Instances`s. The i-th `Instances contains the ground-truth per-instance annotations for the i-th input image.

Returns

list[Tensor] – List of #img tensors. i-th element is a vector of labels whose length is the total number of anchors across all feature maps (sum(Hi * Wi * A)). Label values are in {-1, 0, …, K}, with -1 means ignore, and K means background.

list[Tensor]: i-th element is a Rx4 tensor, where R is the total number of anchors across feature maps. The values are the matched gt boxes for each anchor. Values are undefined for those anchors not labeled as foreground.

losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)
Parameters
  • anchors (list[Boxes]) – a list of #feature level Boxes

  • gt_labels – see output of RetinaNet.label_anchors(). Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)

  • gt_boxes – see output of RetinaNet.label_anchors(). Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)

  • pred_logits – both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in pred_logits.

  • pred_anchor_deltas – both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in pred_logits.

Returns

dict[str, Tensor] – mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: “loss_cls” and “loss_box_reg”

preprocess_image(batched_inputs: List[Dict[str, torch.Tensor]])

Normalize, pad and batch the input images.

visualize_training(batched_inputs, results)

A function used to visualize ground truth images and final network predictions. It shows ground truth bounding boxes on the original image and up to 20 predicted object bounding boxes on the original image.

Parameters
  • batched_inputs (list) – a list that contains input to the model.

  • results (List[Instances]) – a list of #images elements.

training: bool
class detectron2.modeling.SemanticSegmentor(*args, **kwargs)

Bases: torch.nn.Module

Main class for semantic segmentation architectures.

__init__(*, backbone: detectron2.modeling.Backbone, sem_seg_head: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float])
Parameters
  • backbone – a backbone module, must follow detectron2’s backbone interface

  • sem_seg_head – a module that predicts semantic segmentation from backbone features

  • pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

  • pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image

property device
forward(batched_inputs)
Parameters

batched_inputs

a list, batched outputs of DatasetMapper. Each item in the list contains the inputs for one image.

For now, each item in the list is a dict that contains:

  • ”image”: Tensor, image in (C, H, W) format.

  • ”sem_seg”: semantic segmentation ground truth

  • Other information that’s included in the original dicts, such as: “height”, “width” (int): the output resolution of the model (may be different from input resolution), used in inference.

Returns

list[dict] – Each dict is the output for one input image. The dict contains one key “sem_seg” whose value is a Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel.

classmethod from_config(cfg)[source]
training: bool
detectron2.modeling.build_model(cfg)

Build the whole model architecture, defined by cfg.MODEL.META_ARCHITECTURE. Note that it does not load any weights from cfg.

detectron2.modeling.build_sem_seg_head(cfg, input_shape)

Build a semantic segmentation head from cfg.MODEL.SEM_SEG_HEAD.NAME.

detectron2.modeling.detector_postprocess(results: detectron2.structures.Instances, output_height: int, output_width: int, mask_threshold: float = 0.5)

Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs.

This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution.

Parameters
  • results (Instances) – the raw outputs from the detector. results.image_size contains the input image resolution the detector sees. This object might be modified in-place.

  • output_height – the desired output resolution.

  • output_width – the desired output resolution.

Returns

Instances – the resized output from the model, based on the output resolution

detectron2.modeling.build_proposal_generator(cfg, input_shape)

Build a proposal generator from cfg.MODEL.PROPOSAL_GENERATOR.NAME. The name can be “PrecomputedProposals” to use no proposal generator.

detectron2.modeling.build_rpn_head(cfg, input_shape)

Build an RPN head defined by cfg.MODEL.RPN.HEAD_NAME.

class detectron2.modeling.ROIHeads(*args, **kwargs)

Bases: torch.nn.Module

ROIHeads perform all per-region computation in an R-CNN.

It typically contains logic to

  1. (in training only) match proposals with ground truth and sample them

  2. crop the regions and extract per-region features using proposals

  3. make per-region predictions with different heads

It can have many variants, implemented as subclasses of this class. This base class contains the logic to match/sample proposals. But it is not necessary to inherit this class if the sampling logic is not needed.

__init__(*, num_classes, batch_size_per_image, positive_fraction, proposal_matcher, proposal_append_gt=True)

NOTE: this interface is experimental.

Parameters
  • num_classes (int) – number of foreground classes (i.e. background is not included)

  • batch_size_per_image (int) – number of proposals to sample for training

  • positive_fraction (float) – fraction of positive (foreground) proposals to sample for training.

  • proposal_matcher (Matcher) – matcher that matches proposals and ground truth

  • proposal_append_gt (bool) – whether to include ground truth as proposals as well

forward(images: detectron2.structures.ImageList, features: Dict[str, torch.Tensor], proposals: List[detectron2.structures.Instances], targets: Optional[List[detectron2.structures.Instances]] = None) → Tuple[List[detectron2.structures.Instances], Dict[str, torch.Tensor]]
Parameters
  • images (ImageList) –

  • features (dict[str,Tensor]) – input data as a mapping from feature map name to tensor. Axis 0 represents the number of images N in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used).

  • proposals (list[Instances]) – length N list of Instances. The i-th Instances contains object proposals for the i-th input image, with fields “proposal_boxes” and “objectness_logits”.

  • targets (list[Instances], optional) –

    length N list of Instances. The i-th Instances contains the ground-truth per-instance annotations for the i-th input image. Specify targets during training only. It may have the following fields:

    • gt_boxes: the bounding box of each instance.

    • gt_classes: the label for each instance with a category ranging in [0, #class].

    • gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.

    • gt_keypoints: NxKx3, the groud-truth keypoints for each instance.

Returns

list[Instances] – length N list of Instances containing the detected instances. Returned during inference only; may be [] during training.

dict[str->Tensor]: mapping from a named loss to a tensor storing the loss. Used during training only.

classmethod from_config(cfg)[source]
label_and_sample_proposals(proposals: List[detectron2.structures.Instances], targets: List[detectron2.structures.Instances]) → List[detectron2.structures.Instances]

Prepare some proposals to be used to train the ROI heads. It performs box matching between proposals and targets, and assigns training labels to the proposals. It returns self.batch_size_per_image random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger than self.positive_fraction.

:param See ROIHeads.forward():

Returns

list[Instances] – length N list of Instances`s containing the proposals sampled for training. Each `Instances has the following fields:

  • proposal_boxes: the proposal boxes

  • gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random)

Other fields such as “gt_classes”, “gt_masks”, that’s included in targets.

training: bool
class detectron2.modeling.StandardROIHeads(*args, **kwargs)

Bases: detectron2.modeling.ROIHeads

It’s “standard” in a sense that there is no ROI transform sharing or feature sharing between tasks. Each head independently processes the input features by each head’s own pooler and head.

This class is used by most models, such as FPN and C5. To implement more models, you can subclass it and implement a different forward() or a head.

__init__(*, box_in_features: List[str], box_pooler: detectron2.modeling.poolers.ROIPooler, box_head: torch.nn.Module, box_predictor: torch.nn.Module, mask_in_features: Optional[List[str]] = None, mask_pooler: Optional[detectron2.modeling.poolers.ROIPooler] = None, mask_head: Optional[torch.nn.Module] = None, keypoint_in_features: Optional[List[str]] = None, keypoint_pooler: Optional[detectron2.modeling.poolers.ROIPooler] = None, keypoint_head: Optional[torch.nn.Module] = None, train_on_pred_boxes: bool = False, **kwargs)

NOTE: this interface is experimental.

Parameters
  • box_in_features (list[str]) – list of feature names to use for the box head.

  • box_pooler (ROIPooler) – pooler to extra region features for box head

  • box_head (nn.Module) – transform features to make box predictions

  • box_predictor (nn.Module) – make box predictions from the feature. Should have the same interface as FastRCNNOutputLayers.

  • mask_in_features (list[str]) – list of feature names to use for the mask pooler or mask head. None if not using mask head.

  • mask_pooler (ROIPooler) – pooler to extract region features from image features. The mask head will then take region features to make predictions. If None, the mask head will directly take the dict of image features defined by mask_in_features

  • mask_head (nn.Module) – transform features to make mask predictions

  • keypoint_in_features – similar to mask_*.

  • keypoint_pooler – similar to mask_*.

  • keypoint_head – similar to mask_*.

  • train_on_pred_boxes (bool) – whether to use proposal boxes or predicted boxes from the box head to train other heads.

forward(images: detectron2.structures.ImageList, features: Dict[str, torch.Tensor], proposals: List[detectron2.structures.Instances], targets: Optional[List[detectron2.structures.Instances]] = None) → Tuple[List[detectron2.structures.Instances], Dict[str, torch.Tensor]]

See ROIHeads.forward.

forward_with_given_boxes(features: Dict[str, torch.Tensor], instances: List[detectron2.structures.Instances]) → List[detectron2.structures.Instances]

Use the given boxes in instances to produce other (non-box) per-ROI outputs.

This is useful for downstream tasks where a box is known, but need to obtain other attributes (outputs of other heads). Test-time augmentation also uses this.

Parameters
  • features – same as in forward()

  • instances (list[Instances]) – instances to predict other outputs. Expect the keys “pred_boxes” and “pred_classes” to exist.

Returns

list[Instances] – the same Instances objects, with extra fields such as pred_masks or pred_keypoints.

classmethod from_config(cfg, input_shape)[source]
training: bool
mask_on: typing_extensions.Final[bool]
keypoint_on: typing_extensions.Final[bool]
class detectron2.modeling.BaseMaskRCNNHead(*args, **kwargs)

Bases: torch.nn.Module

Implement the basic Mask R-CNN losses and inference logic described in Mask R-CNN

__init__(*, loss_weight: float = 1.0, vis_period: int = 0)

NOTE: this interface is experimental.

Parameters
  • loss_weight (float) – multiplier of the loss

  • vis_period (int) – visualization period

forward(x, instances: List[detectron2.structures.Instances])
Parameters
  • x – input region feature(s) provided by ROIHeads.

  • instances (list[Instances]) – contains the boxes & labels corresponding to the input features. Exact format is up to its caller to decide. Typically, this is the foreground instances in training, with “proposal_boxes” field and other gt annotations. In inference, it contains boxes that are already predicted.

Returns

A dict of losses in training. The predicted “instances” in inference.

classmethod from_config(cfg, input_shape)[source]
layers(x)

Neural network layers that makes predictions from input features.

training: bool
class detectron2.modeling.BaseKeypointRCNNHead(*args, **kwargs)

Bases: torch.nn.Module

Implement the basic Keypoint R-CNN losses and inference logic described in Sec. 5 of Mask R-CNN.

__init__(*, num_keypoints, loss_weight=1.0, loss_normalizer=1.0)

NOTE: this interface is experimental.

Parameters
  • num_keypoints (int) – number of keypoints to predict

  • loss_weight (float) – weight to multiple on the keypoint loss

  • loss_normalizer (float or str) – If float, divide the loss by loss_normalizer * #images. If ‘visible’, the loss is normalized by the total number of visible keypoints across images.

forward(x, instances: List[detectron2.structures.Instances])
Parameters
  • x – input 4D region feature(s) provided by ROIHeads.

  • instances (list[Instances]) – contains the boxes & labels corresponding to the input features. Exact format is up to its caller to decide. Typically, this is the foreground instances in training, with “proposal_boxes” field and other gt annotations. In inference, it contains boxes that are already predicted.

Returns

A dict of losses if in training. The predicted “instances” if in inference.

classmethod from_config(cfg, input_shape)[source]
layers(x)

Neural network layers that makes predictions from regional input features.

training: bool
class detectron2.modeling.FastRCNNOutputLayers(*args, **kwargs)

Bases: torch.nn.Module

Two linear layers for predicting Fast R-CNN outputs:

  1. proposal-to-detection box regression deltas

  2. classification scores

__init__(input_shape: detectron2.layers.shape_spec.ShapeSpec, *, box2box_transform, num_classes: int, test_score_thresh: float = 0.0, test_nms_thresh: float = 0.5, test_topk_per_image: int = 100, cls_agnostic_bbox_reg: bool = False, smooth_l1_beta: float = 0.0, box_reg_loss_type: str = 'smooth_l1', loss_weight: Union[float, Dict[str, float]] = 1.0)

NOTE: this interface is experimental.

Parameters
  • input_shape (ShapeSpec) – shape of the input feature to this module

  • box2box_transform (Box2BoxTransform or Box2BoxTransformRotated) –

  • num_classes (int) – number of foreground classes

  • test_score_thresh (float) – threshold to filter predictions results.

  • test_nms_thresh (float) – NMS threshold for prediction results.

  • test_topk_per_image (int) – number of top predictions to produce per image.

  • cls_agnostic_bbox_reg (bool) – whether to use class agnostic for bbox regression

  • smooth_l1_beta (float) – transition point from L1 to L2 loss. Only used if box_reg_loss_type is “smooth_l1”

  • box_reg_loss_type (str) – Box regression loss type. One of: “smooth_l1”, “giou”

  • loss_weight (float|dict) –

    weights to use for losses. Can be single float for weighting all losses, or a dict of individual weightings. Valid dict keys are:

    • ”loss_cls”: applied to classification loss

    • ”loss_box_reg”: applied to box regression loss

box_reg_loss(proposal_boxes, gt_boxes, pred_deltas, gt_classes)
Parameters
  • boxes are tensors with the same shape Rx (All) –

  • is a long tensor of shape R (gt_classes) –

  • gt class label of each proposal. (the) –

  • shall be the number of proposals. (R) –

forward(x)
Parameters

x – per-region features of shape (N, …) for N bounding boxes to predict.

Returns

(Tensor, Tensor) – First tensor: shape (N,K+1), scores for each of the N box. Each row contains the scores for K object categories and 1 background class.

Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4), or (N,4) for class-agnostic regression.

classmethod from_config(cfg, input_shape)[source]
inference(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])
Parameters
  • predictions – return values of forward().

  • proposals (list[Instances]) – proposals that match the features that were used to compute predictions. The proposal_boxes field is expected.

Returns

list[Instances] – same as fast_rcnn_inference. list[Tensor]: same as fast_rcnn_inference.

losses(predictions, proposals)
Parameters
  • predictions – return values of forward().

  • proposals (list[Instances]) – proposals that match the features that were used to compute predictions. The fields proposal_boxes, gt_boxes, gt_classes are expected.

Returns

Dict[str, Tensor] – dict of losses

predict_boxes(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])
Parameters
  • predictions – return values of forward().

  • proposals (list[Instances]) – proposals that match the features that were used to compute predictions. The proposal_boxes field is expected.

Returns

list[Tensor] – A list of Tensors of predicted class-specific or class-agnostic boxes for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is the number of proposals for image i and B is the box dimension (4 or 5)

predict_boxes_for_gt_classes(predictions, proposals)
Parameters
  • predictions – return values of forward().

  • proposals (list[Instances]) – proposals that match the features that were used to compute predictions. The fields proposal_boxes, gt_classes are expected.

Returns

list[Tensor] – A list of Tensors of predicted boxes for GT classes in case of class-specific box head. Element i of the list has shape (Ri, B), where Ri is the number of proposals for image i and B is the box dimension (4 or 5)

predict_probs(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])
Parameters
  • predictions – return values of forward().

  • proposals (list[Instances]) – proposals that match the features that were used to compute predictions.

Returns

list[Tensor] – A list of Tensors of predicted class probabilities for each image. Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i.

training: bool
detectron2.modeling.build_box_head(cfg, input_shape)

Build a box head defined by cfg.MODEL.ROI_BOX_HEAD.NAME.

detectron2.modeling.build_keypoint_head(cfg, input_shape)

Build a keypoint head from cfg.MODEL.ROI_KEYPOINT_HEAD.NAME.

detectron2.modeling.build_mask_head(cfg, input_shape)

Build a mask head defined by cfg.MODEL.ROI_MASK_HEAD.NAME.

detectron2.modeling.build_roi_heads(cfg, input_shape)

Build ROIHeads defined by cfg.MODEL.ROI_HEADS.NAME.

class detectron2.modeling.DatasetMapperTTA(*args, **kwargs)

Bases: object

Implement test-time augmentation for detection data. It is a callable which takes a dataset dict from a detection dataset, and returns a list of dataset dicts where the images are augmented from the input image by the transformations defined in the config. This is used for test-time augmentation.

__call__(dataset_dict)
Parameters

dict – a dict in standard model input format. See tutorials for details.

Returns

list[dict] – a list of dicts, which contain augmented version of the input image. The total number of dicts is len(min_sizes) * (2 if flip else 1). Each dict has field “transforms” which is a TransformList, containing the transforms that are used to generate this image.

__init__(min_sizes: List[int], max_size: int, flip: bool)
Parameters
  • min_sizes – list of short-edge size to resize the image to

  • max_size – maximum height or width of resized images

  • flip – whether to apply flipping augmentation

classmethod from_config(cfg)[source]
class detectron2.modeling.GeneralizedRCNNWithTTA(cfg, model, tta_mapper=None, batch_size=3)

Bases: torch.nn.Module

A GeneralizedRCNN with test-time augmentation enabled. Its __call__() method has the same interface as GeneralizedRCNN.forward().

__call__(batched_inputs)

Same input/output format as GeneralizedRCNN.forward()

__init__(cfg, model, tta_mapper=None, batch_size=3)
Parameters
  • cfg (CfgNode) –

  • model (GeneralizedRCNN) – a GeneralizedRCNN to apply TTA on.

  • tta_mapper (callable) – takes a dataset dict and returns a list of augmented versions of the dataset dict. Defaults to DatasetMapperTTA(cfg).

  • batch_size (int) – batch the augmented images into this batch size for inference.

training: bool
class detectron2.modeling.MMDetBackbone(backbone: Union[torch.nn.Module, collections.abc.Mapping], neck: Optional[Union[torch.nn.Module, collections.abc.Mapping]] = None, *, pretrained_backbone: Optional[str] = None, output_shapes: List[detectron2.layers.shape_spec.ShapeSpec], output_names: Optional[List[str]] = None)

Bases: detectron2.modeling.Backbone

Wrapper of mmdetection backbones to use in detectron2.

mmdet backbones produce list/tuple of tensors, while detectron2 backbones produce a dict of tensors. This class wraps the given backbone to produce output in detectron2’s convention, so it can be used in place of detectron2 backbones.

__init__(backbone: Union[torch.nn.Module, collections.abc.Mapping], neck: Optional[Union[torch.nn.Module, collections.abc.Mapping]] = None, *, pretrained_backbone: Optional[str] = None, output_shapes: List[detectron2.layers.shape_spec.ShapeSpec], output_names: Optional[List[str]] = None)
Parameters
  • backbone – either a backbone module or a mmdet config dict that defines a backbone. The backbone takes a 4D image tensor and returns a sequence of tensors.

  • neck – either a backbone module or a mmdet config dict that defines a neck. The neck takes outputs of backbone and returns a sequence of tensors. If None, no neck is used.

  • pretrained_backbone – defines the backbone weights that can be loaded by mmdet, such as “torchvision://resnet50”.

  • output_shapes – shape for every output of the backbone (or neck, if given). stride and channels are often needed.

  • output_names – names for every output of the backbone (or neck, if given). By default, will use “out0”, “out1”, …

forward(x) → Dict[str, torch.Tensor]
output_shape() → Dict[str, detectron2.layers.shape_spec.ShapeSpec]
training: bool
class detectron2.modeling.MMDetDetector(detector: Union[torch.nn.Module, collections.abc.Mapping], *, size_divisibility=32, pixel_mean: Tuple[float], pixel_std: Tuple[float])

Bases: torch.nn.Module

Wrapper of a mmdetection detector model, for detection and instance segmentation. Input/output formats of this class follow detectron2’s convention, so a mmdetection model can be trained and evaluated in detectron2.

__init__(detector: Union[torch.nn.Module, collections.abc.Mapping], *, size_divisibility=32, pixel_mean: Tuple[float], pixel_std: Tuple[float])
Parameters
  • detector – a mmdet detector, or a mmdet config dict that defines a detector.

  • size_divisibility – pad input images to multiple of this number

  • pixel_mean – per-channel mean to normalize input image

  • pixel_std – per-channel stddev to normalize input image

property device
forward(batched_inputs: List[Dict[str, torch.Tensor]])
training: bool

detectron2.modeling.poolers module

class detectron2.modeling.poolers.ROIPooler(output_size, scales, sampling_ratio, pooler_type, canonical_box_size=224, canonical_level=4)[source]

Bases: torch.nn.Module

Region of interest feature map pooler that supports pooling from one or more feature maps.

__init__(output_size, scales, sampling_ratio, pooler_type, canonical_box_size=224, canonical_level=4)[source]
Parameters
  • output_size (int, tuple[int] or list[int]) – output size of the pooled region, e.g., 14 x 14. If tuple or list is given, the length must be 2.

  • scales (list[float]) – The scale for each low-level pooling op relative to the input image. For a feature map with stride s relative to the input image, scale is defined as 1/s. The stride must be power of 2. When there are multiple scales, they must form a pyramid, i.e. they must be a monotically decreasing geometric sequence with a factor of 1/2.

  • sampling_ratio (int) – The sampling_ratio parameter for the ROIAlign op.

  • pooler_type (string) – Name of the type of pooling operation that should be applied. For instance, “ROIPool” or “ROIAlignV2”.

  • canonical_box_size (int) – A canonical box size in pixels (sqrt(box area)). The default is heuristically defined as 224 pixels in the FPN paper (based on ImageNet pre-training).

  • canonical_level (int) –

    The feature map level index from which a canonically-sized box should be placed. The default is defined as level 4 (stride=16) in the FPN paper, i.e., a box of size 224x224 will be placed on the feature with stride=16. The box placement for all boxes will be determined from their sizes w.r.t canonical_box_size. For example, a box whose area is 4x that of a canonical box should be used to pool features from feature level canonical_level+1.

    Note that the actual input feature maps given to this module may not have sufficiently many levels for the input boxes. If the boxes are too large or too small for the input feature maps, the closest level will be used.

training: bool
forward(x: List[torch.Tensor], box_lists: List[detectron2.structures.Boxes])[source]
Parameters
  • x (list[Tensor]) – A list of feature maps of NCHW shape, with scales matching those used to construct this module.

  • box_lists (list[Boxes] | list[RotatedBoxes]) – A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. The box coordinates are defined on the original image and will be scaled by the scales argument of ROIPooler.

Returns

Tensor – A tensor of shape (M, C, output_size, output_size) where M is the total number of boxes aggregated over all N batch images and C is the number of channels in x.

detectron2.modeling.sampling module

detectron2.modeling.sampling.subsample_labels(labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int)[source]

Return num_samples (or fewer, if not enough found) random samples from labels which is a mixture of positives & negatives. It will try to return as many positives as possible without exceeding positive_fraction * num_samples, and then try to fill the remaining slots with negatives.

Parameters
  • labels (Tensor) – (N, ) label vector with values: * -1: ignore * bg_label: background (“negative”) class * otherwise: one or more foreground (“positive”) classes

  • num_samples (int) – The total number of labels with value >= 0 to return. Values that are not sampled will be filled with -1 (ignore).

  • positive_fraction (float) – The number of subsampled labels with values > 0 is min(num_positives, int(positive_fraction * num_samples)). The number of negatives sampled is min(num_negatives, num_samples - num_positives_sampled). In order words, if there are not enough positives, the sample is filled with negatives. If there are also not enough negatives, then as many elements are sampled as is possible.

  • bg_label (int) – label index of background (“negative”) class.

Returns

pos_idx, neg_idx (Tensor) – 1D vector of indices. The total length of both is num_samples or fewer.

detectron2.modeling.box_regression module

class detectron2.modeling.box_regression.Box2BoxTransform(weights: Tuple[float, float, float, float], scale_clamp: float = 4.135166556742356)[source]

Bases: object

The box-to-box transform defined in R-CNN. The transformation is parameterized by 4 deltas: (dx, dy, dw, dh). The transformation scales the box’s width and height by exp(dw), exp(dh) and shifts a box’s center by the offset (dx * width, dy * height).

__init__(weights: Tuple[float, float, float, float], scale_clamp: float = 4.135166556742356)[source]
Parameters
  • weights (4-element tuple) – Scaling factors that are applied to the (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set such that the deltas have unit variance; now they are treated as hyperparameters of the system.

  • scale_clamp (float) – When predicting deltas, the predicted box scaling factors (dw and dh) are clamped such that they are <= scale_clamp.

get_deltas(src_boxes, target_boxes)[source]

Get box regression transformation deltas (dx, dy, dw, dh) that can be used to transform the src_boxes into the target_boxes. That is, the relation target_boxes == self.apply_deltas(deltas, src_boxes) is true (unless any delta is too large and is clamped).

Parameters
  • src_boxes (Tensor) – source boxes, e.g., object proposals

  • target_boxes (Tensor) – target of the transformation, e.g., ground-truth boxes.

apply_deltas(deltas, boxes)[source]

Apply transformation deltas (dx, dy, dw, dh) to boxes.

Parameters
  • deltas (Tensor) – transformation deltas of shape (N, k*4), where k >= 1. deltas[i] represents k potentially different class-specific box transformations for the single box boxes[i].

  • boxes (Tensor) – boxes to transform, of shape (N, 4)

class detectron2.modeling.box_regression.Box2BoxTransformRotated(weights: Tuple[float, float, float, float, float], scale_clamp: float = 4.135166556742356)[source]

Bases: object

The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box’s width and height by exp(dw), exp(dh), shifts a box’s center by the offset (dx * width, dy * height), and rotate a box’s angle by da (radians). Note: angles of deltas are in radians while angles of boxes are in degrees.

__init__(weights: Tuple[float, float, float, float, float], scale_clamp: float = 4.135166556742356)[source]
Parameters
  • weights (5-element tuple) – Scaling factors that are applied to the (dx, dy, dw, dh, da) deltas. These are treated as hyperparameters of the system.

  • scale_clamp (float) – When predicting deltas, the predicted box scaling factors (dw and dh) are clamped such that they are <= scale_clamp.

get_deltas(src_boxes, target_boxes)[source]

Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used to transform the src_boxes into the target_boxes. That is, the relation target_boxes == self.apply_deltas(deltas, src_boxes) is true (unless any delta is too large and is clamped).

Parameters
  • src_boxes (Tensor) – Nx5 source boxes, e.g., object proposals

  • target_boxes (Tensor) – Nx5 target of the transformation, e.g., ground-truth boxes.

apply_deltas(deltas, boxes)[source]

Apply transformation deltas (dx, dy, dw, dh, da) to boxes.

Parameters
  • deltas (Tensor) – transformation deltas of shape (N, k*5). deltas[i] represents box transformation for the single box boxes[i].

  • boxes (Tensor) – boxes to transform, of shape (N, 5)

Model Registries

These are different registries provided in modeling. Each registry provide you the ability to replace it with your customized component, without having to modify detectron2’s code.

Note that it is impossible to allow users to customize any line of code directly. Even just to add one line at some place, you’ll likely need to find out the smallest registry which contains that line, and register your component to that registry.

detectron2.modeling.META_ARCH_REGISTRY = Registry of META_ARCH: ╒═══════════════════╤═════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═══════════════════╪═════════════════════════════════════════════════╡ │ GeneralizedRCNN │ <class 'detectron2.modeling.GeneralizedRCNN'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ ProposalNetwork │ <class 'detectron2.modeling.ProposalNetwork'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ SemanticSegmentor │ <class 'detectron2.modeling.SemanticSegmentor'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ PanopticFPN │ <class 'detectron2.modeling.PanopticFPN'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ RetinaNet │ <class 'detectron2.modeling.RetinaNet'> │ ╘═══════════════════╧═════════════════════════════════════════════════╛

Registry for meta-architectures, i.e. the whole model.

The registered object will be called with obj(cfg) and expected to return a nn.Module object.

detectron2.modeling.BACKBONE_REGISTRY = Registry of BACKBONE: ╒═════════════════════════════════════╤══════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════════════════════════╪══════════════════════════════════════════════════════════════════╡ │ build_resnet_backbone │ <function build_resnet_backbone> │ ├─────────────────────────────────────┼──────────────────────────────────────────────────────────────────┤ │ build_resnet_fpn_backbone │ <function build_resnet_fpn_backbone> │ ├─────────────────────────────────────┼──────────────────────────────────────────────────────────────────┤ │ build_retinanet_resnet_fpn_backbone │ <function build_retinanet_resnet_fpn_backbone> │ ╘═════════════════════════════════════╧══════════════════════════════════════════════════════════════════╛

Registry for backbones, which extract feature maps from images

The registered object must be a callable that accepts two arguments:

  1. A detectron2.config.CfgNode

  2. A detectron2.layers.ShapeSpec, which contains the input shape specification.

Registered object must return instance of Backbone.

detectron2.modeling.PROPOSAL_GENERATOR_REGISTRY = Registry of PROPOSAL_GENERATOR: ╒═════════╤════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════╪════════════════════════════════════════════════════════════╡ │ RPN │ <class 'detectron2.modeling.proposal_generator.rpn.RPN'> │ ├─────────┼────────────────────────────────────────────────────────────┤ │ RRPN │ <class 'detectron2.modeling.proposal_generator.rrpn.RRPN'> │ ╘═════════╧════════════════════════════════════════════════════════════╛

Registry for proposal generator, which produces object proposals from feature maps.

The registered object will be called with obj(cfg, input_shape). The call should return a nn.Module object.

detectron2.modeling.RPN_HEAD_REGISTRY = Registry of RPN_HEAD: ╒═════════════════╤══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════╪══════════════════════════════════════════════════════════════════════╡ │ StandardRPNHead │ <class 'detectron2.modeling.proposal_generator.rpn.StandardRPNHead'> │ ╘═════════════════╧══════════════════════════════════════════════════════════════════════╛

Registry for RPN heads, which take feature maps and perform objectness classification and bounding box regression for anchors.

The registered object will be called with obj(cfg, input_shape). The call should return a nn.Module object.

detectron2.modeling.ANCHOR_GENERATOR_REGISTRY = Registry of ANCHOR_GENERATOR: ╒════════════════════════╤═══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞════════════════════════╪═══════════════════════════════════════════════════════════════════════╡ │ DefaultAnchorGenerator │ <class 'detectron2.modeling.anchor_generator.DefaultAnchorGenerator'> │ ├────────────────────────┼───────────────────────────────────────────────────────────────────────┤ │ RotatedAnchorGenerator │ <class 'detectron2.modeling.anchor_generator.RotatedAnchorGenerator'> │ ╘════════════════════════╧═══════════════════════════════════════════════════════════════════════╛

Registry for modules that creates object detection anchors for feature maps.

The registered object will be called with obj(cfg, input_shape).

detectron2.modeling.ROI_HEADS_REGISTRY = Registry of ROI_HEADS: ╒══════════════════╤══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞══════════════════╪══════════════════════════════════════════════════════════════════════╡ │ Res5ROIHeads │ <class 'detectron2.modeling.roi_heads.roi_heads.Res5ROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ StandardROIHeads │ <class 'detectron2.modeling.StandardROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ CascadeROIHeads │ <class 'detectron2.modeling.roi_heads.cascade_rcnn.CascadeROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ RROIHeads │ <class 'detectron2.modeling.roi_heads.rotated_fast_rcnn.RROIHeads'> │ ╘══════════════════╧══════════════════════════════════════════════════════════════════════╛

Registry for ROI heads in a generalized R-CNN model. ROIHeads take feature maps and region proposals, and perform per-region computation.

The registered object will be called with obj(cfg, input_shape). The call is expected to return an ROIHeads.

detectron2.modeling.ROI_BOX_HEAD_REGISTRY = Registry of ROI_BOX_HEAD: ╒════════════════════╤═════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞════════════════════╪═════════════════════════════════════════════════════════════════════╡ │ FastRCNNConvFCHead │ <class 'detectron2.modeling.roi_heads.box_head.FastRCNNConvFCHead'> │ ╘════════════════════╧═════════════════════════════════════════════════════════════════════╛

Registry for box heads, which make box predictions from per-region features.

The registered object will be called with obj(cfg, input_shape).

detectron2.modeling.ROI_MASK_HEAD_REGISTRY = Registry of ROI_MASK_HEAD: ╒══════════════════════════╤════════════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞══════════════════════════╪════════════════════════════════════════════════════════════════════════════╡ │ MaskRCNNConvUpsampleHead │ <class 'detectron2.modeling.roi_heads.mask_head.MaskRCNNConvUpsampleHead'> │ ╘══════════════════════════╧════════════════════════════════════════════════════════════════════════════╛

Registry for mask heads, which predicts instance masks given per-region features.

The registered object will be called with obj(cfg, input_shape).

detectron2.modeling.ROI_KEYPOINT_HEAD_REGISTRY = Registry of ROI_KEYPOINT_HEAD: ╒═════════════════════════════╤═══════════════════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════════════════╪═══════════════════════════════════════════════════════════════════════════════════╡ │ KRCNNConvDeconvUpsampleHead │ <class 'detectron2.modeling.roi_heads.keypoint_head.KRCNNConvDeconvUpsampleHead'> │ ╘═════════════════════════════╧═══════════════════════════════════════════════════════════════════════════════════╛

Registry for keypoint heads, which make keypoint predictions from per-region features.

The registered object will be called with obj(cfg, input_shape).