Welcome to detectron2’s documentation!¶
Tutorials¶
Installation¶
Requirements¶
Linux or macOS with Python ≥ 3.6
PyTorch ≥ 1.7 and torchvision that matches the PyTorch installation. Install them together at pytorch.org to make sure of this
OpenCV is optional but needed by demo and visualization
Build Detectron2 from Source¶
gcc & g++ ≥ 5.4 are required. ninja is optional but recommended for faster build. After having them, run:
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
# (add --user if you don't have permission)
# Or, to install it from a local clone:
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2
# On macOS, you may need to prepend the above commands with a few environment variables:
CC=clang CXX=clang++ ARCHFLAGS="-arch x86_64" python -m pip install ...
To rebuild detectron2 that’s built from a local clone, use rm -rf build/ **/*.so
to clean the
old build first. You often need to rebuild detectron2 after reinstalling PyTorch.
Install Pre-Built Detectron2 (Linux only)¶
Choose from this table to install v0.5 (Jul 2021):
CUDA | torch 1.9 | torch 1.8 | torch 1.7 |
---|---|---|---|
11.1 | install | install | |
11.0 | install | ||
10.2 | install | install | install |
10.1 | install | install | |
9.2 | install | ||
cpu | install | install | install |
Note that:
The pre-built packages have to be used with corresponding version of CUDA and the official package of PyTorch. Otherwise, please build detectron2 from source.
New packages are released every few months. Therefore, packages may not contain latest features in the master branch and may not be compatible with the master branch of a research project that uses detectron2 (e.g. those in projects).
Common Installation Issues¶
Click each issue for its solutions:
Undefined symbols that contains TH,aten,torch,caffe2.
This usually happens when detectron2 or torchvision is not compiled with the version of PyTorch you’re running.
If the error comes from a pre-built torchvision, uninstall torchvision and pytorch and reinstall them following pytorch.org. So the versions will match.
If the error comes from a pre-built detectron2, check release notes, uninstall and reinstall the correct pre-built detectron2 that matches pytorch version.
If the error comes from detectron2 or torchvision that you built manually from source,
remove files you built (build/
, **/*.so
) and rebuild it so it can pick up the version of pytorch currently in your environment.
If the above instructions do not resolve this problem, please provide an environment (e.g. a dockerfile) that can reproduce the issue.
Missing torch dynamic libraries, OR segmentation fault immediately when using detectron2.
This usually happens when detectron2 or torchvision is not compiled with the version of PyTorch you're running. See the previous common issue for the solution.Undefined C++ symbols (e.g. GLIBCXX) or C++ symbols not found.
Usually it's because the library is compiled with a newer C++ compiler but run with an old C++ runtime.
This often happens with old anaconda.
It may help to run conda update libgcc
to upgrade its runtime.
The fundamental solution is to avoid the mismatch, either by compiling using older version of C++
compiler, or run the code with proper C++ runtime.
To run the code with a specific C++ runtime, you can use environment variable LD_PRELOAD=/path/to/libstdc++.so
.
"nvcc not found" or "Not compiled with GPU support" or "Detectron2 CUDA Compiler: not available".
CUDA is not found when building detectron2. You should make sure
python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
print (True, a directory with cuda)
at the time you build detectron2.
Most models can run inference (but not training) without GPU support. To use CPUs, set MODEL.DEVICE='cpu'
in the config.
"invalid device function" or "no kernel image is available for execution".
Two possibilities:
You build detectron2 with one version of CUDA but run it with a different version.
To check whether it is the case, use
python -m detectron2.utils.collect_env
to find out inconsistent CUDA versions. In the output of this command, you should expect “Detectron2 CUDA Compiler”, “CUDA_HOME”, “PyTorch built with - CUDA” to contain cuda libraries of the same version.When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself) to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
PyTorch/torchvision/Detectron2 is not built for the correct GPU SM architecture (aka. compute capability).
The architecture included by PyTorch/detectron2/torchvision is available in the “architecture flags” in
python -m detectron2.utils.collect_env
. It must include the architecture of your GPU, which can be found at developer.nvidia.com/cuda-gpus.If you’re using pre-built PyTorch/detectron2/torchvision, they have included support for most popular GPUs already. If not supported, you need to build them from source.
When building detectron2/torchvision from source, they detect the GPU device and build for only the device. This means the compiled code may not work on a different GPU device. To recompile them for the correct architecture, remove all installed/compiled files, and rebuild them with the
TORCH_CUDA_ARCH_LIST
environment variable set properly. For example,export TORCH_CUDA_ARCH_LIST="6.0;7.0"
makes it compile for both P100s and V100s.
Undefined CUDA symbols; Cannot open libcudart.so
The version of NVCC you use to build detectron2 or torchvision does not match the version of CUDA you are running with. This often happens when using anaconda's CUDA runtime.
Use python -m detectron2.utils.collect_env
to find out inconsistent CUDA versions.
In the output of this command, you should expect “Detectron2 CUDA Compiler”, “CUDA_HOME”, “PyTorch built with - CUDA”
to contain cuda libraries of the same version.
When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself) to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
C++ compilation errors from NVCC / NVRTC, or "Unsupported gpu architecture"
A few possibilities:
Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in
python collect_env.py
. When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself) to match your local CUDA installation, or install a different version of CUDA to match PyTorch.Local CUDA/NVCC version shall support the SM architecture (a.k.a. compute capability) of your GPU. The capability of your GPU can be found at developer.nvidia.com/cuda-gpus. The capability supported by NVCC is listed at here. If your NVCC version is too old, this can be workaround by setting environment variable
TORCH_CUDA_ARCH_LIST
to a lower, supported capability.The combination of NVCC and GCC you use is incompatible. You need to change one of their versions. See here for some valid combinations. Notably, CUDA<=10.1.105 doesn’t support GCC>7.3.
The CUDA/GCC version used by PyTorch can be found by
print(torch.__config__.show())
.
"ImportError: cannot import name '_C'".
Please build and install detectron2 following the instructions above.
Or, if you are running code from detectron2’s root directory, cd
to a different one.
Otherwise you may not import the code that you installed.
Any issue on windows.
Detectron2 is continuously built on windows with CircleCI. However we do not provide official support for it. PRs that improves code compatibility on windows are welcome.
ONNX conversion segfault after some "TraceWarning".
The ONNX package is compiled with a too old compiler.
Please build and install ONNX from its source code using a compiler
whose version is closer to what’s used by PyTorch (available in torch.__config__.show()
).
"library not found for -lstdc++" on older version of MacOS
See [this stackoverflow answer](https://stackoverflow.com/questions/56083725/macos-build-issues-lstdc-not-found-while-building-python-package).
Installation inside specific environments:¶
Colab: see our Colab Tutorial which has step-by-step instructions.
Docker: The official Dockerfile installs detectron2 with a few simple commands.
Getting Started with Detectron2¶
This document provides a brief intro of the usage of builtin command-line tools in detectron2.
For a tutorial that involves actual coding with the API, see our Colab Notebook which covers how to run inference with an existing model, and how to train a builtin model on a custom dataset.
Inference Demo with Pre-trained Models¶
Pick a model and its config file from model zoo, for example,
mask_rcnn_R_50_FPN_3x.yaml
.We provide
demo.py
that is able to demo builtin configs. Run it with:
cd demo/
python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
--input input1.jpg input2.jpg \
[--other-options]
--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
The configs are made for training, therefore we need to specify MODEL.WEIGHTS
to a model from model zoo for evaluation.
This command will run the inference and show visualizations in an OpenCV window.
For details of the command line arguments, see demo.py -h
or look at its source code
to understand its behavior. Some common arguments are:
To run on your webcam, replace
--input files
with--webcam
.To run on a video, replace
--input files
with--video-input video.mp4
.To run on cpu, add
MODEL.DEVICE cpu
after--opts
.To save outputs to a directory (for images) or a file (for webcam or video), use
--output
.
Training & Evaluation in Command Line¶
We provide two scripts in “tools/plain_train_net.py” and “tools/train_net.py”, that are made to train all the configs provided in detectron2. You may want to use it as a reference to write your own training script.
Compared to “train_net.py”, “plain_train_net.py” supports fewer default features. It also includes fewer abstraction, therefore is easier to add custom logic.
To train a model with “train_net.py”, first setup the corresponding datasets following datasets/README.md, then run:
cd tools/
./train_net.py --num-gpus 8 \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
The configs are made for 8-GPU training. To train on 1 GPU, you may need to change some parameters, e.g.:
./train_net.py \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
--num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
To evaluate a model’s performance, use
./train_net.py \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
For more options, see ./train_net.py -h
.
Use Detectron2 APIs in Your Code¶
See our Colab Notebook to learn how to use detectron2 APIs to:
run inference with an existing model
train a builtin model on a custom dataset
See detectron2/projects for more ways to build your project on detectron2.
Use Builtin Datasets¶
A dataset can be used by accessing DatasetCatalog
for its data, or MetadataCatalog for its metadata (class names, etc).
This document explains how to setup the builtin datasets so they can be used by the above APIs.
Use Custom Datasets gives a deeper dive on how to use DatasetCatalog
and MetadataCatalog
,
and how to add new datasets to them.
Detectron2 has builtin support for a few datasets.
The datasets are assumed to exist in a directory specified by the environment variable
DETECTRON2_DATASETS
.
Under this directory, detectron2 will look for datasets in the structure described below, if needed.
$DETECTRON2_DATASETS/
coco/
lvis/
cityscapes/
VOC20{07,12}/
You can set the location for builtin datasets by export DETECTRON2_DATASETS=/path/to/datasets
.
If left unset, the default is ./datasets
relative to your current working directory.
The model zoo contains configs and models that use these builtin datasets.
Expected dataset structure for COCO instance/keypoint detection:¶
coco/
annotations/
instances_{train,val}2017.json
person_keypoints_{train,val}2017.json
{train,val}2017/
# image files that are mentioned in the corresponding json
You can use the 2014 version of the dataset as well.
Some of the builtin tests (dev/run_*_tests.sh
) uses a tiny version of the COCO dataset,
which you can download with ./datasets/prepare_for_tests.sh
.
Expected dataset structure for PanopticFPN:¶
Extract panoptic annotations from COCO website into the following structure:
coco/
annotations/
panoptic_{train,val}2017.json
panoptic_{train,val}2017/ # png annotations
panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
Install panopticapi by:
pip install git+https://github.com/cocodataset/panopticapi.git
Then, run python datasets/prepare_panoptic_fpn.py
, to extract semantic annotations from panoptic annotations.
Expected dataset structure for LVIS instance segmentation:¶
coco/
{train,val,test}2017/
lvis/
lvis_v0.5_{train,val}.json
lvis_v0.5_image_info_test.json
lvis_v1_{train,val}.json
lvis_v1_image_info_test{,_challenge}.json
Install lvis-api by:
pip install git+https://github.com/lvis-dataset/lvis-api.git
To evaluate models trained on the COCO dataset using LVIS annotations,
run python datasets/prepare_cocofied_lvis.py
to prepare “cocofied” LVIS annotations.
Expected dataset structure for cityscapes:¶
cityscapes/
gtFine/
train/
aachen/
color.png, instanceIds.png, labelIds.png, polygons.json,
labelTrainIds.png
...
val/
test/
# below are generated Cityscapes panoptic annotation
cityscapes_panoptic_train.json
cityscapes_panoptic_train/
cityscapes_panoptic_val.json
cityscapes_panoptic_val/
cityscapes_panoptic_test.json
cityscapes_panoptic_test/
leftImg8bit/
train/
val/
test/
Install cityscapes scripts by:
pip install git+https://github.com/mcordts/cityscapesScripts.git
Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with:
CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py
These files are not needed for instance segmentation.
Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
These files are not needed for semantic and instance segmentation.
Expected dataset structure for Pascal VOC:¶
VOC20{07,12}/
Annotations/
ImageSets/
Main/
trainval.txt
test.txt
# train.txt or val.txt, if you use these splits
JPEGImages/
Expected dataset structure for ADE20k Scene Parsing:¶
ADEChallengeData2016/
annotations/
annotations_detectron2/
images/
objectInfo150.txt
The directory annotations_detectron2
is generated by running python datasets/prepare_ade20k_sem_seg.py
.
Extend Detectron2’s Defaults¶
Research is about doing things in new ways. This brings a tension in how to create abstractions in code, which is a challenge for any research engineering project of a significant size:
On one hand, it needs to have very thin abstractions to allow for the possibility of doing everything in new ways. It should be reasonably easy to break existing abstractions and replace them with new ones.
On the other hand, such a project also needs reasonably high-level abstractions, so that users can easily do things in standard ways, without worrying too much about the details that only certain researchers care about.
In detectron2, there are two types of interfaces that address this tension together:
Functions and classes that take a config (
cfg
) argument created from a yaml file (sometimes with few extra arguments).Such functions and classes implement the “standard default” behavior: it will read what it needs from a given config and do the “standard” thing. Users only need to load an expert-made config and pass it around, without having to worry about which arguments are used and what they all mean.
See Yacs Configs for a detailed tutorial.
Functions and classes that have well-defined explicit arguments.
Each of these is a small building block of the entire system. They require users’ expertise to understand what each argument should be, and require more effort to stitch together to a larger system. But they can be stitched together in more flexible ways.
When you need to implement something not supported by the “standard defaults” included in detectron2, these well-defined components can be reused.
The LazyConfig system relies on such functions and classes.
A few functions and classes are implemented with the @configurable decorator - they can be called with either a config, or with explicit arguments, or a mixture of both. Their explicit argument interfaces are currently experimental.
As an example, a Mask R-CNN model can be built in the following ways:
Config-only:
# load proper yaml config file, then model = build_model(cfg)
Mixture of config and additional argument overrides:
model = GeneralizedRCNN( cfg, roi_heads=StandardROIHeads(cfg, batch_size_per_image=666), pixel_std=[57.0, 57.0, 57.0])
Full explicit arguments:
(click to expand)
model = GeneralizedRCNN( backbone=FPN( ResNet( BasicStem(3, 64, norm="FrozenBN"), ResNet.make_default_stages(50, stride_in_1x1=True, norm="FrozenBN"), out_features=["res2", "res3", "res4", "res5"], ).freeze(2), ["res2", "res3", "res4", "res5"], 256, top_block=LastLevelMaxPool(), ), proposal_generator=RPN( in_features=["p2", "p3", "p4", "p5", "p6"], head=StandardRPNHead(in_channels=256, num_anchors=3), anchor_generator=DefaultAnchorGenerator( sizes=[[32], [64], [128], [256], [512]], aspect_ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64], offset=0.0, ), anchor_matcher=Matcher([0.3, 0.7], [0, -1, 1], allow_low_quality_matches=True), box2box_transform=Box2BoxTransform([1.0, 1.0, 1.0, 1.0]), batch_size_per_image=256, positive_fraction=0.5, pre_nms_topk=(2000, 1000), post_nms_topk=(1000, 1000), nms_thresh=0.7, ), roi_heads=StandardROIHeads( num_classes=80, batch_size_per_image=512, positive_fraction=0.25, proposal_matcher=Matcher([0.5], [0, 1], allow_low_quality_matches=False), box_in_features=["p2", "p3", "p4", "p5"], box_pooler=ROIPooler(7, (1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 0, "ROIAlignV2"), box_head=FastRCNNConvFCHead( ShapeSpec(channels=256, height=7, width=7), conv_dims=[], fc_dims=[1024, 1024] ), box_predictor=FastRCNNOutputLayers( ShapeSpec(channels=1024), test_score_thresh=0.05, box2box_transform=Box2BoxTransform((10, 10, 5, 5)), num_classes=80, ), mask_in_features=["p2", "p3", "p4", "p5"], mask_pooler=ROIPooler(14, (1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 0, "ROIAlignV2"), mask_head=MaskRCNNConvUpsampleHead( ShapeSpec(channels=256, width=14, height=14), num_classes=80, conv_dims=[256, 256, 256, 256, 256], ), ), pixel_mean=[103.530, 116.280, 123.675], pixel_std=[1.0, 1.0, 1.0], input_format="BGR", )
If you only need the standard behavior, the Beginner’s Tutorial should suffice. If you need to extend detectron2 to your own needs, see the following tutorials for more details:
Detectron2 includes a few standard datasets. To use custom ones, see Use Custom Datasets.
Detectron2 contains the standard logic that creates a data loader for training/testing from a dataset, but you can write your own as well. See Use Custom Data Loaders.
Detectron2 implements many standard detection models, and provide ways for you to overwrite their behaviors. See Use Models and Write Models.
Detectron2 provides a default training loop that is good for common training tasks. You can customize it with hooks, or write your own loop instead. See training.
Use Custom Datasets¶
This document explains how the dataset APIs (DatasetCatalog, MetadataCatalog) work, and how to use them to add custom datasets.
Datasets that have builtin support in detectron2 are listed in builtin datasets. If you want to use a custom dataset while also reusing detectron2’s data loaders, you will need to:
Register your dataset (i.e., tell detectron2 how to obtain your dataset).
Optionally, register metadata for your dataset.
Next, we explain the above two concepts in detail.
The Colab tutorial has a live example of how to register and train on a dataset of custom formats.
Register a Dataset¶
To let detectron2 know how to obtain a dataset named “my_dataset”, users need to implement a function that returns the items in your dataset and then tell detectron2 about this function:
def my_dataset_function():
...
return list[dict] in the following format
from detectron2.data import DatasetCatalog
DatasetCatalog.register("my_dataset", my_dataset_function)
# later, to access the data:
data: List[Dict] = DatasetCatalog.get("my_dataset")
Here, the snippet associates a dataset named “my_dataset” with a function that returns the data. The function must return the same data (with same order) if called multiple times. The registration stays effective until the process exits.
The function can do arbitrary things and should return the data in list[dict]
, each dict in either
of the following formats:
Detectron2’s standard dataset dict, described below. This will make it work with many other builtin features in detectron2, so it’s recommended to use it when it’s sufficient.
Any custom format. You can also return arbitrary dicts in your own format, such as adding extra keys for new tasks. Then you will need to handle them properly downstream as well. See below for more details.
Standard Dataset Dicts¶
For standard tasks
(instance detection, instance/semantic/panoptic segmentation, keypoint detection),
we load the original dataset into list[dict]
with a specification similar to COCO’s annotations.
This is our standard representation for a dataset.
Each dict contains information about one image. The dict may have the following fields, and the required fields vary based on what the dataloader or the task needs (see more below).
Task |
Fields |
---|---|
Common |
file_name, height, width, image_id |
Instance detection/segmentation |
annotations |
Semantic segmentation |
sem_seg_file_name |
Panoptic segmentation |
pan_seg_file_name, segments_info |
file_name
: the full path to the image file.height
,width
: integer. The shape of the image.image_id
(str or int): a unique id that identifies this image. Required by many evaluators to identify the images, but a dataset may use it for different purposes.annotations
(list[dict]): Required by instance detection/segmentation or keypoint detection tasks. Each dict corresponds to annotations of one instance in this image, and may contain the following keys:bbox
(list[float], required): list of 4 numbers representing the bounding box of the instance.bbox_mode
(int, required): the format of bbox. It must be a member of structures.BoxMode. Currently supports:BoxMode.XYXY_ABS
,BoxMode.XYWH_ABS
.category_id
(int, required): an integer in the range [0, num_categories-1] representing the category label. The value num_categories is reserved to represent the “background” category, if applicable.segmentation
(list[list[float]] or dict): the segmentation mask of the instance.If
list[list[float]]
, it represents a list of polygons, one for each connected component of the object. Eachlist[float]
is one simple polygon in the format of[x1, y1, ..., xn, yn]
(n≥3). The Xs and Ys are absolute coordinates in unit of pixels.If
dict
, it represents the per-pixel segmentation mask in COCO’s compressed RLE format. The dict should have keys “size” and “counts”. You can convert a uint8 segmentation mask of 0s and 1s into such dict bypycocotools.mask.encode(np.asarray(mask, order="F"))
.cfg.INPUT.MASK_FORMAT
must be set tobitmask
if using the default data loader with such format.
keypoints
(list[float]): in the format of [x1, y1, v1,…, xn, yn, vn]. v[i] means the visibility of this keypoint.n
must be equal to the number of keypoint categories. The Xs and Ys are absolute real-value coordinates in range [0, W or H].(Note that the keypoint coordinates in COCO format are integers in range [0, W-1 or H-1], which is different from our standard format. Detectron2 adds 0.5 to COCO keypoint coordinates to convert them from discrete pixel indices to floating point coordinates.)
iscrowd
: 0 (default) or 1. Whether this instance is labeled as COCO’s “crowd region”. Don’t include this field if you don’t know what it means.
If
annotations
is an empty list, it means the image is labeled to have no objects. Such images will by default be removed from training, but can be included usingDATALOADER.FILTER_EMPTY_ANNOTATIONS
.sem_seg_file_name
(str): The full path to the semantic segmentation ground truth file. It should be a grayscale image whose pixel values are integer labels.pan_seg_file_name
(str): The full path to panoptic segmentation ground truth file. It should be an RGB image whose pixel values are integer ids encoded using the panopticapi.utils.id2rgb function. The ids are defined bysegments_info
. If an id does not appear insegments_info
, the pixel is considered unlabeled and is usually ignored in training & evaluation.segments_info
(list[dict]): defines the meaning of each id in panoptic segmentation ground truth. Each dict has the following keys:id
(int): integer that appears in the ground truth image.category_id
(int): an integer in the range [0, num_categories-1] representing the category label.iscrowd
: 0 (default) or 1. Whether this instance is labeled as COCO’s “crowd region”.
Note
The PanopticFPN model does not use the panoptic segmentation format defined here, but a combination of both instance segmentation and semantic segmentation data format. See Use Builtin Datasets for instructions on COCO.
Fast R-CNN (with pre-computed proposals) models are rarely used today. To train a Fast R-CNN, the following extra keys are needed:
proposal_boxes
(array): 2D numpy array with shape (K, 4) representing K precomputed proposal boxes for this image.proposal_objectness_logits
(array): numpy array with shape (K, ), which corresponds to the objectness logits of proposals in ‘proposal_boxes’.proposal_bbox_mode
(int): the format of the precomputed proposal bbox. It must be a member of structures.BoxMode. Default isBoxMode.XYXY_ABS
.
Custom Dataset Dicts for New Tasks¶
In the list[dict]
that your dataset function returns, the dictionary can also have arbitrary custom data.
This will be useful for a new task that needs extra information not covered
by the standard dataset dicts. In this case, you need to make sure the downstream code can handle your data
correctly. Usually this requires writing a new mapper
for the dataloader (see Use Custom Dataloaders).
When designing a custom format, note that all dicts are stored in memory (sometimes serialized and with multiple copies). To save memory, each dict is meant to contain small but sufficient information about each sample, such as file names and annotations. Loading full samples typically happens in the data loader.
For attributes shared among the entire dataset, use Metadata
(see below).
To avoid extra memory, do not save such information inside each sample.
“Metadata” for Datasets¶
Each dataset is associated with some metadata, accessible through
MetadataCatalog.get(dataset_name).some_metadata
.
Metadata is a key-value mapping that contains information that’s shared among
the entire dataset, and usually is used to interpret what’s in the dataset, e.g.,
names of classes, colors of classes, root of files, etc.
This information will be useful for augmentation, evaluation, visualization, logging, etc.
The structure of metadata depends on what is needed from the corresponding downstream code.
If you register a new dataset through DatasetCatalog.register
,
you may also want to add its corresponding metadata through
MetadataCatalog.get(dataset_name).some_key = some_value
, to enable any features that need the metadata.
You can do it like this (using the metadata key “thing_classes” as an example):
from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]
Here is a list of metadata keys that are used by builtin features in detectron2. If you add your own dataset without these metadata, some features may be unavailable to you:
thing_classes
(list[str]): Used by all instance detection/segmentation tasks. A list of names for each instance/thing category. If you load a COCO format dataset, it will be automatically set by the functionload_coco_json
.thing_colors
(list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each thing category. Used for visualization. If not given, random colors will be used.stuff_classes
(list[str]): Used by semantic and panoptic segmentation tasks. A list of names for each stuff category.stuff_colors
(list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each stuff category. Used for visualization. If not given, random colors are used.ignore_label
(int): Used by semantic and panoptic segmentation tasks. Pixels in ground-truth annotations with this category label should be ignored in evaluation. Typically these are “unlabeled” pixels.keypoint_names
(list[str]): Used by keypoint detection. A list of names for each keypoint.keypoint_flip_map
(list[tuple[str]]): Used by keypoint detection. A list of pairs of names, where each pair are the two keypoints that should be flipped if the image is flipped horizontally during augmentation.keypoint_connection_rules
: list[tuple(str, str, (r, g, b))]. Each tuple specifies a pair of keypoints that are connected and the color to use for the line between them when visualized.
Some additional metadata that are specific to the evaluation of certain datasets (e.g. COCO):
thing_dataset_id_to_contiguous_id
(dict[int->int]): Used by all instance detection/segmentation tasks in the COCO format. A mapping from instance class ids in the dataset to contiguous ids in range [0, #class). Will be automatically set by the functionload_coco_json
.stuff_dataset_id_to_contiguous_id
(dict[int->int]): Used when generating prediction json files for semantic/panoptic segmentation. A mapping from semantic segmentation class ids in the dataset to contiguous ids in [0, num_categories). It is useful for evaluation only.json_file
: The COCO annotation json file. Used by COCO evaluation for COCO-format datasets.panoptic_root
,panoptic_json
: Used by COCO-format panoptic evaluation.evaluator_type
: Used by the builtin main training script to select evaluator. Don’t use it in a new training script. You can just provide the DatasetEvaluator for your dataset directly in your main script.
Note
In recognition, sometimes we use the term “thing” for instance-level tasks, and “stuff” for semantic segmentation tasks. Both are used in panoptic segmentation tasks. For background on the concept of “thing” and “stuff”, see On Seeing Stuff: The Perception of Materials by Humans and Machines.
Register a COCO Format Dataset¶
If your instance-level (detection, segmentation, keypoint) dataset is already a json file in the COCO format, the dataset and its associated metadata can be registered easily with:
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset", {}, "json_annotation.json", "path/to/image/dir")
If your dataset is in COCO format but need to be further processed, or has extra custom per-instance annotations, the load_coco_json function might be useful.
Update the Config for New Datasets¶
Once you’ve registered the dataset, you can use the name of the dataset (e.g., “my_dataset” in
example above) in cfg.DATASETS.{TRAIN,TEST}
.
There are other configs you might want to change to train or evaluate on new datasets:
MODEL.ROI_HEADS.NUM_CLASSES
andMODEL.RETINANET.NUM_CLASSES
are the number of thing classes for R-CNN and RetinaNet models, respectively.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS
sets the number of keypoints for Keypoint R-CNN. You’ll also need to set Keypoint OKS withTEST.KEYPOINT_OKS_SIGMAS
for evaluation.MODEL.SEM_SEG_HEAD.NUM_CLASSES
sets the number of stuff classes for Semantic FPN & Panoptic FPN.TEST.DETECTIONS_PER_IMAGE
controls the maximum number of objects to be detected. Set it to a larger number if test images may contain >100 objects.If you’re training Fast R-CNN (with precomputed proposals),
DATASETS.PROPOSAL_FILES_{TRAIN,TEST}
need to match the datasets. The format of proposal files are documented here.
New models (e.g. TensorMask, PointRend) often have similar configs of their own that need to be changed as well.
Tip
After changing the number of classes, certain layers in a pre-trained model will become incompatible and therefore cannot be loaded to the new model. This is expected, and loading such pre-trained models will produce warnings about such layers.
Dataloader¶
Dataloader is the component that provides data to models. A dataloader usually (but not necessarily) takes raw information from datasets, and process them into a format needed by the model.
How the Existing Dataloader Works¶
Detectron2 contains a builtin data loading pipeline. It’s good to understand how it works, in case you need to write a custom one.
Detectron2 provides two functions
build_detection_{train,test}_loader
that create a default data loader from a given config.
Here is how build_detection_{train,test}_loader
work:
It takes the name of a registered dataset (e.g., “coco_2017_train”) and loads a
list[dict]
representing the dataset items in a lightweight format. These dataset items are not yet ready to be used by the model (e.g., images are not loaded into memory, random augmentations have not been applied, etc.). Details about the dataset format and dataset registration can be found in datasets.Each dict in this list is mapped by a function (“mapper”):
Users can customize this mapping function by specifying the “mapper” argument in
build_detection_{train,test}_loader
. The default mapper is DatasetMapper.The output format of the mapper can be arbitrary, as long as it is accepted by the consumer of this data loader (usually the model). The outputs of the default mapper, after batching, follow the default model input format documented in Use Models.
The role of the mapper is to transform the lightweight representation of a dataset item into a format that is ready for the model to consume (including, e.g., read images, perform random data augmentation and convert to torch Tensors). If you would like to perform custom transformations to data, you often want a custom mapper.
The outputs of the mapper are batched (simply into a list).
This batched data is the output of the data loader. Typically, it’s also the input of
model.forward()
.
Write a Custom Dataloader¶
Using a different “mapper” with build_detection_{train,test}_loader(mapper=)
works for most use cases
of custom data loading.
For example, if you want to resize all images to a fixed size for training, use:
import detectron2.data.transforms as T
from detectron2.data import DatasetMapper # the default mapper
dataloader = build_detection_train_loader(cfg,
mapper=DatasetMapper(cfg, is_train=True, augmentations=[
T.Resize((800, 800))
]))
# use this dataloader instead of the default
If the arguments of the default DatasetMapper does not provide what you need, you may write a custom mapper function and use it instead, e.g.:
from detectron2.data import detection_utils as utils
# Show how to implement a minimal mapper, similar to the default DatasetMapper
def mapper(dataset_dict):
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# can use other ways to read image
image = utils.read_image(dataset_dict["file_name"], format="BGR")
# See "Data Augmentation" tutorial for details usage
auginput = T.AugInput(image)
transform = T.Resize((800, 800))(auginput)
image = torch.from_numpy(auginput.image.transpose(2, 0, 1))
annos = [
utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
for annotation in dataset_dict.pop("annotations")
]
return {
# create the format that the model expects
"image": image,
"instances": utils.annotations_to_instances(annos, image.shape[1:])
}
dataloader = build_detection_train_loader(cfg, mapper=mapper)
If you want to change not only the mapper (e.g., in order to implement different sampling or batching logic),
build_detection_train_loader
won’t work and you will need to write a different data loader.
The data loader is simply a
python iterator that produces the format that the model accepts.
You can implement it using any tools you like.
No matter what to implement, it’s recommended to check out API documentation of detectron2.data to learn more about the APIs of these functions.
Use a Custom Dataloader¶
If you use DefaultTrainer,
you can overwrite its build_{train,test}_loader
method to use your own dataloader.
See the deeplab dataloader
for an example.
If you write your own training loop, you can plug in your data loader easily.
Data Augmentation¶
Augmentation is an important part of training. Detectron2’s data augmentation system aims at addressing the following goals:
Allow augmenting multiple data types together (e.g., images together with their bounding boxes and masks)
Allow applying a sequence of statically-declared augmentation
Allow adding custom new data types to augment (rotated bounding boxes, video clips, etc.)
Process and manipulate the operations that are applied by augmentations
The first two features cover most of the common use cases, and is also available in other libraries such as albumentations. Supporting other features adds some overhead to detectron2’s augmentation API, which we’ll explain in this tutorial.
This tutorial focuses on how to use augmentations when writing new data loaders, and how to write new augmentations. If you use the default data loader in detectron2, it already supports taking a user-provided list of custom augmentations, as explained in the Dataloader tutorial.
Basic Usage¶
The basic usage of feature (1) and (2) is like the following:
from detectron2.data import transforms as T
# Define a sequence of augmentations:
augs = T.AugmentationList([
T.RandomBrightness(0.9, 1.1),
T.RandomFlip(prob=0.5),
T.RandomCrop("absolute", (640, 640))
]) # type: T.Augmentation
# Define the augmentation input ("image" required, others optional):
input = T.AugInput(image, boxes=boxes, sem_seg=sem_seg)
# Apply the augmentation:
transform = augs(input) # type: T.Transform
image_transformed = input.image # new image
sem_seg_transformed = input.sem_seg # new semantic segmentation
# For any extra data that needs to be augmented together, use transform, e.g.:
image2_transformed = transform.apply_image(image2)
polygons_transformed = transform.apply_polygons(polygons)
Three basic concepts are involved here. They are:
T.Augmentation defines the “policy” to modify inputs.
its
__call__(AugInput) -> Transform
method augments the inputs in-place, and returns the operation that is applied
T.Transform implements the actual operations to transform data
it has methods such as
apply_image
,apply_coords
that define how to transform each data type
T.AugInput stores inputs needed by
T.Augmentation
and how they should be transformed. This concept is needed for some advanced usage. Using this class directly should be sufficient for all common use cases, since extra data not inT.AugInput
can be augmented using the returnedtransform
, as shown in the above example.
Write New Augmentations¶
Most 2D augmentations only need to know about the input image. Such augmentation can be implemented easily like this:
class MyColorAugmentation(T.Augmentation):
def get_transform(self, image):
r = np.random.rand(2)
return T.ColorTransform(lambda x: x * r[0] + r[1] * 10)
class MyCustomResize(T.Augmentation):
def get_transform(self, image):
old_h, old_w = image.shape[:2]
new_h, new_w = int(old_h * np.random.rand()), int(old_w * 1.5)
return T.ResizeTransform(old_h, old_w, new_h, new_w)
augs = MyCustomResize()
transform = augs(input)
In addition to image, any attributes of the given AugInput
can be used as long
as they are part of the function signature, e.g.:
class MyCustomCrop(T.Augmentation):
def get_transform(self, image, sem_seg):
# decide where to crop using both image and sem_seg
return T.CropTransform(...)
augs = MyCustomCrop()
assert hasattr(input, "image") and hasattr(input, "sem_seg")
transform = augs(input)
New transform operation can also be added by subclassing T.Transform.
Advanced Usage¶
We give a few examples of advanced usages that are enabled by our system. These options can be interesting to new research, although changing them is often not needed for standard use cases.
Custom transform strategy¶
Instead of only returning the augmented data, detectron2’s Augmentation
returns the operations as T.Transform
.
This allows users to apply custom transform strategy on their data.
We use keypoints data as an example.
Keypoints are (x, y) coordinates, but they are not so trivial to augment due to the semantic meaning they carry.
Such meaning is only known to the users, therefore users may want to augment them manually
by looking at the returned transform
.
For example, when an image is horizontally flipped, we’d like to to swap the keypoint annotations for “left eye” and “right eye”.
This can be done like this (included by default in detectron2’s default data loader):
# augs, input are defined as in previous examples
transform = augs(input) # type: T.Transform
keypoints_xy = transform.apply_coords(keypoints_xy) # transform the coordinates
# get a list of all transforms that were applied
transforms = T.TransformList([transform]).transforms
# check if it is flipped for odd number of times
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms) % 2 == 1
if do_hflip:
keypoints_xy = keypoints_xy[flip_indices_mapping]
As another example, keypoints annotations often have a “visibility” field. A sequence of augmentations might augment a visible keypoint out of the image boundary (e.g. with cropping), but then bring it back within the boundary afterwards (e.g. with image padding). If users decide to label such keypoints “invisible”, then the visibility check has to happen after every transform step. This can be achieved by:
transform = augs(input) # type: T.TransformList
assert isinstance(transform, T.TransformList)
for t in transform.transforms:
keypoints_xy = t.apply_coords(keypoints_xy)
visibility &= (keypoints_xy >= [0, 0] & keypoints_xy <= [W, H]).all(axis=1)
# btw, detectron2's `transform_keypoint_annotations` function chooses to label such keypoints "visible":
# keypoints_xy = transform.apply_coords(keypoints_xy)
# visibility &= (keypoints_xy >= [0, 0] & keypoints_xy <= [W, H]).all(axis=1)
Geometrically invert the transform¶
If images are pre-processed by augmentations before inference, the predicted results such as segmentation masks are localized on the augmented image. We’d like to invert the applied augmentation with the inverse() API, to obtain results on the original image:
transform = augs(input)
pred_mask = make_prediction(input.image)
inv_transform = transform.inverse()
pred_mask_orig = inv_transform.apply_segmentation(pred_mask)
Add new data types¶
T.Transform supports a few common data types to transform, including images, coordinates, masks, boxes, polygons. It allows registering new data types, e.g.:
@T.HFlipTransform.register_type("rotated_boxes")
def func(flip_transform: T.HFlipTransform, rotated_boxes: Any):
# do the work
return flipped_rotated_boxes
t = HFlipTransform(width=800)
transformed_rotated_boxes = t.apply_rotated_boxes(rotated_boxes) # func will be called
Extend T.AugInput¶
An augmentation can only access attributes available in the given input. T.AugInput defines “image”, “boxes”, “sem_seg”, which are sufficient for common augmentation strategies to decide how to augment. If not, a custom implementation is needed.
By re-implement the “transform()” method in AugInput, it is also possible to augment different fields in ways that are dependent on each other. Such use case is uncommon (e.g. post-process bounding box based on augmented masks), but allowed by the system.
Use Models¶
Build Models from Yacs Config¶
From a yacs config object,
models (and their sub-models) can be built by
functions such as build_model
, build_backbone
, build_roi_heads
:
from detectron2.modeling import build_model
model = build_model(cfg) # returns a torch.nn.Module
build_model
only builds the model structure and fills it with random parameters.
See below for how to load an existing checkpoint to the model and how to use the model
object.
Load/Save a Checkpoint¶
from detectron2.checkpoint import DetectionCheckpointer
DetectionCheckpointer(model).load(file_path_or_url) # load a file, usually from cfg.MODEL.WEIGHTS
checkpointer = DetectionCheckpointer(model, save_dir="output")
checkpointer.save("model_999") # save to output/model_999.pth
Detectron2’s checkpointer recognizes models in pytorch’s .pth
format, as well as the .pkl
files
in our model zoo.
See API doc
for more details about its usage.
The model files can be arbitrarily manipulated using torch.{load,save}
for .pth
files or
pickle.{dump,load}
for .pkl
files.
Use a Model¶
A model can be called by outputs = model(inputs)
, where inputs
is a list[dict]
.
Each dict corresponds to one image and the required keys
depend on the type of model, and whether the model is in training or evaluation mode.
For example, in order to do inference,
all existing models expect the “image” key, and optionally “height” and “width”.
The detailed format of inputs and outputs of existing models are explained below.
Training: When in training mode, all models are required to be used under an EventStorage
.
The training statistics will be put into the storage:
from detectron2.utils.events import EventStorage
with EventStorage() as storage:
losses = model(inputs)
Inference: If you only want to do simple inference using an existing model, DefaultPredictor is a wrapper around model that provides such basic functionality. It includes default behavior including model loading, preprocessing, and operates on single image rather than batches. See its documentation for usage.
You can also run inference directly like this:
model.eval()
with torch.no_grad():
outputs = model(inputs)
Model Input Format¶
Users can implement custom models that support any arbitrary input format.
Here we describe the standard input format that all builtin models support in detectron2.
They all take a list[dict]
as the inputs. Each dict
corresponds to information about one image.
The dict may contain the following keys:
“image”:
Tensor
in (C, H, W) format. The meaning of channels are defined bycfg.INPUT.FORMAT
. Image normalization, if any, will be performed inside the model usingcfg.MODEL.PIXEL_{MEAN,STD}
.“height”, “width”: the desired output height and width, which is not necessarily the same as the height or width of the
image
field. For example, theimage
field contains the resized image, if resize is used as a preprocessing step. But you may want the outputs to be in original resolution. If provided, the model will produce output in this resolution, rather than in the resolution of theimage
as input into the model. This is more efficient and accurate.“instances”: an Instances object for training, with the following fields:
“gt_boxes”: a Boxes object storing N boxes, one for each instance.
“gt_classes”:
Tensor
of long type, a vector of N labels, in range [0, num_categories).“gt_masks”: a PolygonMasks or BitMasks object storing N masks, one for each instance.
“gt_keypoints”: a Keypoints object storing N keypoint sets, one for each instance.
“sem_seg”:
Tensor[int]
in (H, W) format. The semantic segmentation ground truth for training. Values represent category labels starting from 0.“proposals”: an Instances object used only in Fast R-CNN style models, with the following fields:
“proposal_boxes”: a Boxes object storing P proposal boxes.
“objectness_logits”:
Tensor
, a vector of P scores, one for each proposal.
For inference of builtin models, only “image” key is required, and “width/height” are optional.
We currently don’t define standard input format for panoptic segmentation training, because models now use custom formats produced by custom data loaders.
How it connects to data loader:¶
The output of the default DatasetMapper is a dict
that follows the above format.
After the data loader performs batching, it becomes list[dict]
which the builtin models support.
Model Output Format¶
When in training mode, the builtin models output a dict[str->ScalarTensor]
with all the losses.
When in inference mode, the builtin models output a list[dict]
, one dict for each image.
Based on the tasks the model is doing, each dict may contain the following fields:
“instances”: Instances object with the following fields:
“pred_boxes”: Boxes object storing N boxes, one for each detected instance.
“scores”:
Tensor
, a vector of N confidence scores.“pred_classes”:
Tensor
, a vector of N labels in range [0, num_categories).
“pred_masks”: a
Tensor
of shape (N, H, W), masks for each detected instance.“pred_keypoints”: a
Tensor
of shape (N, num_keypoint, 3). Each row in the last dimension is (x, y, score). Confidence scores are larger than 0.
“sem_seg”:
Tensor
of (num_categories, H, W), the semantic segmentation prediction.“proposals”: Instances object with the following fields:
“proposal_boxes”: Boxes object storing N boxes.
“objectness_logits”: a torch vector of N confidence scores.
“panoptic_seg”: A tuple of
(pred: Tensor, segments_info: Optional[list[dict]])
. Thepred
tensor has shape (H, W), containing the segment id of each pixel.If
segments_info
exists, each dict describes one segment id inpred
and has the following fields:“id”: the segment id
“isthing”: whether the segment is a thing or stuff
“category_id”: the category id of this segment.
If a pixel’s id does not exist in
segments_info
, it is considered to be void label defined in Panoptic Segmentation.If
segments_info
is None, all pixel values inpred
must be ≥ -1. Pixels with value -1 are assigned void labels. Otherwise, the category id of each pixel is obtained bycategory_id = pixel // metadata.label_divisor
.
Partially execute a model:¶
Sometimes you may want to obtain an intermediate tensor inside a model, such as the input of certain layer, the output before post-processing. Since there are typically hundreds of intermediate tensors, there isn’t an API that provides you the intermediate result you need. You have the following options:
Write a (sub)model. Following the tutorial, you can rewrite a model component (e.g. a head of a model), such that it does the same thing as the existing component, but returns the output you need.
Partially execute a model. You can create the model as usual, but use custom code to execute it instead of its
forward()
. For example, the following code obtains mask features before mask head.images = ImageList.from_tensors(...) # preprocessed input tensor model = build_model(cfg) model.eval() features = model.backbone(images.tensor) proposals, _ = model.proposal_generator(images, features) instances, _ = model.roi_heads(images, features, proposals) mask_features = [features[f] for f in model.roi_heads.in_features] mask_features = model.roi_heads.mask_pooler(mask_features, [x.pred_boxes for x in instances])
Use forward hooks. Forward hooks can help you obtain inputs or outputs of a certain module. If they are not exactly what you want, they can at least be used together with partial execution to obtain other tensors.
All options require you to read documentation and sometimes code of the existing models to understand the internal logic, in order to write code to obtain the internal tensors.
Write Models¶
If you are trying to do something completely new, you may wish to implement a model entirely from scratch. However, in many situations you may be interested in modifying or extending some components of an existing model. Therefore, we also provide mechanisms that let users override the behavior of certain internal components of standard models.
Register New Components¶
For common concepts that users often want to customize, such as “backbone feature extractor”, “box head”, we provide a registration mechanism for users to inject custom implementation that will be immediately available to use in config files.
For example, to add a new backbone, import this code in your code:
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
@BACKBONE_REGISTRY.register()
class ToyBackbone(Backbone):
def __init__(self, cfg, input_shape):
super().__init__()
# create your own backbone
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=16, padding=3)
def forward(self, image):
return {"conv1": self.conv1(image)}
def output_shape(self):
return {"conv1": ShapeSpec(channels=64, stride=16)}
In this code, we implement a new backbone following the interface of the
Backbone class,
and register it into the BACKBONE_REGISTRY
which requires subclasses of Backbone
.
After importing this code, detectron2 can link the name of the class to its implementation. Therefore you can write the following code:
cfg = ... # read a config
cfg.MODEL.BACKBONE.NAME = 'ToyBackbone' # or set it in the config file
model = build_model(cfg) # it will find `ToyBackbone` defined above
As another example, to add new abilities to the ROI heads in the Generalized R-CNN meta-architecture,
you can implement a new
ROIHeads subclass and put it in the ROI_HEADS_REGISTRY
.
DensePose
and MeshRCNN
are two examples that implement new ROIHeads to perform new tasks.
And projects/
contains more examples that implement different architectures.
A complete list of registries can be found in API documentation. You can register components in these registries to customize different parts of a model, or the entire model.
Construct Models with Explicit Arguments¶
Registry is a bridge to connect names in config files to the actual code. They are meant to cover a few main components that users frequently need to replace. However, the capability of a text-based config file is sometimes limited and some deeper customization may be available only through writing code.
Most model components in detectron2 have a clear __init__
interface that documents
what input arguments it needs. Calling them with custom arguments will give you a custom variant
of the model.
As an example, to use custom loss function in the box head of a Faster R-CNN, we can do the following:
Losses are currently computed in FastRCNNOutputLayers. We need to implement a variant or a subclass of it, with custom loss functions, named
MyRCNNOutput
.Call
StandardROIHeads
withbox_predictor=MyRCNNOutput()
argument instead of the builtinFastRCNNOutputLayers
. If all other arguments should stay unchanged, this can be easily achieved by using the configurable__init__
mechanism:roi_heads = StandardROIHeads( cfg, backbone.output_shape(), box_predictor=MyRCNNOutput(...) )
(optional) If we want to enable this new model from a config file, registration is needed:
@ROI_HEADS_REGISTRY.register() class MyStandardROIHeads(StandardROIHeads): def __init__(self, cfg, input_shape): super().__init__(cfg, input_shape, box_predictor=MyRCNNOutput(...))
Training¶
From the previous tutorials, you may now have a custom model and a data loader. To run training, users typically have a preference in one of the following two styles:
Custom Training Loop¶
With a model and a data loader ready, everything else needed to write a training loop can be found in PyTorch, and you are free to write the training loop yourself. This style allows researchers to manage the entire training logic more clearly and have full control. One such example is provided in tools/plain_train_net.py.
Any customization on the training logic is then easily controlled by the user.
Trainer Abstraction¶
We also provide a standarized “trainer” abstraction with a hook system that helps simplify the standard training behavior. It includes the following two instantiations:
SimpleTrainer provides a minimal training loop for single-cost single-optimizer single-data-source training, with nothing else. Other tasks (checkpointing, logging, etc) can be implemented using the hook system.
DefaultTrainer is a
SimpleTrainer
initialized from a yacs config, used by tools/train_net.py and many scripts. It includes more standard default behaviors that one might want to opt in, including default configurations for optimizer, learning rate schedule, logging, evaluation, checkpointing etc.
To customize a DefaultTrainer
:
For simple customizations (e.g. change optimizer, evaluator, LR scheduler, data loader, etc.), overwrite its methods in a subclass, just like tools/train_net.py.
For extra tasks during training, check the hook system to see if it’s supported.
As an example, to print hello during training:
class HelloHook(HookBase): def after_step(self): if self.trainer.iter % 100 == 0: print(f"Hello at iteration {self.trainer.iter}!")
Using a trainer+hook system means there will always be some non-standard behaviors that cannot be supported, especially in research. For this reason, we intentionally keep the trainer & hook system minimal, rather than powerful. If anything cannot be achieved by such a system, it’s easier to start from tools/plain_train_net.py to implement custom training logic manually.
Logging of Metrics¶
During training, detectron2 models and trainer put metrics to a centralized EventStorage. You can use the following code to access it and log metrics to it:
from detectron2.utils.events import get_event_storage
# inside the model:
if self.training:
value = # compute the value from inputs
storage = get_event_storage()
storage.put_scalar("some_accuracy", value)
Refer to its documentation for more details.
Metrics are then written to various destinations with EventWriter.
DefaultTrainer enables a few EventWriter
with default configurations.
See above for how to customize them.
Evaluation¶
Evaluation is a process that takes a number of inputs/outputs pairs and aggregate them. You can always use the model directly and just parse its inputs/outputs manually to perform evaluation. Alternatively, evaluation is implemented in detectron2 using the DatasetEvaluator interface.
Detectron2 includes a few DatasetEvaluator
that computes metrics using standard dataset-specific
APIs (e.g., COCO, LVIS).
You can also implement your own DatasetEvaluator
that performs some other jobs
using the inputs/outputs pairs.
For example, to count how many instances are detected on the validation set:
class Counter(DatasetEvaluator):
def reset(self):
self.count = 0
def process(self, inputs, outputs):
for output in outputs:
self.count += len(output["instances"])
def evaluate(self):
# save self.count somewhere, or print it, or return it.
return {"count": self.count}
Use evaluators¶
To evaluate using the methods of evaluators manually:
def get_all_inputs_outputs():
for data in data_loader:
yield data, model(data)
evaluator.reset()
for inputs, outputs in get_all_inputs_outputs():
evaluator.process(inputs, outputs)
eval_results = evaluator.evaluate()
Evaluators can also be used with inference_on_dataset. For example,
eval_results = inference_on_dataset(
model,
data_loader,
DatasetEvaluators([COCOEvaluator(...), Counter()]))
This will execute model
on all inputs from data_loader
, and call evaluator to process them.
Compared to running the evaluation manually using the model, the benefit of this function is that evaluators can be merged together using DatasetEvaluators, and all the evaluation can finish in one forward pass over the dataset. This function also provides accurate speed benchmarks for the given model and dataset.
Evaluators for custom dataset¶
Many evaluators in detectron2 are made for specific datasets, in order to obtain scores using each dataset’s official API. In addition to that, two evaluators are able to evaluate any generic dataset that follows detectron2’s standard dataset format, so they can be used to evaluate custom datasets:
COCOEvaluator is able to evaluate AP (Average Precision) for box detection, instance segmentation, keypoint detection on any custom dataset.
SemSegEvaluator is able to evaluate semantic segmentation metrics on any custom dataset.
Yacs Configs¶
Detectron2 provides a key-value based config system that can be used to obtain standard, common behaviors.
This system uses YAML and yacs. Yaml is a very limited language, so we do not expect all features in detectron2 to be available through configs. If you need something that’s not available in the config space, please write code using detectron2’s API.
With the introduction of a more powerful LazyConfig system, we no longer add functionality / new keys to the Yacs/Yaml-based config system.
Basic Usage¶
Some basic usage of the CfgNode
object is shown here. See more in documentation.
from detectron2.config import get_cfg
cfg = get_cfg() # obtain detectron2's default config
cfg.xxx = yyy # add new configs for your own custom components
cfg.merge_from_file("my_cfg.yaml") # load values from a file
cfg.merge_from_list(["MODEL.WEIGHTS", "weights.pth"]) # can also load values from a list of str
print(cfg.dump()) # print formatted configs
In addition to the basic Yaml syntax, the config file can
define a _BASE_: base.yaml
field, which will load a base config file first.
Values in the base config will be overwritten in sub-configs, if there are any conflicts.
We provided several base configs for standard model architectures.
Many builtin tools in detectron2 accept command line config overwrite: Key-value pairs provided in the command line will overwrite the existing values in the config file. For example, demo.py can be used with
./demo.py --config-file config.yaml [--other-options] \
--opts MODEL.WEIGHTS /path/to/weights INPUT.MIN_SIZE_TEST 1000
To see a list of available configs in detectron2 and what they mean, check Config References
Configs in Projects¶
A project that lives outside the detectron2 library may define its own configs, which will need to be added for the project to be functional, e.g.:
from detectron2.projects.point_rend import add_pointrend_config
cfg = get_cfg() # obtain detectron2's default config
add_pointrend_config(cfg) # add pointrend's default config
# ... ...
Best Practice with Configs¶
Treat the configs you write as “code”: avoid copying them or duplicating them; use
_BASE_
to share common parts between configs.Keep the configs you write simple: don’t include keys that do not affect the experimental setting.
Lazy Configs¶
The traditional yacs-based config system provides basic, standard functionalities. However, it does not offer enough flexibility for many new projects. We develop an alternative, non-intrusive config system that can be used with detectron2 or potentially any other complex projects.
Python Syntax¶
Our config objects are still dictionaries. Instead of using Yaml to define dictionaries, we create dictionaries in Python directly. This gives users the following power that doesn’t exist in Yaml:
Easily manipulate the dictionary (addition & deletion) using Python.
Write simple arithmetics or call simple functions.
Use more data types / objects.
Import / compose other config files, using the familiar Python import syntax.
A Python config file can be loaded like this:
# config.py:
a = dict(x=1, y=2, z=dict(xx=1))
b = dict(x=3, y=4)
# my_code.py:
from detectron2.config import LazyConfig
cfg = LazyConfig.load("path/to/config.py") # an omegaconf dictionary
assert cfg.a.z.xx == 1
After LazyConfig.load
, cfg
will be a dictionary that contains all dictionaries
defined in the global scope of the config file. Note that:
All dictionaries are turned to an omegaconf config object during loading. This enables access to omegaconf features, such as its access syntax and interoplation.
Absolute imports in
config.py
works the same as in regular Python.Relative imports can only import dictionaries from config files. They are simply a syntax sugar for LazyConfig.load_rel. They can load Python files at relative path without requiring
__init__.py
.
Recursive Instantiation¶
The LazyConfig system heavily uses recursive instantiation, which is a pattern that uses a dictionary to describe a call to a function/class. The dictionary consists of:
A “_target_” key which contains path to the callable, such as “module.submodule.class_name”.
Other keys that represent arguments to pass to the callable. Arguments themselves can be defined using recursive instantiation.
We provide a helper function LazyCall that helps create such dictionaries.
The following code using LazyCall
from detectron2.config import LazyCall as L
from my_app import Trainer, Optimizer
cfg = L(Trainer)(
optimizer=L(Optimizer)(
lr=0.01,
algo="SGD"
)
)
creates a dictionary like this:
cfg = {
"_target_": "my_app.Trainer",
"optimizer": {
"_target_": "my_app.Optimizer",
"lr": 0.01, "algo": "SGD"
}
}
By representing objects using such dictionaries, a general instantiate function can turn them into actual objects, i.e.:
from detectron2.config import instantiate
trainer = instantiate(cfg)
# equivalent to:
# from my_app import Trainer, Optimizer
# trainer = Trainer(optimizer=Optimizer(lr=0.01, algo="SGD"))
This pattern is powerful enough to describe very complex objects, e.g.:
A Full Mask R-CNN described in recursive instantiation (click to expand)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | from detectron2.config import LazyCall as L
from detectron2.layers import ShapeSpec
from detectron2.modeling.meta_arch import GeneralizedRCNN
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
from detectron2.modeling.backbone import BasicStem, FPN, ResNet
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.proposal_generator import RPN, StandardRPNHead
from detectron2.modeling.roi_heads import (
StandardROIHeads,
FastRCNNOutputLayers,
MaskRCNNConvUpsampleHead,
FastRCNNConvFCHead,
)
model = L(GeneralizedRCNN)(
backbone=L(FPN)(
bottom_up=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=True,
norm="FrozenBN",
),
out_features=["res2", "res3", "res4", "res5"],
),
in_features="${.bottom_up.out_features}",
out_channels=256,
top_block=L(LastLevelMaxPool)(),
),
proposal_generator=L(RPN)(
in_features=["p2", "p3", "p4", "p5", "p6"],
head=L(StandardRPNHead)(in_channels=256, num_anchors=3),
anchor_generator=L(DefaultAnchorGenerator)(
sizes=[[32], [64], [128], [256], [512]],
aspect_ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64],
offset=0.0,
),
anchor_matcher=L(Matcher)(
thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True
),
box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]),
batch_size_per_image=256,
positive_fraction=0.5,
pre_nms_topk=(2000, 1000),
post_nms_topk=(1000, 1000),
nms_thresh=0.7,
),
roi_heads=L(StandardROIHeads)(
num_classes=80,
batch_size_per_image=512,
positive_fraction=0.25,
proposal_matcher=L(Matcher)(
thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False
),
box_in_features=["p2", "p3", "p4", "p5"],
box_pooler=L(ROIPooler)(
output_size=7,
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
sampling_ratio=0,
pooler_type="ROIAlignV2",
),
box_head=L(FastRCNNConvFCHead)(
input_shape=ShapeSpec(channels=256, height=7, width=7),
conv_dims=[],
fc_dims=[1024, 1024],
),
box_predictor=L(FastRCNNOutputLayers)(
input_shape=ShapeSpec(channels=1024),
test_score_thresh=0.05,
box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)),
num_classes="${..num_classes}",
),
mask_in_features=["p2", "p3", "p4", "p5"],
mask_pooler=L(ROIPooler)(
output_size=14,
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
sampling_ratio=0,
pooler_type="ROIAlignV2",
),
mask_head=L(MaskRCNNConvUpsampleHead)(
input_shape=ShapeSpec(channels=256, width=14, height=14),
num_classes="${..num_classes}",
conv_dims=[256, 256, 256, 256, 256],
),
),
pixel_mean=[103.530, 116.280, 123.675],
pixel_std=[1.0, 1.0, 1.0],
input_format="BGR",
)
|
There are also objects or logic that cannot be described simply by a dictionary, such as reused objects or method calls. They may require some refactoring to work with recursive instantiation.
Using Model Zoo LazyConfigs¶
We provide some configs in the model zoo using the LazyConfig system, for example:
After installing detectron2, they can be loaded by the model zoo API model_zoo.get_config.
Our model zoo configs follow some simple conventions, e.g.
cfg.model
defines a model object, cfg.dataloader.{train,test}
defines dataloader objects,
and cfg.train
contains training options in key-value form.
We provide a reference training script
tools/lazyconfig_train_net.py,
that can train/eval our model zoo configs.
It also shows how to support command line value overrides.
Nevertheless, you are free to define any custom structure for your project and use it with your own scripts.
To demonstrate the power and flexibility of the new system, we show that a simple config file can let detectron2 train an ImageNet classification model from torchvision, even though detectron2 contains no features about ImageNet classification. This can serve as a reference for using detectron2 in other deep learning tasks.
Summary¶
By using recursive instantiation to create objects,
we avoid passing a giant config to many places, because cfg
is only passed to instantiate
.
This has the following benefits:
It’s non-intrusive: objects to be constructed are config-agnostic, regular Python functions/classes. They can even live in other libraries. For example,
{"_target_": "torch.nn.Conv2d", "in_channels": 10, "out_channels": 10, "kernel_size": 1}
defines a conv layer.Clarity of what function/classes will be called, and what arguments they use.
cfg
doesn’t need pre-defined keys and structures. It’s valid as long as it translates to valid code. This gives a lot more flexibility.You can still pass huge dictionaries as arguments, just like the old way.
Putting recursive instantiation together with the Python config file syntax, the config file looks a lot like the code that will be executed:
However, the config file just defines dictionaries, which can be easily manipulated further
by composition or overrides.
The corresponding code will only be executed
later when instantiate
is called. In some way,
in config files we’re writing “editable code” that will be “lazily executed” later when needed.
That’s why we call this system “LazyConfig”.
Deployment¶
Models written in Python need to go through an export process to become a deployable artifact. A few basic concepts about this process:
“Export method” is how a Python model is fully serialized to a deployable format. We support the following export methods:
tracing
: see pytorch documentation to learn about itscripting
: see pytorch documentation to learn about itcaffe2_tracing
: replace parts of the model by caffe2 operators, then use tracing.
“Format” is how a serialized model is described in a file, e.g. TorchScript, Caffe2 protobuf, ONNX format. “Runtime” is an engine that loads a serialized model and executes it, e.g., PyTorch, Caffe2, TensorFlow, onnxruntime, TensorRT, etc. A runtime is often tied to a specific format (e.g. PyTorch needs TorchScript format, Caffe2 needs protobuf format). We currently support the following combination and each has some limitations:
Export Method |
tracing |
scripting |
caffe2_tracing |
---|---|---|---|
Formats |
TorchScript |
TorchScript |
Caffe2, TorchScript, ONNX |
Runtime |
PyTorch |
PyTorch |
Caffe2, PyTorch |
C++/Python inference |
✅ |
✅ |
✅ |
Dynamic resolution |
✅ |
✅ |
✅ |
Batch size requirement |
Constant |
Dynamic |
Batch inference unsupported |
Extra runtime deps |
torchvision |
torchvision |
Caffe2 ops (usually already included in PyTorch) |
Faster/Mask/Keypoint R-CNN |
✅ |
✅ |
✅ |
RetinaNet |
✅ |
✅ |
✅ |
PointRend R-CNN |
✅ |
❌ |
❌ |
We don’t plan to work on additional support for other formats/runtime, but contributions are welcome.
Deployment with Tracing or Scripting¶
Models can be exported to TorchScript format, by either tracing or scripting. The output model file can be loaded without detectron2 dependency in either Python or C++. The exported model often requires torchvision (or its C++ library) dependency for some custom ops.
This feature requires PyTorch ≥ 1.8 (or latest on github before 1.8 is released).
Coverage¶
Most official models under the meta architectures GeneralizedRCNN
and RetinaNet
are supported in both tracing and scripting mode. Cascade R-CNN is currently not supported.
PointRend is currently supported in tracing.
Users’ custom extensions are supported if they are also scriptable or traceable.
For models exported with tracing, dynamic input resolution is allowed, but batch size (number of input images) must be fixed. Scripting can support dynamic batch size.
Usage¶
The main export APIs for tracing and scripting are TracingAdapter
and scripting_with_instances.
Their usage is currently demonstrated in test_export_torchscript.py
(see TestScripting
and TestTracing
)
as well as the deployment example.
Please check that these examples can run, and then modify for your use cases.
The usage now requires some user effort and necessary knowledge for each model to workaround the limitation of scripting and tracing.
In the future we plan to wrap these under simpler APIs to lower the bar to use them.
Deployment with Caffe2-tracing¶
We provide Caffe2Tracer that performs the export logic. It replaces parts of the model with Caffe2 operators, and then export the model into Caffe2, TorchScript or ONNX format.
The converted model is able to run in either Python or C++ without detectron2/torchvision dependency, on CPU or GPUs. It has a runtime optimized for CPU & mobile inference, but not optimized for GPU inference.
This feature requires 1.9 > ONNX ≥ 1.6.
Coverage¶
Most official models under these 3 common meta architectures: GeneralizedRCNN
, RetinaNet
, PanopticFPN
are supported. Cascade R-CNN is not supported. Batch inference is not supported.
Users’ custom extensions under these architectures (added through registration) are supported as long as they do not contain control flow or operators not available in Caffe2 (e.g. deformable convolution). For example, custom backbones and heads are often supported out of the box.
Usage¶
The APIs are listed at the API documentation. We provide export_model.py as an example that uses these APIs to convert a standard model. For custom models/datasets, you can add them to this script.
Use the model in C++/Python¶
The model can be loaded in C++ and deployed with either Caffe2 or Pytorch runtime.. C++ examples for Mask R-CNN are given as a reference. Note that:
Models exported with
caffe2_tracing
method take a special input format described in documentation. This was taken care of in the C++ example.The converted models do not contain post-processing operations that transform raw layer outputs into formatted predictions. For example, the C++ examples only produce raw outputs (28x28 masks) from the final layers that are not post-processed, because in actual deployment, an application often needs its custom lightweight post-processing, so this step is left for users.
To help use the Caffe2-format model in python, we provide a python wrapper around the converted model, in the Caffe2Model.__call__ method. This method has an interface that’s identical to the pytorch versions of models, and it internally applies pre/post-processing code to match the formats. This wrapper can serve as a reference for how to use Caffe2’s python API, or for how to implement pre/post-processing in actual deployment.
Conversion to TensorFlow¶
tensorpack Faster R-CNN provides scripts to convert a few standard detectron2 R-CNN models to TensorFlow’s pb format. It works by translating configs and weights, therefore only support a few models.
Notes¶
Benchmarks¶
Here we benchmark the training speed of a Mask R-CNN in detectron2, with some other popular open source Mask R-CNN implementations.
Settings¶
Hardware: 8 NVIDIA V100s with NVLink.
Software: Python 3.7, CUDA 10.1, cuDNN 7.6.5, PyTorch 1.5, TensorFlow 1.15.0rc2, Keras 2.2.5, MxNet 1.6.0b20190820.
Model: an end-to-end R-50-FPN Mask-RCNN model, using the same hyperparameter as the Detectron baseline config (it does no have scale augmentation).
Metrics: We use the average throughput in iterations 100-500 to skip GPU warmup time. Note that for R-CNN-style models, the throughput of a model typically changes during training, because it depends on the predictions of the model. Therefore this metric is not directly comparable with “train speed” in model zoo, which is the average speed of the entire training run.
Main Results¶
Implementation |
Throughput (img/s) |
---|---|
62 |
|
53 |
|
53 |
|
50 |
|
39 |
|
19 |
|
14 |
Details for each implementation:
Detectron2: with release v0.1.2, run:
python tools/train_net.py --config-file configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml --num-gpus 8
mmdetection: at commit
b0d845f
, run./tools/dist_train.sh configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py 8
maskrcnn-benchmark: use commit
0ce8f6f
withsed -i 's/torch.uint8/torch.bool/g' **/*.py; sed -i 's/AT_CHECK/TORCH_CHECK/g' **/*.cu
to make it compatible with PyTorch 1.5. Then, run training withpython -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/e2e_mask_rcnn_R_50_FPN_1x.yaml
The speed we observed is faster than its model zoo, likely due to different software versions.
tensorpack: at commit
caafda
,export TF_CUDNN_USE_AUTOTUNE=0
, then runmpirun -np 8 ./train.py --config DATA.BASEDIR=/data/coco TRAINER=horovod BACKBONE.STRIDE_1X1=True TRAIN.STEPS_PER_EPOCH=50 --load ImageNet-R50-AlignPadding.npz
SimpleDet: at commit
9187a1
, runpython detection_train.py --config config/mask_r50v1_fpn_1x.py
Detectron: run
python tools/train_net.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml
Note that many of its ops run on CPUs, therefore the performance is limited.
matterport/Mask_RCNN: at commit
3deaec
, apply the following diff,export TF_CUDNN_USE_AUTOTUNE=0
, then runpython coco.py train --dataset=/data/coco/ --model=imagenet
Note that many small details in this implementation might be different from Detectron’s standards.
(diff to make it use the same hyperparameters - click to expand)
diff --git i/mrcnn/model.py w/mrcnn/model.py index 62cb2b0..61d7779 100644 --- i/mrcnn/model.py +++ w/mrcnn/model.py @@ -2367,8 +2367,8 @@ class MaskRCNN(): epochs=epochs, steps_per_epoch=self.config.STEPS_PER_EPOCH, callbacks=callbacks, - validation_data=val_generator, - validation_steps=self.config.VALIDATION_STEPS, + #validation_data=val_generator, + #validation_steps=self.config.VALIDATION_STEPS, max_queue_size=100, workers=workers, use_multiprocessing=True, diff --git i/mrcnn/parallel_model.py w/mrcnn/parallel_model.py index d2bf53b..060172a 100644 --- i/mrcnn/parallel_model.py +++ w/mrcnn/parallel_model.py @@ -32,6 +32,7 @@ class ParallelModel(KM.Model): keras_model: The Keras model to parallelize gpu_count: Number of GPUs. Must be > 1 """ + super().__init__() self.inner_model = keras_model self.gpu_count = gpu_count merged_outputs = self.make_parallel() diff --git i/samples/coco/coco.py w/samples/coco/coco.py index 5d172b5..239ed75 100644 --- i/samples/coco/coco.py +++ w/samples/coco/coco.py @@ -81,7 +81,10 @@ class CocoConfig(Config): IMAGES_PER_GPU = 2 # Uncomment to train on 8 GPUs (default is 1) - # GPU_COUNT = 8 + GPU_COUNT = 8 + BACKBONE = "resnet50" + STEPS_PER_EPOCH = 50 + TRAIN_ROIS_PER_IMAGE = 512 # Number of classes (including background) NUM_CLASSES = 1 + 80 # COCO has 80 classes @@ -496,29 +499,10 @@ if __name__ == '__main__': # *** This training schedule is an example. Update to your needs *** # Training - Stage 1 - print("Training network heads") model.train(dataset_train, dataset_val, learning_rate=config.LEARNING_RATE, epochs=40, - layers='heads', - augmentation=augmentation) - - # Training - Stage 2 - # Finetune layers from ResNet stage 4 and up - print("Fine tune Resnet stage 4 and up") - model.train(dataset_train, dataset_val, - learning_rate=config.LEARNING_RATE, - epochs=120, - layers='4+', - augmentation=augmentation) - - # Training - Stage 3 - # Fine tune all layers - print("Fine tune all layers") - model.train(dataset_train, dataset_val, - learning_rate=config.LEARNING_RATE / 10, - epochs=160, - layers='all', + layers='3+', augmentation=augmentation) elif args.command == "evaluate":
Compatibility with Other Libraries¶
Compatibility with Detectron (and maskrcnn-benchmark)¶
Detectron2 addresses some legacy issues left in Detectron. As a result, their models are not compatible: running inference with the same model weights will produce different results in the two code bases.
The major differences regarding inference are:
The height and width of a box with corners (x1, y1) and (x2, y2) is now computed more naturally as width = x2 - x1 and height = y2 - y1; In Detectron, a “+ 1” was added both height and width.
Note that the relevant ops in Caffe2 have adopted this change of convention with an extra option. So it is still possible to run inference with a Detectron2-trained model in Caffe2.
The change in height/width calculations most notably changes:
encoding/decoding in bounding box regression.
non-maximum suppression. The effect here is very negligible, though.
RPN now uses simpler anchors with fewer quantization artifacts.
In Detectron, the anchors were quantized and do not have accurate areas. In Detectron2, the anchors are center-aligned to feature grid points and not quantized.
Classification layers have a different ordering of class labels.
This involves any trainable parameter with shape (…, num_categories + 1, …). In Detectron2, integer labels [0, K-1] correspond to the K = num_categories object categories and the label “K” corresponds to the special “background” category. In Detectron, label “0” means background, and labels [1, K] correspond to the K categories.
ROIAlign is implemented differently. The new implementation is available in Caffe2.
All the ROIs are shifted by half a pixel compared to Detectron in order to create better image-feature-map alignment. See
layers/roi_align.py
for details. To enable the old behavior, useROIAlign(aligned=False)
, orPOOLER_TYPE=ROIAlign
instead ofROIAlignV2
(the default).The ROIs are not required to have a minimum size of 1. This will lead to tiny differences in the output, but should be negligible.
Mask inference function is different.
In Detectron2, the “paste_mask” function is different and should be more accurate than in Detectron. This change can improve mask AP on COCO by ~0.5% absolute.
There are some other differences in training as well, but they won’t affect model-level compatibility. The major ones are:
We fixed a bug in Detectron, by making
RPN.POST_NMS_TOPK_TRAIN
per-image, rather than per-batch. The fix may lead to a small accuracy drop for a few models (e.g. keypoint detection) and will require some parameter tuning to match the Detectron results.For simplicity, we change the default loss in bounding box regression to L1 loss, instead of smooth L1 loss. We have observed that this tends to slightly decrease box AP50 while improving box AP for higher overlap thresholds (and leading to a slight overall improvement in box AP).
We interpret the coordinates in COCO bounding box and segmentation annotations as coordinates in range
[0, width]
or[0, height]
. The coordinates in COCO keypoint annotations are interpreted as pixel indices in range[0, width - 1]
or[0, height - 1]
. Note that this affects how flip augmentation is implemented.
This article explains more details on the above mentioned issues about pixels, coordinates, and “+1”s.
Compatibility with Caffe2¶
As mentioned above, despite the incompatibilities with Detectron, the relevant ops have been implemented in Caffe2. Therefore, models trained with detectron2 can be converted in Caffe2. See Deployment for the tutorial.
Compatibility with TensorFlow¶
Most ops are available in TensorFlow, although some tiny differences in the implementation of resize / ROIAlign / padding need to be addressed. A working conversion script is provided by tensorpack Faster R-CNN to run a standard detectron2 model in TensorFlow.
Contributing to detectron2¶
Issues¶
We use GitHub issues to track public bugs and questions. Please make sure to follow one of the issue templates when reporting any issues.
Facebook has a bounty program for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue.
Pull Requests¶
We actively welcome pull requests.
However, if you’re adding any significant features (e.g. > 50 lines), please make sure to discuss with maintainers about your motivation and proposals in an issue before sending a PR. This is to save your time so you don’t spend time on a PR that we’ll not accept.
We do not always accept new features, and we take the following factors into consideration:
Whether the same feature can be achieved without modifying detectron2. Detectron2 is designed so that you can implement many extensions from the outside, e.g. those in projects.
If some part of detectron2 is not extensible enough, you can also bring up a more general issue to improve it. Such feature request may be useful to more users.
Whether the feature is potentially useful to a large audience (e.g. an impactful detection paper, a popular dataset, a significant speedup, a widely useful utility), or only to a small portion of users (e.g., a less-known paper, an improvement not in the object detection field, a trick that’s not very popular in the community, code to handle a non-standard type of data)
Adoption of additional models, datasets, new task are by default not added to detectron2 before they receive significant popularity in the community. We sometimes accept such features in
projects/
, or as a link inprojects/README.md
.
Whether the proposed solution has a good design / interface. This can be discussed in the issue prior to PRs, or in the form of a draft PR.
Whether the proposed solution adds extra mental/practical overhead to users who don’t need such feature.
Whether the proposed solution breaks existing APIs.
To add a feature to an existing function/class Func
, there are always two approaches:
(1) add new arguments to Func
; (2) write a new Func_with_new_feature
.
To meet the above criteria, we often prefer approach (2), because:
It does not involve modifying or potentially breaking existing code.
It does not add overhead to users who do not need the new feature.
Adding new arguments to a function/class is not scalable w.r.t. all the possible new research ideas in the future.
When sending a PR, please do:
If a PR contains multiple orthogonal changes, split it to several PRs.
If you’ve added code that should be tested, add tests.
For PRs that need experiments (e.g. adding a new model or new methods), you don’t need to update model zoo, but do provide experiment results in the description of the PR.
If APIs are changed, update the documentation.
We use the Google style docstrings in python.
Make sure your code lints with
./dev/linter.sh
.
Contributor License Agreement (“CLA”)¶
In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook’s open source projects.
Complete your CLA here: https://code.facebook.com/cla
License¶
By contributing to detectron2, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree.
Change Log and Backward Compatibility¶
Releases¶
See release logs at https://github.com/facebookresearch/detectron2/releases for new updates.
Backward Compatibility¶
Due to the research nature of what the library does, there might be backward incompatible changes. But we try to reduce users’ disruption by the following ways:
APIs listed in API documentation, including function/class names, their arguments, and documented class attributes, are considered stable unless otherwise noted in the documentation. They are less likely to be broken, but if needed, will trigger a deprecation warning for a reasonable period before getting broken, and will be documented in release logs.
Others functions/classses/attributes are considered internal, and are more likely to change. However, we’re aware that some of them may be already used by other projects, and in particular we may use them for convenience among projects under
detectron2/projects
. For such APIs, we may treat them as stable APIs and also apply the above strategies. They may be promoted to stable when we’re ready.Projects under “detectron2/projects” or imported with “detectron2.projects” are research projects and are all considered experimental.
Classes/functions that contain the word “default” or are explicitly documented to produce “default behavior” may change their behaviors when new features are added.
Despite of the possible breakage, if a third-party project would like to keep up with the latest updates in detectron2, using it as a library will still be less disruptive than forking, because the frequency and scope of API changes will be much smaller than code changes.
To see such changes, search for “incompatible changes” in release logs.
Config Version Change Log¶
Detectron2’s config version has not been changed since open source. There is no need for an open source user to worry about this.
v1: Rename
RPN_HEAD.NAME
toRPN.HEAD_NAME
.v2: A batch of rename of many configurations before release.
Silent Regressions in Historical Versions:¶
We list a few silent regressions, since they may silently produce incorrect results and will be hard to debug.
04/01/2020 - 05/11/2020: Bad accuracy if
TRAIN_ON_PRED_BOXES
is set to True.03/30/2020 - 04/01/2020: ResNets are not correctly built.
12/19/2019 - 12/26/2019: Using aspect ratio grouping causes a drop in accuracy.
11/9/2019: Test time augmentation does not predict the last category.
API Documentation¶
detectron2.checkpoint¶
-
class
detectron2.checkpoint.
Checkpointer
(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any)[source]¶ Bases:
object
A checkpointer that can save/load model as well as extra checkpointable objects.
-
__init__
(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any) → None[source]¶ - Parameters
model (nn.Module) – model.
save_dir (str) – a directory to save and find checkpoints.
save_to_disk (bool) – if True, save checkpoint to disk, otherwise disable saving for this checkpointer.
checkpointables (object) – any checkpointable objects, i.e., objects that have the
state_dict()
andload_state_dict()
method. For example, it can be used like Checkpointer(model, “dir”, optimizer=optimizer).
-
add_checkpointable
(key: str, checkpointable: Any) → None[source]¶ Add checkpointable object for this checkpointer to track.
- Parameters
key (str) – the key used to save the object
checkpointable – any object with
state_dict()
andload_state_dict()
method
-
load
(path: str, checkpointables: Optional[List[str]] = None) → Dict[str, Any][source]¶ Load from the given checkpoint.
- Parameters
- Returns
dict – extra data loaded from the checkpoint that has not been processed. For example, those saved with
save(**extra_data)()
.
-
has_checkpoint
() → bool[source]¶ - Returns
bool – whether a checkpoint exists in the target directory.
-
get_all_checkpoint_files
() → List[str][source]¶ - Returns
list –
- All available checkpoint files (.pth files) in target
directory.
-
-
class
detectron2.checkpoint.
PeriodicCheckpointer
(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]¶ Bases:
object
Save checkpoints periodically. When .step(iteration) is called, it will execute checkpointer.save on the given checkpointer, if iteration is a multiple of period or if max_iter is reached.
-
checkpointer
¶ the underlying checkpointer object
- Type
-
__init__
(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model') → None[source]¶ - Parameters
checkpointer – the checkpointer object used to save checkpoints.
period (int) – the period to save checkpoint.
max_iter (int) – maximum number of iterations. When it is reached, a checkpoint named “{file_prefix}_final” will be saved.
max_to_keep (int) – maximum number of most current checkpoints to keep, previous checkpoints will be deleted
file_prefix (str) – the prefix of checkpoint’s filename
-
step
(iteration: int, **kwargs: Any) → None[source]¶ Perform the appropriate action at the given iteration.
- Parameters
iteration (int) – the current iteration, ranged in [0, max_iter-1].
kwargs (Any) – extra data to save, same as in
Checkpointer.save()
.
-
save
(name: str, **kwargs: Any) → None[source]¶ Same argument as
Checkpointer.save()
. Use this method to manually save checkpoints outside the schedule.- Parameters
name (str) – file name.
kwargs (Any) – extra data to save, same as in
Checkpointer.save()
.
-
-
class
detectron2.checkpoint.
DetectionCheckpointer
(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]¶ Bases:
fvcore.common.checkpoint.Checkpointer
Same as
Checkpointer
, but is able to: 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. 2. correctly load checkpoints that are only available on the master worker
detectron2.config¶
Related tutorials: Yacs Configs, Extend Detectron2’s Defaults.
-
class
detectron2.config.
CfgNode
(init_dict=None, key_list=None, new_allowed=False)¶ Bases:
fvcore.common.config.CfgNode
The same as fvcore.common.config.CfgNode, but different in:
Use unsafe yaml loading by default. Note that this may lead to arbitrary code execution: you must not load a config file from untrusted sources before manually inspecting the content of the file.
Support config versioning. When attempting to merge an old config, it will convert the old config automatically.
-
DEPRECATED_KEYS
= '__deprecated_keys__'¶
-
IMMUTABLE
= '__immutable__'¶
-
NEW_ALLOWED
= '__new_allowed__'¶
-
RENAMED_KEYS
= '__renamed_keys__'¶
-
clear
() → None. Remove all items from D.¶
-
copy
() → a shallow copy of D¶
-
dump
(*args, **kwargs)¶ - Returns
str – a yaml string representation of the config
-
fromkeys
(value=None, /)¶ Create a new dictionary with keys from iterable and values set to value.
-
get
(key, default=None, /)¶ Return the value for key if key is in the dictionary, else default.
-
items
() → a set-like object providing a view on D’s items¶
-
keys
() → a set-like object providing a view on D’s keys¶
-
classmethod
load_cfg
(cfg_file_obj_or_str)[source]¶ Load a cfg. :param cfg_file_obj_or_str: Supports loading from:
A file object backed by a YAML file
A file object backed by a Python source file that exports an attribute “cfg” that is either a dict or a CfgNode
A string that can be parsed as valid YAML
-
classmethod
load_yaml_with_base
(filename: str, allow_unsafe: bool = False) → Dict[str, Any][source]¶ - Just like yaml.load(open(filename)), but inherit attributes from its
_BASE_.
-
merge_from_list
(cfg_list: List[str]) → Callable[], None][source]¶ - Parameters
cfg_list (list) – list of configs to merge from.
-
merge_from_other_cfg
(cfg_other: fvcore.common.config.CfgNode) → Callable[], None][source]¶ - Parameters
cfg_other (CfgNode) – configs to merge from.
-
pop
(k[, d]) → v, remove specified key and return the corresponding value.¶ If key is not found, d is returned if given, otherwise KeyError is raised
-
popitem
() → (k, v), remove and return some (key, value) pair as a¶ 2-tuple; but raise KeyError if D is empty.
-
register_deprecated_key
(key)[source]¶ Register key (e.g. FOO.BAR) a deprecated option. When merging deprecated keys a warning is generated and the key is ignored.
-
register_renamed_key
(old_name, new_name, message=None)[source]¶ Register a key as having been renamed from old_name to new_name. When merging a renamed key, an exception is thrown alerting to user to the fact that the key has been renamed.
-
set_new_allowed
(is_new_allowed)[source]¶ Set this config (and recursively its subconfigs) to allow merging new keys from other configs.
-
setdefault
(key, default=None, /)¶ Insert key with a value of default if key is not in the dictionary.
Return the value for key if key is in the dictionary, else default.
-
update
([E, ]**F) → None. Update D from dict/iterable E and F.¶ If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
-
values
() → an object providing a view on D’s values¶
-
detectron2.config.
get_cfg
() → detectron2.config.CfgNode¶ Get a copy of the default config.
- Returns
a detectron2 CfgNode instance.
-
detectron2.config.
set_global_cfg
(cfg: detectron2.config.CfgNode) → None¶ Let the global config point to the given cfg.
Assume that the given “cfg” has the key “KEY”, after calling set_global_cfg(cfg), the key can be accessed by:
from detectron2.config import global_cfg print(global_cfg.KEY)
By using a hacky global config, you can access these configs anywhere, without having to pass the config object or the values deep into the code. This is a hacky feature introduced for quick prototyping / research exploration.
-
detectron2.config.
downgrade_config
(cfg: detectron2.config.CfgNode, to_version: int) → detectron2.config.CfgNode¶ Downgrade a config from its current version to an older version.
Note
A general downgrade of arbitrary configs is not always possible due to the different functionalities in different versions. The purpose of downgrade is only to recover the defaults in old versions, allowing it to load an old partial yaml config. Therefore, the implementation only needs to fill in the default values in the old version when a general downgrade is not possible.
-
detectron2.config.
upgrade_config
(cfg: detectron2.config.CfgNode, to_version: Optional[int] = None) → detectron2.config.CfgNode¶ Upgrade a config from its current version to a newer version.
-
detectron2.config.
configurable
(init_func=None, *, from_config=None)¶ Decorate a function or a class’s __init__ method so that it can be called with a
CfgNode
object using afrom_config()
function that translatesCfgNode
to arguments.Examples:
# Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
- Parameters
init_func (callable) – a class’s
__init__
method in usage 1. The class must have afrom_config
classmethod which takes cfg as the first argument.from_config (callable) – the from_config function in usage 2. It must take cfg as its first argument.
-
detectron2.config.
instantiate
(cfg)¶ Recursively instantiate objects defined in dictionaries by “_target_” and arguments.
- Parameters
cfg – a dict-like object with “_target_” that defines the caller, and other keys that define the arguments
- Returns
object instantiated by cfg
-
class
detectron2.config.
LazyCall
(target)¶ Bases:
object
Wrap a callable so that when it’s called, the call will not be executed, but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional arguments are not yet supported.
Examples:
from detectron2.config import instantiate, LazyCall layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) layer_cfg.out_channels = 64 # can edit it afterwards layer = instantiate(layer_cfg)
-
class
detectron2.config.
LazyConfig
¶ Bases:
object
Provid methods to save, load, and overrides an omegaconf config object which may contain definition of lazily-constructed objects.
-
static
apply_overrides
(cfg, overrides: List[str])[source]¶ In-place override contents of cfg.
- Parameters
cfg – an omegaconf config object
overrides – list of strings in the format of “a=b” to override configs. See https://hydra.cc/docs/next/advanced/override_grammar/basic/ for syntax.
- Returns
the cfg object
-
static
load
(filename: str, keys: Union[None, str, Tuple[str, …]] = None)[source]¶ Load a config file.
- Parameters
filename – absolute path or relative path w.r.t. the current working directory
keys – keys to load and return. If not given, return all keys (whose values are config objects) in a dict.
-
static
load_rel
(filename: str, keys: Union[None, str, Tuple[str, …]] = None)[source]¶ Similar to
load()
, but load path relative to the caller’s source file.This has the same functionality as a relative import, except that this method accepts filename as a string, so more characters are allowed in the filename.
-
static
save
(cfg, filename: str)[source]¶ Save a config object to a yaml file. Note that when the config dictionary contains complex objects (e.g. lambda), it can’t be saved to yaml. In that case we will print an error and attempt to save to a pkl file instead.
- Parameters
cfg – an omegaconf config object
filename – yaml file name to save the config file
-
static
to_py
(cfg, prefix: str = 'cfg.')[source]¶ Try to convert a config object into its equivalent Python code.
Note that this is not always possible. So the returned results are mainly meant to be human-readable, and not meant to be loaded back.
- Parameters
cfg – an omegaconf config object
prefix – root name for the resulting code (default: “cfg.”)
- Returns
str of formatted Python code
-
static
Yaml Config References¶
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 | # -----------------------------------------------------------------------------
# Convention about Training / Test specific parameters
# -----------------------------------------------------------------------------
# Whenever an argument can be either used for training or for testing, the
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
# or _TEST for a test-specific parameter.
# For example, the number of images during training will be
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
# IMAGES_PER_BATCH_TEST
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()
# The version number, to upgrade from old configs to new ones if any
# changes happen. It's recommended to keep a VERSION in your config file.
_C.VERSION = 2
_C.MODEL = CN()
_C.MODEL.LOAD_PROPOSALS = False
_C.MODEL.MASK_ON = False
_C.MODEL.KEYPOINT_ON = False
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file
# to be loaded to the model. You can find available models in the model zoo.
_C.MODEL.WEIGHTS = ""
# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
# To train on images of different number of channels, just set different mean & std.
# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
# When using pre-trained models in Detectron1 or any MSRA models,
# std has been absorbed into its conv1 weights, so the std needs to be set 1.
# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the smallest side of the image during training
_C.INPUT.MIN_SIZE_TRAIN = (800,)
# Sample size of smallest side by choice or random selection from range give by
# INPUT.MIN_SIZE_TRAIN
_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
# Maximum size of the side of the image during training
_C.INPUT.MAX_SIZE_TRAIN = 1333
# Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
_C.INPUT.MIN_SIZE_TEST = 800
# Maximum size of the side of the image during testing
_C.INPUT.MAX_SIZE_TEST = 1333
# Mode for flipping images used in data augmentation during training
# choose one of ["horizontal, "vertical", "none"]
_C.INPUT.RANDOM_FLIP = "horizontal"
# `True` if cropping is used for data augmentation during training
_C.INPUT.CROP = CN({"ENABLED": False})
# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation.
_C.INPUT.CROP.TYPE = "relative_range"
# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
# pixels if CROP.TYPE is "absolute"
_C.INPUT.CROP.SIZE = [0.9, 0.9]
# Whether the model needs RGB, YUV, HSV etc.
# Should be one of the modes defined here, as we use PIL to read the image:
# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
# with BGR being the one exception. One can set image format to BGR, we will
# internally use RGB for conversion and flip the channels over
_C.INPUT.FORMAT = "BGR"
# The ground truth mask format that the model will use.
# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training. Must be registered in DatasetCatalog
# Samples from these datasets will be merged and used as one dataset.
_C.DATASETS.TRAIN = ()
# List of the pre-computed proposal files for training, which must be consistent
# with datasets listed in DATASETS.TRAIN.
_C.DATASETS.PROPOSAL_FILES_TRAIN = ()
# Number of top scoring precomputed proposals to keep for training
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
# List of the dataset names for testing. Must be registered in DatasetCatalog
_C.DATASETS.TEST = ()
# List of the pre-computed proposal files for test, which must be consistent
# with datasets listed in DATASETS.TEST.
_C.DATASETS.PROPOSAL_FILES_TEST = ()
# Number of top scoring precomputed proposals to keep for test
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Number of data loading threads
_C.DATALOADER.NUM_WORKERS = 4
# If True, each batch should contain only images for which the aspect ratio
# is compatible. This groups portrait images together, and landscape images
# are not batched with portrait images.
_C.DATALOADER.ASPECT_RATIO_GROUPING = True
# Options: TrainingSampler, RepeatFactorTrainingSampler
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
# Repeat threshold for RepeatFactorTrainingSampler
_C.DATALOADER.REPEAT_THRESHOLD = 0.0
# Tf True, when working on datasets that have instance annotations, the
# training dataloader will filter out images without associated annotations
_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
# ---------------------------------------------------------------------------- #
# Backbone options
# ---------------------------------------------------------------------------- #
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
# Freeze the first several stages so they are not trained.
# There are 5 stages in ResNet. The first is a convolution, and the following
# stages are each group of residual blocks.
_C.MODEL.BACKBONE.FREEZE_AT = 2
# ---------------------------------------------------------------------------- #
# FPN options
# ---------------------------------------------------------------------------- #
_C.MODEL.FPN = CN()
# Names of the input feature maps to be used by FPN
# They must have contiguous power of 2 strides
# e.g., ["res2", "res3", "res4", "res5"]
_C.MODEL.FPN.IN_FEATURES = []
_C.MODEL.FPN.OUT_CHANNELS = 256
# Options: "" (no norm), "GN"
_C.MODEL.FPN.NORM = ""
# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
_C.MODEL.FPN.FUSE_TYPE = "sum"
# ---------------------------------------------------------------------------- #
# Proposal generator options
# ---------------------------------------------------------------------------- #
_C.MODEL.PROPOSAL_GENERATOR = CN()
# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
# Proposal height and width both need to be greater than MIN_SIZE
# (a the scale used during training or inference)
_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
# ---------------------------------------------------------------------------- #
# Anchor generator options
# ---------------------------------------------------------------------------- #
_C.MODEL.ANCHOR_GENERATOR = CN()
# The generator can be any name in the ANCHOR_GENERATOR registry
_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for
# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1.
# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES.
_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
# ratios are generated by an anchor generator.
# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
# for all IN_FEATURES.
_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
# Anchor angles.
# list[list[float]], the angle in degrees, for each input feature map.
# ANGLES[i] specifies the list of angles for IN_FEATURES[i].
_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
# Relative offset between the center of the first anchor and the top-left corner of the image
# Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
# The value is not expected to affect model accuracy.
_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
# ---------------------------------------------------------------------------- #
# RPN options
# ---------------------------------------------------------------------------- #
_C.MODEL.RPN = CN()
_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
# Names of the input feature maps to be used by RPN
# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
_C.MODEL.RPN.IN_FEATURES = ["res4"]
# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
_C.MODEL.RPN.BOUNDARY_THRESH = -1
# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
# Minimum overlap required between an anchor and ground-truth box for the
# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
# ==> positive RPN example: 1)
# Maximum overlap allowed between an anchor and ground-truth box for the
# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
# ==> negative RPN example: 0)
# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
# are ignored (-1)
_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
_C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
# Number of regions per image used to train RPN
_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
# Target fraction of foreground (positive) examples per RPN minibatch
_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
# Options are: "smooth_l1", "giou"
_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
_C.MODEL.RPN.LOSS_WEIGHT = 1.0
# Number of top scoring RPN proposals to keep before applying NMS
# When FPN is used, this is *per FPN level* (not total)
_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
# Number of top scoring RPN proposals to keep after applying NMS
# When FPN is used, this limit is applied per level and then again to the union
# of proposals from all levels
# NOTE: When FPN is used, the meaning of this config is different from Detectron1.
# It means per-batch topk in Detectron1, but per-image topk here.
# See the "find_top_rpn_proposals" function for details.
_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
# NMS threshold used on RPN proposals
_C.MODEL.RPN.NMS_THRESH = 0.7
# Set this to -1 to use the same number of output channels as input channels.
_C.MODEL.RPN.CONV_DIMS = [-1]
# ---------------------------------------------------------------------------- #
# ROI HEADS options
# ---------------------------------------------------------------------------- #
_C.MODEL.ROI_HEADS = CN()
_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
# Number of foreground classes
_C.MODEL.ROI_HEADS.NUM_CLASSES = 80
# Names of the input feature maps to be used by ROI heads
# Currently all heads (box, mask, ...) use the same input feature map list
# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
# IOU overlap ratios [IOU_THRESHOLD]
# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
# RoI minibatch size *per image* (number of regions of interest [ROIs])
# Total number of RoIs per training minibatch =
# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
# E.g., a common configuration is: 512 * 16 = 8192
_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
# Only used on test mode
# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
# balance obtaining high recall with not having too many low precision
# detections that will slow down inference post processing steps (like NMS)
# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
# inference.
_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
# Overlap threshold used for non-maximum suppression (suppress boxes with
# IoU >= this threshold)
_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
# If True, augment proposals with ground-truth boxes before sampling proposals to
# train ROI heads.
_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
# ---------------------------------------------------------------------------- #
# Box Head
# ---------------------------------------------------------------------------- #
_C.MODEL.ROI_BOX_HEAD = CN()
# C4 don't use head name option
# Options for non-C4 models: FastRCNNConvFCHead,
_C.MODEL.ROI_BOX_HEAD.NAME = ""
# Options are: "smooth_l1", "giou"
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
# The final scaling coefficient on the box regression loss, used to balance the magnitude of its
# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0
# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
# These are empirically chosen to approximately lead to unit variance targets
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
# Type of pooling operation applied to the incoming feature map for each RoI
_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
# Hidden layer dimension for FC layers in the RoI box head
_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
# Channel dimension for Conv layers in the RoI box head
_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
# Normalization method for the convolution layers.
# Options: "" (no norm), "GN", "SyncBN".
_C.MODEL.ROI_BOX_HEAD.NORM = ""
# Whether to use class agnostic for bbox regression
_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
# ---------------------------------------------------------------------------- #
# Cascaded Box Head
# ---------------------------------------------------------------------------- #
_C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
# The number of cascade stages is implicitly defined by the length of the following two configs.
_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
(10.0, 10.0, 5.0, 5.0),
(20.0, 20.0, 10.0, 10.0),
(30.0, 30.0, 15.0, 15.0),
)
_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
# ---------------------------------------------------------------------------- #
# Mask Head
# ---------------------------------------------------------------------------- #
_C.MODEL.ROI_MASK_HEAD = CN()
_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
# Normalization method for the convolution layers.
# Options: "" (no norm), "GN", "SyncBN".
_C.MODEL.ROI_MASK_HEAD.NORM = ""
# Whether to use class agnostic for mask prediction
_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
# Type of pooling operation applied to the incoming feature map for each RoI
_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
# ---------------------------------------------------------------------------- #
# Keypoint Head
# ---------------------------------------------------------------------------- #
_C.MODEL.ROI_KEYPOINT_HEAD = CN()
_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
# Images with too few (or no) keypoints are excluded from training.
_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
# Normalize by the total number of visible keypoints in the minibatch if True.
# Otherwise, normalize by the total number of keypoints that could ever exist
# in the minibatch.
# The keypoint softmax loss is only calculated on visible keypoints.
# Since the number of visible keypoints can vary significantly between
# minibatches, this has the effect of up-weighting the importance of
# minibatches with few visible keypoints. (Imagine the extreme case of
# only one visible keypoint versus N: in the case of N, each one
# contributes 1/N to the gradient compared to the single keypoint
# determining the gradient direction). Instead, we can normalize the
# loss by the total number of keypoints, if it were the case that all
# keypoints were visible in a full minibatch. (Returning to the example,
# this means that the one visible keypoint contributes as much as each
# of the N keypoints.)
_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
# Multi-task loss weight to use for keypoints
# Recommended values:
# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
# Type of pooling operation applied to the incoming feature map for each RoI
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
# ---------------------------------------------------------------------------- #
# Semantic Segmentation Head
# ---------------------------------------------------------------------------- #
_C.MODEL.SEM_SEG_HEAD = CN()
_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
# the correposnding pixel.
_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
# Number of classes in the semantic segmentation head
_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
# Number of channels in the 3x3 convs inside semantic-FPN heads.
_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
# Normalization method for the convolution layers. Options: "" (no norm), "GN".
_C.MODEL.SEM_SEG_HEAD.NORM = "GN"
_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
_C.MODEL.PANOPTIC_FPN = CN()
# Scaling of all losses from instance detection / segmentation head.
_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
# options when combining instance & semantic segmentation outputs
_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used
_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
# ---------------------------------------------------------------------------- #
# RetinaNet Head
# ---------------------------------------------------------------------------- #
_C.MODEL.RETINANET = CN()
# This is the number of foreground classes.
_C.MODEL.RETINANET.NUM_CLASSES = 80
_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
_C.MODEL.RETINANET.NUM_CONVS = 4
# IoU overlap ratio [bg, fg] for labeling anchors.
# Anchors with < bg are labeled negative (0)
# Anchors with >= bg and < fg are ignored (-1)
# Anchors with >= fg are labeled positive (1)
_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
# Prior prob for rare case (i.e. foreground) at the beginning of training.
# This is used to set the bias for the logits layer of the classifier subnet.
# This improves training stability in the case of heavy class imbalance.
_C.MODEL.RETINANET.PRIOR_PROB = 0.01
# Inference cls score threshold, only anchors with score > INFERENCE_TH are
# considered for inference (to improve speed)
_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
# Select topk candidates before NMS
_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
# Loss parameters
_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
# Options are: "smooth_l1", "giou"
_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
# One of BN, SyncBN, FrozenBN, GN
# Only supports GN until unshared norm is implemented
_C.MODEL.RETINANET.NORM = ""
# ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
# Note that parts of a resnet may be used for both the backbone and the head
# These options apply to both
# ---------------------------------------------------------------------------- #
_C.MODEL.RESNETS = CN()
_C.MODEL.RESNETS.DEPTH = 50
_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
_C.MODEL.RESNETS.NUM_GROUPS = 1
# Options: FrozenBN, GN, "SyncBN", "BN"
_C.MODEL.RESNETS.NORM = "FrozenBN"
# Baseline width of each group.
# Scaling this parameters will scale the width of all bottleneck layers.
_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
# Place the stride 2 conv on the 1x1 filter
# Use True only for the original MSRA ResNet; use False for C2 and Torch models
_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
# Apply dilation in stage "res5"
_C.MODEL.RESNETS.RES5_DILATION = 1
# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
# For R18 and R34, this needs to be set to 64
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
# Apply Deformable Convolution in stages
# Specify if apply deform_conv on Res2, Res3, Res4, Res5
_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
# Use False for DeformableV1.
_C.MODEL.RESNETS.DEFORM_MODULATED = False
# Number of groups in deformable conv.
_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
# See detectron2/solver/build.py for LR scheduler options
_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.BASE_LR = 0.001
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.NESTEROV = False
_C.SOLVER.WEIGHT_DECAY = 0.0001
# The weight decay that's applied to parameters of normalization layers
# (typically the affine transformation)
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
_C.SOLVER.GAMMA = 0.1
# The iteration number to decrease learning rate by GAMMA.
_C.SOLVER.STEPS = (30000,)
_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
_C.SOLVER.WARMUP_ITERS = 1000
_C.SOLVER.WARMUP_METHOD = "linear"
# Save a checkpoint after every this number of iterations
_C.SOLVER.CHECKPOINT_PERIOD = 5000
# Number of images per batch across all machines. This is also the number
# of training images per step (i.e. per iteration). If we use 16 GPUs
# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch.
# May be adjusted automatically if REFERENCE_WORLD_SIZE is set.
_C.SOLVER.IMS_PER_BATCH = 16
# The reference number of workers (GPUs) this config is meant to train with.
# It takes no effect when set to 0.
# With a non-zero value, it will be used by DefaultTrainer to compute a desired
# per-worker batch size, and then scale the other related configs (total batch size,
# learning rate, etc) to match the per-worker batch size.
# See documentation of `DefaultTrainer.auto_scale_workers` for details:
_C.SOLVER.REFERENCE_WORLD_SIZE = 0
# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
# biases. This is not useful (at least for recent models). You should avoid
# changing these and they exist only to reproduce Detectron v1 training if
# desired.
_C.SOLVER.BIAS_LR_FACTOR = 1.0
_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
# Gradient clipping
_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
# Type of gradient clipping, currently 2 values are supported:
# - "value": the absolute values of elements of each gradients are clipped
# - "norm": the norm of the gradient for each parameter is clipped thus
# affecting all elements in the parameter
_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
# Maximum absolute value used for clipping gradients
_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
# Floating point number p for L-p norm to be used with the "norm"
# gradient clipping type; for L-inf, please specify .inf
_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
# Enable automatic mixed precision for training
# Note that this does not change model's inference behavior.
# To use AMP in inference, run inference under autocast()
_C.SOLVER.AMP = CN({"ENABLED": False})
# ---------------------------------------------------------------------------- #
# Specific test options
# ---------------------------------------------------------------------------- #
_C.TEST = CN()
# For end-to-end tests to verify the expected accuracy.
# Each item is [task, metric, value, tolerance]
# e.g.: [['bbox', 'AP', 38.5, 0.2]]
_C.TEST.EXPECTED_RESULTS = []
# The period (in terms of steps) to evaluate the model during training.
# Set to 0 to disable.
_C.TEST.EVAL_PERIOD = 0
# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
# When empty, it will use the defaults in COCO.
# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
_C.TEST.KEYPOINT_OKS_SIGMAS = []
# Maximum number of detections to return per image during inference (100 is
# based on the limit established for the COCO dataset).
_C.TEST.DETECTIONS_PER_IMAGE = 100
_C.TEST.AUG = CN({"ENABLED": False})
_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
_C.TEST.AUG.MAX_SIZE = 4000
_C.TEST.AUG.FLIP = True
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
_C.TEST.PRECISE_BN.NUM_ITER = 200
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
# Directory where output files are written
_C.OUTPUT_DIR = "./output"
# Set seed to negative to fully randomize everything.
# Set seed to positive to use a fixed seed. Note that a fixed seed increases
# reproducibility but does not guarantee fully deterministic behavior.
# Disabling all parallelism further increases reproducibility.
_C.SEED = -1
# Benchmark different cudnn algorithms.
# If input images have very different sizes, this option will have large overhead
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False
# The period (in terms of steps) for minibatch visualization at train time.
# Set to 0 to disable.
_C.VIS_PERIOD = 0
# global config is for quick hack purposes.
# You can set them in command line or config files,
# and access it with:
#
# from detectron2.config import global_cfg
# print(global_cfg.HACK)
#
# Do not commit any configs into it.
_C.GLOBAL = CN()
_C.GLOBAL.HACK = 1.0
|
detectron2.data¶
-
detectron2.data.
DatasetCatalog
(dict)¶ A global dictionary that stores information about the datasets and how to obtain them.
It contains a mapping from strings (which are names that identify a dataset, e.g. “coco_2014_train”) to a function which parses the dataset and returns the samples in the format of list[dict].
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details) if used with the data loader functionalities in data/build.py,data/detection_transform.py.
The purpose of having this catalog is to make it easy to choose different datasets, by just using the strings in the config.
-
detectron2.data.
MetadataCatalog
(dict)¶ MetadataCatalog is a global dictionary that provides access to
Metadata
of a given dataset.The metadata associated with a certain name is a singleton: once created, the metadata will stay alive and will be returned by future calls to
get(name)
.It’s like global variables, so don’t abuse it. It’s meant for storing knowledge that’s constant and shared across the execution of the program, e.g.: the class names in COCO.
-
detectron2.data.
build_detection_test_loader
(dataset, *, mapper, sampler=None, num_workers=0)[source]¶ Similar to build_detection_train_loader, but uses a batch size of 1, and
InferenceSampler
. This sampler coordinates all workers to produce the exact set of all samples. This interface is experimental.- Parameters
dataset (list or torch.utils.data.Dataset) – a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using
DatasetCatalog.get()
orget_detection_dataset_dicts()
.mapper (callable) – a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is
DatasetMapper(cfg, is_train=False)
.sampler (torch.utils.data.sampler.Sampler or None) – a sampler that produces indices to be applied on
dataset
. Default toInferenceSampler
, which splits the dataset across all workers.num_workers (int) – number of parallel data loading workers
- Returns
DataLoader – a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching.
Examples:
data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test")
-
detectron2.data.
build_detection_train_loader
(dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0)[source]¶ Build a dataloader for object detection with some default features. This interface is experimental.
- Parameters
dataset (list or torch.utils.data.Dataset) – a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using
DatasetCatalog.get()
orget_detection_dataset_dicts()
.mapper (callable) – a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is
DatasetMapper(cfg, is_train=True)
.sampler (torch.utils.data.sampler.Sampler or None) – a sampler that produces indices to be applied on
dataset
. Default toTrainingSampler
, which coordinates an infinite random shuffle sequence across all workers.total_batch_size (int) – total batch size across all workers. Batching simply puts data into a list.
aspect_ratio_grouping (bool) – whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys “width” and “height”.
num_workers (int) – number of parallel data loading workers
- Returns
torch.utils.data.DataLoader – a dataloader. Each output from it is a
list[mapped_element]
of lengthtotal_batch_size / num_workers
, wheremapped_element
is produced by themapper
.
-
detectron2.data.
get_detection_dataset_dicts
(names, filter_empty=True, min_keypoints=0, proposal_files=None)[source]¶ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
- Parameters
names (str or list[str]) – a dataset name or a list of dataset names
filter_empty (bool) – whether to filter out images without instance annotations
min_keypoints (int) – filter out images with fewer keypoints than min_keypoints. Set to 0 to do nothing.
proposal_files (list[str]) – if given, a list of object proposal files that match each dataset in names.
- Returns
list[dict] – a list of dicts following the standard dataset dict format.
-
detectron2.data.
load_proposals_into_dataset
(dataset_dicts, proposal_file)[source]¶ Load precomputed object proposals into the dataset.
The proposal file should be a pickled dict with the following keys:
“ids”: list[int] or list[str], the image ids
“boxes”: list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
“objectness_logits”: list[np.ndarray], each is an N sized array of objectness scores corresponding to the boxes.
“bbox_mode”: the BoxMode of the boxes array. Defaults to
BoxMode.XYXY_ABS
.
-
class
detectron2.data.
Metadata
[source]¶ Bases:
types.SimpleNamespace
A class that supports simple attribute setter/getter. It is intended for storing metadata of a dataset and make it accessible globally.
Examples:
# somewhere when you load the data: MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"] # somewhere when you print statistics or visualize: classes = MetadataCatalog.get("mydataset").thing_classes
-
class
detectron2.data.
DatasetFromList
(*args, **kwds)[source]¶ Bases:
torch.utils.data.Dataset
Wrap a list to a torch Dataset. It produces elements of the list as data.
-
__init__
(lst: list, copy: bool = True, serialize: bool = True)[source]¶ - Parameters
lst (list) – a list which contains elements to produce.
copy (bool) – whether to deepcopy the element when producing it, so that the result can be modified in place without affecting the source in the list.
serialize (bool) – whether to hold memory using serialized objects, when enabled, data loader workers can use shared RAM from master process instead of making a copy.
-
-
class
detectron2.data.
MapDataset
(*args, **kwds)[source]¶ Bases:
torch.utils.data.Dataset
Map a function over the elements in a dataset.
- Parameters
dataset – a dataset where map function is applied.
map_func – a callable which maps the element in dataset. map_func is responsible for error handling, when error happens, it needs to return None so the MapDataset will randomly use other elements from the dataset.
-
class
detectron2.data.
DatasetMapper
(*args, **kwargs)[source]¶ Bases:
object
A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See Dataloader for details.
The callable currently does the following:
Read the image from “file_name”
Applies cropping/geometric transforms to the image and annotations
Prepare data and annotations to Tensor and
Instances
-
__init__
(is_train: bool, *, augmentations: List[Union[detectron2.data.transforms.Augmentation, detectron2.data.transforms.Transform]], image_format: str, use_instance_mask: bool = False, use_keypoint: bool = False, instance_mask_format: str = 'polygon', keypoint_hflip_indices: Optional[numpy.ndarray] = None, precomputed_proposal_topk: Optional[int] = None, recompute_boxes: bool = False)[source]¶ NOTE: this interface is experimental.
- Parameters
is_train – whether it’s used in training or inference
augmentations – a list of augmentations or deterministic transforms to apply
image_format – an image format supported by
detection_utils.read_image()
.use_instance_mask – whether to process instance segmentation annotations, if available
use_keypoint – whether to process keypoint annotations if available
instance_mask_format – one of “polygon” or “bitmask”. Process instance segmentation masks into this format.
keypoint_hflip_indices – see
detection_utils.create_keypoint_hflip_indices()
precomputed_proposal_topk – if given, will load pre-computed proposals from dataset_dict and keep the top k proposals for each image.
recompute_boxes – whether to overwrite bounding box annotations by computing tight bounding boxes from instance mask annotations.
detectron2.data.detection_utils module¶
Common data processing utilities that are used in a typical object detection data pipeline.
-
exception
detectron2.data.detection_utils.
SizeMismatchError
[source]¶ Bases:
ValueError
When loaded image has difference width/height compared with annotation.
-
detectron2.data.detection_utils.
convert_image_to_rgb
(image, format)[source]¶ Convert an image from given format to RGB.
- Parameters
image (np.ndarray or Tensor) – an HWC image
format (str) – the format of input image, also see read_image
- Returns
(np.ndarray) – (H,W,3) RGB image in 0-255 range, can be either float or uint8
-
detectron2.data.detection_utils.
check_image_size
(dataset_dict, image)[source]¶ Raise an error if the image does not match the size specified in the dict.
-
detectron2.data.detection_utils.
transform_proposals
(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0)[source]¶ Apply transformations to the proposals in dataset_dict, if any.
- Parameters
dataset_dict (dict) – a dict read from the dataset, possibly contains fields “proposal_boxes”, “proposal_objectness_logits”, “proposal_bbox_mode”
image_shape (tuple) – height, width
transforms (TransformList) –
proposal_topk (int) – only keep top-K scoring proposals
min_box_size (int) – proposals with either side smaller than this threshold are removed
The input dict is modified in-place, with abovementioned keys removed. A new key “proposals” will be added. Its value is an Instances object which contains the transformed proposals in its field “proposal_boxes” and “objectness_logits”.
-
detectron2.data.detection_utils.
transform_instance_annotations
(annotation, transforms, image_size, *, keypoint_hflip_indices=None)[source]¶ Apply transforms to box, segmentation and keypoints annotations of a single instance.
It will use transforms.apply_box for the box, and transforms.apply_coords for segmentation polygons & keypoints. If you need anything more specially designed for each data structure, you’ll need to implement your own version of this function or the transforms.
- Parameters
- Returns
dict – the same input dict with fields “bbox”, “segmentation”, “keypoints” transformed according to transforms. The “bbox_mode” field will be set to XYXY_ABS.
-
detectron2.data.detection_utils.
annotations_to_instances
(annos, image_size, mask_format='polygon')[source]¶ Create an
Instances
object used by the models, from instance annotations in the dataset dict.- Parameters
- Returns
Instances – It will contain fields “gt_boxes”, “gt_classes”, “gt_masks”, “gt_keypoints”, if they can be obtained from annos. This is the format that builtin models expect.
-
detectron2.data.detection_utils.
annotations_to_instances_rotated
(annos, image_size)[source]¶ Create an
Instances
object used by the models, from instance annotations in the dataset dict. Compared to annotations_to_instances, this function is for rotated boxes only
-
detectron2.data.detection_utils.
build_augmentation
(cfg, is_train)[source]¶ Create a list of default
Augmentation
from config. Now it includes resizing and flipping.- Returns
list[Augmentation]
-
detectron2.data.detection_utils.
create_keypoint_hflip_indices
(dataset_names: Union[str, List[str]]) → List[int][source]¶ - Parameters
dataset_names – list of dataset names
- Returns
list[int] – a list of size=#keypoints, storing the horizontally-flipped keypoint indices.
-
detectron2.data.detection_utils.
filter_empty_instances
(instances, by_box=True, by_mask=True, box_threshold=1e-05, return_mask=False)[source]¶ Filter out empty instances in an Instances object.
- Parameters
instances (Instances) –
by_box (bool) – whether to filter out instances with empty boxes
by_mask (bool) – whether to filter out instances with empty masks
box_threshold (float) – minimum width and height to be considered non-empty
return_mask (bool) – whether to return boolean mask of filtered instances
- Returns
Instances – the filtered instances. tensor[bool], optional: boolean mask of filtered instances
detectron2.data.datasets module¶
-
detectron2.data.datasets.
load_coco_json
(json_file, image_root, dataset_name=None, extra_annotation_keys=None)[source]¶ Load a json file with COCO’s instances annotation format. Currently supports instance detection, instance segmentation, and person keypoints annotations.
- Parameters
json_file (str) – full path to the json file in COCO instances annotation format.
image_root (str or path-like) – the directory where the images in this json file exists.
the name of the dataset (e.g., coco_2017_train). When provided, this function will also do the following:
Put “thing_classes” into the metadata associated with this dataset.
Map the category ids into a contiguous range (needed by standard dataset format), and add “thing_dataset_id_to_contiguous_id” to the metadata associated with this dataset.
This option should usually be provided, unless users need to load the original json content and apply more processing manually.
extra_annotation_keys (list[str]) – list of per-annotation keys that should also be loaded into the dataset dict (besides “iscrowd”, “bbox”, “keypoints”, “category_id”, “segmentation”). The values for these keys will be returned as-is. For example, the densepose annotations are loaded in this way.
- Returns
list[dict] – a list of dicts in Detectron2 standard dataset dicts format (See Using Custom Datasets ) when dataset_name is not None. If dataset_name is None, the returned category_ids may be incontiguous and may not conform to the Detectron2 standard format.
Notes
This function does not read the image files. The results do not have the “image” field.
-
detectron2.data.datasets.
load_sem_seg
(gt_root, image_root, gt_ext='png', image_ext='jpg')[source]¶ Load semantic segmentation datasets. All files under “gt_root” with “gt_ext” extension are treated as ground truth annotations and all files under “image_root” with “image_ext” extension as input images. Ground truth and input images are matched using file paths relative to “gt_root” and “image_root” respectively without taking into account file extensions. This works for COCO as well as some other datasets.
- Parameters
gt_root (str) – full path to ground truth semantic segmentation files. Semantic segmentation annotations are stored as images with integer values in pixels that represent corresponding semantic labels.
image_root (str) – the directory where the input images are.
gt_ext (str) – file extension for ground truth annotations.
image_ext (str) – file extension for input images.
- Returns
list[dict] – a list of dicts in detectron2 standard format without instance-level annotation.
Notes
This function does not read the image and ground truth files. The results do not have the “image” and “sem_seg” fields.
-
detectron2.data.datasets.
register_coco_instances
(name, metadata, json_file, image_root)[source]¶ Register a dataset in COCO’s json annotation format for instance detection, instance segmentation and keypoint detection. (i.e., Type 1 and 2 in http://cocodataset.org/#format-data. instances*.json and person_keypoints*.json in the dataset).
This is an example of how to register a new dataset. You can do something similar to this function, to register new datasets.
- Parameters
name (str) – the name that identifies a dataset, e.g. “coco_2014_train”.
metadata (dict) – extra metadata associated with this dataset. You can leave it as an empty dict.
json_file (str) – path to the json instance annotation file.
image_root (str or path-like) – directory which contains all the images.
-
detectron2.data.datasets.
register_coco_panoptic
(name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None)[source]¶ Register a “standard” version of COCO panoptic segmentation dataset named name. The dictionaries in this registered dataset follows detectron2’s standard format. Hence it’s called “standard”.
- Parameters
name (str) – the name that identifies a dataset, e.g. “coco_2017_train_panoptic”
metadata (dict) – extra metadata associated with this dataset.
image_root (str) – directory which contains all the images
panoptic_root (str) – directory which contains panoptic annotation images in COCO format
panoptic_json (str) – path to the json panoptic annotation file in COCO format
sem_seg_root (none) – not used, to be consistent with register_coco_panoptic_separated.
instances_json (str) – path to the json instance annotation file
-
detectron2.data.datasets.
register_coco_panoptic_separated
(name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json)[source]¶ Register a “separated” version of COCO panoptic segmentation dataset named name. The annotations in this registered dataset will contain both instance annotations and semantic annotations, each with its own contiguous ids. Hence it’s called “separated”.
It follows the setting used by the PanopticFPN paper:
The instance annotations directly come from polygons in the COCO instances annotation task, rather than from the masks in the COCO panoptic annotations.
The two format have small differences: Polygons in the instance annotations may have overlaps. The mask annotations are produced by labeling the overlapped polygons with depth ordering.
The semantic annotations are converted from panoptic annotations, where all “things” are assigned a semantic id of 0. All semantic categories will therefore have ids in contiguous range [1, #stuff_categories].
This function will also register a pure semantic segmentation dataset named
name + '_stuffonly'
.- Parameters
name (str) – the name that identifies a dataset, e.g. “coco_2017_train_panoptic”
metadata (dict) – extra metadata associated with this dataset.
image_root (str) – directory which contains all the images
panoptic_root (str) – directory which contains panoptic annotation images
panoptic_json (str) – path to the json panoptic annotation file
sem_seg_root (str) – directory which contains all the ground truth segmentation annotations.
instances_json (str) – path to the json instance annotation file
-
detectron2.data.datasets.
load_lvis_json
(json_file, image_root, dataset_name=None)[source]¶ Load a json file in LVIS’s annotation format.
- Parameters
json_file (str) – full path to the LVIS json annotation file.
image_root (str) – the directory where the images in this json file exists.
dataset_name (str) – the name of the dataset (e.g., “lvis_v0.5_train”). If provided, this function will put “thing_classes” into the metadata associated with this dataset.
- Returns
list[dict] – a list of dicts in Detectron2 standard format. (See Using Custom Datasets )
Notes
This function does not read the image files. The results do not have the “image” field.
-
detectron2.data.datasets.
register_lvis_instances
(name, metadata, json_file, image_root)[source]¶ Register a dataset in LVIS’s json annotation format for instance detection and segmentation.
- Parameters
-
detectron2.data.datasets.
get_lvis_instances_meta
(dataset_name)[source]¶ Load LVIS metadata.
- Parameters
dataset_name (str) – LVIS dataset name without the split name (e.g., “lvis_v0.5”).
- Returns
dict – LVIS metadata with keys: thing_classes
-
detectron2.data.datasets.
load_voc_instances
(dirname: str, split: str, class_names: Union[List[str], Tuple[str, …]])[source]¶ Load Pascal VOC detection annotations to Detectron2 format.
- Parameters
dirname – Contain “Annotations”, “ImageSets”, “JPEGImages”
split (str) – one of “train”, “test”, “val”, “trainval”
class_names – list or tuple of class names
detectron2.data.samplers module¶
-
class
detectron2.data.samplers.
TrainingSampler
(*args, **kwds)[source]¶ Bases:
torch.utils.data.Sampler
In training, we only care about the “infinite stream” of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces indices[worker_id::num_workers] where indices is an infinite stream of indices consisting of shuffle(range(size)) + shuffle(range(size)) + … (if shuffle is True) or range(size) + range(size) + … (if shuffle is False)
-
__init__
(size: int, shuffle: bool = True, seed: Optional[int] = None)[source]¶ - Parameters
size (int) – the total number of data of the underlying dataset to sample from
shuffle (bool) – whether to shuffle the indices or not
seed (int) – the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers).
-
-
class
detectron2.data.samplers.
InferenceSampler
(*args, **kwds)[source]¶ Bases:
torch.utils.data.Sampler
Produce indices for inference across all workers. Inference needs to run on the __exact__ set of samples, therefore when the total number of samples is not divisible by the number of workers, this sampler produces different number of samples on different workers.
-
class
detectron2.data.samplers.
RepeatFactorTrainingSampler
(*args, **kwds)[source]¶ Bases:
torch.utils.data.Sampler
Similar to TrainingSampler, but a sample may appear more times than others based on its “repeat factor”. This is suitable for training on class imbalanced datasets like LVIS.
-
__init__
(repeat_factors, *, shuffle=True, seed=None)[source]¶ - Parameters
repeat_factors (Tensor) – a float vector, the repeat factor for each indice. When it’s full of ones, it is equivalent to
TrainingSampler(len(repeat_factors), ...)
.shuffle (bool) – whether to shuffle the indices or not
seed (int) – the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers).
-
static
repeat_factors_from_category_frequency
(dataset_dicts, repeat_thresh)[source]¶ Compute (fractional) per-image repeat factors based on category frequency. The repeat factor for an image is a function of the frequency of the rarest category labeled in that image. The “frequency of category c” in [0, 1] is defined as the fraction of images in the training set (without repeats) in which category c appears. See LVIS: A Dataset for Large Vocabulary Instance Segmentation (>= v2) Appendix B.2.
- Parameters
- Returns
torch.Tensor – the i-th element is the repeat factor for the dataset image at index i.
-
detectron2.data.transforms¶
Related tutorial: Data Augmentation.
-
class
detectron2.data.transforms.
Transform
¶ Bases:
object
Base class for implementations of deterministic transformations for image and other data structures. “Deterministic” requires that the output of all methods of this class are deterministic w.r.t their input arguments. Note that this is different from (random) data augmentations. To perform data augmentations in training, there should be a higher-level policy that generates these transform ops.
Each transform op may handle several data types, e.g.: image, coordinates, segmentation, bounding boxes, with its
apply_*
methods. Some of them have a default implementation, but can be overwritten if the default isn’t appropriate. See documentation of each pre-definedapply_*
methods for details. Note that The implementation of these method may choose to modify its input data in-place for efficient transformation.The class can be extended to support arbitrary new data types with its
register_type()
method.-
__repr__
()¶ Produce something like: “MyTransform(field1={self.field1}, field2={self.field2})”
-
apply_box
(box: numpy.ndarray) → numpy.ndarray¶ Apply the transform on an axis-aligned box. By default will transform the corner points and use their minimum/maximum to create a new axis-aligned box. Note that this default may change the size of your box, e.g. after rotations.
- Parameters
box (ndarray) – Nx4 floating point array of XYXY format in absolute coordinates.
- Returns
ndarray – box after apply the transformation.
Note
The coordinates are not pixel indices. Coordinates inside an image of shape (H, W) are in range [0, W] or [0, H].
This function does not clip boxes to force them inside the image. It is up to the application that uses the boxes to decide.
-
abstract
apply_coords
(coords: numpy.ndarray)¶ Apply the transform on coordinates.
- Parameters
coords (ndarray) – floating point array of shape Nx2. Each row is (x, y).
- Returns
ndarray – coordinates after apply the transformation.
Note
The coordinates are not pixel indices. Coordinates inside an image of shape (H, W) are in range [0, W] or [0, H]. This function should correctly transform coordinates outside the image as well.
-
abstract
apply_image
(img: numpy.ndarray)¶ Apply the transform on an image.
- Parameters
img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- Returns
ndarray – image after apply the transformation.
-
apply_polygons
(polygons: list) → list¶ Apply the transform on a list of polygons, each represented by a Nx2 array. By default will just transform all the points.
- Parameters
polygon (list[ndarray]) – each is a Nx2 floating point array of (x, y) format in absolute coordinates.
- Returns
list[ndarray] – polygon after apply the transformation.
Note
The coordinates are not pixel indices. Coordinates on an image of shape (H, W) are in range [0, W] or [0, H].
-
apply_segmentation
(segmentation: numpy.ndarray) → numpy.ndarray¶ Apply the transform on a full-image segmentation. By default will just perform “apply_image”.
- Parameters
segmentation (ndarray) – of shape HxW. The array should have integer
bool dtype. (or) –
- Returns
ndarray – segmentation after apply the transformation.
-
inverse
() → detectron2.data.transforms.Transform¶ Create a transform that inverts the geometric changes (i.e. change of coordinates) of this transform.
Note that the inverse is meant for geometric changes only. The inverse of photometric transforms that do not change coordinates is defined to be a no-op, even if they may be invertible.
- Returns
Transform
-
classmethod
register_type
(data_type: str, func: Optional[Callable] = None)[source]¶ Register the given function as a handler that this transform will use for a specific data type.
- Parameters
data_type (str) – the name of the data type (e.g., box)
func (callable) – takes a transform and a data, returns the transformed data.
Examples:
# call it directly def func(flip_transform, voxel_data): return transformed_voxel_data HFlipTransform.register_type("voxel", func) # or, use it as a decorator @HFlipTransform.register_type("voxel") def func(flip_transform, voxel_data): return transformed_voxel_data # ... transform = HFlipTransform(...) transform.apply_voxel(voxel_data) # func will be called
-
-
class
detectron2.data.transforms.
TransformList
(transforms: List[detectron2.data.transforms.Transform])¶ Bases:
detectron2.data.transforms.Transform
Maintain a list of transform operations which will be applied in sequence. .. attribute:: transforms
- type
list[Transform]
-
__add__
(other: detectron2.data.transforms.TransformList) → detectron2.data.transforms.TransformList¶ - Parameters
other (TransformList) – transformation to add.
- Returns
TransformList – list of transforms.
-
__iadd__
(other: detectron2.data.transforms.TransformList) → detectron2.data.transforms.TransformList¶ - Parameters
other (TransformList) – transformation to add.
- Returns
TransformList – list of transforms.
-
__init__
(transforms: List[detectron2.data.transforms.Transform])¶
-
__radd__
(other: detectron2.data.transforms.TransformList) → detectron2.data.transforms.TransformList¶ - Parameters
other (TransformList) – transformation to add.
- Returns
TransformList – list of transforms.
-
apply_coords
(x)¶
-
apply_image
(x)¶
-
inverse
() → detectron2.data.transforms.TransformList¶ Invert each transform in reversed order.
-
class
detectron2.data.transforms.
BlendTransform
(src_image: numpy.ndarray, src_weight: float, dst_weight: float)¶ Bases:
detectron2.data.transforms.Transform
Transforms pixel colors with PIL enhance functions.
-
__init__
(src_image: numpy.ndarray, src_weight: float, dst_weight: float)¶ Blends the input image (dst_image) with the src_image using formula:
src_weight * src_image + dst_weight * dst_image
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶ Apply no transform on the coordinates.
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray¶ Apply blend transform on the image(s).
- Parameters
img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
interp (str) – keep this option for consistency, perform blend would not require interpolation.
- Returns
ndarray – blended image(s).
-
apply_segmentation
(segmentation: numpy.ndarray) → numpy.ndarray¶ Apply no transform on the full-image segmentation.
-
inverse
() → detectron2.data.transforms.Transform¶ The inverse is a no-op.
-
-
class
detectron2.data.transforms.
CropTransform
(x0: int, y0: int, w: int, h: int, orig_w: Optional[int] = None, orig_h: Optional[int] = None)¶ Bases:
detectron2.data.transforms.Transform
-
__init__
(x0: int, y0: int, w: int, h: int, orig_w: Optional[int] = None, orig_h: Optional[int] = None)¶ - Parameters
x0 (int) – crop the image(s) by img[y0:y0+h, x0:x0+w].
y0 (int) – crop the image(s) by img[y0:y0+h, x0:x0+w].
w (int) – crop the image(s) by img[y0:y0+h, x0:x0+w].
h (int) – crop the image(s) by img[y0:y0+h, x0:x0+w].
orig_w (int) – optional, the original width and height before cropping. Needed to make this transform invertible.
orig_h (int) – optional, the original width and height before cropping. Needed to make this transform invertible.
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶ Apply crop transform on coordinates.
- Parameters
coords (ndarray) – floating point array of shape Nx2. Each row is (x, y).
- Returns
ndarray – cropped coordinates.
-
apply_image
(img: numpy.ndarray) → numpy.ndarray¶ Crop the image(s).
- Parameters
img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- Returns
ndarray – cropped image(s).
-
apply_polygons
(polygons: list) → list¶ Apply crop transform on a list of polygons, each represented by a Nx2 array. It will crop the polygon with the box, therefore the number of points in the polygon might change.
- Parameters
polygon (list[ndarray]) – each is a Nx2 floating point array of (x, y) format in absolute coordinates.
- Returns
ndarray – cropped polygons.
-
inverse
() → detectron2.data.transforms.Transform¶
-
-
class
detectron2.data.transforms.
PadTransform
(x0: int, y0: int, x1: int, y1: int, orig_w: Optional[int] = None, orig_h: Optional[int] = None, pad_value: float = 0)¶ Bases:
detectron2.data.transforms.Transform
-
__init__
(x0: int, y0: int, x1: int, y1: int, orig_w: Optional[int] = None, orig_h: Optional[int] = None, pad_value: float = 0)¶ - Parameters
x0 – number of padded pixels on the left and top
y0 – number of padded pixels on the left and top
x1 – number of padded pixels on the right and bottom
y1 – number of padded pixels on the right and bottom
orig_w – optional, original width and height. Needed to make this transform invertible.
orig_h – optional, original width and height. Needed to make this transform invertible.
pad_value – the padding value
-
apply_coords
(coords)¶
-
apply_image
(img)¶
-
inverse
() → detectron2.data.transforms.Transform¶
-
-
class
detectron2.data.transforms.
GridSampleTransform
(grid: numpy.ndarray, interp: str)¶ Bases:
detectron2.data.transforms.Transform
-
__init__
(grid: numpy.ndarray, interp: str)¶ - Parameters
grid (ndarray) – grid has x and y input pixel locations which are used to compute output. Grid has values in the range of [-1, 1], which is normalized by the input height and width. The dimension is N x H x W x 2.
interp (str) – interpolation methods. Options include nearest and bilinear.
-
apply_coords
(coords: numpy.ndarray)¶ Not supported.
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray¶ Apply grid sampling on the image(s).
- Parameters
img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
interp (str) – interpolation methods. Options include nearest and bilinear.
- Returns
ndarray – grid sampled image(s).
-
apply_segmentation
(segmentation: numpy.ndarray) → numpy.ndarray¶ Apply grid sampling on the full-image segmentation.
- Parameters
segmentation (ndarray) – of shape HxW. The array should have integer or bool dtype.
- Returns
ndarray – grid sampled segmentation.
-
-
class
detectron2.data.transforms.
HFlipTransform
(width: int)¶ Bases:
detectron2.data.transforms.Transform
Perform horizontal flip.
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶ Flip the coordinates.
- Parameters
coords (ndarray) – floating point array of shape Nx2. Each row is (x, y).
- Returns
ndarray – the flipped coordinates.
Note
The inputs are floating point coordinates, not pixel indices. Therefore they are flipped by (W - x, H - y), not (W - 1 - x, H - 1 - y).
-
apply_image
(img: numpy.ndarray) → numpy.ndarray¶ Flip the image(s).
- Parameters
img (ndarray) – of shape HxW, HxWxC, or NxHxWxC. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- Returns
ndarray – the flipped image(s).
-
apply_rotated_box
(rotated_boxes)¶ Apply the horizontal flip transform on rotated boxes.
- Parameters
rotated_boxes (ndarray) – Nx5 floating point array of (x_center, y_center, width, height, angle_degrees) format in absolute coordinates.
-
inverse
() → detectron2.data.transforms.Transform¶ The inverse is to flip again
-
-
class
detectron2.data.transforms.
VFlipTransform
(height: int)¶ Bases:
detectron2.data.transforms.Transform
Perform vertical flip.
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶ Flip the coordinates.
- Parameters
coords (ndarray) – floating point array of shape Nx2. Each row is (x, y).
- Returns
ndarray – the flipped coordinates.
Note
The inputs are floating point coordinates, not pixel indices. Therefore they are flipped by (W - x, H - y), not (W - 1 - x, H - 1 - y).
-
apply_image
(img: numpy.ndarray) → numpy.ndarray¶ Flip the image(s).
- Parameters
img (ndarray) – of shape HxW, HxWxC, or NxHxWxC. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- Returns
ndarray – the flipped image(s).
-
inverse
() → detectron2.data.transforms.Transform¶ The inverse is to flip again
-
-
class
detectron2.data.transforms.
NoOpTransform
¶ Bases:
detectron2.data.transforms.Transform
A transform that does nothing.
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶
-
apply_image
(img: numpy.ndarray) → numpy.ndarray¶
-
apply_rotated_box
(x)¶
-
inverse
() → detectron2.data.transforms.Transform¶
-
-
class
detectron2.data.transforms.
ScaleTransform
(h: int, w: int, new_h: int, new_w: int, interp: str = None)¶ Bases:
detectron2.data.transforms.Transform
Resize the image to a target size.
-
__init__
(h: int, w: int, new_h: int, new_w: int, interp: str = None)¶ - Parameters
h (int) – original image size.
w (int) – original image size.
new_h (int) – new image size.
new_w (int) – new image size.
interp (str) – interpolation methods. Options includes nearest, linear (3D-only), bilinear, bicubic (4D-only), and area. Details can be found in: https://pytorch.org/docs/stable/nn.functional.html
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray¶ Compute the coordinates after resize.
- Parameters
coords (ndarray) – floating point array of shape Nx2. Each row is (x, y).
- Returns
ndarray – resized coordinates.
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray¶ Resize the image(s).
- Parameters
img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
interp (str) – interpolation methods. Options includes nearest, linear (3D-only), bilinear, bicubic (4D-only), and area. Details can be found in: https://pytorch.org/docs/stable/nn.functional.html
- Returns
ndarray – resized image(s).
-
apply_segmentation
(segmentation: numpy.ndarray) → numpy.ndarray¶ Apply resize on the full-image segmentation.
- Parameters
segmentation (ndarray) – of shape HxW. The array should have integer or bool dtype.
- Returns
ndarray – resized segmentation.
-
inverse
() → detectron2.data.transforms.Transform¶ The inverse is to resize it back.
-
-
class
detectron2.data.transforms.
ExtentTransform
(src_rect, output_size, interp=2, fill=0)¶ Bases:
detectron2.data.transforms.Transform
Extracts a subregion from the source image and scales it to the output size.
The fill color is used to map pixels from the source rect that fall outside the source image.
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
-
__init__
(src_rect, output_size, interp=2, fill=0)¶ - Parameters
src_rect (x0, y0, x1, y1) – src coordinates
output_size (h, w) – dst image size
interp – PIL interpolation methods
fill – Fill color used when src_rect extends outside image
-
apply_coords
(coords)¶
-
apply_image
(img, interp=None)¶
-
apply_segmentation
(segmentation)¶
-
-
class
detectron2.data.transforms.
ResizeTransform
(h, w, new_h, new_w, interp=None)¶ Bases:
detectron2.data.transforms.Transform
Resize the image to a target size.
-
__init__
(h, w, new_h, new_w, interp=None)¶
-
apply_coords
(coords)¶
-
apply_image
(img, interp=None)¶
-
apply_rotated_box
(rotated_boxes)¶ Apply the resizing transform on rotated boxes. For details of how these (approximation) formulas are derived, please refer to
RotatedBoxes.scale()
.- Parameters
rotated_boxes (ndarray) – Nx5 floating point array of (x_center, y_center, width, height, angle_degrees) format in absolute coordinates.
-
apply_segmentation
(segmentation)¶
-
inverse
()¶
-
-
class
detectron2.data.transforms.
RotationTransform
(h, w, angle, expand=True, center=None, interp=None)¶ Bases:
detectron2.data.transforms.Transform
This method returns a copy of this image, rotated the given number of degrees counter clockwise around its center.
-
__init__
(h, w, angle, expand=True, center=None, interp=None)¶ - Parameters
h (int) – original image size
w (int) – original image size
angle (float) – degrees for rotation
expand (bool) – choose if the image should be resized to fit the whole rotated image (default), or simply cropped
center (tuple (width, height)) – coordinates of the rotation center if left to None, the center will be fit to the center of each image center has no effect if expand=True because it only affects shifting
interp – cv2 interpolation method, default cv2.INTER_LINEAR
-
apply_coords
(coords)¶ coords should be a N * 2 array-like, containing N couples of (x, y) points
-
apply_image
(img, interp=None)¶ img should be a numpy array, formatted as Height * Width * Nchannels
-
apply_segmentation
(segmentation)¶
-
create_rotation_matrix
(offset=0)¶
-
inverse
()¶ The inverse is to rotate it back with expand, and crop to get the original shape.
-
-
class
detectron2.data.transforms.
ColorTransform
(op)¶ Bases:
detectron2.data.transforms.Transform
Generic wrapper for any photometric transforms. These transformations should only affect the color space and
not the coordinate space of the image (e.g. annotation coordinates such as bounding boxes should not be changed)
-
__init__
(op)¶ - Parameters
op (Callable) – operation to be applied to the image, which takes in an ndarray and returns an ndarray.
-
apply_coords
(coords)¶
-
apply_image
(img)¶
-
apply_segmentation
(segmentation)¶
-
inverse
()¶
-
-
class
detectron2.data.transforms.
PILColorTransform
(op)¶ Bases:
detectron2.data.transforms.ColorTransform
- Generic wrapper for PIL Photometric image transforms,
which affect the color space and not the coordinate space of the image
-
__init__
(op)¶ - Parameters
op (Callable) – operation to be applied to the image, which takes in a PIL Image and returns a transformed PIL Image. For reference on possible operations see: - https://pillow.readthedocs.io/en/stable/
-
apply_image
(img)¶
-
class
detectron2.data.transforms.
Augmentation
¶ Bases:
object
Augmentation defines (often random) policies/strategies to generate
Transform
from data. It is often used for pre-processing of input data.A “policy” that generates a
Transform
may, in the most general case, need arbitrary information from input data in order to determine what transforms to apply. Therefore, eachAugmentation
instance defines the arguments needed by itsget_transform()
method. When called with the positional arguments, theget_transform()
method executes the policy.Note that
Augmentation
defines the policies to create aTransform
, but not how to execute the actual transform operations to those data. Its__call__()
method will useAugInput.transform()
to execute the transform.The returned Transform object is meant to describe deterministic transformation, which means it can be re-applied on associated data, e.g. the geometry of an image and its segmentation masks need to be transformed together. (If such re-application is not needed, then determinism is not a crucial requirement.)
-
__call__
(aug_input) → detectron2.data.transforms.Transform¶ Augment the given aug_input in-place, and return the transform that’s used.
This method will be called to apply the augmentation. In most augmentation, it is enough to use the default implementation, which calls
get_transform()
using the inputs. But a subclass can overwrite it to have more complicated logic.- Parameters
aug_input (AugInput) – an object that has attributes needed by this augmentation (defined by
self.get_transform
). Itstransform
method will be called to in-place transform it.- Returns
Transform – the transform that is applied on the input.
-
__repr__
()¶ Produce something like: “MyAugmentation(field1={self.field1}, field2={self.field2})”
-
__str__
()¶ Produce something like: “MyAugmentation(field1={self.field1}, field2={self.field2})”
-
get_transform
(*args) → detectron2.data.transforms.Transform¶ Execute the policy based on input data, and decide what transform to apply to inputs.
- Parameters
args – Any fixed-length positional arguments. By default, the name of the arguments should exist in the
AugInput
to be used.- Returns
Transform – Returns the deterministic transform to apply to the input.
Examples:
class MyAug: # if a policy needs to know both image and semantic segmentation def get_transform(image, sem_seg) -> T.Transform: pass tfm: Transform = MyAug().get_transform(image, sem_seg) new_image = tfm.apply_image(image)
Notes
Users can freely use arbitrary new argument names in custom
get_transform()
method, as long as they are available in the input data. In detectron2 we use the following convention:image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
boxes: (N,4) ndarray of float32. It represents the instance bounding boxes of N instances. Each is in XYXY format in unit of absolute coordinates.
sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
We do not specify convention for other types and do not include builtin
Augmentation
that uses other types in detectron2.
-
-
class
detectron2.data.transforms.
AugmentationList
(augs)¶ Bases:
detectron2.data.transforms.Augmentation
Apply a sequence of augmentations.
It has
__call__
method to apply the augmentations.Note that
get_transform()
method is impossible (will throw error if called) forAugmentationList
, because in order to apply a sequence of augmentations, the kth augmentation must be applied first, to provide inputs needed by the (k+1)th augmentation.-
__init__
(augs)¶ - Parameters
augs (list[Augmentation or Transform]) –
-
-
class
detectron2.data.transforms.
AugInput
(image: numpy.ndarray, *, boxes: Optional[numpy.ndarray] = None, sem_seg: Optional[numpy.ndarray] = None)¶ Bases:
object
Input that can be used with
Augmentation.__call__()
. This is a standard implementation for the majority of use cases. This class provides the standard attributes “image”, “boxes”, “sem_seg” defined in__init__()
and they may be needed by different augmentations. Most augmentation policies do not need attributes beyond these three.After applying augmentations to these attributes (using
AugInput.transform()
), the returned transforms can then be used to transform other data structures that users have.Examples:
input = AugInput(image, boxes=boxes) tfms = augmentation(input) transformed_image = input.image transformed_boxes = input.boxes transformed_other_data = tfms.apply_other(other_data)
An extended project that works with new data types may implement augmentation policies that need other inputs. An algorithm may need to transform inputs in a way different from the standard approach defined in this class. In those rare situations, users can implement a class similar to this class, that satify the following condition:
The input must provide access to these data in the form of attribute access (
getattr
). For example, if anAugmentation
to be applied needs “image” and “sem_seg” arguments, its input must have the attribute “image” and “sem_seg”.The input must have a
transform(tfm: Transform) -> None
method which in-place transforms all its attributes.
-
__init__
(image: numpy.ndarray, *, boxes: Optional[numpy.ndarray] = None, sem_seg: Optional[numpy.ndarray] = None)¶ - Parameters
image (ndarray) – (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255]. The meaning of C is up to users.
boxes (ndarray or None) – Nx4 float32 boxes in XYXY_ABS mode
sem_seg (ndarray or None) – HxW uint8 semantic segmentation mask. Each element is an integer label of pixel.
-
transform
(tfm: detectron2.data.transforms.Transform) → None¶ In-place transform all attributes of this class.
By “in-place”, it means after calling this method, accessing an attribute such as
self.image
will return transformed data.
-
class
detectron2.data.transforms.
FixedSizeCrop
(crop_size: Tuple[int], pad_value: float = 128.0)¶ Bases:
detectron2.data.transforms.Augmentation
If crop_size is smaller than the input image size, then it uses a random crop of the crop size. If crop_size is larger than the input image size, then it pads the right and the bottom of the image to the crop size.
-
__init__
(crop_size: Tuple[int], pad_value: float = 128.0)¶ - Parameters
crop_size – target image (height, width).
pad_value – the padding value.
-
get_transform
(image: numpy.ndarray) → detectron2.data.transforms.TransformList¶
-
-
class
detectron2.data.transforms.
RandomApply
(tfm_or_aug, prob=0.5)¶ Bases:
detectron2.data.transforms.Augmentation
Randomly apply an augmentation with a given probability.
-
__init__
(tfm_or_aug, prob=0.5)¶ - Parameters
tfm_or_aug (Transform, Augmentation) – the transform or augmentation to be applied. It can either be a Transform or Augmentation instance.
prob (float) – probability between 0.0 and 1.0 that the wrapper transformation is applied
-
get_transform
(*args)¶
-
-
class
detectron2.data.transforms.
RandomBrightness
(intensity_min, intensity_max)¶ Bases:
detectron2.data.transforms.Augmentation
Randomly transforms image brightness.
Brightness intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce brightness - intensity = 1 will preserve the input image - intensity > 1 will increase brightness
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
__init__
(intensity_min, intensity_max)¶
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomContrast
(intensity_min, intensity_max)¶ Bases:
detectron2.data.transforms.Augmentation
Randomly transforms image contrast.
Contrast intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce contrast - intensity = 1 will preserve the input image - intensity > 1 will increase contrast
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
__init__
(intensity_min, intensity_max)¶
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomCrop
(crop_type: str, crop_size)¶ Bases:
detectron2.data.transforms.Augmentation
Randomly crop a rectangle region out of an image.
-
__init__
(crop_type: str, crop_size)¶ - Parameters
“relative”: crop a (H * crop_size[0], W * crop_size[1]) region from an input image of size (H, W). crop size should be in (0, 1]
“relative_range”: uniformly sample two values from [crop_size[0], 1] and [crop_size[1]], 1], and use them as in “relative” crop type.
“absolute” crop a (crop_size[0], crop_size[1]) region from input image. crop_size must be smaller than the input image size.
“absolute_range”, for an input of size (H, W), uniformly sample H_crop in [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])]. Then crop a region (H_crop, W_crop).
-
get_crop_size
(image_size)¶ - Parameters
image_size (tuple) – height, width
- Returns
crop_size (tuple) – height, width in absolute pixels
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomExtent
(scale_range, shift_range)¶ Bases:
detectron2.data.transforms.Augmentation
Outputs an image by cropping a random “subrect” of the source image.
The subrect can be parameterized to include pixels outside the source image, in which case they will be set to zeros (i.e. black). The size of the output image will vary with the size of the random subrect.
-
__init__
(scale_range, shift_range)¶ - Parameters
output_size (h, w) – Dimensions of output image
scale_range (l, h) – Range of input-to-output size scaling factor
shift_range (x, y) – Range of shifts of the cropped subrect. The rect is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)], where (w, h) is the (width, height) of the input image. Set each component to zero to crop at the image’s center.
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomFlip
(prob=0.5, *, horizontal=True, vertical=False)¶ Bases:
detectron2.data.transforms.Augmentation
Flip the image horizontally or vertically with the given probability.
-
__init__
(prob=0.5, *, horizontal=True, vertical=False)¶ - Parameters
prob (float) – probability of flip.
horizontal (boolean) – whether to apply horizontal flipping
vertical (boolean) – whether to apply vertical flipping
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomSaturation
(intensity_min, intensity_max)¶ Bases:
detectron2.data.transforms.Augmentation
Randomly transforms saturation of an RGB image. Input images are assumed to have ‘RGB’ channel order.
Saturation intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce saturation (make the image more grayscale) - intensity = 1 will preserve the input image - intensity > 1 will increase saturation
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
__init__
(intensity_min, intensity_max)¶
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomLighting
(scale)¶ Bases:
detectron2.data.transforms.Augmentation
The “lighting” augmentation described in AlexNet, using fixed PCA over ImageNet. Input images are assumed to have ‘RGB’ channel order.
The degree of color jittering is randomly sampled via a normal distribution, with standard deviation given by the scale parameter.
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomRotation
(angle, expand=True, center=None, sample_style='range', interp=None)¶ Bases:
detectron2.data.transforms.Augmentation
This method returns a copy of this image, rotated the given number of degrees counter clockwise around the given center.
-
__init__
(angle, expand=True, center=None, sample_style='range', interp=None)¶ - Parameters
angle (list[float]) – If
sample_style=="range"
, a [min, max] interval from which to sample the angle (in degrees). Ifsample_style=="choice"
, a list of angles to sample fromexpand (bool) – choose if the image should be resized to fit the whole rotated image (default), or simply cropped
center (list[[float, float]]) – If
sample_style=="range"
, a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center, [0, 0] being the top left of the image and [1, 1] the bottom right. Ifsample_style=="choice"
, a list of centers to sample from Default: None, which means that the center of rotation is the center of the image center has no effect if expand=True because it only affects shifting
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
Resize
(shape, interp=2)¶ Bases:
detectron2.data.transforms.Augmentation
Resize image to a fixed target size
-
__init__
(shape, interp=2)¶ - Parameters
shape – (h, w) tuple or a int
interp – PIL interpolation method
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
ResizeScale
(min_scale: float, max_scale: float, target_height: int, target_width: int, interp: int = 2)¶ Bases:
detectron2.data.transforms.Augmentation
Takes target size as input and randomly scales the given target size between min_scale and max_scale. It then scales the input image such that it fits inside the scaled target box, keeping the aspect ratio constant. This implements the resize part of the Google’s ‘resize_and_crop’ data augmentation: https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
-
__init__
(min_scale: float, max_scale: float, target_height: int, target_width: int, interp: int = 2)¶ - Parameters
min_scale – minimum image scale range.
max_scale – maximum image scale range.
target_height – target image height.
target_width – target image width.
interp – image interpolation method.
-
get_transform
(image: numpy.ndarray) → detectron2.data.transforms.Transform¶
-
-
class
detectron2.data.transforms.
ResizeShortestEdge
(short_edge_length, max_size=9223372036854775807, sample_style='range', interp=2)¶ Bases:
detectron2.data.transforms.Augmentation
Scale the shorter edge to the given size, with a limit of max_size on the longer edge. If max_size is reached, then downscale so that the longer edge does not exceed max_size.
-
__init__
(short_edge_length, max_size=9223372036854775807, sample_style='range', interp=2)¶ - Parameters
short_edge_length (list[int]) – If
sample_style=="range"
, a [min, max] interval from which to sample the shortest edge length. Ifsample_style=="choice"
, a list of shortest edge lengths to sample from.max_size (int) – maximum allowed longest edge length.
sample_style (str) – either “range” or “choice”.
-
get_transform
(image)¶
-
-
class
detectron2.data.transforms.
RandomCrop_CategoryAreaConstraint
(crop_type: str, crop_size, single_category_max_area: float = 1.0, ignored_category: int = None)¶ Bases:
detectron2.data.transforms.Augmentation
Similar to
RandomCrop
, but find a cropping window such that no single category occupies a ratio of more than single_category_max_area in semantic segmentation ground truth, which can cause unstability in training. The function attempts to find such a valid cropping window for at most 10 times.-
__init__
(crop_type: str, crop_size, single_category_max_area: float = 1.0, ignored_category: int = None)¶ - Parameters
crop_type – same as in
RandomCrop
crop_size – same as in
RandomCrop
single_category_max_area – the maximum allowed area ratio of a category. Set to 1.0 to disable
ignored_category – allow this category in the semantic segmentation ground truth to exceed the area ratio. Usually set to the category that’s ignored in training.
-
get_transform
(image, sem_seg)¶
-
detectron2.engine¶
Related tutorial: Training.
-
detectron2.engine.
launch
(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=(), timeout=datetime.timedelta(seconds=1800))[source]¶ Launch multi-gpu or distributed training. This function must be called on all machines involved in the training. It will spawn child processes (defined by
num_gpus_per_machine
) on each machine.- Parameters
main_func – a function that will be called by main_func(*args)
num_gpus_per_machine (int) – number of GPUs per machine
num_machines (int) – the total number of machines
machine_rank (int) – the rank of this machine
dist_url (str) – url to connect to for distributed jobs, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to “auto” to automatically select a free port on localhost
timeout (timedelta) – timeout of the distributed workers
args (tuple) – arguments passed to main_func
-
class
detectron2.engine.
HookBase
[source]¶ Bases:
object
Base class for hooks that can be registered with
TrainerBase
.Each hook can implement 4 methods. The way they are called is demonstrated in the following snippet:
hook.before_train() for iter in range(start_iter, max_iter): hook.before_step() trainer.run_step() hook.after_step() iter += 1 hook.after_train()
Notes
In the hook method, users can access
self.trainer
to access more properties about the context (e.g., model, current iteration, or config if usingDefaultTrainer
).A hook that does something in
before_step()
can often be implemented equivalently inafter_step()
. If the hook takes non-trivial time, it is strongly recommended to implement the hook inafter_step()
instead ofbefore_step()
. The convention is thatbefore_step()
should only take negligible time.Following this convention will allow hooks that do care about the difference between
before_step()
andafter_step()
(e.g., timer) to function properly.
-
trainer
: detectron2.engine.train_loop.TrainerBase = None¶ A weak reference to the trainer object. Set by the trainer when the hook is registered.
-
class
detectron2.engine.
TrainerBase
[source]¶ Bases:
object
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop. A subclass can implement what the loop is. We made no assumptions about the existence of dataloader, optimizer, model, etc.
-
storage
¶ An EventStorage that’s opened during the course of training.
- Type
-
register_hooks
(hooks: List[Optional[detectron2.engine.train_loop.HookBase]]) → None[source]¶ Register hooks to the trainer. The hooks are executed in the order they are registered.
-
-
class
detectron2.engine.
SimpleTrainer
(model, data_loader, optimizer)[source]¶ Bases:
detectron2.engine.train_loop.TrainerBase
A simple trainer for the most common type of task: single-cost single-optimizer single-data-source iterative optimization, optionally using data-parallelism. It assumes that every step, you:
Compute the loss with a data from the data_loader.
Compute the gradients with the above loss.
Update the model with the optimizer.
All other tasks during training (checkpointing, logging, evaluation, LR schedule) are maintained by hooks, which can be registered by
TrainerBase.register_hooks()
.If you want to do anything fancier than this, either subclass TrainerBase and implement your own run_step, or write your own training loop.
-
__init__
(model, data_loader, optimizer)[source]¶ - Parameters
model – a torch Module. Takes a data from data_loader and returns a dict of losses.
data_loader – an iterable. Contains data to be used to call model.
optimizer – a torch optimizer.
-
class
detectron2.engine.
AMPTrainer
(model, data_loader, optimizer, grad_scaler=None)[source]¶ Bases:
detectron2.engine.train_loop.SimpleTrainer
Like
SimpleTrainer
, but uses PyTorch’s native automatic mixed precision in the training loop.-
__init__
(model, data_loader, optimizer, grad_scaler=None)[source]¶ - Parameters
model – same as in
SimpleTrainer
.data_loader – same as in
SimpleTrainer
.optimizer – same as in
SimpleTrainer
.grad_scaler – torch GradScaler to automatically scale gradients.
-
detectron2.engine.defaults module¶
This file contains components with some default boilerplate logic user may need in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change, since they are meant to represent the “common default behavior” people need in their projects.
-
detectron2.engine.defaults.
create_ddp_model
(model, *, fp16_compression=False, **kwargs)[source]¶ Create a DistributedDataParallel model if there are >1 processes.
- Parameters
model – a torch.nn.Module
fp16_compression – add fp16 compression hooks to the ddp object. See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs – other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
-
detectron2.engine.defaults.
default_argument_parser
(epilog=None)[source]¶ Create a parser with some common arguments used by detectron2 users.
- Parameters
epilog (str) – epilog passed to ArgumentParser describing the usage.
- Returns
argparse.ArgumentParser
-
detectron2.engine.defaults.
default_setup
(cfg, args)[source]¶ Perform some basic common setups at the beginning of a job, including:
Set up the detectron2 logger
Log basic information about environment, cmdline arguments, and config
Backup the config to the output directory
- Parameters
cfg (CfgNode or omegaconf.DictConfig) – the full config to be used
args (argparse.NameSpace) – the command line arguments to be logged
-
detectron2.engine.defaults.
default_writers
(output_dir: str, max_iter: Optional[int] = None)[source]¶ Build a list of
EventWriter
to be used. It now consists of aCommonMetricPrinter
,TensorboardXWriter
andJSONWriter
.- Parameters
output_dir – directory to store JSON metrics and tensorboard events
max_iter – the total number of iterations
- Returns
list[EventWriter] – a list of
EventWriter
objects.
-
class
detectron2.engine.defaults.
DefaultPredictor
(cfg)[source]¶ Bases:
object
Create a simple end-to-end predictor with the given config that runs on single device for a single input image.
Compared to using the model directly, this class does the following additions:
Load checkpoint from cfg.MODEL.WEIGHTS.
Always take BGR image as the input and apply conversion defined by cfg.INPUT.FORMAT.
Apply resizing defined by cfg.INPUT.{MIN,MAX}_SIZE_TEST.
Take one input image and produce a single output, instead of a batch.
This is meant for simple demo purposes, so it does the above steps automatically. This is not meant for benchmarks or running complicated inference logic. If you’d like to do anything more complicated, please refer to its source code as examples to build and use the model manually.
Examples:
pred = DefaultPredictor(cfg) inputs = cv2.imread("input.jpg") outputs = pred(inputs)
-
__call__
(original_image)[source]¶ - Parameters
original_image (np.ndarray) – an image of shape (H, W, C) (in BGR order).
- Returns
predictions (dict) – the output of the model for one image only. See Use Models for details about the format.
-
class
detectron2.engine.defaults.
DefaultTrainer
(cfg)[source]¶ Bases:
detectron2.engine.train_loop.TrainerBase
A trainer with default training logic. It does the following:
Create a
SimpleTrainer
using model, optimizer, dataloader defined by the given config. Create a LR scheduler defined by the config.Load the last checkpoint or cfg.MODEL.WEIGHTS, if exists, when resume_or_load is called.
Register a few common hooks defined by the config.
It is created to simplify the standard model training workflow and reduce code boilerplate for users who only need the standard training workflow, with standard features. It means this class makes many assumptions about your training logic that may easily become invalid in a new research. In fact, any assumptions beyond those made in the
SimpleTrainer
are too much for research.The code of this class has been annotated about restrictive assumptions it makes. When they do not work for you, you’re encouraged to:
Overwrite methods of this class, OR:
Use
SimpleTrainer
, which only does minimal SGD training and nothing else. You can then add your own hooks if needed. OR:Write your own training loop similar to tools/plain_train_net.py.
See the Training tutorials for more details.
Note that the behavior of this class, like other functions/classes in this file, is not stable, since it is meant to represent the “common default behavior”. It is only guaranteed to work well with the standard models and training workflow in detectron2. To obtain more stable behavior, write your own training logic with other public APIs.
Examples:
trainer = DefaultTrainer(cfg) trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS trainer.train()
-
scheduler
¶
-
checkpointer
¶
-
resume_or_load
(resume=True)[source]¶ If resume==True and cfg.OUTPUT_DIR contains the last checkpoint (defined by a last_checkpoint file), resume from the file. Resuming means loading all available states (eg. optimizer and scheduler) and update iteration counter from the checkpoint.
cfg.MODEL.WEIGHTS
will not be used.Otherwise, this is considered as an independent training. The method will load model weights from the file cfg.MODEL.WEIGHTS (but will not load other states) and start from iteration 0.
- Parameters
resume (bool) – whether to do resume or not
-
build_hooks
()[source]¶ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events.
- Returns
list[HookBase]
-
build_writers
()[source]¶ Build a list of writers to be used using
default_writers()
. If you’d like a different list of writers, you can overwrite it in your trainer.- Returns
list[EventWriter] – a list of
EventWriter
objects.
-
train
()[source]¶ Run training.
- Returns
OrderedDict of results, if evaluation is enabled. Otherwise None.
-
classmethod
build_model
(cfg)[source]¶ - Returns
torch.nn.Module
It now calls
detectron2.modeling.build_model()
. Overwrite it if you’d like a different model.
-
classmethod
build_optimizer
(cfg, model)[source]¶ - Returns
torch.optim.Optimizer
It now calls
detectron2.solver.build_optimizer()
. Overwrite it if you’d like a different optimizer.
-
classmethod
build_lr_scheduler
(cfg, optimizer)[source]¶ It now calls
detectron2.solver.build_lr_scheduler()
. Overwrite it if you’d like a different scheduler.
-
classmethod
build_train_loader
(cfg)[source]¶ - Returns
iterable
It now calls
detectron2.data.build_detection_train_loader()
. Overwrite it if you’d like a different data loader.
-
classmethod
build_test_loader
(cfg, dataset_name)[source]¶ - Returns
iterable
It now calls
detectron2.data.build_detection_test_loader()
. Overwrite it if you’d like a different data loader.
-
classmethod
build_evaluator
(cfg, dataset_name)[source]¶ - Returns
DatasetEvaluator or None
It is not implemented by default.
-
classmethod
test
(cfg, model, evaluators=None)[source]¶ - Parameters
cfg (CfgNode) –
model (nn.Module) –
evaluators (list[DatasetEvaluator] or None) – if None, will call
build_evaluator()
. Otherwise, must have the same length ascfg.DATASETS.TEST
.
- Returns
dict – a dict of result metrics
-
static
auto_scale_workers
(cfg, num_workers: int)[source]¶ When the config is defined for certain number of workers (according to
cfg.SOLVER.REFERENCE_WORLD_SIZE
) that’s different from the number of workers currently in use, returns a new cfg where the total batch size is scaled so that the per-GPU batch size stays the same as the originalIMS_PER_BATCH // REFERENCE_WORLD_SIZE
.Other config options are also scaled accordingly: * training steps and warmup steps are scaled inverse proportionally. * learning rate are scaled proportionally, following Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.
For example, with the original config like the following:
IMS_PER_BATCH: 16 BASE_LR: 0.1 REFERENCE_WORLD_SIZE: 8 MAX_ITER: 5000 STEPS: (4000,) CHECKPOINT_PERIOD: 1000
When this config is used on 16 GPUs instead of the reference number 8, calling this method will return a new config with:
IMS_PER_BATCH: 32 BASE_LR: 0.2 REFERENCE_WORLD_SIZE: 16 MAX_ITER: 2500 STEPS: (2000,) CHECKPOINT_PERIOD: 500
Note that both the original config and this new config can be trained on 16 GPUs. It’s up to user whether to enable this feature (by setting
REFERENCE_WORLD_SIZE
).- Returns
CfgNode – a new config. Same as original if
cfg.SOLVER.REFERENCE_WORLD_SIZE==0
.
-
property
data_loader
¶
-
property
model
¶
-
property
optimizer
¶
detectron2.engine.hooks module¶
-
class
detectron2.engine.hooks.
CallbackHook
(*, before_train=None, after_train=None, before_step=None, after_step=None)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
Create a hook using callback functions provided by the user.
-
class
detectron2.engine.hooks.
IterationTimer
(warmup_iter=3)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training.
This hook uses the time between the call to its
before_step()
andafter_step()
methods. Under the convention thatbefore_step()
of all hooks should only take negligible amount of time, theIterationTimer
hook should be placed at the beginning of the list of hooks to obtain accurate timing.
-
class
detectron2.engine.hooks.
PeriodicWriter
(writers, period=20)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
Write events to EventStorage (by calling
writer.write()
) periodically.It is executed every
period
iterations and after the last iteration. Note thatperiod
does not affect how data is smoothed by each writer.
-
class
detectron2.engine.hooks.
PeriodicCheckpointer
(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]¶ Bases:
fvcore.common.checkpoint.PeriodicCheckpointer
,detectron2.engine.train_loop.HookBase
Same as
detectron2.checkpoint.PeriodicCheckpointer
, but as a hook.Note that when used as a hook, it is unable to save additional data other than what’s defined by the given checkpointer.
It is executed every
period
iterations and after the last iteration.
-
class
detectron2.engine.hooks.
LRScheduler
(optimizer=None, scheduler=None)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
A hook which executes a torch builtin LR scheduler and summarizes the LR. It is executed after every iteration.
-
__init__
(optimizer=None, scheduler=None)[source]¶ - Parameters
optimizer (torch.optim.Optimizer) –
scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler) – if a
ParamScheduler
object, it defines the multiplier over the base LR in the optimizer.
If any argument is not given, will try to obtain it from the trainer.
-
property
scheduler
¶
-
-
class
detectron2.engine.hooks.
AutogradProfiler
(enable_predicate, output_dir, *, use_cuda=True)[source]¶ Bases:
detectron2.engine.hooks.TorchProfiler
A hook which runs torch.autograd.profiler.profile.
Examples:
hooks.AutogradProfiler( lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR )
The above example will run the profiler for iteration 10~20 and dump results to
OUTPUT_DIR
. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in thechrome://tracing
page in chrome browser.Note
When used together with NCCL on older version of GPUs, autograd profiler may cause deadlock because it unnecessarily allocates memory on every device it sees. The memory management calls, if interleaved with NCCL calls, lead to deadlock on GPUs that do not support
cudaLaunchCooperativeKernelMultiDevice
.-
__init__
(enable_predicate, output_dir, *, use_cuda=True)[source]¶ - Parameters
enable_predicate (callable[trainer -> bool]) – a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile.
output_dir (str) – the output directory to dump tracing files.
use_cuda (bool) – same as in torch.autograd.profiler.profile.
-
-
class
detectron2.engine.hooks.
EvalHook
(eval_period, eval_function)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
Run an evaluation function periodically, and at the end of training.
It is executed every
eval_period
iterations and after the last iteration.-
__init__
(eval_period, eval_function)[source]¶ - Parameters
eval_period (int) – the period to run eval_function. Set to 0 to not evaluate periodically (but still after the last iteration).
eval_function (callable) – a function which takes no arguments, and returns a nested dict of evaluation metrics.
Note
This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (eval_function=lambda: None).
-
-
class
detectron2.engine.hooks.
PreciseBN
(period, model, data_loader, num_iter)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
The standard implementation of BatchNorm uses EMA in inference, which is sometimes suboptimal. This class computes the true average of statistics rather than the moving average, and put true averages to every BN layer in the given model.
It is executed every
period
iterations and after the last iteration.-
__init__
(period, model, data_loader, num_iter)[source]¶ - Parameters
period (int) – the period this hook is run, or 0 to not run during training. The hook will always run in the end of training.
model (nn.Module) – a module whose all BN layers in training mode will be updated by precise BN. Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered.
data_loader (iterable) – it will produce data to be run by model(data).
num_iter (int) – number of iterations used to compute the precise statistics.
-
-
class
detectron2.engine.hooks.
TorchProfiler
(enable_predicate, output_dir, *, activities=None, save_tensorboard=True)[source]¶ Bases:
detectron2.engine.train_loop.HookBase
A hook which runs torch.profiler.profile.
Examples:
hooks.TorchProfiler( lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR )
The above example will run the profiler for iteration 10~20 and dump results to
OUTPUT_DIR
. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in thechrome://tracing
page in chrome browser, and the tensorboard visualizations can be visualized usingtensorboard --logdir OUTPUT_DIR/log
-
__init__
(enable_predicate, output_dir, *, activities=None, save_tensorboard=True)[source]¶ - Parameters
enable_predicate (callable[trainer -> bool]) – a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile.
output_dir (str) – the output directory to dump tracing files.
activities (iterable) – same as in torch.profiler.profile.
save_tensorboard (bool) – whether to save tensorboard visualizations at (output_dir)/log/
-
detectron2.evaluation¶
-
class
detectron2.evaluation.
CityscapesInstanceEvaluator
(dataset_name)[source]¶ Bases:
detectron2.evaluation.cityscapes_evaluation.CityscapesEvaluator
Evaluate instance segmentation results on cityscapes dataset using cityscapes API.
Note
It does not work in multi-machine distributed training.
It contains a synchronization, therefore has to be used on all ranks.
Only the main process runs evaluation.
-
class
detectron2.evaluation.
CityscapesSemSegEvaluator
(dataset_name)[source]¶ Bases:
detectron2.evaluation.cityscapes_evaluation.CityscapesEvaluator
Evaluate semantic segmentation results on cityscapes dataset using cityscapes API.
Note
It does not work in multi-machine distributed training.
It contains a synchronization, therefore has to be used on all ranks.
Only the main process runs evaluation.
-
class
detectron2.evaluation.
COCOEvaluator
(dataset_name, tasks=None, distributed=True, output_dir=None, *, use_fast_impl=True, kpt_oks_sigmas=())[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Evaluate AR for object proposals, AP for instance detection/segmentation, AP for keypoint detection outputs using COCO’s metrics. See http://cocodataset.org/#detection-eval and http://cocodataset.org/#keypoints-eval to understand its metrics. The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means the metric cannot be computed (e.g. due to no predictions made).
In addition to COCO, this evaluator is able to support any bounding box detection, instance segmentation, or keypoint detection dataset.
-
__init__
(dataset_name, tasks=None, distributed=True, output_dir=None, *, use_fast_impl=True, kpt_oks_sigmas=())[source]¶ - Parameters
dataset_name (str) –
name of the dataset to be evaluated. It must have either the following corresponding metadata:
”json_file”: the path to the COCO format annotation
Or it must be in detectron2’s standard dataset format so it can be converted to COCO format automatically.
tasks (tuple[str]) – tasks that can be evaluated under the given configuration. A task is one of “bbox”, “segm”, “keypoints”. By default, will infer this automatically from predictions.
distributed (True) – if True, will collect results from all ranks and run evaluation in the main process. Otherwise, will only evaluate the results in the current process.
output_dir (str) –
optional, an output directory to dump all results predicted on the dataset. The dump contains two files:
”instances_predictions.pth” a file that can be loaded with torch.load and contains all the results in the format they are produced by the model.
”coco_instances_results.json” a json file in COCO’s result format.
use_fast_impl (bool) – use a fast but unofficial implementation to compute AP. Although the results should be very close to the official implementation in COCO API, it is still recommended to compute results with the official API for use in papers. The faster implementation also uses more RAM.
kpt_oks_sigmas (list[float]) – The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval When empty, it will use the defaults in COCO. Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
-
process
(inputs, outputs)[source]¶ - Parameters
inputs – the inputs to a COCO model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.
outputs – the outputs of a COCO model. It is a list of dicts with key “instances” that contains
Instances
.
-
-
class
detectron2.evaluation.
RotatedCOCOEvaluator
(dataset_name, tasks=None, distributed=True, output_dir=None, *, use_fast_impl=True, kpt_oks_sigmas=())[source]¶ Bases:
detectron2.evaluation.coco_evaluation.COCOEvaluator
Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs, with rotated boxes support. Note: this uses IOU only and does not consider angle differences.
-
process
(inputs, outputs)[source]¶ - Parameters
inputs – the inputs to a COCO model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.
outputs – the outputs of a COCO model. It is a list of dicts with key “instances” that contains
Instances
.
-
-
class
detectron2.evaluation.
DatasetEvaluator
[source]¶ Bases:
object
Base class for a dataset evaluator.
The function
inference_on_dataset()
runs the model over all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.This class will accumulate information of the inputs/outputs (by
process()
), and produce evaluation results in the end (byevaluate()
).-
reset
()[source]¶ Preparation for a new round of evaluation. Should be called before starting a round of evaluation.
-
process
(inputs, outputs)[source]¶ Process the pair of inputs and outputs. If they contain batches, the pairs can be consumed one-by-one using zip:
for input_, output in zip(inputs, outputs): # do evaluation on single input/output pair ...
-
evaluate
()[source]¶ Evaluate/summarize the performance, after processing all input/output pairs.
- Returns
dict – A new evaluator class can return a dict of arbitrary format as long as the user can process the results. In our train_net.py, we expect the following format:
key: the name of the task (e.g., bbox)
value: a dict of {metric name: score}, e.g.: {“AP50”: 80}
-
-
class
detectron2.evaluation.
DatasetEvaluators
(evaluators)[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Wrapper class to combine multiple
DatasetEvaluator
instances.This class dispatches every evaluation call to all of its
DatasetEvaluator
.
-
detectron2.evaluation.
inference_context
(model)[source]¶ A context where the model is temporarily changed to eval mode, and restored to previous mode afterwards.
- Parameters
model – a torch Module
-
detectron2.evaluation.
inference_on_dataset
(model, data_loader, evaluator: Optional[Union[detectron2.evaluation.evaluator.DatasetEvaluator, List[detectron2.evaluation.evaluator.DatasetEvaluator]]])[source]¶ Run model on the data_loader and evaluate the metrics with evaluator. Also benchmark the inference speed of model.__call__ accurately. The model will be used in eval mode.
- Parameters
model (callable) –
a callable which takes an object from data_loader and returns some outputs.
If it’s an nn.Module, it will be temporarily set to eval mode. If you wish to evaluate a model in training mode instead, you can wrap the given model and override its behavior of .eval() and .train().
data_loader – an iterable object with a length. The elements it generates will be the inputs to the model.
evaluator – the evaluator(s) to run. Use None if you only want to benchmark, but don’t want to do any evaluation.
- Returns
The return value of evaluator.evaluate()
-
class
detectron2.evaluation.
LVISEvaluator
(dataset_name, tasks=None, distributed=True, output_dir=None)[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Evaluate object proposal and instance detection/segmentation outputs using LVIS’s metrics and evaluation API.
-
__init__
(dataset_name, tasks=None, distributed=True, output_dir=None)[source]¶ - Parameters
dataset_name (str) – name of the dataset to be evaluated. It must have the following corresponding metadata: “json_file”: the path to the LVIS format annotation
tasks (tuple[str]) – tasks that can be evaluated under the given configuration. A task is one of “bbox”, “segm”. By default, will infer this automatically from predictions.
distributed (True) – if True, will collect results from all ranks for evaluation. Otherwise, will evaluate the results in the current process.
output_dir (str) – optional, an output directory to dump results.
-
process
(inputs, outputs)[source]¶ - Parameters
inputs – the inputs to a LVIS model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.
outputs – the outputs of a LVIS model. It is a list of dicts with key “instances” that contains
Instances
.
-
-
class
detectron2.evaluation.
COCOPanopticEvaluator
(dataset_name: str, output_dir: Optional[str] = None)[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Evaluate Panoptic Quality metrics on COCO using PanopticAPI. It saves panoptic segmentation prediction in output_dir
It contains a synchronize call and has to be called from all workers.
-
class
detectron2.evaluation.
PascalVOCDetectionEvaluator
(dataset_name)[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Evaluate Pascal VOC style AP for Pascal VOC dataset. It contains a synchronization, therefore has to be called from all ranks.
Note that the concept of AP can be implemented in different ways and may not produce identical results. This class mimics the implementation of the official Pascal VOC Matlab API, and should produce similar but not identical results to the official API.
-
class
detectron2.evaluation.
SemSegEvaluator
(dataset_name, distributed=True, output_dir=None, *, num_classes=None, ignore_label=None)[source]¶ Bases:
detectron2.evaluation.evaluator.DatasetEvaluator
Evaluate semantic segmentation metrics.
-
__init__
(dataset_name, distributed=True, output_dir=None, *, num_classes=None, ignore_label=None)[source]¶ - Parameters
dataset_name (str) – name of the dataset to be evaluated.
distributed (bool) – if True, will collect results from all ranks for evaluation. Otherwise, will evaluate the results in the current process.
output_dir (str) – an output directory to dump results.
num_classes – deprecated argument
ignore_label – deprecated argument
-
process
(inputs, outputs)[source]¶ - Parameters
inputs – the inputs to a model. It is a list of dicts. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”.
outputs – the outputs of a model. It is either list of semantic segmentation predictions (Tensor [H, W]) or list of dicts with key “sem_seg” that contains semantic segmentation prediction in the same format.
-
evaluate
()[source]¶ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
Mean intersection-over-union averaged across classes (mIoU)
Frequency Weighted IoU (fwIoU)
Mean pixel accuracy averaged across classes (mACC)
Pixel Accuracy (pACC)
-
encode_json_sem_seg
(sem_seg, input_file_name)[source]¶ Convert semantic segmentation to COCO stuff format with segments encoded as RLEs. See http://cocodataset.org/#format-results
-
detectron2.layers¶
-
class
detectron2.layers.
FrozenBatchNorm2d
(num_features, eps=1e-05)[source]¶ Bases:
torch.nn.Module
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called “weight” and “bias”, “running_mean”, “running_var”, initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain “weight” and “bias”, which are computed from the original four parameters of BN. The affine transform x * weight + bias will perform the equivalent computation of (x - running_mean) / sqrt(running_var) * weight + bias. When loading a backbone model from Caffe2, “running_mean” and “running_var” will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by F.batch_norm(…, training=False).
-
classmethod
convert_frozen_batchnorm
(module)[source]¶ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
- Parameters
module (torch.nn.Module) –
- Returns
If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
-
classmethod
-
detectron2.layers.
get_norm
(norm, out_channels)[source]¶ - Parameters
norm (str or callable) – either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module.
- Returns
nn.Module or None – the normalization layer
-
class
detectron2.layers.
NaiveSyncBatchNorm
(*args, stats_mode='', **kwargs)[source]¶ Bases:
torch.nn.BatchNorm2d
In PyTorch<=1.5,
nn.SyncBatchNorm
has incorrect gradient when the batch size on each worker is different. (e.g., when scale augmentation is used, or when it is applied to mask head).This is a slower but correct alternative to nn.SyncBatchNorm.
Note
There isn’t a single definition of Sync BatchNorm.
When
stats_mode==""
, this module computes overall statistics by using statistics of each worker with equal weight. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (N, H, W). This mode does not support inputs with zero batch size.When
stats_mode=="N"
, this module computes overall statistics by weighting the statistics of each worker by theirN
. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (H, W). It is slower thanstats_mode==""
.Even though the result of this module may not be the true statistics of all samples, it may still be reasonable because it might be preferrable to assign equal weights to all workers, regardless of their (H, W) dimension, instead of putting larger weight on larger images. From preliminary experiments, little difference is found between such a simplified implementation and an accurate computation of overall mean & variance.
-
class
detectron2.layers.
DeformConv
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]¶ Bases:
torch.nn.Module
-
__init__
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, norm=None, activation=None)[source]¶ Deformable convolution from Deformable Convolutional Networks.
Arguments are similar to
Conv2D
. Extra arguments:- Parameters
deformable_groups (int) – number of groups used in deformable convolution.
norm (nn.Module, optional) – a normalization layer
activation (callable(Tensor) -> Tensor) – a callable activation function
-
-
class
detectron2.layers.
ModulatedDeformConv
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]¶ Bases:
torch.nn.Module
-
__init__
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=True, norm=None, activation=None)[source]¶ Modulated deformable convolution from Deformable ConvNets v2: More Deformable, Better Results.
Arguments are similar to
Conv2D
. Extra arguments:- Parameters
deformable_groups (int) – number of groups used in deformable convolution.
norm (nn.Module, optional) – a normalization layer
activation (callable(Tensor) -> Tensor) – a callable activation function
-
-
detectron2.layers.
paste_masks_in_image
(masks: torch.Tensor, boxes: detectron2.structures.Boxes, image_shape: Tuple[int, int], threshold: float = 0.5)[source]¶ Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes.
Note
This is a complicated but more accurate implementation. In actual deployment, it is often enough to use a faster but less accurate implementation. See
paste_mask_in_image_old()
in this file for an alternative implementation.- Parameters
masks (tensor) – Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of detected object instances in the image and Hmask, Wmask are the mask width and mask height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
boxes (Boxes or Tensor) – A Boxes of length Bimg or Tensor of shape (Bimg, 4). boxes[i] and masks[i] correspond to the same object instance.
image_shape (tuple) – height, width
threshold (float) – A threshold in [0, 1] for converting the (soft) masks to binary masks.
- Returns
img_masks (Tensor) – A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i.
-
detectron2.layers.
nms
(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) → torch.Tensor[source]¶ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an IoU greater than iou_threshold with another (higher scoring) box.
If multiple boxes have the exact same score and satisfy the IoU criterion with respect to a reference box, the selected box is not guaranteed to be the same between CPU and GPU. This is similar to the behavior of argsort in PyTorch when repeated values are present.
- boxesTensor[N, 4])
boxes to perform NMS on. They are expected to be in (x1, y1, x2, y2) format
- scoresTensor[N]
scores for each one of the boxes
- iou_thresholdfloat
discards all overlapping boxes with IoU > iou_threshold
- keepTensor
int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores
-
detectron2.layers.
batched_nms
(boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float)[source]¶ Same as torchvision.ops.boxes.batched_nms, but safer.
-
detectron2.layers.
batched_nms_rotated
(boxes, scores, idxs, iou_threshold)[source]¶ Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS will not be applied between elements of different categories.
- Parameters
boxes (Tensor[N, 5]) – boxes where NMS will be performed. They are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
scores (Tensor[N]) – scores for each one of the boxes
idxs (Tensor[N]) – indices of the categories for each one of the boxes.
iou_threshold (float) – discards all overlapping boxes with IoU < iou_threshold
- Returns
Tensor – int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores
-
detectron2.layers.
nms_rotated
(boxes, scores, iou_threshold)[source]¶ Performs non-maximum suppression (NMS) on the rotated boxes according to their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an IoU greater than iou_threshold with another (higher scoring) rotated box.
Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they can be representing completely different objects in certain tasks, e.g., OCR.
As for the question of whether rotated-NMS should treat them as faraway boxes even though their IOU is 1, it depends on the application and/or ground truth annotation.
As an extreme example, consider a single character v and the square box around it.
If the angle is 0 degree, the object (text) would be read as ‘v’;
If the angle is 90 degrees, the object (text) would become ‘>’;
If the angle is 180 degrees, the object (text) would become ‘^’;
If the angle is 270/-90 degrees, the object (text) would become ‘<’
All of these cases have IoU of 1 to each other, and rotated NMS that only uses IoU as criterion would only keep one of them with the highest score - which, practically, still makes sense in most cases because typically only one of theses orientations is the correct one. Also, it does not matter as much if the box is only used to classify the object (instead of transcribing them with a sequential OCR recognition model) later.
On the other hand, when we use IoU to filter proposals that are close to the ground truth during training, we should definitely take the angle into account if we know the ground truth is labeled with the strictly correct orientation (as in, upside-down words are annotated with -180 degrees even though they can be covered with a 0/90/-90 degree box, etc.)
The way the original dataset is annotated also matters. For example, if the dataset is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, we can estimate a minimum rotated bounding box to this polygon, but there’s no way we can tell the correct angle with 100% confidence (as shown above, there could be 4 different rotated boxes, with angles differed by 90 degrees to each other, covering the exactly same region). In that case we have to just use IoU to determine the box proximity (as many detection benchmarks (even for text) do) unless there’re other assumptions we can make (like width is always larger than height, or the object is not rotated by more than 90 degrees CCW/CW, etc.)
In summary, not considering angles in rotated NMS seems to be a good option for now, but we should be aware of its implications.
- Parameters
boxes (Tensor[N, 5]) – Rotated boxes to perform NMS on. They are expected to be in (x_center, y_center, width, height, angle_degrees) format.
scores (Tensor[N]) – Scores for each one of the rotated boxes
iou_threshold (float) – Discards all overlapping rotated boxes with IoU < iou_threshold
- Returns
keep (Tensor) – int64 tensor with the indices of the elements that have been kept by Rotated NMS, sorted in decreasing order of scores
-
detectron2.layers.
roi_align
(input: torch.Tensor, boxes: torch.Tensor, output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = - 1, aligned: bool = False) → torch.Tensor[source]¶ Performs Region of Interest (RoI) Align operator described in Mask R-CNN
- Parameters
input (Tensor[N, C, H, W]) – input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]) – the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. If a single Tensor is passed, then the first column should contain the batch index. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i in a batch
output_size (int or Tuple[int, int]) – the size of the output after the cropping is performed, as (height, width)
spatial_scale (float) – a scaling factor that maps the input coordinates to the box coordinates. Default: 1.0
sampling_ratio (int) – number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. If > 0, then exactly sampling_ratio x sampling_ratio grid points are used. If <= 0, then an adaptive number of grid points are used (computed as ceil(roi_width / pooled_w), and likewise for height). Default: -1
aligned (bool) – If False, use the legacy implementation. If True, pixel shift it by -0.5 for align more perfectly about two neighboring pixel indices. This version in Detectron2
- Returns
output (Tensor[K, C, output_size[0], output_size[1]])
-
class
detectron2.layers.
ROIAlign
(output_size, spatial_scale, sampling_ratio, aligned=True)[source]¶ Bases:
torch.nn.Module
-
__init__
(output_size, spatial_scale, sampling_ratio, aligned=True)[source]¶ - Parameters
output_size (tuple) – h, w
spatial_scale (float) – scale the input boxes by this number
sampling_ratio (int) – number of inputs samples to take for each output sample. 0 to take samples densely.
aligned (bool) – if False, use the legacy implementation in Detectron. If True, align the results more perfectly.
Note
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5). But the original roi_align (aligned=False) does not subtract the 0.5 when computing neighboring pixel indices and therefore it uses pixels with a slightly incorrect alignment (relative to our pixel model) when performing bilinear interpolation.
With aligned=True, we first appropriately scale the ROI and then shift it by -0.5 prior to calling roi_align. This produces the correct neighbors; see detectron2/tests/test_roi_align.py for verification.
The difference does not make a difference to the model’s performance if ROIAlign is used together with conv layers.
-
-
detectron2.layers.
roi_align_rotated
()¶
-
class
detectron2.layers.
ROIAlignRotated
(output_size, spatial_scale, sampling_ratio)[source]¶ Bases:
torch.nn.Module
-
__init__
(output_size, spatial_scale, sampling_ratio)[source]¶ - Parameters
Note
ROIAlignRotated supports continuous coordinate by default: Given a continuous coordinate c, its two neighboring pixel indices (in our pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled from the underlying signal at continuous coordinates 0.5 and 1.5).
-
-
class
detectron2.layers.
ShapeSpec
(channels=None, height=None, width=None, stride=None)[source]¶ Bases:
detectron2.layers.shape_spec._ShapeSpec
A simple structure that contains basic shape specification about a tensor. It is often used as the auxiliary inputs/outputs of models, to complement the lack of shape inference ability among pytorch modules.
-
channels
¶
-
height
¶
-
width
¶
-
stride
¶
-
-
class
detectron2.layers.
BatchNorm2d
(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)¶ Bases:
torch.nn.modules.batchnorm._BatchNorm
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the input size). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentum
of 0.1.If
track_running_stats
is set toFalse
, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.Note
This
momentum
argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.
- Parameters
num_features – \(C\) from an expected input of size \((N, C, H, W)\)
eps – a value added to the denominator for numerical stability. Default: 1e-5
momentum – the value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average (i.e. simple average). Default: 0.1affine – a boolean value that when set to
True
, this module has learnable affine parameters. Default:True
track_running_stats – a boolean value that when set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics, and initializes statistics buffersrunning_mean
andrunning_var
asNone
. When these buffers areNone
, this module always uses batch statistics. in both training and eval modes. Default:True
- Shape:
Input: \((N, C, H, W)\)
Output: \((N, C, H, W)\) (same shape as input)
Examples:
>>> # With Learnable Parameters >>> m = nn.BatchNorm2d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm2d(100, affine=False) >>> input = torch.randn(20, 100, 35, 45) >>> output = m(input)
-
class
detectron2.layers.
Conv2d
(*args, **kwargs)[source]¶ Bases:
torch.nn.Conv2d
A wrapper around
torch.nn.Conv2d
to support empty inputs and more features.-
__init__
(*args, **kwargs)[source]¶ Extra keyword arguments supported in addition to those in torch.nn.Conv2d:
- Parameters
norm (nn.Module, optional) – a normalization layer
activation (callable(Tensor) -> Tensor) – a callable activation function
It assumes that norm layer is used before activation.
-
bias
: Optional[torch.Tensor]¶
-
weight
: torch.Tensor¶
-
-
class
detectron2.layers.
ConvTranspose2d
(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = 'zeros')¶ Bases:
torch.nn.modules.conv._ConvTransposeNd
Applies a 2D transposed convolution operator over an input image composed of several input planes.
This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation).
This module supports TensorFloat32.
stride
controls the stride for the cross-correlation.padding
controls the amount of implicit zero-paddings on both sides fordilation * (kernel_size - 1) - padding
number of points. See note below for details.output_padding
controls the additional size added to one side of the output shape. See note below for details.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of whatdilation
does.groups
controls the connections between inputs and outputs.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor\)).
The parameters
kernel_size
,stride
,padding
,output_padding
can either be:a single
int
– in which case the same value is used for the height and width dimensionsa
tuple
of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
Note
Depending of the size of your kernel, several (of the last) columns of the input might be lost, because it is a valid cross-correlation, and not a full cross-correlation. It is up to the user to add proper padding.
Note
The
padding
argument effectively addsdilation * (kernel_size - 1) - padding
amount of zero padding to both sizes of the input. This is set so that when aConv2d
and aConvTranspose2d
are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, whenstride > 1
,Conv2d
maps multiple input shapes to the same output shape.output_padding
is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note thatoutput_padding
is only used to find output shape, but does not actually add zero-padding to output.Note
In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting
torch.backends.cudnn.deterministic = True
. Please see the notes on /notes/randomness for background.- Parameters
in_channels (int) – Number of channels in the input image
out_channels (int) – Number of channels produced by the convolution
stride (int or tuple, optional) – Stride of the convolution. Default: 1
padding (int or tuple, optional) –
dilation * (kernel_size - 1) - padding
zero-padding will be added to both sides of each dimension in the input. Default: 0output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0
groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional) – If
True
, adds a learnable bias to the output. Default:True
dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
- Shape:
Input: \((N, C_{in}, H_{in}, W_{in})\)
Output: \((N, C_{out}, H_{out}, W_{out})\) where
\[H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1\]\[W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1\]
-
weight
¶ the learnable weights of the module of shape \((\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},\) \(\text{kernel\_size[0]}, \text{kernel\_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}\)
- Type
Tensor
-
bias
¶ the learnable bias of the module of shape (out_channels) If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}\)- Type
Tensor
Examples:
>>> # With square kernels and equal stride >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> input = torch.randn(20, 16, 50, 100) >>> output = m(input) >>> # exact output size can be also specified as an argument >>> input = torch.randn(1, 16, 12, 12) >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) >>> h = downsample(input) >>> h.size() torch.Size([1, 16, 6, 6]) >>> output = upsample(h, output_size=input.size()) >>> output.size() torch.Size([1, 16, 12, 12])
-
forward
(input: torch.Tensor, output_size: Optional[List[int]] = None) → torch.Tensor¶
-
bias
: Optional[torch.Tensor]¶
-
weight
: torch.Tensor¶
-
detectron2.layers.
cat
(tensors: List[torch.Tensor], dim: int = 0)[source]¶ Efficient version of torch.cat that avoids a copy if there is only a single element in a list
-
detectron2.layers.
interpolate
(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)[source]¶ Down/up samples the input to either the given
size
or the givenscale_factor
The algorithm used for interpolation is determined by
mode
.Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape.
The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width.
The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area
- Parameters
input (Tensor) – the input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]) – output spatial size.
scale_factor (float or Tuple[float]) – multiplier for spatial size. Has to match input size if it is a tuple.
mode (str) – algorithm used for upsampling:
'nearest'
|'linear'
|'bilinear'
|'bicubic'
|'trilinear'
|'area'
. Default:'nearest'
align_corners (bool, optional) – Geometrically, we consider the pixels of the input and output as squares rather than points. If set to
True
, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. If set toFalse
, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation independent of input size whenscale_factor
is kept the same. This only has an effect whenmode
is'linear'
,'bilinear'
,'bicubic'
or'trilinear'
. Default:False
recompute_scale_factor (bool, optional) – recompute the scale_factor for use in the interpolation calculation. When scale_factor is passed as a parameter, it is used to compute the output_size. If recompute_scale_factor is
`False
or not specified, the passed-in scale_factor will be used in the interpolation computation. Otherwise, a new scale_factor will be computed based on the output and input sizes for use in the interpolation computation (i.e. the computation will be identical to if the computed output_size were passed-in explicitly). Note that when scale_factor is floating-point, the recomputed scale_factor may differ from the one passed in due to rounding and precision issues.
Note
With
mode='bicubic'
, it’s possible to cause overshoot, in other words it can produce negative values or values greater than 255 for images. Explicitly callresult.clamp(min=0, max=255)
if you want to reduce the overshoot when displaying the image.Warning
With
align_corners = True
, the linearly interpolating modes (linear, bilinear, and trilinear) don’t proportionally align the output and input pixels, and thus the output values can depend on the input size. This was the default behavior for these modes up to version 0.3.1. Since then, the default behavior isalign_corners = False
. SeeUpsample
for concrete examples on how this affects the outputs.Warning
When scale_factor is specified, if recompute_scale_factor=True, scale_factor is used to compute the output_size which will then be used to infer new scales for the interpolation. The default behavior for recompute_scale_factor changed to False in 1.6.0, and scale_factor is used in the interpolation calculation.
Note
When using the CUDA backend, this operation may induce nondeterministic behaviour in its backward pass that is not easily switched off. Please see the notes on /notes/randomness for background.
-
class
detectron2.layers.
Linear
(in_features: int, out_features: int, bias: bool = True)¶ Bases:
torch.nn.Module
Applies a linear transformation to the incoming data: \(y = xA^T + b\)
This module supports TensorFloat32.
- Parameters
in_features – size of each input sample
out_features – size of each output sample
bias – If set to
False
, the layer will not learn an additive bias. Default:True
- Shape:
Input: \((N, *, H_{in})\) where \(*\) means any number of additional dimensions and \(H_{in} = \text{in\_features}\)
Output: \((N, *, H_{out})\) where all but the last dimension are the same shape as the input and \(H_{out} = \text{out\_features}\).
-
weight
¶ the learnable weights of the module of shape \((\text{out\_features}, \text{in\_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in\_features}}\)
-
bias
¶ the learnable bias of the module of shape \((\text{out\_features})\). If
bias
isTrue
, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{in\_features}}\)
Examples:
>>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])
-
forward
(input: torch.Tensor) → torch.Tensor¶
-
weight
: torch.Tensor¶
-
detectron2.layers.
nonzero_tuple
(x)[source]¶ A ‘as_tuple=True’ version of torch.nonzero to support torchscript. because of https://github.com/pytorch/pytorch/issues/38718
-
detectron2.layers.
cross_entropy
(input, target, *, reduction='mean', **kwargs)[source]¶ Same as torch.nn.functional.cross_entropy, but returns 0 (instead of nan) for empty inputs.
-
class
detectron2.layers.
CNNBlockBase
(in_channels, out_channels, stride)[source]¶ Bases:
torch.nn.Module
A CNN block is assumed to have input channels, output channels and a stride. The input and output of forward() method must be NCHW tensors. The method can perform arbitrary computation but must match the given channels and stride specification.
- Attribute:
in_channels (int): out_channels (int): stride (int):
-
__init__
(in_channels, out_channels, stride)[source]¶ The __init__ method of any subclass should also contain these arguments.
-
class
detectron2.layers.
DepthwiseSeparableConv2d
(in_channels, out_channels, kernel_size=3, padding=1, dilation=1, *, norm1=None, activation1=None, norm2=None, activation2=None)[source]¶ Bases:
torch.nn.Module
A kxk depthwise convolution + a 1x1 convolution.
In Xception: Deep Learning with Depthwise Separable Convolutions, norm & activation are applied on the second conv. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications uses norm & activation on both convs.
-
__init__
(in_channels, out_channels, kernel_size=3, padding=1, dilation=1, *, norm1=None, activation1=None, norm2=None, activation2=None)[source]¶ - Parameters
norm1 (str or callable) – normalization for the two conv layers.
norm2 (str or callable) – normalization for the two conv layers.
activation1 (callable(Tensor) -> Tensor) – activation function for the two conv layers.
activation2 (callable(Tensor) -> Tensor) – activation function for the two conv layers.
-
-
class
detectron2.layers.
ASPP
(in_channels, out_channels, dilations, *, norm, activation, pool_kernel_size=None, dropout: float = 0.0, use_depthwise_separable_conv=False)[source]¶ Bases:
torch.nn.Module
Atrous Spatial Pyramid Pooling (ASPP).
-
__init__
(in_channels, out_channels, dilations, *, norm, activation, pool_kernel_size=None, dropout: float = 0.0, use_depthwise_separable_conv=False)[source]¶ - Parameters
in_channels (int) – number of input channels for ASPP.
out_channels (int) – number of output channels.
dilations (list) – a list of 3 dilations in ASPP.
norm (str or callable) – normalization for all conv layers. See
layers.get_norm()
for supported format. norm is applied to all conv layers except the conv following global average pooling.activation (callable) – activation function.
pool_kernel_size (tuple, list) – the average pooling size (kh, kw) for image pooling layer in ASPP. If set to None, it always performs global average pooling. If not None, it must be divisible by the shape of inputs in forward(). It is recommended to use a fixed input feature size in training, and set this option to match this size, so that it performs global average pooling in training, and the size of the pooling window stays consistent in inference.
dropout (float) – apply dropout on the output of ASPP. It is used in the official DeepLab implementation with a rate of 0.1: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa
use_depthwise_separable_conv (bool) – use DepthwiseSeparableConv2d for 3x3 convs in ASPP, proposed in Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
-
detectron2.model_zoo¶
Model Zoo API for Detectron2: a collection of functions to create common model architectures listed in MODEL_ZOO.md, and optionally load their pre-trained weights.
-
detectron2.model_zoo.
get_checkpoint_url
(config_path)[source]¶ Returns the URL to the model trained using the given config
- Parameters
config_path (str) – config file name relative to detectron2’s “configs/” directory, e.g., “COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml”
- Returns
str – a URL to the model
-
detectron2.model_zoo.
get
(config_path, trained: bool = False, device: Optional[str] = None)[source]¶ Get a model specified by relative path under Detectron2’s official
configs/
directory.- Parameters
config_path (str) – config file name relative to detectron2’s “configs/” directory, e.g., “COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml”
trained (bool) – see
get_config()
.device (str or None) – overwrite the device in config, if given.
- Returns
nn.Module – a detectron2 model. Will be in training mode.
Example:
from detectron2 import model_zoo model = model_zoo.get("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml", trained=True)
-
detectron2.model_zoo.
get_config_file
(config_path)[source]¶ Returns path to a builtin config file.
- Parameters
config_path (str) – config file name relative to detectron2’s “configs/” directory, e.g., “COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml”
- Returns
str – the real path to the config file.
-
detectron2.model_zoo.
get_config
(config_path, trained: bool = False)[source]¶ Returns a config object for a model in model zoo.
- Parameters
config_path (str) – config file name relative to detectron2’s “configs/” directory, e.g., “COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml”
trained (bool) – If True, will set
MODEL.WEIGHTS
to trained model zoo weights. If False, the checkpoint specified in the config file’sMODEL.WEIGHTS
is used instead; this will typically (though not always) initialize a subset of weights using an ImageNet pre-trained model, while randomly initializing the other weights.
- Returns
CfgNode or omegaconf.DictConfig – a config object
detectron2.modeling¶
-
detectron2.modeling.
build_anchor_generator
(cfg, input_shape)¶ Built an anchor generator from cfg.MODEL.ANCHOR_GENERATOR.NAME.
-
class
detectron2.modeling.
FPN
(bottom_up, in_features, out_channels, norm='', top_block=None, fuse_type='sum')¶ Bases:
detectron2.modeling.Backbone
This module implements Feature Pyramid Networks for Object Detection. It creates pyramid features built on top of some input feature maps.
-
__init__
(bottom_up, in_features, out_channels, norm='', top_block=None, fuse_type='sum')¶ - Parameters
bottom_up (Backbone) – module representing the bottom up subnetwork. Must be a subclass of
Backbone
. The multi-scale feature maps generated by the bottom up network, and listed in in_features, are used to generate FPN levels.in_features (list[str]) – names of the input feature maps coming from the backbone to which FPN is attached. For example, if the backbone produces [“res2”, “res3”, “res4”], any contiguous sublist of these may be used; order must be from high to low resolution.
out_channels (int) – number of channels in the output feature maps.
norm (str) – the normalization to use.
top_block (nn.Module or None) – if provided, an extra operation will be performed on the output of the last (smallest resolution) FPN output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute “num_levels”, meaning the number of extra FPN levels added by this block, and “in_feature”, which is a string representing its input feature (e.g., p5).
fuse_type (str) – types for fusing the top down features and the lateral ones. It can be “sum” (default), which sums up element-wise; or “avg”, which takes the element-wise mean of the two.
-
forward
(x)¶ - Parameters
input (dict[str->Tensor]) – mapping feature map name (e.g., “res5”) to feature map tensor for each feature level in high to low resolution order.
- Returns
dict[str->Tensor] – mapping from feature map name to FPN feature map tensor in high to low resolution order. Returned feature names follow the FPN paper convention: “p<stage>”, where stage has stride = 2 ** stage e.g., [“p2”, “p3”, …, “p6”].
-
output_shape
()¶
-
property
size_divisibility
¶
-
-
class
detectron2.modeling.
Backbone
¶ Bases:
torch.nn.Module
Abstract base class for network backbones.
-
__init__
()¶ The __init__ method of any subclass can specify its own set of arguments.
-
abstract
forward
()¶ Subclasses must override this method, but adhere to the same return type.
- Returns
dict[str->Tensor] – mapping from feature name (e.g., “res2”) to tensor
-
output_shape
()¶ - Returns
dict[str->ShapeSpec]
-
property
size_divisibility
¶ Some backbones require the input height and width to be divisible by a specific integer. This is typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match dimension in the “bottom up” and “top down” paths. Set to 0 if no specific input size divisibility is required.
-
-
class
detectron2.modeling.
ResNet
(stem, stages, num_classes=None, out_features=None, freeze_at=0)¶ Bases:
detectron2.modeling.Backbone
Implement Deep Residual Learning for Image Recognition.
-
__init__
(stem, stages, num_classes=None, out_features=None, freeze_at=0)¶ - Parameters
stem (nn.Module) – a stem module
stages (list[list[CNNBlockBase]]) – several (typically 4) stages, each contains multiple
CNNBlockBase
.num_classes (None or int) – if None, will not perform classification. Otherwise, will create a linear layer.
out_features (list[str]) – name of the layers whose outputs should be returned in forward. Can be anything in “stem”, “linear”, or “res2” … If None, will return the output of the last layer.
freeze_at (int) – The number of stages at the beginning to freeze. see
freeze()
for detailed explanation.
-
forward
(x)¶ - Parameters
x – Tensor of shape (N,C,H,W). H, W must be a multiple of
self.size_divisibility
.- Returns
dict[str->Tensor] – names and the corresponding features
-
freeze
(freeze_at=0)¶ Freeze the first several stages of the ResNet. Commonly used in fine-tuning.
Layers that produce the same feature map spatial size are defined as one “stage” by Feature Pyramid Networks for Object Detection.
- Parameters
freeze_at (int) – number of stages to freeze. 1 means freezing the stem. 2 means freezing the stem and one residual stage, etc.
- Returns
nn.Module – this ResNet itself
-
static
make_default_stages
(depth, block_class=None, **kwargs)[source]¶ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). If it doesn’t create the ResNet variant you need, please use
make_stage()
instead for fine-grained customization.- Parameters
depth (int) – depth of ResNet
block_class (type) – the CNN block class. Has to accept bottleneck_channels argument for depth > 50. By default it is BasicBlock or BottleneckBlock, based on the depth.
kwargs – other arguments to pass to make_stage. Should not contain stride and channels, as they are predefined for each depth.
- Returns
list[list[CNNBlockBase]] –
- modules in all stages; see arguments of
-
static
make_stage
(block_class, num_blocks, *, in_channels, out_channels, **kwargs)[source]¶ Create a list of blocks of the same type that forms one ResNet stage.
- Parameters
block_class (type) – a subclass of CNNBlockBase that’s used to create all blocks in this stage. A module of this type must not change spatial resolution of inputs unless its stride != 1.
num_blocks (int) – number of blocks in this stage
in_channels (int) – input channels of the entire stage.
out_channels (int) – output channels of every block in the stage.
kwargs – other arguments passed to the constructor of block_class. If the argument name is “xx_per_block”, the argument is a list of values to be passed to each block in the stage. Otherwise, the same argument is passed to every block in the stage.
- Returns
list[CNNBlockBase] – a list of block module.
Examples:
stage = ResNet.make_stage( BottleneckBlock, 3, in_channels=16, out_channels=64, bottleneck_channels=16, num_groups=1, stride_per_block=[2, 1, 1], dilations_per_block=[1, 1, 2] )
Usually, layers that produce the same feature map spatial size are defined as one “stage” (in Feature Pyramid Networks for Object Detection). Under such definition,
stride_per_block[1:]
should all be 1.
-
output_shape
()¶
-
-
detectron2.modeling.
build_backbone
(cfg, input_shape=None)¶ Build a backbone from cfg.MODEL.BACKBONE.NAME.
- Returns
an instance of
Backbone
-
detectron2.modeling.
build_resnet_backbone
(cfg, input_shape)¶ Create a ResNet instance from config.
- Returns
ResNet – a
ResNet
instance.
-
class
detectron2.modeling.
GeneralizedRCNN
(*args, **kwargs)¶ Bases:
torch.nn.Module
Generalized R-CNN. Any models that contains the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction
-
__init__
(*, backbone: detectron2.modeling.Backbone, proposal_generator: torch.nn.Module, roi_heads: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], input_format: Optional[str] = None, vis_period: int = 0)¶ - Parameters
backbone – a backbone module, must follow detectron2’s backbone interface
proposal_generator – a module that generates proposals using backbone features
roi_heads – a ROI head that performs per-region computation
pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
input_format – describe the meaning of channels of input. Needed by visualization
vis_period – the period to run visualization. Set to 0 to disable.
-
property
device
¶
-
forward
(batched_inputs: List[Dict[str, torch.Tensor]])¶ - Parameters
batched_inputs –
a list, batched outputs of
DatasetMapper
. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:image: Tensor, image in (C, H, W) format.
instances (optional): groundtruth
Instances
proposals (optional):
Instances
, precomputed proposals.
Other information that’s included in the original dicts, such as:
”height”, “width” (int): the output resolution of the model, used in inference. See
postprocess()
for details.
- Returns
list[dict] – Each dict is the output for one input image. The dict contains one key “instances” whose value is a
Instances
. TheInstances
object has the following keys: “pred_boxes”, “pred_classes”, “scores”, “pred_masks”, “pred_keypoints”
-
inference
(batched_inputs: List[Dict[str, torch.Tensor]], detected_instances: Optional[List[detectron2.structures.Instances]] = None, do_postprocess: bool = True)¶ Run inference on the given inputs.
- Parameters
detected_instances (None or list[Instances]) – if not None, it contains an Instances object per image. The Instances object contains “pred_boxes” and “pred_classes” which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs.
do_postprocess (bool) – whether to apply post-processing on the outputs.
- Returns
When do_postprocess=True, same as in
forward()
. Otherwise, a list[Instances] containing raw network outputs.
-
preprocess_image
(batched_inputs: List[Dict[str, torch.Tensor]])¶ Normalize, pad and batch the input images.
-
visualize_training
(batched_inputs, proposals)¶ A function used to visualize images and proposals. It shows ground truth bounding boxes on the original image and up to 20 top-scoring predicted object proposals on the original image. Users can implement different visualization functions for different models.
-
-
class
detectron2.modeling.
PanopticFPN
(*args, **kwargs)¶ Bases:
detectron2.modeling.GeneralizedRCNN
Implement the paper Panoptic Feature Pyramid Networks.
-
__init__
(*, sem_seg_head: torch.nn.Module, combine_overlap_thresh: float = 0.5, combine_stuff_area_thresh: float = 4096, combine_instances_score_thresh: float = 0.5, **kwargs)¶ NOTE: this interface is experimental.
- Parameters
sem_seg_head – a module for the semantic segmentation head.
combine_overlap_thresh – combine masks into one instances if they have enough overlap
combine_stuff_area_thresh – ignore stuff areas smaller than this threshold
combine_instances_score_thresh – ignore instances whose score is smaller than this threshold
Other arguments are the same as
GeneralizedRCNN
.
-
forward
(batched_inputs)¶ - Parameters
batched_inputs –
a list, batched outputs of
DatasetMapper
. Each item in the list contains the inputs for one image.For now, each item in the list is a dict that contains:
”image”: Tensor, image in (C, H, W) format.
”instances”: Instances
”sem_seg”: semantic segmentation ground truth.
Other information that’s included in the original dicts, such as: “height”, “width” (int): the output resolution of the model, used in inference. See
postprocess()
for details.
- Returns
list[dict] – each dict has the results for one image. The dict contains the following keys:
”instances”: see
GeneralizedRCNN.forward()
for its format.”sem_seg”: see
SemanticSegmentor.forward()
for its format.”panoptic_seg”: See the return value of
combine_semantic_and_instance_outputs()
for its format.
-
inference
(batched_inputs: List[Dict[str, torch.Tensor]], do_postprocess: bool = True)¶ Run inference on the given inputs.
- Parameters
- Returns
When do_postprocess=True, see docs in
forward()
. Otherwise, returns a (list[Instances], list[Tensor]) that contains the raw detector outputs, and raw semantic segmentation outputs.
-
-
class
detectron2.modeling.
ProposalNetwork
(*args, **kwargs)¶ Bases:
torch.nn.Module
A meta architecture that only predicts object proposals.
-
__init__
(*, backbone: detectron2.modeling.Backbone, proposal_generator: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float])¶ - Parameters
backbone – a backbone module, must follow detectron2’s backbone interface
proposal_generator – a module that generates proposals using backbone features
pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
-
property
device
¶
-
forward
(batched_inputs)¶ :param Same as in
GeneralizedRCNN.forward
:- Returns
list[dict] – Each dict is the output for one input image. The dict contains one key “proposals” whose value is a
Instances
with keys “proposal_boxes” and “objectness_logits”.
-
-
class
detectron2.modeling.
RetinaNet
(*args, **kwargs)¶ Bases:
torch.nn.Module
Implement RetinaNet in Focal Loss for Dense Object Detection.
-
__init__
(*, backbone: detectron2.modeling.Backbone, head: torch.nn.Module, head_in_features, anchor_generator, box2box_transform, anchor_matcher, num_classes, focal_loss_alpha=0.25, focal_loss_gamma=2.0, smooth_l1_beta=0.0, box_reg_loss_type='smooth_l1', test_score_thresh=0.05, test_topk_candidates=1000, test_nms_thresh=0.5, max_detections_per_image=100, pixel_mean, pixel_std, vis_period=0, input_format='BGR')¶ NOTE: this interface is experimental.
- Parameters
backbone – a backbone module, must follow detectron2’s backbone interface
head (nn.Module) – a module that predicts logits and regression deltas for each level from a list of per-level features
head_in_features (Tuple[str]) – Names of the input feature maps to be used in head
anchor_generator (nn.Module) – a module that creates anchors from a list of features. Usually an instance of
AnchorGenerator
box2box_transform (Box2BoxTransform) – defines the transform from anchors boxes to instance boxes
anchor_matcher (Matcher) – label the anchors by matching them with ground truth.
num_classes (int) – number of classes. Used to label background proposals.
Loss parameters (#) –
focal_loss_alpha (float) – focal_loss_alpha
focal_loss_gamma (float) – focal_loss_gamma
smooth_l1_beta (float) – smooth_l1_beta
box_reg_loss_type (str) – Options are “smooth_l1”, “giou”
Inference parameters (#) –
test_score_thresh (float) – Inference cls score threshold, only anchors with score > INFERENCE_TH are considered for inference (to improve speed)
test_topk_candidates (int) – Select topk candidates before NMS
test_nms_thresh (float) – Overlap threshold used for non-maximum suppression (suppress boxes with IoU >= this threshold)
max_detections_per_image (int) – Maximum number of detections to return per image during inference (100 is based on the limit established for the COCO dataset).
Input parameters (#) –
pixel_mean (Tuple[float]) – Values to be used for image normalization (BGR order). To train on images of different number of channels, set different mean & std. Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
pixel_std (Tuple[float]) – When using pre-trained models in Detectron1 or any MSRA models, std has been absorbed into its conv1 weights, so the std needs to be set 1. Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
vis_period (int) – The period (in terms of steps) for minibatch visualization at train time. Set to 0 to disable.
input_format (str) – Whether the model needs RGB, YUV, HSV etc.
-
property
device
¶
-
forward
(batched_inputs: List[Dict[str, torch.Tensor]])¶ - Parameters
batched_inputs –
a list, batched outputs of
DatasetMapper
. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains:image: Tensor, image in (C, H, W) format.
instances: Instances
Other information that’s included in the original dicts, such as:
”height”, “width” (int): the output resolution of the model, used in inference. See
postprocess()
for details.
- Returns
In training, dict[str, Tensor] – mapping from a named loss to a tensor storing the loss. Used during training only. In inference, the standard output format, described in Use Models.
-
inference
(anchors: List[detectron2.structures.Boxes], pred_logits: List[torch.Tensor], pred_anchor_deltas: List[torch.Tensor], image_sizes: List[Tuple[int, int]])¶ - Parameters
anchors (list[Boxes]) – A list of #feature level Boxes. The Boxes contain anchors of this image on the specific feature level.
pred_logits – list[Tensor], one per level. Each has shape (N, Hi * Wi * Ai, K or 4)
pred_anchor_deltas – list[Tensor], one per level. Each has shape (N, Hi * Wi * Ai, K or 4)
image_sizes (List[(h, w)]) – the input image sizes
- Returns
results (List[Instances]) – a list of #images elements.
-
inference_single_image
(anchors: List[detectron2.structures.Boxes], box_cls: List[torch.Tensor], box_delta: List[torch.Tensor], image_size: Tuple[int, int])¶ Single-image inference. Return bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS).
- Parameters
anchors (list[Boxes]) – list of #feature levels. Each entry contains a Boxes object, which contains all the anchors in that feature level.
box_cls (list[Tensor]) – list of #feature levels. Each entry contains tensor of size (H x W x A, K)
box_delta (list[Tensor]) – Same shape as ‘box_cls’ except that K becomes 4.
image_size (tuple(H, W)) – a tuple of the image height and width.
- Returns
Same as inference, but for only one image.
-
label_anchors
(anchors, gt_instances)¶ - Parameters
- Returns
list[Tensor] – List of #img tensors. i-th element is a vector of labels whose length is the total number of anchors across all feature maps (sum(Hi * Wi * A)). Label values are in {-1, 0, …, K}, with -1 means ignore, and K means background.
list[Tensor]: i-th element is a Rx4 tensor, where R is the total number of anchors across feature maps. The values are the matched gt boxes for each anchor. Values are undefined for those anchors not labeled as foreground.
-
losses
(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)¶ - Parameters
gt_labels – see output of
RetinaNet.label_anchors()
. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)gt_boxes – see output of
RetinaNet.label_anchors()
. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)pred_logits – both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in pred_logits.
pred_anchor_deltas – both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in pred_logits.
- Returns
dict[str, Tensor] – mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: “loss_cls” and “loss_box_reg”
-
preprocess_image
(batched_inputs: List[Dict[str, torch.Tensor]])¶ Normalize, pad and batch the input images.
-
visualize_training
(batched_inputs, results)¶ A function used to visualize ground truth images and final network predictions. It shows ground truth bounding boxes on the original image and up to 20 predicted object bounding boxes on the original image.
-
-
class
detectron2.modeling.
SemanticSegmentor
(*args, **kwargs)¶ Bases:
torch.nn.Module
Main class for semantic segmentation architectures.
-
__init__
(*, backbone: detectron2.modeling.Backbone, sem_seg_head: torch.nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float])¶ - Parameters
backbone – a backbone module, must follow detectron2’s backbone interface
sem_seg_head – a module that predicts semantic segmentation from backbone features
pixel_mean – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
pixel_std – list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image
-
property
device
¶
-
forward
(batched_inputs)¶ - Parameters
batched_inputs –
a list, batched outputs of
DatasetMapper
. Each item in the list contains the inputs for one image.For now, each item in the list is a dict that contains:
”image”: Tensor, image in (C, H, W) format.
”sem_seg”: semantic segmentation ground truth
Other information that’s included in the original dicts, such as: “height”, “width” (int): the output resolution of the model (may be different from input resolution), used in inference.
- Returns
list[dict] – Each dict is the output for one input image. The dict contains one key “sem_seg” whose value is a Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel.
-
-
detectron2.modeling.
build_model
(cfg)¶ Build the whole model architecture, defined by
cfg.MODEL.META_ARCHITECTURE
. Note that it does not load any weights fromcfg
.
-
detectron2.modeling.
build_sem_seg_head
(cfg, input_shape)¶ Build a semantic segmentation head from cfg.MODEL.SEM_SEG_HEAD.NAME.
-
detectron2.modeling.
detector_postprocess
(results: detectron2.structures.Instances, output_height: int, output_width: int, mask_threshold: float = 0.5)¶ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs.
This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution.
- Parameters
results (Instances) – the raw outputs from the detector. results.image_size contains the input image resolution the detector sees. This object might be modified in-place.
output_height – the desired output resolution.
output_width – the desired output resolution.
- Returns
Instances – the resized output from the model, based on the output resolution
-
detectron2.modeling.
build_proposal_generator
(cfg, input_shape)¶ Build a proposal generator from cfg.MODEL.PROPOSAL_GENERATOR.NAME. The name can be “PrecomputedProposals” to use no proposal generator.
-
detectron2.modeling.
build_rpn_head
(cfg, input_shape)¶ Build an RPN head defined by cfg.MODEL.RPN.HEAD_NAME.
-
class
detectron2.modeling.
ROIHeads
(*args, **kwargs)¶ Bases:
torch.nn.Module
ROIHeads perform all per-region computation in an R-CNN.
It typically contains logic to
(in training only) match proposals with ground truth and sample them
crop the regions and extract per-region features using proposals
make per-region predictions with different heads
It can have many variants, implemented as subclasses of this class. This base class contains the logic to match/sample proposals. But it is not necessary to inherit this class if the sampling logic is not needed.
-
__init__
(*, num_classes, batch_size_per_image, positive_fraction, proposal_matcher, proposal_append_gt=True)¶ NOTE: this interface is experimental.
- Parameters
num_classes (int) – number of foreground classes (i.e. background is not included)
batch_size_per_image (int) – number of proposals to sample for training
positive_fraction (float) – fraction of positive (foreground) proposals to sample for training.
proposal_matcher (Matcher) – matcher that matches proposals and ground truth
proposal_append_gt (bool) – whether to include ground truth as proposals as well
-
forward
(images: detectron2.structures.ImageList, features: Dict[str, torch.Tensor], proposals: List[detectron2.structures.Instances], targets: Optional[List[detectron2.structures.Instances]] = None) → Tuple[List[detectron2.structures.Instances], Dict[str, torch.Tensor]]¶ - Parameters
images (ImageList) –
features (dict[str,Tensor]) – input data as a mapping from feature map name to tensor. Axis 0 represents the number of images N in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used).
proposals (list[Instances]) – length N list of Instances. The i-th Instances contains object proposals for the i-th input image, with fields “proposal_boxes” and “objectness_logits”.
targets (list[Instances], optional) –
length N list of Instances. The i-th Instances contains the ground-truth per-instance annotations for the i-th input image. Specify targets during training only. It may have the following fields:
gt_boxes: the bounding box of each instance.
gt_classes: the label for each instance with a category ranging in [0, #class].
gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
- Returns
list[Instances] – length N list of Instances containing the detected instances. Returned during inference only; may be [] during training.
dict[str->Tensor]: mapping from a named loss to a tensor storing the loss. Used during training only.
-
label_and_sample_proposals
(proposals: List[detectron2.structures.Instances], targets: List[detectron2.structures.Instances]) → List[detectron2.structures.Instances]¶ Prepare some proposals to be used to train the ROI heads. It performs box matching between proposals and targets, and assigns training labels to the proposals. It returns
self.batch_size_per_image
random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger thanself.positive_fraction
.:param See
ROIHeads.forward()
:- Returns
list[Instances] – length N list of Instances`s containing the proposals sampled for training. Each `Instances has the following fields:
proposal_boxes: the proposal boxes
gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random)
Other fields such as “gt_classes”, “gt_masks”, that’s included in targets.
-
class
detectron2.modeling.
StandardROIHeads
(*args, **kwargs)¶ Bases:
detectron2.modeling.ROIHeads
It’s “standard” in a sense that there is no ROI transform sharing or feature sharing between tasks. Each head independently processes the input features by each head’s own pooler and head.
This class is used by most models, such as FPN and C5. To implement more models, you can subclass it and implement a different
forward()
or a head.-
__init__
(*, box_in_features: List[str], box_pooler: detectron2.modeling.poolers.ROIPooler, box_head: torch.nn.Module, box_predictor: torch.nn.Module, mask_in_features: Optional[List[str]] = None, mask_pooler: Optional[detectron2.modeling.poolers.ROIPooler] = None, mask_head: Optional[torch.nn.Module] = None, keypoint_in_features: Optional[List[str]] = None, keypoint_pooler: Optional[detectron2.modeling.poolers.ROIPooler] = None, keypoint_head: Optional[torch.nn.Module] = None, train_on_pred_boxes: bool = False, **kwargs)¶ NOTE: this interface is experimental.
- Parameters
box_in_features (list[str]) – list of feature names to use for the box head.
box_pooler (ROIPooler) – pooler to extra region features for box head
box_head (nn.Module) – transform features to make box predictions
box_predictor (nn.Module) – make box predictions from the feature. Should have the same interface as
FastRCNNOutputLayers
.mask_in_features (list[str]) – list of feature names to use for the mask pooler or mask head. None if not using mask head.
mask_pooler (ROIPooler) – pooler to extract region features from image features. The mask head will then take region features to make predictions. If None, the mask head will directly take the dict of image features defined by mask_in_features
mask_head (nn.Module) – transform features to make mask predictions
keypoint_in_features – similar to
mask_*
.keypoint_pooler – similar to
mask_*
.keypoint_head – similar to
mask_*
.train_on_pred_boxes (bool) – whether to use proposal boxes or predicted boxes from the box head to train other heads.
-
forward
(images: detectron2.structures.ImageList, features: Dict[str, torch.Tensor], proposals: List[detectron2.structures.Instances], targets: Optional[List[detectron2.structures.Instances]] = None) → Tuple[List[detectron2.structures.Instances], Dict[str, torch.Tensor]]¶ See
ROIHeads.forward
.
-
forward_with_given_boxes
(features: Dict[str, torch.Tensor], instances: List[detectron2.structures.Instances]) → List[detectron2.structures.Instances]¶ Use the given boxes in instances to produce other (non-box) per-ROI outputs.
This is useful for downstream tasks where a box is known, but need to obtain other attributes (outputs of other heads). Test-time augmentation also uses this.
-
-
class
detectron2.modeling.
BaseMaskRCNNHead
(*args, **kwargs)¶ Bases:
torch.nn.Module
Implement the basic Mask R-CNN losses and inference logic described in Mask R-CNN
-
forward
(x, instances: List[detectron2.structures.Instances])¶ - Parameters
x – input region feature(s) provided by
ROIHeads
.instances (list[Instances]) – contains the boxes & labels corresponding to the input features. Exact format is up to its caller to decide. Typically, this is the foreground instances in training, with “proposal_boxes” field and other gt annotations. In inference, it contains boxes that are already predicted.
- Returns
A dict of losses in training. The predicted “instances” in inference.
-
layers
(x)¶ Neural network layers that makes predictions from input features.
-
-
class
detectron2.modeling.
BaseKeypointRCNNHead
(*args, **kwargs)¶ Bases:
torch.nn.Module
Implement the basic Keypoint R-CNN losses and inference logic described in Sec. 5 of Mask R-CNN.
-
__init__
(*, num_keypoints, loss_weight=1.0, loss_normalizer=1.0)¶ NOTE: this interface is experimental.
- Parameters
-
forward
(x, instances: List[detectron2.structures.Instances])¶ - Parameters
x – input 4D region feature(s) provided by
ROIHeads
.instances (list[Instances]) – contains the boxes & labels corresponding to the input features. Exact format is up to its caller to decide. Typically, this is the foreground instances in training, with “proposal_boxes” field and other gt annotations. In inference, it contains boxes that are already predicted.
- Returns
A dict of losses if in training. The predicted “instances” if in inference.
-
layers
(x)¶ Neural network layers that makes predictions from regional input features.
-
-
class
detectron2.modeling.
FastRCNNOutputLayers
(*args, **kwargs)¶ Bases:
torch.nn.Module
Two linear layers for predicting Fast R-CNN outputs:
proposal-to-detection box regression deltas
classification scores
-
__init__
(input_shape: detectron2.layers.shape_spec.ShapeSpec, *, box2box_transform, num_classes: int, test_score_thresh: float = 0.0, test_nms_thresh: float = 0.5, test_topk_per_image: int = 100, cls_agnostic_bbox_reg: bool = False, smooth_l1_beta: float = 0.0, box_reg_loss_type: str = 'smooth_l1', loss_weight: Union[float, Dict[str, float]] = 1.0)¶ NOTE: this interface is experimental.
- Parameters
input_shape (ShapeSpec) – shape of the input feature to this module
box2box_transform (Box2BoxTransform or Box2BoxTransformRotated) –
num_classes (int) – number of foreground classes
test_score_thresh (float) – threshold to filter predictions results.
test_nms_thresh (float) – NMS threshold for prediction results.
test_topk_per_image (int) – number of top predictions to produce per image.
cls_agnostic_bbox_reg (bool) – whether to use class agnostic for bbox regression
smooth_l1_beta (float) – transition point from L1 to L2 loss. Only used if box_reg_loss_type is “smooth_l1”
box_reg_loss_type (str) – Box regression loss type. One of: “smooth_l1”, “giou”
loss_weight (float|dict) –
weights to use for losses. Can be single float for weighting all losses, or a dict of individual weightings. Valid dict keys are:
”loss_cls”: applied to classification loss
”loss_box_reg”: applied to box regression loss
-
box_reg_loss
(proposal_boxes, gt_boxes, pred_deltas, gt_classes)¶ - Parameters
boxes are tensors with the same shape Rx (All) –
is a long tensor of shape R (gt_classes) –
gt class label of each proposal. (the) –
shall be the number of proposals. (R) –
-
forward
(x)¶ - Parameters
x – per-region features of shape (N, …) for N bounding boxes to predict.
- Returns
(Tensor, Tensor) – First tensor: shape (N,K+1), scores for each of the N box. Each row contains the scores for K object categories and 1 background class.
Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4), or (N,4) for class-agnostic regression.
-
inference
(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])¶
-
losses
(predictions, proposals)¶
-
predict_boxes
(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])¶ - Parameters
- Returns
list[Tensor] – A list of Tensors of predicted class-specific or class-agnostic boxes for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is the number of proposals for image i and B is the box dimension (4 or 5)
-
predict_boxes_for_gt_classes
(predictions, proposals)¶ - Parameters
- Returns
list[Tensor] – A list of Tensors of predicted boxes for GT classes in case of class-specific box head. Element i of the list has shape (Ri, B), where Ri is the number of proposals for image i and B is the box dimension (4 or 5)
-
predict_probs
(predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[detectron2.structures.Instances])¶ - Parameters
- Returns
list[Tensor] – A list of Tensors of predicted class probabilities for each image. Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i.
-
detectron2.modeling.
build_box_head
(cfg, input_shape)¶ Build a box head defined by cfg.MODEL.ROI_BOX_HEAD.NAME.
-
detectron2.modeling.
build_keypoint_head
(cfg, input_shape)¶ Build a keypoint head from cfg.MODEL.ROI_KEYPOINT_HEAD.NAME.
-
detectron2.modeling.
build_mask_head
(cfg, input_shape)¶ Build a mask head defined by cfg.MODEL.ROI_MASK_HEAD.NAME.
-
detectron2.modeling.
build_roi_heads
(cfg, input_shape)¶ Build ROIHeads defined by cfg.MODEL.ROI_HEADS.NAME.
-
class
detectron2.modeling.
DatasetMapperTTA
(*args, **kwargs)¶ Bases:
object
Implement test-time augmentation for detection data. It is a callable which takes a dataset dict from a detection dataset, and returns a list of dataset dicts where the images are augmented from the input image by the transformations defined in the config. This is used for test-time augmentation.
-
__call__
(dataset_dict)¶ - Parameters
dict – a dict in standard model input format. See tutorials for details.
- Returns
list[dict] – a list of dicts, which contain augmented version of the input image. The total number of dicts is
len(min_sizes) * (2 if flip else 1)
. Each dict has field “transforms” which is a TransformList, containing the transforms that are used to generate this image.
-
-
class
detectron2.modeling.
GeneralizedRCNNWithTTA
(cfg, model, tta_mapper=None, batch_size=3)¶ Bases:
torch.nn.Module
A GeneralizedRCNN with test-time augmentation enabled. Its
__call__()
method has the same interface asGeneralizedRCNN.forward()
.-
__call__
(batched_inputs)¶ Same input/output format as
GeneralizedRCNN.forward()
-
__init__
(cfg, model, tta_mapper=None, batch_size=3)¶ - Parameters
cfg (CfgNode) –
model (GeneralizedRCNN) – a GeneralizedRCNN to apply TTA on.
tta_mapper (callable) – takes a dataset dict and returns a list of augmented versions of the dataset dict. Defaults to DatasetMapperTTA(cfg).
batch_size (int) – batch the augmented images into this batch size for inference.
-
-
class
detectron2.modeling.
MMDetBackbone
(backbone: Union[torch.nn.Module, collections.abc.Mapping], neck: Optional[Union[torch.nn.Module, collections.abc.Mapping]] = None, *, pretrained_backbone: Optional[str] = None, output_shapes: List[detectron2.layers.shape_spec.ShapeSpec], output_names: Optional[List[str]] = None)¶ Bases:
detectron2.modeling.Backbone
Wrapper of mmdetection backbones to use in detectron2.
mmdet backbones produce list/tuple of tensors, while detectron2 backbones produce a dict of tensors. This class wraps the given backbone to produce output in detectron2’s convention, so it can be used in place of detectron2 backbones.
-
__init__
(backbone: Union[torch.nn.Module, collections.abc.Mapping], neck: Optional[Union[torch.nn.Module, collections.abc.Mapping]] = None, *, pretrained_backbone: Optional[str] = None, output_shapes: List[detectron2.layers.shape_spec.ShapeSpec], output_names: Optional[List[str]] = None)¶ - Parameters
backbone – either a backbone module or a mmdet config dict that defines a backbone. The backbone takes a 4D image tensor and returns a sequence of tensors.
neck – either a backbone module or a mmdet config dict that defines a neck. The neck takes outputs of backbone and returns a sequence of tensors. If None, no neck is used.
pretrained_backbone – defines the backbone weights that can be loaded by mmdet, such as “torchvision://resnet50”.
output_shapes – shape for every output of the backbone (or neck, if given). stride and channels are often needed.
output_names – names for every output of the backbone (or neck, if given). By default, will use “out0”, “out1”, …
-
forward
(x) → Dict[str, torch.Tensor]¶
-
-
class
detectron2.modeling.
MMDetDetector
(detector: Union[torch.nn.Module, collections.abc.Mapping], *, size_divisibility=32, pixel_mean: Tuple[float], pixel_std: Tuple[float])¶ Bases:
torch.nn.Module
Wrapper of a mmdetection detector model, for detection and instance segmentation. Input/output formats of this class follow detectron2’s convention, so a mmdetection model can be trained and evaluated in detectron2.
-
__init__
(detector: Union[torch.nn.Module, collections.abc.Mapping], *, size_divisibility=32, pixel_mean: Tuple[float], pixel_std: Tuple[float])¶ - Parameters
detector – a mmdet detector, or a mmdet config dict that defines a detector.
size_divisibility – pad input images to multiple of this number
pixel_mean – per-channel mean to normalize input image
pixel_std – per-channel stddev to normalize input image
-
property
device
¶
-
forward
(batched_inputs: List[Dict[str, torch.Tensor]])¶
-
detectron2.modeling.poolers module¶
-
class
detectron2.modeling.poolers.
ROIPooler
(output_size, scales, sampling_ratio, pooler_type, canonical_box_size=224, canonical_level=4)[source]¶ Bases:
torch.nn.Module
Region of interest feature map pooler that supports pooling from one or more feature maps.
-
__init__
(output_size, scales, sampling_ratio, pooler_type, canonical_box_size=224, canonical_level=4)[source]¶ - Parameters
output_size (int, tuple[int] or list[int]) – output size of the pooled region, e.g., 14 x 14. If tuple or list is given, the length must be 2.
scales (list[float]) – The scale for each low-level pooling op relative to the input image. For a feature map with stride s relative to the input image, scale is defined as 1/s. The stride must be power of 2. When there are multiple scales, they must form a pyramid, i.e. they must be a monotically decreasing geometric sequence with a factor of 1/2.
sampling_ratio (int) – The sampling_ratio parameter for the ROIAlign op.
pooler_type (string) – Name of the type of pooling operation that should be applied. For instance, “ROIPool” or “ROIAlignV2”.
canonical_box_size (int) – A canonical box size in pixels (sqrt(box area)). The default is heuristically defined as 224 pixels in the FPN paper (based on ImageNet pre-training).
canonical_level (int) –
The feature map level index from which a canonically-sized box should be placed. The default is defined as level 4 (stride=16) in the FPN paper, i.e., a box of size 224x224 will be placed on the feature with stride=16. The box placement for all boxes will be determined from their sizes w.r.t canonical_box_size. For example, a box whose area is 4x that of a canonical box should be used to pool features from feature level
canonical_level+1
.Note that the actual input feature maps given to this module may not have sufficiently many levels for the input boxes. If the boxes are too large or too small for the input feature maps, the closest level will be used.
-
forward
(x: List[torch.Tensor], box_lists: List[detectron2.structures.Boxes])[source]¶ - Parameters
x (list[Tensor]) – A list of feature maps of NCHW shape, with scales matching those used to construct this module.
box_lists (list[Boxes] | list[RotatedBoxes]) – A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. The box coordinates are defined on the original image and will be scaled by the scales argument of
ROIPooler
.
- Returns
Tensor – A tensor of shape (M, C, output_size, output_size) where M is the total number of boxes aggregated over all N batch images and C is the number of channels in x.
-
detectron2.modeling.sampling module¶
-
detectron2.modeling.sampling.
subsample_labels
(labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int)[source]¶ Return num_samples (or fewer, if not enough found) random samples from labels which is a mixture of positives & negatives. It will try to return as many positives as possible without exceeding positive_fraction * num_samples, and then try to fill the remaining slots with negatives.
- Parameters
labels (Tensor) – (N, ) label vector with values: * -1: ignore * bg_label: background (“negative”) class * otherwise: one or more foreground (“positive”) classes
num_samples (int) – The total number of labels with value >= 0 to return. Values that are not sampled will be filled with -1 (ignore).
positive_fraction (float) – The number of subsampled labels with values > 0 is min(num_positives, int(positive_fraction * num_samples)). The number of negatives sampled is min(num_negatives, num_samples - num_positives_sampled). In order words, if there are not enough positives, the sample is filled with negatives. If there are also not enough negatives, then as many elements are sampled as is possible.
bg_label (int) – label index of background (“negative”) class.
- Returns
pos_idx, neg_idx (Tensor) – 1D vector of indices. The total length of both is num_samples or fewer.
detectron2.modeling.box_regression module¶
-
class
detectron2.modeling.box_regression.
Box2BoxTransform
(weights: Tuple[float, float, float, float], scale_clamp: float = 4.135166556742356)[source]¶ Bases:
object
The box-to-box transform defined in R-CNN. The transformation is parameterized by 4 deltas: (dx, dy, dw, dh). The transformation scales the box’s width and height by exp(dw), exp(dh) and shifts a box’s center by the offset (dx * width, dy * height).
-
__init__
(weights: Tuple[float, float, float, float], scale_clamp: float = 4.135166556742356)[source]¶ - Parameters
weights (4-element tuple) – Scaling factors that are applied to the (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set such that the deltas have unit variance; now they are treated as hyperparameters of the system.
scale_clamp (float) – When predicting deltas, the predicted box scaling factors (dw and dh) are clamped such that they are <= scale_clamp.
-
get_deltas
(src_boxes, target_boxes)[source]¶ Get box regression transformation deltas (dx, dy, dw, dh) that can be used to transform the src_boxes into the target_boxes. That is, the relation
target_boxes == self.apply_deltas(deltas, src_boxes)
is true (unless any delta is too large and is clamped).- Parameters
src_boxes (Tensor) – source boxes, e.g., object proposals
target_boxes (Tensor) – target of the transformation, e.g., ground-truth boxes.
-
apply_deltas
(deltas, boxes)[source]¶ Apply transformation deltas (dx, dy, dw, dh) to boxes.
- Parameters
deltas (Tensor) – transformation deltas of shape (N, k*4), where k >= 1. deltas[i] represents k potentially different class-specific box transformations for the single box boxes[i].
boxes (Tensor) – boxes to transform, of shape (N, 4)
-
-
class
detectron2.modeling.box_regression.
Box2BoxTransformRotated
(weights: Tuple[float, float, float, float, float], scale_clamp: float = 4.135166556742356)[source]¶ Bases:
object
The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box’s width and height by exp(dw), exp(dh), shifts a box’s center by the offset (dx * width, dy * height), and rotate a box’s angle by da (radians). Note: angles of deltas are in radians while angles of boxes are in degrees.
-
__init__
(weights: Tuple[float, float, float, float, float], scale_clamp: float = 4.135166556742356)[source]¶ - Parameters
weights (5-element tuple) – Scaling factors that are applied to the (dx, dy, dw, dh, da) deltas. These are treated as hyperparameters of the system.
scale_clamp (float) – When predicting deltas, the predicted box scaling factors (dw and dh) are clamped such that they are <= scale_clamp.
-
get_deltas
(src_boxes, target_boxes)[source]¶ Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used to transform the src_boxes into the target_boxes. That is, the relation
target_boxes == self.apply_deltas(deltas, src_boxes)
is true (unless any delta is too large and is clamped).- Parameters
src_boxes (Tensor) – Nx5 source boxes, e.g., object proposals
target_boxes (Tensor) – Nx5 target of the transformation, e.g., ground-truth boxes.
-
Model Registries¶
These are different registries provided in modeling. Each registry provide you the ability to replace it with your customized component, without having to modify detectron2’s code.
Note that it is impossible to allow users to customize any line of code directly. Even just to add one line at some place, you’ll likely need to find out the smallest registry which contains that line, and register your component to that registry.
-
detectron2.modeling.
META_ARCH_REGISTRY
= Registry of META_ARCH: ╒═══════════════════╤═════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═══════════════════╪═════════════════════════════════════════════════╡ │ GeneralizedRCNN │ <class 'detectron2.modeling.GeneralizedRCNN'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ ProposalNetwork │ <class 'detectron2.modeling.ProposalNetwork'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ SemanticSegmentor │ <class 'detectron2.modeling.SemanticSegmentor'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ PanopticFPN │ <class 'detectron2.modeling.PanopticFPN'> │ ├───────────────────┼─────────────────────────────────────────────────┤ │ RetinaNet │ <class 'detectron2.modeling.RetinaNet'> │ ╘═══════════════════╧═════════════════════════════════════════════════╛¶ Registry for meta-architectures, i.e. the whole model.
The registered object will be called with obj(cfg) and expected to return a nn.Module object.
-
detectron2.modeling.
BACKBONE_REGISTRY
= Registry of BACKBONE: ╒═════════════════════════════════════╤══════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════════════════════════╪══════════════════════════════════════════════════════════════════╡ │ build_resnet_backbone │ <function build_resnet_backbone> │ ├─────────────────────────────────────┼──────────────────────────────────────────────────────────────────┤ │ build_resnet_fpn_backbone │ <function build_resnet_fpn_backbone> │ ├─────────────────────────────────────┼──────────────────────────────────────────────────────────────────┤ │ build_retinanet_resnet_fpn_backbone │ <function build_retinanet_resnet_fpn_backbone> │ ╘═════════════════════════════════════╧══════════════════════════════════════════════════════════════════╛¶ Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
A
detectron2.layers.ShapeSpec
, which contains the input shape specification.
Registered object must return instance of
Backbone
.
-
detectron2.modeling.
PROPOSAL_GENERATOR_REGISTRY
= Registry of PROPOSAL_GENERATOR: ╒═════════╤════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════╪════════════════════════════════════════════════════════════╡ │ RPN │ <class 'detectron2.modeling.proposal_generator.rpn.RPN'> │ ├─────────┼────────────────────────────────────────────────────────────┤ │ RRPN │ <class 'detectron2.modeling.proposal_generator.rrpn.RRPN'> │ ╘═════════╧════════════════════════════════════════════════════════════╛¶ Registry for proposal generator, which produces object proposals from feature maps.
The registered object will be called with obj(cfg, input_shape). The call should return a nn.Module object.
-
detectron2.modeling.
RPN_HEAD_REGISTRY
= Registry of RPN_HEAD: ╒═════════════════╤══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════╪══════════════════════════════════════════════════════════════════════╡ │ StandardRPNHead │ <class 'detectron2.modeling.proposal_generator.rpn.StandardRPNHead'> │ ╘═════════════════╧══════════════════════════════════════════════════════════════════════╛¶ Registry for RPN heads, which take feature maps and perform objectness classification and bounding box regression for anchors.
The registered object will be called with obj(cfg, input_shape). The call should return a nn.Module object.
-
detectron2.modeling.
ANCHOR_GENERATOR_REGISTRY
= Registry of ANCHOR_GENERATOR: ╒════════════════════════╤═══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞════════════════════════╪═══════════════════════════════════════════════════════════════════════╡ │ DefaultAnchorGenerator │ <class 'detectron2.modeling.anchor_generator.DefaultAnchorGenerator'> │ ├────────────────────────┼───────────────────────────────────────────────────────────────────────┤ │ RotatedAnchorGenerator │ <class 'detectron2.modeling.anchor_generator.RotatedAnchorGenerator'> │ ╘════════════════════════╧═══════════════════════════════════════════════════════════════════════╛¶ Registry for modules that creates object detection anchors for feature maps.
The registered object will be called with obj(cfg, input_shape).
-
detectron2.modeling.
ROI_HEADS_REGISTRY
= Registry of ROI_HEADS: ╒══════════════════╤══════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞══════════════════╪══════════════════════════════════════════════════════════════════════╡ │ Res5ROIHeads │ <class 'detectron2.modeling.roi_heads.roi_heads.Res5ROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ StandardROIHeads │ <class 'detectron2.modeling.StandardROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ CascadeROIHeads │ <class 'detectron2.modeling.roi_heads.cascade_rcnn.CascadeROIHeads'> │ ├──────────────────┼──────────────────────────────────────────────────────────────────────┤ │ RROIHeads │ <class 'detectron2.modeling.roi_heads.rotated_fast_rcnn.RROIHeads'> │ ╘══════════════════╧══════════════════════════════════════════════════════════════════════╛¶ Registry for ROI heads in a generalized R-CNN model. ROIHeads take feature maps and region proposals, and perform per-region computation.
The registered object will be called with obj(cfg, input_shape). The call is expected to return an
ROIHeads
.
-
detectron2.modeling.
ROI_BOX_HEAD_REGISTRY
= Registry of ROI_BOX_HEAD: ╒════════════════════╤═════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞════════════════════╪═════════════════════════════════════════════════════════════════════╡ │ FastRCNNConvFCHead │ <class 'detectron2.modeling.roi_heads.box_head.FastRCNNConvFCHead'> │ ╘════════════════════╧═════════════════════════════════════════════════════════════════════╛¶ Registry for box heads, which make box predictions from per-region features.
The registered object will be called with obj(cfg, input_shape).
-
detectron2.modeling.
ROI_MASK_HEAD_REGISTRY
= Registry of ROI_MASK_HEAD: ╒══════════════════════════╤════════════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞══════════════════════════╪════════════════════════════════════════════════════════════════════════════╡ │ MaskRCNNConvUpsampleHead │ <class 'detectron2.modeling.roi_heads.mask_head.MaskRCNNConvUpsampleHead'> │ ╘══════════════════════════╧════════════════════════════════════════════════════════════════════════════╛¶ Registry for mask heads, which predicts instance masks given per-region features.
The registered object will be called with obj(cfg, input_shape).
-
detectron2.modeling.
ROI_KEYPOINT_HEAD_REGISTRY
= Registry of ROI_KEYPOINT_HEAD: ╒═════════════════════════════╤═══════════════════════════════════════════════════════════════════════════════════╕ │ Names │ Objects │ ╞═════════════════════════════╪═══════════════════════════════════════════════════════════════════════════════════╡ │ KRCNNConvDeconvUpsampleHead │ <class 'detectron2.modeling.roi_heads.keypoint_head.KRCNNConvDeconvUpsampleHead'> │ ╘═════════════════════════════╧═══════════════════════════════════════════════════════════════════════════════════╛¶ Registry for keypoint heads, which make keypoint predictions from per-region features.
The registered object will be called with obj(cfg, input_shape).
detectron2.solver¶
-
detectron2.solver.
build_lr_scheduler
(cfg: detectron2.config.CfgNode, optimizer: torch.optim.optimizer.Optimizer) → torch.optim.lr_scheduler._LRScheduler[source]¶ Build a LR scheduler from config.
-
detectron2.solver.
build_optimizer
(cfg: detectron2.config.CfgNode, model: torch.nn.Module) → torch.optim.optimizer.Optimizer[source]¶ Build an optimizer from config.
-
detectron2.solver.
get_default_optimizer_params
(model: torch.nn.Module, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None)[source]¶ Get default param list for optimizer, with support for a few types of overrides. If no overrides needed, this is equivalent to model.parameters().
- Parameters
base_lr – lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay – weight decay for every group by default. Can be omitted to use the one in optimizer.
weight_decay_norm – override weight decay for params in normalization layers
bias_lr_factor – multiplier of lr for bias parameters.
weight_decay_bias – override weight decay for bias parameters
overrides – if not None, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g.
{"embedding": {"lr": 0.01, "weight_decay": 0.1}}
will set the LR and weight decay values for all module parameters named embedding.
For common detection models,
weight_decay_norm
is the only option needed to be set.bias_lr_factor,weight_decay_bias
are legacy settings from Detectron1 that are not found useful.Example:
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), lr=0.01, weight_decay=1e-4, momentum=0.9)
-
class
detectron2.solver.
LRMultiplier
(optimizer: torch.optim.optimizer.Optimizer, multiplier: fvcore.common.param_scheduler.ParamScheduler, max_iter: int, last_iter: int = - 1)[source]¶ Bases:
torch.optim.lr_scheduler._LRScheduler
A LRScheduler which uses fvcore
ParamScheduler
to multiply the learning rate of each param in the optimizer. Every step, the learning rate of each parameter becomes its initial value multiplied by the output of the givenParamScheduler
.The absolute learning rate value of each parameter can be different. This scheduler can be used as long as the relative scale among them do not change during training.
Examples:
LRMultiplier( opt, WarmupParamScheduler( MultiStepParamScheduler( [1, 0.1, 0.01], milestones=[60000, 80000], num_updates=90000, ), 0.001, 100 / 90000 ), max_iter=90000 )
-
__init__
(optimizer: torch.optim.optimizer.Optimizer, multiplier: fvcore.common.param_scheduler.ParamScheduler, max_iter: int, last_iter: int = - 1)[source]¶ - Parameters
optimizer – See
torch.optim.lr_scheduler._LRScheduler
.last_iter
is the same aslast_epoch
.last_iter – See
torch.optim.lr_scheduler._LRScheduler
.last_iter
is the same aslast_epoch
.multiplier – a fvcore ParamScheduler that defines the multiplier on every LR of the optimizer
max_iter – the total number of training iterations
-
-
class
detectron2.solver.
WarmupParamScheduler
(scheduler: fvcore.common.param_scheduler.ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear')[source]¶ Bases:
fvcore.common.param_scheduler.CompositeParamScheduler
Add an initial warmup stage to another scheduler.
-
__init__
(scheduler: fvcore.common.param_scheduler.ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear')[source]¶ - Parameters
scheduler – warmup will be added at the beginning of this scheduler
warmup_factor – the factor w.r.t the initial value of
scheduler
, e.g. 0.001warmup_length – the relative length (in [0, 1]) of warmup steps w.r.t the entire training, e.g. 0.01
warmup_method – one of “linear” or “constant”
-
detectron2.structures¶
-
class
detectron2.structures.
Boxes
(tensor: torch.Tensor)¶ Bases:
object
This structure stores a list of boxes as a Nx4 torch.Tensor. It supports some common methods about boxes (area, clip, nonempty, etc), and also behaves like a Tensor (support indexing, to(device), .device, and iteration over all boxes)
-
tensor
¶ float matrix of Nx4. Each row is (x1, y1, x2, y2).
- Type
-
__getitem__
(item) → detectron2.structures.Boxes¶ - Parameters
item – int, slice, or a BoolTensor
- Returns
Boxes – Create a new
Boxes
by indexing.
The following usage are allowed:
new_boxes = boxes[3]: return a Boxes which contains only one box.
new_boxes = boxes[2:10]: return a slice of boxes.
new_boxes = boxes[vector], where vector is a torch.BoolTensor with length = len(boxes). Nonzero elements in the vector will be selected.
Note that the returned Boxes might share storage with this Boxes, subject to Pytorch’s indexing semantics.
-
__init__
(tensor: torch.Tensor)¶ - Parameters
tensor (Tensor[float]) – a Nx4 matrix. Each row is (x1, y1, x2, y2).
-
__iter__
()¶ Yield a box as a Tensor of shape (4,) at a time.
-
area
() → torch.Tensor¶ Computes the area of all the boxes.
- Returns
torch.Tensor – a vector with areas of each box.
-
classmethod
cat
(boxes_list: List[Boxes]) → detectron2.structures.Boxes[source]¶ Concatenates a list of Boxes into a single Boxes
-
clip
(box_size: Tuple[int, int]) → None¶ Clip (in place) the boxes by limiting x coordinates to the range [0, width] and y coordinates to the range [0, height].
- Parameters
box_size (height, width) – The clipping box’s size.
-
clone
() → detectron2.structures.Boxes¶ Clone the Boxes.
- Returns
Boxes
-
property
device
¶
-
get_centers
() → torch.Tensor¶ - Returns
The box centers in a Nx2 array of (x, y).
-
inside_box
(box_size: Tuple[int, int], boundary_threshold: int = 0) → torch.Tensor¶ - Parameters
box_size (height, width) – Size of the reference box.
boundary_threshold (int) – Boxes that extend beyond the reference box boundary by more than boundary_threshold are considered “outside”.
- Returns
a binary vector, indicating whether each box is inside the reference box.
-
nonempty
(threshold: float = 0.0) → torch.Tensor¶ Find boxes that are non-empty. A box is considered empty, if either of its side is no larger than threshold.
- Returns
Tensor – a binary vector which represents whether each box is empty (False) or non-empty (True).
-
scale
(scale_x: float, scale_y: float) → None¶ Scale the box with horizontal and vertical scaling factors
-
to
(device: torch.device)¶
-
-
class
detectron2.structures.
BoxMode
(value)¶ Bases:
enum.IntEnum
Enum of different ways to represent a box.
-
XYXY_ABS
= 0¶
-
XYWH_ABS
= 1¶
-
XYXY_REL
= 2¶
-
XYWH_REL
= 3¶
-
XYWHA_ABS
= 4¶
-
static
convert
(box: Union[List[float], Tuple[float, …], torch.Tensor, numpy.ndarray], from_mode: detectron2.structures.BoxMode, to_mode: detectron2.structures.BoxMode) → Union[List[float], Tuple[float, …], torch.Tensor, numpy.ndarray][source]¶
-
-
detectron2.structures.
pairwise_iou
(boxes1: detectron2.structures.Boxes, boxes2: detectron2.structures.Boxes) → torch.Tensor¶ Given two lists of boxes of size N and M, compute the IoU (intersection over union) between all N x M pairs of boxes. The box order must be (xmin, ymin, xmax, ymax).
-
detectron2.structures.
pairwise_ioa
(boxes1: detectron2.structures.Boxes, boxes2: detectron2.structures.Boxes) → torch.Tensor¶ Similar to
pariwise_iou()
but compute the IoA (intersection over boxes2 area).
-
class
detectron2.structures.
ImageList
(tensor: torch.Tensor, image_sizes: List[Tuple[int, int]])¶ Bases:
object
Structure that holds a list of images (of possibly varying sizes) as a single tensor. This works by padding the images to the same size, and storing in a field the original sizes of each image
-
image_sizes
¶ each tuple is (h, w). During tracing, it becomes list[Tensor] instead.
-
__getitem__
(idx) → torch.Tensor¶ Access the individual image in its original size.
- Parameters
idx – int or slice
- Returns
Tensor – an image of shape (H, W) or (C_1, …, C_K, H, W) where K >= 1
-
__init__
(tensor: torch.Tensor, image_sizes: List[Tuple[int, int]])¶
-
property
device
¶
-
static
from_tensors
(tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0) → detectron2.structures.ImageList[source]¶ - Parameters
tensors – a tuple or list of torch.Tensor, each of shape (Hi, Wi) or (C_1, …, C_K, Hi, Wi) where K >= 1. The Tensors will be padded to the same shape with pad_value.
size_divisibility (int) – If size_divisibility > 0, add padding to ensure the common height and width is divisible by size_divisibility. This depends on the model and many models need a divisibility of 32.
pad_value (float) – value to pad
- Returns
an ImageList.
-
to
(*args: Any, **kwargs: Any) → detectron2.structures.ImageList¶
-
-
class
detectron2.structures.
Instances
(image_size: Tuple[int, int], **kwargs: Any)¶ Bases:
object
This class represents a list of instances in an image. It stores the attributes of instances (e.g., boxes, masks, labels, scores) as “fields”. All fields must have the same
__len__
which is the number of instances.All other (non-field) attributes of this class are considered private: they must start with ‘_’ and are not modifiable by a user.
Some basic usage:
Set/get/check a field:
instances.gt_boxes = Boxes(...) print(instances.pred_masks) # a tensor of shape (N, H, W) print('gt_masks' in instances)
len(instances)
returns the number of instancesIndexing:
instances[indices]
will apply the indexing on all the fields and returns a newInstances
. Typically,indices
is a integer vector of indices, or a binary mask of lengthnum_instances
category_3_detections = instances[instances.pred_classes == 3] confident_detections = instances[instances.scores > 0.9]
-
__getitem__
(item: Union[int, slice, torch.BoolTensor]) → detectron2.structures.Instances¶ - Parameters
item – an index-like object and will be used to index all the fields.
- Returns
If item is a string, return the data in the corresponding field. Otherwise, returns an Instances where all fields are indexed by item.
-
__init__
(image_size: Tuple[int, int], **kwargs: Any)¶ - Parameters
image_size (height, width) – the spatial size of the image.
kwargs – fields to add to this Instances.
-
static
cat
(instance_lists: List[Instances]) → detectron2.structures.Instances[source]¶
-
get_fields
() → Dict[str, Any]¶ - Returns
dict – a dict which maps names (str) to data of the fields
Modifying the returned dict will modify this instance.
-
property
image_size
¶ Returns: tuple: height, width
-
set
(name: str, value: Any) → None¶ Set the field named name to value. The length of value must be the number of instances, and must agree with other existing fields in this object.
-
to
(*args: Any, **kwargs: Any) → detectron2.structures.Instances¶ - Returns
Instances – all fields are called with a to(device), if the field has this method.
-
class
detectron2.structures.
Keypoints
(keypoints: Union[torch.Tensor, numpy.ndarray, List[List[float]]])¶ Bases:
object
Stores keypoint annotation data. GT Instances have a gt_keypoints property containing the x,y location and visibility flag of each keypoint. This tensor has shape (N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
The visibility flag follows the COCO format and must be one of three integers:
v=0: not labeled (in which case x=y=0)
v=1: labeled but not visible
v=2: labeled and visible
-
__getitem__
(item: Union[int, slice, torch.BoolTensor]) → detectron2.structures.Keypoints¶ Create a new Keypoints by indexing on this Keypoints.
The following usage are allowed:
new_kpts = kpts[3]: return a Keypoints which contains only one instance.
new_kpts = kpts[2:10]: return a slice of key points.
new_kpts = kpts[vector], where vector is a torch.ByteTensor with length = len(kpts). Nonzero elements in the vector will be selected.
Note that the returned Keypoints might share storage with this Keypoints, subject to Pytorch’s indexing semantics.
-
__init__
(keypoints: Union[torch.Tensor, numpy.ndarray, List[List[float]]])¶ - Parameters
keypoints – A Tensor, numpy array, or list of the x, y, and visibility of each keypoint. The shape should be (N, K, 3) where N is the number of instances, and K is the number of keypoints per instance.
-
property
device
¶
-
to
(*args: Any, **kwargs: Any) → detectron2.structures.Keypoints¶
-
to_heatmap
(boxes: torch.Tensor, heatmap_size: int) → torch.Tensor¶ Convert keypoint annotations to a heatmap of one-hot labels for training, as described in Mask R-CNN.
- Parameters
boxes – Nx4 tensor, the boxes to draw the keypoints to
- Returns
- heatmaps – A tensor of shape (N, K), each element is integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
- valid:
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
-
detectron2.structures.
heatmaps_to_keypoints
(maps: torch.Tensor, rois: torch.Tensor) → torch.Tensor¶ Extract predicted keypoint locations from heatmaps.
- Parameters
maps (Tensor) – (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for each ROI and each keypoint.
rois (Tensor) – (#ROIs, 4). The box of each ROI.
- Returns
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to (x, y, logit, score) for each keypoint.
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate, we maintain consistency with
Keypoints.to_heatmap()
by using the conversion from Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
-
class
detectron2.structures.
BitMasks
(tensor: Union[torch.Tensor, numpy.ndarray])¶ Bases:
object
This class stores the segmentation masks for all objects in one image, in the form of bitmaps.
-
tensor
¶ bool Tensor of N,H,W, representing N instances in the image.
-
__getitem__
(item: Union[int, slice, torch.BoolTensor]) → detectron2.structures.BitMasks¶ - Returns
BitMasks – Create a new
BitMasks
by indexing.
The following usage are allowed:
new_masks = masks[3]: return a BitMasks which contains only one mask.
new_masks = masks[2:10]: return a slice of masks.
new_masks = masks[vector], where vector is a torch.BoolTensor with length = len(masks). Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object, subject to Pytorch’s indexing semantics.
-
__init__
(tensor: Union[torch.Tensor, numpy.ndarray])¶ - Parameters
tensor – bool Tensor of N,H,W, representing N instances in the image.
-
static
cat
(bitmasks_list: List[BitMasks]) → detectron2.structures.BitMasks[source]¶ Concatenates a list of BitMasks into a single BitMasks
-
crop_and_resize
(boxes: torch.Tensor, mask_size: int) → torch.Tensor¶ Crop each bitmask by the given box, and resize results to (mask_size, mask_size). This can be used to prepare training targets for Mask R-CNN. It has less reconstruction error compared to rasterization with polygons. However we observe no difference in accuracy, but BitMasks requires more memory to store all the masks.
- Parameters
boxes (Tensor) – Nx4 tensor storing the boxes for each mask
mask_size (int) – the size of the rasterized mask.
- Returns
Tensor – A bool tensor of shape (N, mask_size, mask_size), where N is the number of predicted boxes for this image.
-
property
device
¶
-
static
from_polygon_masks
(polygon_masks: Union[PolygonMasks, List[List[numpy.ndarray]]], height: int, width: int) → detectron2.structures.BitMasks[source]¶ - Parameters
polygon_masks (list[list[ndarray]] or PolygonMasks) –
height (int) –
width (int) –
-
static
from_roi_masks
(roi_masks: detectron2.structures.ROIMasks, height: int, width: int) → detectron2.structures.BitMasks[source]¶
-
get_bounding_boxes
() → detectron2.structures.Boxes¶ - Returns
Boxes – tight bounding boxes around bitmasks. If a mask is empty, it’s bounding box will be all zero.
-
nonempty
() → torch.Tensor¶ Find masks that are non-empty.
- Returns
Tensor –
- a BoolTensor which represents
whether each mask is empty (False) or non-empty (True).
-
to
(*args: Any, **kwargs: Any) → detectron2.structures.BitMasks¶
-
-
class
detectron2.structures.
PolygonMasks
(polygons: List[List[Union[torch.Tensor, numpy.ndarray]]])¶ Bases:
object
This class stores the segmentation masks for all objects in one image, in the form of polygons.
-
polygons
¶ list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
-
__getitem__
(item: Union[int, slice, List[int], torch.BoolTensor]) → detectron2.structures.PolygonMasks¶ Support indexing over the instances and return a PolygonMasks object. item can be:
An integer. It will return an object with only one instance.
A slice. It will return an object with the selected instances.
A list[int]. It will return an object with the selected instances, correpsonding to the indices in the list.
A vector mask of type BoolTensor, whose length is num_instances. It will return an object with the instances whose mask is nonzero.
-
__init__
(polygons: List[List[Union[torch.Tensor, numpy.ndarray]]])¶ - Parameters
polygons (list[list[np.ndarray]]) – The first level of the list correspond to individual instances, the second level to all the polygons that compose the instance, and the third level to the polygon coordinates. The third level array should have the format of [x0, y0, x1, y1, …, xn, yn] (n >= 3).
-
__iter__
() → Iterator[List[numpy.ndarray]]¶ - Yields
list[ndarray] – the polygons for one instance. Each Tensor is a float64 vector representing a polygon.
-
area
()¶ Computes area of the mask. Only works with Polygons, using the shoelace formula: https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
- Returns
Tensor – a vector, area for each instance
-
static
cat
(polymasks_list: List[PolygonMasks]) → detectron2.structures.PolygonMasks[source]¶ Concatenates a list of PolygonMasks into a single PolygonMasks
- Parameters
polymasks_list (list[PolygonMasks]) –
- Returns
PolygonMasks – the concatenated PolygonMasks
-
crop_and_resize
(boxes: torch.Tensor, mask_size: int) → torch.Tensor¶ Crop each mask by the given box, and resize results to (mask_size, mask_size). This can be used to prepare training targets for Mask R-CNN.
- Parameters
boxes (Tensor) – Nx4 tensor storing the boxes for each mask
mask_size (int) – the size of the rasterized mask.
- Returns
Tensor – A bool tensor of shape (N, mask_size, mask_size), where N is the number of predicted boxes for this image.
-
property
device
¶
-
get_bounding_boxes
() → detectron2.structures.Boxes¶ - Returns
Boxes – tight bounding boxes around polygon masks.
-
nonempty
() → torch.Tensor¶ Find masks that are non-empty.
- Returns
Tensor – a BoolTensor which represents whether each mask is empty (False) or not (True).
-
to
(*args: Any, **kwargs: Any) → detectron2.structures.PolygonMasks¶
-
-
detectron2.structures.
polygons_to_bitmask
(polygons: List[numpy.ndarray], height: int, width: int) → numpy.ndarray¶
-
class
detectron2.structures.
ROIMasks
(tensor: torch.Tensor)¶ Bases:
object
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given, full-image bitmask can be obtained by “pasting” the mask on the region defined by the corresponding ROI box.
-
__getitem__
(item) → detectron2.structures.ROIMasks¶ - Returns
ROIMasks – Create a new
ROIMasks
by indexing.
The following usage are allowed:
new_masks = masks[2:10]: return a slice of masks.
new_masks = masks[vector], where vector is a torch.BoolTensor with length = len(masks). Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object, subject to Pytorch’s indexing semantics.
-
__init__
(tensor: torch.Tensor)¶ - Parameters
tensor – (N, M, M) mask tensor that defines the mask within each ROI.
-
property
device
¶
-
to
(device: torch.device) → detectron2.structures.ROIMasks¶
-
to_bitmasks
(boxes: torch.Tensor, height, width, threshold=0.5)¶ Args:
-
-
class
detectron2.structures.
RotatedBoxes
(tensor: torch.Tensor)¶ Bases:
detectron2.structures.Boxes
This structure stores a list of rotated boxes as a Nx5 torch.Tensor. It supports some common methods about boxes (area, clip, nonempty, etc), and also behaves like a Tensor (support indexing, to(device), .device, and iteration over all boxes)
-
__getitem__
(item) → detectron2.structures.RotatedBoxes¶ - Returns
RotatedBoxes – Create a new
RotatedBoxes
by indexing.
The following usage are allowed:
new_boxes = boxes[3]: return a RotatedBoxes which contains only one box.
new_boxes = boxes[2:10]: return a slice of boxes.
new_boxes = boxes[vector], where vector is a torch.ByteTensor with length = len(boxes). Nonzero elements in the vector will be selected.
Note that the returned RotatedBoxes might share storage with this RotatedBoxes, subject to Pytorch’s indexing semantics.
-
__init__
(tensor: torch.Tensor)¶ - Parameters
tensor (Tensor[float]) – a Nx5 matrix. Each row is (x_center, y_center, width, height, angle), in which angle is represented in degrees. While there’s no strict range restriction for it, the recommended principal range is between [-180, 180) degrees.
Assume we have a horizontal box B = (x_center, y_center, width, height), where width is along the x-axis and height is along the y-axis. The rotated box B_rot (x_center, y_center, width, height, angle) can be seen as:
When angle == 0: B_rot == B
When angle > 0: B_rot is obtained by rotating B w.r.t its center by \(|angle|\) degrees CCW;
When angle < 0: B_rot is obtained by rotating B w.r.t its center by \(|angle|\) degrees CW.
Mathematically, since the right-handed coordinate system for image space is (y, x), where y is top->down and x is left->right, the 4 vertices of the rotated rectangle \((yr_i, xr_i)\) (i = 1, 2, 3, 4) can be obtained from the vertices of the horizontal rectangle \((y_i, x_i)\) (i = 1, 2, 3, 4) in the following way (\(\theta = angle*\pi/180\) is the angle in radians, \((y_c, x_c)\) is the center of the rectangle):
\[ \begin{align}\begin{aligned}yr_i = \cos(\theta) (y_i - y_c) - \sin(\theta) (x_i - x_c) + y_c,\\xr_i = \sin(\theta) (y_i - y_c) + \cos(\theta) (x_i - x_c) + x_c,\end{aligned}\end{align} \]which is the standard rigid-body rotation transformation.
Intuitively, the angle is (1) the rotation angle from y-axis in image space to the height vector (top->down in the box’s local coordinate system) of the box in CCW, and (2) the rotation angle from x-axis in image space to the width vector (left->right in the box’s local coordinate system) of the box in CCW.
More intuitively, consider the following horizontal box ABCD represented in (x1, y1, x2, y2): (3, 2, 7, 4), covering the [3, 7] x [2, 4] region of the continuous coordinate system which looks like this:
O--------> x | | A---B | | | | D---C | v y
Note that each capital letter represents one 0-dimensional geometric point instead of a ‘square pixel’ here.
In the example above, using (x, y) to represent a point we have:
\[O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)\]We name vector AB = vector DC as the width vector in box’s local coordinate system, and vector AD = vector BC as the height vector in box’s local coordinate system. Initially, when angle = 0 degree, they’re aligned with the positive directions of x-axis and y-axis in the image space, respectively.
For better illustration, we denote the center of the box as E,
O--------> x | | A---B | | E | | D---C | v y
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
Also,
\[width = |AB| = |CD| = 7 - 3 = 4, height = |AD| = |BC| = 4 - 2 = 2.\]Therefore, the corresponding representation for the same shape in rotated box in (x_center, y_center, width, height, angle) format is:
(5, 3, 4, 2, 0),
Now, let’s consider (5, 3, 4, 2, 90), which is rotated by 90 degrees CCW (counter-clockwise) by definition. It looks like this:
O--------> x | B-C | | | | |E| | | | | A-D v y
The center E is still located at the same point (5, 3), while the vertices ABCD are rotated by 90 degrees CCW with regard to E: A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to vector AD or vector BC (the top->down height vector in box’s local coordinate system), or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right width vector in box’s local coordinate system).
\[width = |AB| = |CD| = 5 - 1 = 4, height = |AD| = |BC| = 6 - 4 = 2.\]Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise) by definition? It looks like this:
O--------> x | D-A | | | | |E| | | | | C-B v y
The center E is still located at the same point (5, 3), while the vertices ABCD are rotated by 90 degrees CW with regard to E: A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
\[width = |AB| = |CD| = 5 - 1 = 4, height = |AD| = |BC| = 6 - 4 = 2.\]This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU will be 1. However, these two will generate different RoI Pooling results and should not be treated as an identical box.
On the other hand, it’s easy to see that (X, Y, W, H, A) is identical to (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is equivalent to rotating the same shape 90 degrees CW.
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
O--------> x | | C---D | | E | | B---A | v y
\[ \begin{align}\begin{aligned}A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),\\width = |AB| = |CD| = 7 - 3 = 4, height = |AD| = |BC| = 4 - 2 = 2.\end{aligned}\end{align} \]Finally, this is a very inaccurate (heavily quantized) illustration of how (5, 3, 4, 2, 60) looks like in case anyone wonders:
O--------> x | B | / C | /E / | A / | `D v y
It’s still a rectangle with center of (5, 3), width of 4 and height of 2, but its angle (and thus orientation) is somewhere between (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
-
__iter__
()¶ Yield a box as a Tensor of shape (5,) at a time.
-
area
() → torch.Tensor¶ Computes the area of all the boxes.
- Returns
torch.Tensor – a vector with areas of each box.
-
classmethod
cat
(boxes_list: List[RotatedBoxes]) → detectron2.structures.RotatedBoxes[source]¶ Concatenates a list of RotatedBoxes into a single RotatedBoxes
- Parameters
boxes_list (list[RotatedBoxes]) –
- Returns
RotatedBoxes – the concatenated RotatedBoxes
-
clip
(box_size: Tuple[int, int], clip_angle_threshold: float = 1.0) → None¶ Clip (in place) the boxes by limiting x coordinates to the range [0, width] and y coordinates to the range [0, height].
For RRPN: Only clip boxes that are almost horizontal with a tolerance of clip_angle_threshold to maintain backward compatibility.
Rotated boxes beyond this threshold are not clipped for two reasons:
There are potentially multiple ways to clip a rotated box to make it fit within the image.
It’s tricky to make the entire rectangular box fit within the image and still be able to not leave out pixels of interest.
Therefore we rely on ops like RoIAlignRotated to safely handle this.
- Parameters
box_size (height, width) – The clipping box’s size.
clip_angle_threshold – Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees), we do the clipping as horizontal boxes.
-
clone
() → detectron2.structures.RotatedBoxes¶ Clone the RotatedBoxes.
- Returns
RotatedBoxes
-
property
device
¶
-
get_centers
() → torch.Tensor¶ - Returns
The box centers in a Nx2 array of (x, y).
-
inside_box
(box_size: Tuple[int, int], boundary_threshold: int = 0) → torch.Tensor¶ - Parameters
box_size (height, width) – Size of the reference box covering [0, width] x [0, height]
boundary_threshold (int) – Boxes that extend beyond the reference box boundary by more than boundary_threshold are considered “outside”.
For RRPN, it might not be necessary to call this function since it’s common for rotated box to extend to outside of the image boundaries (the clip function only clips the near-horizontal boxes)
- Returns
a binary vector, indicating whether each box is inside the reference box.
-
nonempty
(threshold: float = 0.0) → torch.Tensor¶ Find boxes that are non-empty. A box is considered empty, if either of its side is no larger than threshold.
- Returns
Tensor – a binary vector which represents whether each box is empty (False) or non-empty (True).
-
scale
(scale_x: float, scale_y: float) → None¶ Scale the rotated box with horizontal and vertical scaling factors Note: when scale_factor_x != scale_factor_y, the rotated box does not preserve the rectangular shape when the angle is not a multiple of 90 degrees under resize transformation. Instead, the shape is a parallelogram (that has skew) Here we make an approximation by fitting a rotated rectangle to the parallelogram.
-
to
(device: torch.device)¶
-
-
detectron2.structures.
pairwise_iou_rotated
(boxes1: detectron2.structures.RotatedBoxes, boxes2: detectron2.structures.RotatedBoxes) → None¶ Given two lists of rotated boxes of size N and M, compute the IoU (intersection over union) between all N x M pairs of boxes. The box order must be (x_center, y_center, width, height, angle).
- Parameters
boxes1 (RotatedBoxes) – two RotatedBoxes. Contains N & M rotated boxes, respectively.
boxes2 (RotatedBoxes) – two RotatedBoxes. Contains N & M rotated boxes, respectively.
- Returns
Tensor – IoU, sized [N,M].
detectron2.utils¶
detectron2.utils.colormap module¶
An awesome colormap for really neat visualizations. Copied from Detectron, and removed gray colors.
detectron2.utils.comm module¶
This file contains primitives for multi-gpu communication. This is useful when doing distributed training.
-
detectron2.utils.comm.
get_local_rank
() → int[source]¶ - Returns
The rank of the current process within the local (per-machine) process group.
-
detectron2.utils.comm.
get_local_size
() → int[source]¶ - Returns
The size of the per-machine process group, i.e. the number of processes per machine.
-
detectron2.utils.comm.
synchronize
()[source]¶ Helper function to synchronize (barrier) among all processes when using distributed training
-
detectron2.utils.comm.
all_gather
(data, group=None)[source]¶ Run all_gather on arbitrary picklable data (not necessarily tensors).
- Parameters
data – any picklable object
group – a torch process group. By default, will use a group which contains all ranks on gloo backend.
- Returns
list[data] – list of data gathered from each rank
-
detectron2.utils.comm.
gather
(data, dst=0, group=None)[source]¶ Run gather on arbitrary picklable data (not necessarily tensors).
- Parameters
data – any picklable object
dst (int) – destination rank
group – a torch process group. By default, will use a group which contains all ranks on gloo backend.
- Returns
list[data] –
- on dst, a list of data gathered from each rank. Otherwise,
an empty list.
- Returns
int – a random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to create one.
All workers must call this function, otherwise it will deadlock.
detectron2.utils.events module¶
-
detectron2.utils.events.
get_event_storage
()[source]¶ - Returns
The
EventStorage
object that’s currently being used. Throws an error if noEventStorage
is currently enabled.
-
class
detectron2.utils.events.
JSONWriter
(json_file, window_size=20)[source]¶ Bases:
detectron2.utils.events.EventWriter
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
$ cat metrics.json | jq -s '.[0:2]' [ { "data_time": 0.008433341979980469, "iteration": 19, "loss": 1.9228371381759644, "loss_box_reg": 0.050025828182697296, "loss_classifier": 0.5316952466964722, "loss_mask": 0.7236229181289673, "loss_rpn_box": 0.0856662318110466, "loss_rpn_cls": 0.48198649287223816, "lr": 0.007173333333333333, "time": 0.25401854515075684 }, { "data_time": 0.007216215133666992, "iteration": 39, "loss": 1.282649278640747, "loss_box_reg": 0.06222952902317047, "loss_classifier": 0.30682939291000366, "loss_mask": 0.6970193982124329, "loss_rpn_box": 0.038663312792778015, "loss_rpn_cls": 0.1471673548221588, "lr": 0.007706666666666667, "time": 0.2490077018737793 } ] $ cat metrics.json | jq '.loss_mask' 0.7126231789588928 0.689423680305481 0.6776131987571716 ...
-
class
detectron2.utils.events.
TensorboardXWriter
(log_dir: str, window_size: int = 20, **kwargs)[source]¶ Bases:
detectron2.utils.events.EventWriter
Write all scalars to a tensorboard file.
-
class
detectron2.utils.events.
CommonMetricPrinter
(max_iter: Optional[int] = None, window_size: int = 20)[source]¶ Bases:
detectron2.utils.events.EventWriter
Print common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate. It also applies smoothing using a window of 20 elements.
It’s meant to print common metrics in common ways. To print something in more customized ways, please implement a similar printer by yourself.
-
class
detectron2.utils.events.
EventStorage
(start_iter=0)[source]¶ Bases:
object
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
-
put_image
(img_name, img_tensor)[source]¶ Add an img_tensor associated with img_name, to be shown on tensorboard.
- Parameters
img_name (str) – The name of the image to put into tensorboard.
img_tensor (torch.Tensor or numpy.array) – An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard.
-
put_scalar
(name, value, smoothing_hint=True)[source]¶ Add a scalar value to the HistoryBuffer associated with name.
- Parameters
smoothing_hint (bool) –
a ‘hint’ on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through
EventStorage.smoothing_hints()
. A writer may ignore the hint and apply custom smoothing rule.It defaults to True because most scalars we save need to be smoothed to provide any useful signal.
-
put_scalars
(*, smoothing_hint=True, **kwargs)[source]¶ Put multiple scalars from keyword arguments.
Examples
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
-
put_histogram
(hist_name, hist_tensor, bins=1000)[source]¶ Create a histogram from a tensor.
- Parameters
hist_name (str) – The name of the histogram to put into tensorboard.
hist_tensor (torch.Tensor) – A Tensor of arbitrary shape to be converted into a histogram.
bins (int) – Number of histogram bins.
-
latest
()[source]¶ - Returns
dict[str -> (float, int)] –
- mapping from the name of each scalar to the most
recent value and the iteration number its added.
-
latest_with_smoothing_hint
(window_size=20)[source]¶ Similar to
latest()
, but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.This provides a default behavior that other writers can use.
-
smoothing_hints
()[source]¶ - Returns
dict[name -> bool] –
- the user-provided hint on whether the scalar
is noisy and needs smoothing.
-
step
()[source]¶ User should either: (1) Call this function to increment storage.iter when needed. Or (2) Set storage.iter to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
-
property
iter
¶ Returns: int: The current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
-
property
iteration
¶
-
name_scope
(name)[source]¶ - Yields
A context within which all the events added to this storage will be prefixed by the name scope.
-
detectron2.utils.logger module¶
-
detectron2.utils.logger.
setup_logger
(output=None, distributed_rank=0, *, color=True, name='detectron2', abbrev_name=None)[source]¶ Initialize the detectron2 logger and set its verbosity level to “DEBUG”.
- Parameters
output (str) – a file name or a directory to save log. If None, will not save log file. If ends with “.txt” or “.log”, assumed to be a file name. Otherwise, logs will be saved to output/log.txt.
name (str) – the root module name of this logger
abbrev_name (str) – an abbreviation of the module, to avoid long names in logs. Set to “” to not log the root module in logs. By default, will abbreviate “detectron2” to “d2” and leave other modules unchanged.
- Returns
logging.Logger – a logger
-
detectron2.utils.logger.
log_first_n
(lvl, msg, n=1, *, name=None, key='caller')[source]¶ Log only for the first n times.
- Parameters
lvl (int) – the logging level
msg (str) –
n (int) –
name (str) – name of the logger to use. Will use the caller’s module by default.
key (str or tuple[str]) – the string(s) can be one of “caller” or “message”, which defines how to identify duplicated logs. For example, if called with n=1, key=”caller”, this function will only log the first call from the same caller, regardless of the message content. If called with n=1, key=”message”, this function will log the same content only once, even if they are called from different places. If called with n=1, key=(“caller”, “message”), this function will not log only if the same caller has logged the same message before.
detectron2.utils.registry module¶
-
class
detectron2.utils.registry.
Registry
(*args, **kwds)[source]¶ Bases:
collections.abc.Iterable
,typing.Generic
The registry that provides name -> object mapping, to support third-party users’ custom modules.
To create a registry (e.g. a backbone registry):
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
@BACKBONE_REGISTRY.register() class MyBackbone(): ...
Or:
BACKBONE_REGISTRY.register(MyBackbone)
detectron2.utils.memory module¶
-
detectron2.utils.memory.
retry_if_cuda_oom
(func)[source]¶ Makes a function retry itself after encountering pytorch’s CUDA OOM error. It will first retry after calling torch.cuda.empty_cache().
If that still fails, it will then retry by trying to convert inputs to CPUs. In this case, it expects the function to dispatch to CPU implementation. The return values may become CPU tensors as well and it’s user’s responsibility to convert it back to CUDA tensor if needed.
- Parameters
func – a stateless callable that takes tensor-like objects as arguments
- Returns
a callable which retries func if OOM is encountered.
Examples:
output = retry_if_cuda_oom(some_torch_function)(input1, input2) # output may be on CPU even if inputs are on GPU
Note
When converting inputs to CPU, it will only look at each argument and check if it has .device and .to for conversion. Nested structures of tensors are not supported.
Since the function might be called more than once, it has to be stateless.
detectron2.utils.analysis module¶
-
detectron2.utils.analysis.
activation_count_operators
(model: torch.nn.Module, inputs: list, **kwargs) → DefaultDict[str, float][source]¶ Implement operator-level activations counting using jit. This is a wrapper of fvcore.nn.activation_count, that supports standard detection models in detectron2.
Note
The function runs the input through the model to compute activations. The activations of a detection model is often input-dependent, for example, the activations of box & mask head depends on the number of proposals & the number of detected objects.
-
detectron2.utils.analysis.
flop_count_operators
(model: torch.nn.Module, inputs: list) → DefaultDict[str, float][source]¶ Implement operator-level flops counting using jit. This is a wrapper of
fvcore.nn.flop_count()
and adds supports for standard detection models in detectron2. Please useFlopCountAnalysis
for more advanced functionalities.Note
The function runs the input through the model to compute flops. The flops of a detection model is often input-dependent, for example, the flops of box & mask head depends on the number of proposals & the number of detected objects. Therefore, the flops counting using a single input may not accurately reflect the computation cost of a model. It’s recommended to average across a number of inputs.
- Parameters
model – a detectron2 model that takes list[dict] as input.
inputs (list[dict]) – inputs to model, in detectron2’s standard format. Only “image” key will be used.
supported_ops (dict[str, Handle]) – see documentation of
fvcore.nn.flop_count()
- Returns
Counter – Gflop count per operator
-
detectron2.utils.analysis.
parameter_count_table
(model: torch.nn.Module, max_depth: int = 3) → str[source]¶ Format the parameter count of the model (and its submodules or parameters) in a nice table. It looks like this:
| name | #elements or shape | |:--------------------------------|:---------------------| | model | 37.9M | | backbone | 31.5M | | backbone.fpn_lateral3 | 0.1M | | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | | backbone.fpn_lateral3.bias | (256,) | | backbone.fpn_output3 | 0.6M | | backbone.fpn_output3.weight | (256, 256, 3, 3) | | backbone.fpn_output3.bias | (256,) | | backbone.fpn_lateral4 | 0.3M | | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | | backbone.fpn_lateral4.bias | (256,) | | backbone.fpn_output4 | 0.6M | | backbone.fpn_output4.weight | (256, 256, 3, 3) | | backbone.fpn_output4.bias | (256,) | | backbone.fpn_lateral5 | 0.5M | | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | | backbone.fpn_lateral5.bias | (256,) | | backbone.fpn_output5 | 0.6M | | backbone.fpn_output5.weight | (256, 256, 3, 3) | | backbone.fpn_output5.bias | (256,) | | backbone.top_block | 5.3M | | backbone.top_block.p6 | 4.7M | | backbone.top_block.p7 | 0.6M | | backbone.bottom_up | 23.5M | | backbone.bottom_up.stem | 9.4K | | backbone.bottom_up.res2 | 0.2M | | backbone.bottom_up.res3 | 1.2M | | backbone.bottom_up.res4 | 7.1M | | backbone.bottom_up.res5 | 14.9M | | ...... | ..... |
- Parameters
model – a torch module
max_depth (int) – maximum depth to recursively print submodules or parameters
- Returns
str – the table to be printed
-
detectron2.utils.analysis.
parameter_count
(model: torch.nn.Module) → DefaultDict[str, int][source]¶ Count parameters of a model and its submodules.
- Parameters
model – a torch module
- Returns
dict (str-> int) – the key is either a parameter name or a module name. The value is the number of elements in the parameter, or in all parameters of the module. The key “” corresponds to the total number of parameters of the model.
-
class
detectron2.utils.analysis.
FlopCountAnalysis
(model, inputs)[source]¶ Bases:
fvcore.nn.flop_count.FlopCountAnalysis
Same as
fvcore.nn.FlopCountAnalysis
, but supports detectron2 models.
detectron2.utils.visualizer module¶
-
class
detectron2.utils.visualizer.
ColorMode
(value)[source]¶ Bases:
enum.Enum
Enum of different color modes to use for instance visualizations.
-
IMAGE
= 0¶ Picks a random color for every instance and overlay segmentations with low opacity.
-
SEGMENTATION
= 1¶ Let instances of the same category have similar colors (from metadata.thing_colors), and overlay them with high opacity. This provides more attention on the quality of segmentation.
-
IMAGE_BW
= 2¶ Same as IMAGE, but convert all areas without masks to gray-scale. Only available for drawing per-instance mask predictions.
-
-
class
detectron2.utils.visualizer.
VisImage
(img, scale=1.0)[source]¶ Bases:
object
-
__init__
(img, scale=1.0)[source]¶ - Parameters
img (ndarray) – an RGB image of shape (H, W, 3).
scale (float) – scale the input image
-
-
class
detectron2.utils.visualizer.
Visualizer
(img_rgb, metadata=None, scale=1.0, instance_mode=<ColorMode.IMAGE: 0>)[source]¶ Bases:
object
Visualizer that draws data about detection/segmentation on images.
It contains methods like draw_{text,box,circle,line,binary_mask,polygon} that draw primitive objects to images, as well as high-level wrappers like draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict} that draw composite data in some pre-defined style.
Note that the exact visualization style for the high-level wrappers are subject to change. Style such as color, opacity, label contents, visibility of labels, or even the visibility of objects themselves (e.g. when the object is too small) may change according to different heuristics, as long as the results still look visually reasonable.
To obtain a consistent style, you can implement custom drawing functions with the abovementioned primitive methods instead. If you need more customized visualization styles, you can process the data yourself following their format documented in tutorials (Use Models, Use Custom Datasets). This class does not intend to satisfy everyone’s preference on drawing styles.
This visualizer focuses on high rendering quality rather than performance. It is not designed to be used for real-time applications.
-
__init__
(img_rgb, metadata=None, scale=1.0, instance_mode=<ColorMode.IMAGE: 0>)[source]¶ - Parameters
img_rgb – a numpy array of shape (H, W, C), where H and W correspond to the height and width of the image respectively. C is the number of color channels. The image is required to be in RGB format since that is a requirement of the Matplotlib library. The image is also expected to be in the range [0, 255].
metadata (Metadata) – dataset metadata (e.g. class names and colors)
instance_mode (ColorMode) – defines one of the pre-defined style for drawing instances on an image.
-
draw_instance_predictions
(predictions)[source]¶ Draw instance-level prediction results on an image.
- Parameters
predictions (Instances) – the output of an instance detection/segmentation model. Following fields will be used to draw: “pred_boxes”, “pred_classes”, “scores”, “pred_masks” (or “pred_masks_rle”).
- Returns
output (VisImage) – image object with visualizations.
-
draw_sem_seg
(sem_seg, area_threshold=None, alpha=0.8)[source]¶ Draw semantic segmentation predictions/labels.
- Parameters
- Returns
output (VisImage) – image object with visualizations.
-
draw_panoptic_seg
(panoptic_seg, segments_info, area_threshold=None, alpha=0.7)[source]¶ Draw panoptic prediction annotations or results.
- Parameters
panoptic_seg (Tensor) – of shape (height, width) where the values are ids for each segment.
segments_info (list[dict] or None) – Describe each segment in panoptic_seg. If it is a
list[dict]
, each dict contains keys “id”, “category_id”. If None, category id of each pixel is computed bypixel // metadata.label_divisor
.area_threshold (int) – stuff segments with less than area_threshold are not drawn.
- Returns
output (VisImage) – image object with visualizations.
-
draw_dataset_dict
(dic)[source]¶ Draw annotations/segmentaions in Detectron2 Dataset format.
- Parameters
dic (dict) – annotation/segmentation data of one image, in Detectron2 Dataset format.
- Returns
output (VisImage) – image object with visualizations.
-
overlay_instances
(*, boxes=None, labels=None, masks=None, keypoints=None, assigned_colors=None, alpha=0.5)[source]¶ - Parameters
boxes (Boxes, RotatedBoxes or ndarray) – either a
Boxes
, or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, or aRotatedBoxes
, or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image,labels (list[str]) – the text to be displayed for each instance.
masks (masks-like object) –
Supported types are:
detectron2.structures.PolygonMasks
,detectron2.structures.BitMasks
.list[list[ndarray]]: contains the segmentation masks for all objects in one image. The first level of the list corresponds to individual instances. The second level to all the polygon that compose the instance, and the third level to the polygon coordinates. The third level should have the format of [x0, y0, x1, y1, …, xn, yn] (n >= 3).
list[ndarray]: each ndarray is a binary mask of shape (H, W).
list[dict]: each dict is a COCO-style RLE.
keypoints (Keypoint or array like) – an array-like object of shape (N, K, 3), where the N is the number of instances and K is the number of keypoints. The last dimension corresponds to (x, y, visibility or score).
assigned_colors (list[matplotlib.colors]) – a list of colors, where each color corresponds to each mask or box in the image. Refer to ‘matplotlib.colors’ for full list of formats that the colors are accepted in.
- Returns
output (VisImage) – image object with visualizations.
-
overlay_rotated_instances
(boxes=None, labels=None, assigned_colors=None)[source]¶ - Parameters
boxes (ndarray) – an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image.
labels (list[str]) – the text to be displayed for each instance.
assigned_colors (list[matplotlib.colors]) – a list of colors, where each color corresponds to each mask or box in the image. Refer to ‘matplotlib.colors’ for full list of formats that the colors are accepted in.
- Returns
output (VisImage) – image object with visualizations.
-
draw_and_connect_keypoints
(keypoints)[source]¶ Draws keypoints of an instance and follows the rules for keypoint connections to draw lines between appropriate keypoints. This follows color heuristics for line color.
- Parameters
keypoints (Tensor) – a tensor of shape (K, 3), where K is the number of keypoints and the last dimension corresponds to (x, y, probability).
- Returns
output (VisImage) – image object with visualizations.
-
draw_text
(text, position, *, font_size=None, color='g', horizontal_alignment='center', rotation=0)[source]¶ - Parameters
text (str) – class label
position (tuple) – a tuple of the x and y coordinates to place text on image.
font_size (int, optional) – font of the text. If not provided, a font size proportional to the image width is calculated and used.
color – color of the text. Refer to matplotlib.colors for full list of formats that are accepted.
horizontal_alignment (str) – see matplotlib.text.Text
rotation – rotation angle in degrees CCW
- Returns
output (VisImage) – image object with text drawn.
-
draw_box
(box_coord, alpha=0.5, edge_color='g', line_style='-')[source]¶ - Parameters
box_coord (tuple) – a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 are the coordinates of the image’s top left corner. x1 and y1 are the coordinates of the image’s bottom right corner.
alpha (float) – blending efficient. Smaller values lead to more transparent masks.
edge_color – color of the outline of the box. Refer to matplotlib.colors for full list of formats that are accepted.
line_style (string) – the string to use to create the outline of the boxes.
- Returns
output (VisImage) – image object with box drawn.
-
draw_rotated_box_with_label
(rotated_box, alpha=0.5, edge_color='g', line_style='-', label=None)[source]¶ Draw a rotated box with label on its top-left corner.
- Parameters
rotated_box (tuple) – a tuple containing (cnt_x, cnt_y, w, h, angle), where cnt_x and cnt_y are the center coordinates of the box. w and h are the width and height of the box. angle represents how many degrees the box is rotated CCW with regard to the 0-degree box.
alpha (float) – blending efficient. Smaller values lead to more transparent masks.
edge_color – color of the outline of the box. Refer to matplotlib.colors for full list of formats that are accepted.
line_style (string) – the string to use to create the outline of the boxes.
label (string) – label for rotated box. It will not be rendered when set to None.
- Returns
output (VisImage) – image object with box drawn.
-
draw_circle
(circle_coord, color, radius=3)[source]¶ - Parameters
- Returns
output (VisImage) – image object with box drawn.
-
draw_line
(x_data, y_data, color, linestyle='-', linewidth=None)[source]¶ - Parameters
x_data (list[int]) – a list containing x values of all the points being drawn. Length of list should match the length of y_data.
y_data (list[int]) – a list containing y values of all the points being drawn. Length of list should match the length of x_data.
color – color of the line. Refer to matplotlib.colors for a full list of formats that are accepted.
linestyle – style of the line. Refer to matplotlib.lines.Line2D for a full list of formats that are accepted.
linewidth (float or None) – width of the line. When it’s None, a default value will be computed and used.
- Returns
output (VisImage) – image object with line drawn.
-
draw_binary_mask
(binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0)[source]¶ - Parameters
binary_mask (ndarray) – numpy array of shape (H, W), where H is the image height and W is the image width. Each value in the array is either a 0 or 1 value of uint8 type.
color – color of the mask. Refer to matplotlib.colors for a full list of formats that are accepted. If None, will pick a random color.
edge_color – color of the polygon edges. Refer to matplotlib.colors for a full list of formats that are accepted.
text (str) – if None, will be drawn in the object’s center of mass.
alpha (float) – blending efficient. Smaller values lead to more transparent masks.
area_threshold (float) – a connected component small than this will not be shown.
- Returns
output (VisImage) – image object with mask drawn.
-
draw_polygon
(segment, color, edge_color=None, alpha=0.5)[source]¶ - Parameters
segment – numpy array of shape Nx2, containing all the points in the polygon.
color – color of the polygon. Refer to matplotlib.colors for a full list of formats that are accepted.
edge_color – color of the polygon edges. Refer to matplotlib.colors for a full list of formats that are accepted. If not provided, a darker shade of the polygon color will be used instead.
alpha (float) – blending efficient. Smaller values lead to more transparent masks.
- Returns
output (VisImage) – image object with polygon drawn.
-
detectron2.utils.video_visualizer module¶
-
class
detectron2.utils.video_visualizer.
VideoVisualizer
(metadata, instance_mode=<ColorMode.IMAGE: 0>)[source]¶ Bases:
object
-
__init__
(metadata, instance_mode=<ColorMode.IMAGE: 0>)[source]¶ - Parameters
metadata (MetadataCatalog) – image metadata.
-
draw_instance_predictions
(frame, predictions)[source]¶ Draw instance-level prediction results on an image.
- Parameters
frame (ndarray) – an RGB image of shape (H, W, C), in the range [0, 255].
predictions (Instances) – the output of an instance detection/segmentation model. Following fields will be used to draw: “pred_boxes”, “pred_classes”, “scores”, “pred_masks” (or “pred_masks_rle”).
- Returns
output (VisImage) – image object with visualizations.
-
detectron2.export¶
Related tutorial: Deployment.
-
detectron2.export.
add_export_config
(cfg)[source]¶ Add options needed by caffe2 export.
- Parameters
cfg (CfgNode) – a detectron2 config
- Returns
CfgNode – an updated config with new options that will be used by
Caffe2Tracer
.
-
class
detectron2.export.
Caffe2Model
(predict_net, init_net)[source]¶ Bases:
torch.nn.Module
A wrapper around the traced model in Caffe2’s protobuf format. The exported graph has different inputs/outputs from the original Pytorch model, as explained in
Caffe2Tracer
. This class wraps around the exported graph to simulate the same interface as the original Pytorch model. It also provides functions to save/load models in Caffe2’s format.’Examples:
c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2() inputs = [{"image": img_tensor_CHW}] outputs = c2_model(inputs) orig_outputs = torch_model(inputs)
-
property
predict_net
¶ the underlying caffe2 predict net
- Type
caffe2.core.Net
-
property
init_net
¶ the underlying caffe2 init net
- Type
caffe2.core.Net
-
save_protobuf
(output_dir)[source]¶ Save the model as caffe2’s protobuf format. It saves the following files:
“model.pb”: definition of the graph. Can be visualized with tools like netron.
“model_init.pb”: model parameters
“model.pbtxt”: human-readable definition of the graph. Not needed for deployment.
- Parameters
output_dir (str) – the output directory to save protobuf files.
-
save_graph
(output_file, inputs=None)[source]¶ Save the graph as SVG format.
- Parameters
output_file (str) – a SVG file
inputs – optional inputs given to the model. If given, the inputs will be used to run the graph to record shape of every tensor. The shape information will be saved together with the graph.
-
static
load_protobuf
(dir)[source]¶ - Parameters
dir (str) – a directory used to save Caffe2Model with
save_protobuf()
. The files “model.pb” and “model_init.pb” are needed.- Returns
Caffe2Model – the caffe2 model loaded from this directory.
-
__call__
(inputs)[source]¶ An interface that wraps around a Caffe2 model and mimics detectron2’s models’ input/output format. See details about the format at Use Models. This is used to compare the outputs of caffe2 model with its original torch model.
Due to the extra conversion between Pytorch/Caffe2, this method is not meant for benchmark. Because of the conversion, this method also has dependency on detectron2 in order to convert to detectron2’s output format.
-
property
-
class
detectron2.export.
Caffe2Tracer
(cfg: detectron2.config.CfgNode, model: torch.nn.Module, inputs)[source]¶ Bases:
object
Make a detectron2 model traceable with Caffe2 operators. This class creates a traceable version of a detectron2 model which:
Rewrite parts of the model using ops in Caffe2. Note that some ops do not have GPU implementation in Caffe2.
Remove post-processing and only produce raw layer outputs
After making a traceable model, the class provide methods to export such a model to different deployment formats. Exported graph produced by this class take two input tensors:
(1, C, H, W) float “data” which is an image (usually in [0, 255]). (H, W) often has to be padded to multiple of 32 (depend on the model architecture).
1x3 float “im_info”, each row of which is (height, width, 1.0). Height and width are true image shapes before padding.
The class currently only supports models using builtin meta architectures. Batch inference is not supported, and contributions are welcome.
-
__init__
(cfg: detectron2.config.CfgNode, model: torch.nn.Module, inputs)[source]¶ - Parameters
cfg (CfgNode) – a detectron2 config, with extra export-related options added by
add_export_config()
. It’s used to construct caffe2-compatible model.model (nn.Module) – An original pytorch model. Must be among a few official models in detectron2 that can be converted to become caffe2-compatible automatically. Weights have to be already loaded to this model.
inputs – sample inputs that the given model takes for inference. Will be used to trace the model. For most models, random inputs with no detected objects will not work as they lead to wrong traces.
-
export_caffe2
()[source]¶ Export the model to Caffe2’s protobuf format. The returned object can be saved with its
save_protobuf()
method. The result can be loaded and executed using Caffe2 runtime.- Returns
-
export_onnx
()[source]¶ Export the model to ONNX format. Note that the exported model contains custom ops only available in caffe2, therefore it cannot be directly executed by other runtime (such as onnxruntime or TensorRT). Post-processing or transformation passes may be applied on the model to accommodate different runtimes, but we currently do not provide support for them.
- Returns
onnx.ModelProto – an onnx model.
-
class
detectron2.export.
TracingAdapter
(model: torch.nn.Module, inputs, inference_func: Optional[Callable] = None, allow_non_tensor: bool = False)[source]¶ Bases:
torch.nn.Module
A model may take rich input/output format (e.g. dict or custom classes), but torch.jit.trace requires tuple of tensors as input/output. This adapter flattens input/output format of a model so it becomes traceable.
It also records the necessary schema to rebuild model’s inputs/outputs from flattened inputs/outputs.
Example:
outputs = model(inputs) # inputs/outputs may be rich structure adapter = TracingAdapter(model, inputs) # can now trace the model, with adapter.flattened_inputs, or another # tuple of tensors with the same length and meaning traced = torch.jit.trace(adapter, adapter.flattened_inputs) # traced model can only produce flattened outputs (tuple of tensors) flattened_outputs = traced(*adapter.flattened_inputs) # adapter knows the schema to convert it back (new_outputs == outputs) new_outputs = adapter.outputs_schema(flattened_outputs)
-
outputs_schema
: detectron2.export.flatten.Schema = None¶ Schema of the output produced by calling the given model with inputs.
-
__init__
(model: torch.nn.Module, inputs, inference_func: Optional[Callable] = None, allow_non_tensor: bool = False)[source]¶ - Parameters
model – an nn.Module
inputs – An input argument or a tuple of input arguments used to call model. After flattening, it has to only consist of tensors.
inference_func – a callable that takes (model, *inputs), calls the model with inputs, and return outputs. By default it is
lambda model, *inputs: model(*inputs)
. Can be override if you need to call the model differently.allow_non_tensor – allow inputs/outputs to contain non-tensor objects. This option will filter out non-tensor objects to make the model traceable, but
inputs_schema
/outputs_schema
cannot be used anymore because inputs/outputs cannot be rebuilt from pure tensors. This is useful when you’re only interested in the single trace of execution (e.g. for flop count), but not interested in generalizing the traced graph to new inputs.
-
flattened_inputs
: Tuple[torch.Tensor] = None¶ Flattened version of inputs given to this class’s constructor.
-
inputs_schema
: detectron2.export.flatten.Schema = None¶ Schema of the inputs given to this class’s constructor.
-
forward
(*args: torch.Tensor)[source]¶
-
-
detectron2.export.
scripting_with_instances
(model, fields)[source]¶ Run
torch.jit.script()
on a model that uses theInstances
class. Since attributes ofInstances
are “dynamically” added in eager mode,it is difficult for scripting to support it out of the box. This function is made to support scripting a model that usesInstances
. It does the following:Create a scriptable
new_Instances
class which behaves similarly toInstances
, but with all attributes been “static”. The attributes need to be statically declared in thefields
argument.Register
new_Instances
, and force scripting compiler to use it when trying to compileInstances
.
After this function, the process will be reverted. User should be able to script another model using different fields.
Example
Assume that
Instances
in the model consist of two attributes namedproposal_boxes
andobjectness_logits
with typeBoxes
andTensor
respectively during inference. You can call this function like:fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} torchscipt_model = scripting_with_instances(model, fields)
Note
It only support models in evaluation mode.
- Parameters
model (nn.Module) – The input model to be exported by scripting.
fields (Dict[str, type]) – Attribute names and corresponding type that
Instances
will use in the model. Note that all attributes used inInstances
need to be added, regardless of whether they are inputs/outputs of the model. Data type not defined in detectron2 is not supported for now.
- Returns
torch.jit.ScriptModule – the model in torchscript format
-
detectron2.export.
dump_torchscript_IR
(model, dir)[source]¶ Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, inlined graph). Useful for debugging.
- Parameters
model (TracedModule/ScriptModule/ScriptFUnction) – traced or scripted module
dir (str) – output directory to dump files.
fvcore documentation¶
Detectron2 depends on utilities in fvcore. We include part of fvcore documentation here for easier reference.
fvcore.nn¶
-
fvcore.nn.
activation_count
(model: torch.nn.Module, inputs: Tuple[Any, …], supported_ops: Optional[Dict[str, Callable[[List[Any], List[Any]], Union[Counter[str], numbers.Number]]]] = None) → Tuple[DefaultDict[str, float], Counter[str]][source]¶ Given a model and an input to the model, compute the total number of activations of the model.
- Parameters
model (nn.Module) – The model to compute activation counts.
inputs (tuple) – Inputs that are passed to model to count activations. Inputs need to be in a tuple.
supported_ops (dict(str,Callable) or None) – provide additional handlers for extra ops, or overwrite the existing handlers for convolution and matmul. The key is operator name and the value is a function that takes (inputs, outputs) of the op.
- Returns
tuple[defaultdict, Counter] –
- A dictionary that records the number of
activation (mega) for each operation and a Counter that records the number of unsupported operations.
-
class
fvcore.nn.
ActivationCountAnalysis
(model: torch.nn.Module, inputs: Union[torch.Tensor, Tuple[torch.Tensor, …]])[source]¶ Bases:
fvcore.nn.jit_analysis.JitModelAnalysis
Provides access to per-submodule model activation count obtained by tracing a model with pytorch’s jit tracing functionality. By default, comes with standard activation counters for convolutional and dot-product operators.
Handles for additional operators may be added, or the default ones overwritten, using the
.set_op_handle(name, func)
method. See the method documentation for details.Activation counts can be obtained as:
.total(module_name="")
: total activation count for a module.by_operator(module_name="")
: activation counts for the module, as a Counter over different operator types.by_module()
: Counter of activation counts for all submodules.by_module_and_operator()
: dictionary indexed by descendant of Counters over different operator types
An operator is treated as within a module if it is executed inside the module’s
__call__
method. Note that this does not include calls to other methods of the module or explicit calls tomodule.forward(...)
.Example usage:
>>> import torch.nn as nn >>> import torch >>> class TestModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(in_features=1000, out_features=10) ... self.conv = nn.Conv2d( ... in_channels=3, out_channels=10, kernel_size=1 ... ) ... self.act = nn.ReLU() ... def forward(self, x): ... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel() >>> inputs = (torch.randn((1,3,10,10)),) >>> acts = ActivationCountAnalysis(model, inputs) >>> acts.total() 1010 >>> acts.total("fc") 10 >>> acts.by_operator() Counter({"conv" : 1000, "addmm" : 10}) >>> acts.by_module() Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) >>> acts.by_module_and_operator() {"" : Counter({"conv" : 1000, "addmm" : 10}), "fc" : Counter({"addmm" : 10}), "conv" : Counter({"conv" : 1000}), "act" : Counter() }
-
__init__
(model: torch.nn.Module, inputs: Union[torch.Tensor, Tuple[torch.Tensor, …]]) → None[source]¶ - Parameters
model – The model to analyze
inputs – The inputs to the model for analysis.
We will trace the execution of model.forward(inputs). This means inputs have to be tensors or tuple of tensors (see https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). In order to trace other methods or unsupported input types, you may need to implement a wrapper module.
-
ancestor_mode
(mode: str) → T¶ Sets how to determine the ancestor modules of an operator. Must be one of “owner” or “caller”.
“caller”: an operator belongs to all modules that is currently executing forward() at the time the operator is called.
“owner”: an operator belongs to the last module that’s executing forward() at the time the operator is called, plus this module’s recursive parents. If an module has multiple parents (e.g. a shared module), only one will be picked.
For most cases, a module only calls submodules it owns, so both options would work identically. In certain edge cases, this option will affect the hierarchy of results, but won’t affect the total count.
-
by_module
() → Counter[str]¶ Returns the statistics for all submodules, aggregated over all operators.
- Returns
Counter(str) – statistics counter grouped by submodule names
-
by_module_and_operator
() → Dict[str, Counter[str]]¶ Returns the statistics for all submodules, separated out by operator type for each submodule. The operator handle determines the name associated with each operator type.
- Returns
dict(str, Counter(str)) – The statistics for each submodule and each operator. Grouped by submodule names, then by operator name.
-
by_operator
(module_name: str = '') → Counter[str]¶ Returns the statistics for a requested module, grouped by operator type. The operator handle determines the name associated with each operator type.
- Parameters
module_name (str) – The submodule to get data for. Defaults to the entire model.
- Returns
Counter(str) – The statistics for each operator.
-
canonical_module_name
(name: str) → str¶ Returns the canonical module name of the given
name
, which might be different from the givenname
if the module is shared. This is the name that will be used as a key when statistics are output using .by_module() and .by_module_and_operator().- Parameters
name (str) – The name of the module to find the canonical name for.
- Returns
str – The canonical name of the module.
-
clear_op_handles
() → fvcore.nn.jit_analysis.JitModelAnalysis¶ Clears all operator handles currently set.
-
copy
(new_model: Optional[torch.nn.Module] = None, new_inputs: Union[None, torch.Tensor, Tuple[torch.Tensor, …]] = None) → fvcore.nn.jit_analysis.JitModelAnalysis¶ Returns a copy of the
JitModelAnalysis
object, keeping all settings, but on a new model or new inputs.
-
set_op_handle
(*args, **kwargs: Optional[Callable[[List[Any], List[Any]], Union[Counter[str], numbers.Number]]]) → fvcore.nn.jit_analysis.JitModelAnalysis¶ Sets additional operator handles, or replaces existing ones.
- Parameters
args – (str, Handle) pairs of operator names and handles.
kwargs – mapping from operator names to handles.
If a handle is
None
, the op will be explicitly ignored. Otherwise, handle should be a function that calculates the desirable statistic from an operator. The function must take two arguments, which are the inputs and outputs of the operator, in the form oflist(torch._C.Value)
. The function should return a counter object with per-operator statistics.Examples
handlers = {"aten::linear": my_handler} counter.set_op_handle("aten::matmul", None, "aten::bmm", my_handler2) .set_op_handle(**handlers)
-
total
(module_name: str = '') → int¶ Returns the total aggregated statistic across all operators for the requested module.
- Parameters
module_name (str) – The submodule to get data for. Defaults to the entire model.
- Returns
int – The aggregated statistic.
-
tracer_warnings
(mode: str) → T¶ Sets which warnings to print when tracing the graph to calculate statistics. There are three modes. Defaults to ‘no_tracer_warning’. Allowed values are:
‘all’ : keeps all warnings raised while tracing
‘no_tracer_warning’ : suppress torch.jit.TracerWarning only
‘none’ : suppress all warnings raised while tracing
- Parameters
mode (str) – warning mode in one of the above values.
-
uncalled_modules
() → Set[str]¶ Returns a set of submodules that were never called during the trace of the graph. This may be because they were unused, or because they were accessed via direct calls .forward() or with other python methods. In the latter case, statistics will not be attributed to the submodule, though the statistics will be included in the parent module.
- Returns
set(str) –
- The set of submodule names that were never called
during the trace of the model.
-
uncalled_modules_warnings
(enabled: bool) → T¶ Sets if warnings from uncalled submodules are shown. Defaults to true. A submodule is considered “uncalled” if it is never called during tracing. This may be because it is actually unused, or because it is accessed via calls to
.forward()
or other methods of the module. The set of uncalled modules may be obtained fromuncalled_modules()
regardless of this setting.- Parameters
enabled (bool) – Set to ‘True’ to show warnings.
-
unsupported_ops
(module_name: str = '') → Counter[str]¶ Lists the number of operators that were encountered but unsupported because no operator handle is available for them. Does not include operators that are explicitly ignored.
- Parameters
module_name (str) – The submodule to list unsupported ops. Defaults to the entire model.
- Returns
Counter(str) – The number of occurences each unsupported operator.
-
unsupported_ops_warnings
(enabled: bool) → T¶ Sets if warnings for unsupported operators are shown. Defaults to True. Counts of unsupported operators may be obtained from
unsupported_ops()
regardless of this setting.- Parameters
enabled (bool) – Set to ‘True’ to show unsupported operator warnings.
-
fvcore.nn.
flop_count
(model: torch.nn.Module, inputs: Tuple[Any, …], supported_ops: Optional[Dict[str, Callable[[List[Any], List[Any]], Union[Counter[str], numbers.Number]]]] = None) → Tuple[DefaultDict[str, float], Counter[str]][source]¶ Given a model and an input to the model, compute the per-operator Gflops of the given model.
- Parameters
model (nn.Module) – The model to compute flop counts.
inputs (tuple) – Inputs that are passed to model to count flops. Inputs need to be in a tuple.
supported_ops (dict(str,Callable) or None) – provide additional handlers for extra ops, or overwrite the existing handlers for convolution and matmul and einsum. The key is operator name and the value is a function that takes (inputs, outputs) of the op. We count one Multiply-Add as one FLOP.
- Returns
tuple[defaultdict, Counter] –
- A dictionary that records the number of
gflops for each operation and a Counter that records the number of unsupported operations.
-
class
fvcore.nn.
FlopCountAnalysis
(model: torch.nn.Module, inputs: Union[torch.Tensor, Tuple[torch.Tensor, …]])[source]¶ Bases:
fvcore.nn.jit_analysis.JitModelAnalysis
Provides access to per-submodule model flop count obtained by tracing a model with pytorch’s jit tracing functionality. By default, comes with standard flop counters for a few common operators. Note that:
Flop is not a well-defined concept. We just produce our best estimate.
We count one fused multiply-add as one flop.
Handles for additional operators may be added, or the default ones overwritten, using the
.set_op_handle(name, func)
method. See the method documentation for details.Flop counts can be obtained as:
.total(module_name="")
: total flop count for the module.by_operator(module_name="")
: flop counts for the module, as a Counter over different operator types.by_module()
: Counter of flop counts for all submodules.by_module_and_operator()
: dictionary indexed by descendant of Counters over different operator types
An operator is treated as within a module if it is executed inside the module’s
__call__
method. Note that this does not include calls to other methods of the module or explicit calls tomodule.forward(...)
.Example usage:
>>> import torch.nn as nn >>> import torch >>> class TestModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(in_features=1000, out_features=10) ... self.conv = nn.Conv2d( ... in_channels=3, out_channels=10, kernel_size=1 ... ) ... self.act = nn.ReLU() ... def forward(self, x): ... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel() >>> inputs = (torch.randn((1,3,10,10)),) >>> flops = FlopCountAnalysis(model, inputs) >>> flops.total() 13000 >>> flops.total("fc") 10000 >>> flops.by_operator() Counter({"addmm" : 10000, "conv" : 3000}) >>> flops.by_module() Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) >>> flops.by_module_and_operator() {"" : Counter({"addmm" : 10000, "conv" : 3000}), "fc" : Counter({"addmm" : 10000}), "conv" : Counter({"conv" : 3000}), "act" : Counter() }
-
__init__
(model: torch.nn.Module, inputs: Union[torch.Tensor, Tuple[torch.Tensor, …]]) → None[source]¶ - Parameters
model – The model to analyze
inputs – The inputs to the model for analysis.
We will trace the execution of model.forward(inputs). This means inputs have to be tensors or tuple of tensors (see https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). In order to trace other methods or unsupported input types, you may need to implement a wrapper module.
-
ancestor_mode
(mode: str) → T¶ Sets how to determine the ancestor modules of an operator. Must be one of “owner” or “caller”.
“caller”: an operator belongs to all modules that is currently executing forward() at the time the operator is called.
“owner”: an operator belongs to the last module that’s executing forward() at the time the operator is called, plus this module’s recursive parents. If an module has multiple parents (e.g. a shared module), only one will be picked.
For most cases, a module only calls submodules it owns, so both options would work identically. In certain edge cases, this option will affect the hierarchy of results, but won’t affect the total count.
-
by_module
() → Counter[str]¶ Returns the statistics for all submodules, aggregated over all operators.
- Returns
Counter(str) – statistics counter grouped by submodule names
-
by_module_and_operator
() → Dict[str, Counter[str]]¶ Returns the statistics for all submodules, separated out by operator type for each submodule. The operator handle determines the name associated with each operator type.
- Returns
dict(str, Counter(str)) – The statistics for each submodule and each operator. Grouped by submodule names, then by operator name.
-
by_operator
(module_name: str = '') → Counter[str]¶ Returns the statistics for a requested module, grouped by operator type. The operator handle determines the name associated with each operator type.
- Parameters
module_name (str) – The submodule to get data for. Defaults to the entire model.
- Returns
Counter(str) – The statistics for each operator.
-
canonical_module_name
(name: str) → str¶ Returns the canonical module name of the given
name
, which might be different from the givenname
if the module is shared. This is the name that will be used as a key when statistics are output using .by_module() and .by_module_and_operator().- Parameters
name (str) – The name of the module to find the canonical name for.
- Returns
str – The canonical name of the module.
-
clear_op_handles
() → fvcore.nn.jit_analysis.JitModelAnalysis¶ Clears all operator handles currently set.
-
copy
(new_model: Optional[torch.nn.Module] = None, new_inputs: Union[None, torch.Tensor, Tuple[torch.Tensor, …]] = None) → fvcore.nn.jit_analysis.JitModelAnalysis¶ Returns a copy of the
JitModelAnalysis
object, keeping all settings, but on a new model or new inputs.
-
set_op_handle
(*args, **kwargs: Optional[Callable[[List[Any], List[Any]], Union[Counter[str], numbers.Number]]]) → fvcore.nn.jit_analysis.JitModelAnalysis¶ Sets additional operator handles, or replaces existing ones.
- Parameters
args – (str, Handle) pairs of operator names and handles.
kwargs – mapping from operator names to handles.
If a handle is
None
, the op will be explicitly ignored. Otherwise, handle should be a function that calculates the desirable statistic from an operator. The function must take two arguments, which are the inputs and outputs of the operator, in the form oflist(torch._C.Value)
. The function should return a counter object with per-operator statistics.Examples
handlers = {"aten::linear": my_handler} counter.set_op_handle("aten::matmul", None, "aten::bmm", my_handler2) .set_op_handle(**handlers)
-
total
(module_name: str = '') → int¶ Returns the total aggregated statistic across all operators for the requested module.
- Parameters
module_name (str) – The submodule to get data for. Defaults to the entire model.
- Returns
int – The aggregated statistic.
-
tracer_warnings
(mode: str) → T¶ Sets which warnings to print when tracing the graph to calculate statistics. There are three modes. Defaults to ‘no_tracer_warning’. Allowed values are:
‘all’ : keeps all warnings raised while tracing
‘no_tracer_warning’ : suppress torch.jit.TracerWarning only
‘none’ : suppress all warnings raised while tracing
- Parameters
mode (str) – warning mode in one of the above values.
-
uncalled_modules
() → Set[str]¶ Returns a set of submodules that were never called during the trace of the graph. This may be because they were unused, or because they were accessed via direct calls .forward() or with other python methods. In the latter case, statistics will not be attributed to the submodule, though the statistics will be included in the parent module.
- Returns
set(str) –
- The set of submodule names that were never called
during the trace of the model.
-
uncalled_modules_warnings
(enabled: bool) → T¶ Sets if warnings from uncalled submodules are shown. Defaults to true. A submodule is considered “uncalled” if it is never called during tracing. This may be because it is actually unused, or because it is accessed via calls to
.forward()
or other methods of the module. The set of uncalled modules may be obtained fromuncalled_modules()
regardless of this setting.- Parameters
enabled (bool) – Set to ‘True’ to show warnings.
-
unsupported_ops
(module_name: str = '') → Counter[str]¶ Lists the number of operators that were encountered but unsupported because no operator handle is available for them. Does not include operators that are explicitly ignored.
- Parameters
module_name (str) – The submodule to list unsupported ops. Defaults to the entire model.
- Returns
Counter(str) – The number of occurences each unsupported operator.
-
unsupported_ops_warnings
(enabled: bool) → T¶ Sets if warnings for unsupported operators are shown. Defaults to True. Counts of unsupported operators may be obtained from
unsupported_ops()
regardless of this setting.- Parameters
enabled (bool) – Set to ‘True’ to show unsupported operator warnings.
-
fvcore.nn.
sigmoid_focal_loss
(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = - 1, gamma: float = 2, reduction: str = 'none') → torch.Tensor[source]¶ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. :param inputs: A float tensor of arbitrary shape.
The predictions for each example.
- Parameters
targets –
- A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha – (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting).
gamma – Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
reduction – ‘none’ | ‘mean’ | ‘sum’ ‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.
- Returns
Loss tensor with the reduction option applied.
-
fvcore.nn.
sigmoid_focal_loss_star
(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = - 1, gamma: float = 1, reduction: str = 'none') → torch.Tensor[source]¶ FL* described in RetinaNet paper Appendix: https://arxiv.org/abs/1708.02002. :param inputs: A float tensor of arbitrary shape.
The predictions for each example.
- Parameters
targets –
- A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha – (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting).
gamma – Gamma parameter described in FL*. Default = 1 (no weighting).
reduction – ‘none’ | ‘mean’ | ‘sum’ ‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.
- Returns
Loss tensor with the reduction option applied.
-
fvcore.nn.
giou_loss
(boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = 'none', eps: float = 1e-07) → torch.Tensor[source]¶ Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) https://arxiv.org/abs/1902.09630
Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap and scales with the size of their smallest enclosing box. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
- Parameters
boxes1 (Tensor) – box locations in XYXY format, shape (N, 4) or (4,).
boxes2 (Tensor) – box locations in XYXY format, shape (N, 4) or (4,).
reduction – ‘none’ | ‘mean’ | ‘sum’ ‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.
eps (float) – small number to prevent division by zero
-
fvcore.nn.
parameter_count
(model: torch.nn.Module) → DefaultDict[str, int][source]¶ Count parameters of a model and its submodules.
- Parameters
model – a torch module
- Returns
dict (str-> int) – the key is either a parameter name or a module name. The value is the number of elements in the parameter, or in all parameters of the module. The key “” corresponds to the total number of parameters of the model.
-
fvcore.nn.
parameter_count_table
(model: torch.nn.Module, max_depth: int = 3) → str[source]¶ Format the parameter count of the model (and its submodules or parameters) in a nice table. It looks like this:
| name | #elements or shape | |:--------------------------------|:---------------------| | model | 37.9M | | backbone | 31.5M | | backbone.fpn_lateral3 | 0.1M | | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | | backbone.fpn_lateral3.bias | (256,) | | backbone.fpn_output3 | 0.6M | | backbone.fpn_output3.weight | (256, 256, 3, 3) | | backbone.fpn_output3.bias | (256,) | | backbone.fpn_lateral4 | 0.3M | | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | | backbone.fpn_lateral4.bias | (256,) | | backbone.fpn_output4 | 0.6M | | backbone.fpn_output4.weight | (256, 256, 3, 3) | | backbone.fpn_output4.bias | (256,) | | backbone.fpn_lateral5 | 0.5M | | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | | backbone.fpn_lateral5.bias | (256,) | | backbone.fpn_output5 | 0.6M | | backbone.fpn_output5.weight | (256, 256, 3, 3) | | backbone.fpn_output5.bias | (256,) | | backbone.top_block | 5.3M | | backbone.top_block.p6 | 4.7M | | backbone.top_block.p7 | 0.6M | | backbone.bottom_up | 23.5M | | backbone.bottom_up.stem | 9.4K | | backbone.bottom_up.res2 | 0.2M | | backbone.bottom_up.res3 | 1.2M | | backbone.bottom_up.res4 | 7.1M | | backbone.bottom_up.res5 | 14.9M | | ...... | ..... |
- Parameters
model – a torch module
max_depth (int) – maximum depth to recursively print submodules or parameters
- Returns
str – the table to be printed
-
fvcore.nn.
get_bn_modules
(model: torch.nn.Module) → List[torch.nn.Module][source]¶ Find all BatchNorm (BN) modules that are in training mode. See fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are included in this search.
- Parameters
model (nn.Module) – a model possibly containing BN modules.
- Returns
list[nn.Module] – all BN modules in the model.
-
fvcore.nn.
update_bn_stats
(model: torch.nn.Module, data_loader: Iterable[Any], num_iters: int = 200, progress: Optional[str] = None) → None[source]¶ Recompute and update the batch norm stats to make them more precise. During training both BN stats and the weight are changing after every iteration, so the running average can not precisely reflect the actual stats of the current model. In this function, the BN stats are recomputed with fixed weights, to make the running average more precise. Specifically, it computes the true average of per-batch mean/variance instead of the running average. See Sec. 3 of the paper “Rethinking Batch in BatchNorm” for details.
- Parameters
model (nn.Module) –
the model whose bn stats will be recomputed.
Note that:
This function will not alter the training mode of the given model. Users are responsible for setting the layers that needs precise-BN to training mode, prior to calling this function.
Be careful if your models contain other stateful layers in addition to BN, i.e. layers whose state can change in forward iterations. This function will alter their state. If you wish them unchanged, you need to either pass in a submodule without those layers, or backup the states.
data_loader (iterator) – an iterator. Produce data as inputs to the model.
num_iters (int) – number of iterations to compute the stats.
progress – None or “tqdm”. If set, use tqdm to report the progress.
-
fvcore.nn.
flop_count_str
(flops: fvcore.nn.flop_count.FlopCountAnalysis, activations: Optional[fvcore.nn.activation_count.ActivationCountAnalysis] = None) → str[source]¶ Calculates the parameters and flops of the model with the given inputs and returns a string representation of the model that includes the parameters and flops of every submodule. The string is structured to be similar that given by str(model), though it is not guaranteed to be identical in form if the default string representation of a module has been overridden. If a module has zero parameters and flops, statistics will not be reported for succinctness.
The trace can only register the scope of a module if it is called directly, which means flops (and activations) arising from explicit calls to .forward() or to other python functions of the module will not be attributed to that module. Modules that are never called will have ‘N/A’ listed for their flops; this means they are either unused or their statistics are missing for this reason. Any such flops are still counted towards the parent
Example:
>>> import torch >>> import torch.nn as nn
>>> class InnerNet(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc1 = nn.Linear(10,10) ... self.fc2 = nn.Linear(10,10) ... def forward(self, x): ... return self.fc1(self.fc2(x))
>>> class TestNet(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc1 = nn.Linear(10,10) ... self.fc2 = nn.Linear(10,10) ... self.inner = InnerNet() ... def forward(self, x): ... return self.fc1(self.fc2(self.inner(x)))
>>> inputs = torch.randn((1,10)) >>> print(flop_count_str(FlopCountAnalysis(model, inputs))) TestNet( #params: 0.44K, #flops: 0.4K (fc1): Linear( in_features=10, out_features=10, bias=True #params: 0.11K, #flops: 100 ) (fc2): Linear( in_features=10, out_features=10, bias=True #params: 0.11K, #flops: 100 ) (inner): InnerNet( #params: 0.22K, #flops: 0.2K (fc1): Linear( in_features=10, out_features=10, bias=True #params: 0.11K, #flops: 100 ) (fc2): Linear( in_features=10, out_features=10, bias=True #params: 0.11K, #flops: 100 ) ) )
- Parameters
flops (FlopCountAnalysis) – the flop counting object
activations (bool) – If given, the activations of each layer will also be calculated and included in the representation.
- Returns
str – a string representation of the model with the number of parameters and flops included.
-
fvcore.nn.
flop_count_table
(flops: fvcore.nn.flop_count.FlopCountAnalysis, max_depth: int = 3, activations: Optional[fvcore.nn.activation_count.ActivationCountAnalysis] = None, show_param_shapes: bool = True) → str[source]¶ Format the per-module parameters and flops of a model in a table. It looks like this:
| model | #parameters or shape | #flops | |:---------------------------------|:-----------------------|:----------| | model | 34.6M | 65.7G | | s1 | 15.4K | 4.32G | | s1.pathway0_stem | 9.54K | 1.23G | | s1.pathway0_stem.conv | 9.41K | 1.23G | | s1.pathway0_stem.bn | 0.128K | | | s1.pathway1_stem | 5.9K | 3.08G | | s1.pathway1_stem.conv | 5.88K | 3.08G | | s1.pathway1_stem.bn | 16 | | | s1_fuse | 0.928K | 29.4M | | s1_fuse.conv_f2s | 0.896K | 29.4M | | s1_fuse.conv_f2s.weight | (16, 8, 7, 1, 1) | | | s1_fuse.bn | 32 | | | s1_fuse.bn.weight | (16,) | | | s1_fuse.bn.bias | (16,) | | | s2 | 0.226M | 7.73G | | s2.pathway0_res0 | 80.1K | 2.58G | | s2.pathway0_res0.branch1 | 20.5K | 0.671G | | s2.pathway0_res0.branch1_bn | 0.512K | | | s2.pathway0_res0.branch2 | 59.1K | 1.91G | | s2.pathway0_res1.branch2 | 70.4K | 2.28G | | s2.pathway0_res1.branch2.a | 16.4K | 0.537G | | s2.pathway0_res1.branch2.a_bn | 0.128K | | | s2.pathway0_res1.branch2.b | 36.9K | 1.21G | | s2.pathway0_res1.branch2.b_bn | 0.128K | | | s2.pathway0_res1.branch2.c | 16.4K | 0.537G | | s2.pathway0_res1.branch2.c_bn | 0.512K | | | s2.pathway0_res2.branch2 | 70.4K | 2.28G | | s2.pathway0_res2.branch2.a | 16.4K | 0.537G | | s2.pathway0_res2.branch2.a_bn | 0.128K | | | s2.pathway0_res2.branch2.b | 36.9K | 1.21G | | s2.pathway0_res2.branch2.b_bn | 0.128K | | | s2.pathway0_res2.branch2.c | 16.4K | 0.537G | | s2.pathway0_res2.branch2.c_bn | 0.512K | | | ............................. | ...... | ...... |
- Parameters
flops (FlopCountAnalysis) – the flop counting object
max_depth (int) – The max depth of submodules to include in the table. Defaults to 3.
activations (ActivationCountAnalysis or None) – If given, include activation counts as an additional column in the table.
show_param_shapes (bool) – If true, shapes for parameters will be included in the table. Defaults to True.
- Returns
str – The formatted table.
Examples:
print(flop_count_table(FlopCountAnalysis(model, inputs)))
-
fvcore.nn.
smooth_l1_loss
(input: torch.Tensor, target: torch.Tensor, beta: float, reduction: str = 'none') → torch.Tensor[source]¶ Smooth L1 loss defined in the Fast R-CNN paper as:
| 0.5 * x ** 2 / beta if abs(x) < beta smoothl1(x) = | | abs(x) - 0.5 * beta otherwise,
where x = input - target.
Smooth L1 loss is related to Huber loss, which is defined as:
| 0.5 * x ** 2 if abs(x) < beta huber(x) = | | beta * (abs(x) - 0.5 * beta) otherwise
Smooth L1 loss is equal to huber(x) / beta. This leads to the following differences:
As beta -> 0, Smooth L1 loss converges to L1 loss, while Huber loss converges to a constant 0 loss.
As beta -> +inf, Smooth L1 converges to a constant 0 loss, while Huber loss converges to L2 loss.
For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. For Huber loss, the slope of the L1 segment is beta.
Smooth L1 loss can be seen as exactly L1 loss, but with the abs(x) < beta portion replaced with a quadratic function such that at abs(x) = beta, its slope is 1. The quadratic segment smooths the L1 loss near x = 0.
- Parameters
input (Tensor) – input tensor of any shape
target (Tensor) – target value tensor with the same shape as input
beta (float) – L1 to L2 change point. For beta values < 1e-5, L1 loss is computed.
reduction – ‘none’ | ‘mean’ | ‘sum’ ‘none’: No reduction will be applied to the output. ‘mean’: The output will be averaged. ‘sum’: The output will be summed.
- Returns
The loss with the reduction option applied.
Note
PyTorch’s builtin “Smooth L1 loss” implementation does not actually implement Smooth L1 loss, nor does it implement Huber loss. It implements the special case of both in which they are equal (beta=1). See: https://pytorch.org/docs/stable/nn.html#torch.nn.SmoothL1Loss.
-
fvcore.nn.
c2_msra_fill
(module: torch.nn.Module) → None[source]¶ Initialize module.weight using the “MSRAFill” implemented in Caffe2. Also initializes module.bias to 0.
- Parameters
module (torch.nn.Module) – module to initialize.
-
fvcore.nn.
c2_xavier_fill
(module: torch.nn.Module) → None[source]¶ Initialize module.weight using the “XavierFill” implemented in Caffe2. Also initializes module.bias to 0.
- Parameters
module (torch.nn.Module) – module to initialize.
fvcore.common¶
-
class
fvcore.common.checkpoint.
Checkpointer
(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any)[source]¶ Bases:
object
A checkpointer that can save/load model as well as extra checkpointable objects.
-
__init__
(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any) → None[source]¶ - Parameters
model (nn.Module) – model.
save_dir (str) – a directory to save and find checkpoints.
save_to_disk (bool) – if True, save checkpoint to disk, otherwise disable saving for this checkpointer.
checkpointables (object) – any checkpointable objects, i.e., objects that have the
state_dict()
andload_state_dict()
method. For example, it can be used like Checkpointer(model, “dir”, optimizer=optimizer).
-
add_checkpointable
(key: str, checkpointable: Any) → None[source]¶ Add checkpointable object for this checkpointer to track.
- Parameters
key (str) – the key used to save the object
checkpointable – any object with
state_dict()
andload_state_dict()
method
-
load
(path: str, checkpointables: Optional[List[str]] = None) → Dict[str, Any][source]¶ Load from the given checkpoint.
- Parameters
- Returns
dict – extra data loaded from the checkpoint that has not been processed. For example, those saved with
save(**extra_data)()
.
-
has_checkpoint
() → bool[source]¶ - Returns
bool – whether a checkpoint exists in the target directory.
-
get_all_checkpoint_files
() → List[str][source]¶ - Returns
list –
- All available checkpoint files (.pth files) in target
directory.
-
-
class
fvcore.common.checkpoint.
PeriodicCheckpointer
(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]¶ Bases:
object
Save checkpoints periodically. When .step(iteration) is called, it will execute checkpointer.save on the given checkpointer, if iteration is a multiple of period or if max_iter is reached.
-
checkpointer
¶ the underlying checkpointer object
- Type
-
__init__
(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model') → None[source]¶ - Parameters
checkpointer – the checkpointer object used to save checkpoints.
period (int) – the period to save checkpoint.
max_iter (int) – maximum number of iterations. When it is reached, a checkpoint named “{file_prefix}_final” will be saved.
max_to_keep (int) – maximum number of most current checkpoints to keep, previous checkpoints will be deleted
file_prefix (str) – the prefix of checkpoint’s filename
-
step
(iteration: int, **kwargs: Any) → None[source]¶ Perform the appropriate action at the given iteration.
- Parameters
iteration (int) – the current iteration, ranged in [0, max_iter-1].
kwargs (Any) – extra data to save, same as in
Checkpointer.save()
.
-
save
(name: str, **kwargs: Any) → None[source]¶ Same argument as
Checkpointer.save()
. Use this method to manually save checkpoints outside the schedule.- Parameters
name (str) – file name.
kwargs (Any) – extra data to save, same as in
Checkpointer.save()
.
-
-
class
fvcore.common.config.
CfgNode
(init_dict=None, key_list=None, new_allowed=False)[source]¶ Bases:
yacs.config.CfgNode
Our own extended version of
yacs.config.CfgNode
. It contains the following extra features:The
merge_from_file()
method supports the “_BASE_” key, which allows the new CfgNode to inherit all the attributes from the base configuration file.Keys that start with “COMPUTED_” are treated as insertion-only “computed” attributes. They can be inserted regardless of whether the CfgNode is frozen or not.
With “allow_unsafe=True”, it supports pyyaml tags that evaluate expressions in config. See examples in https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types Note that this may lead to arbitrary code execution: you must not load a config file from untrusted sources before manually inspecting the content of the file.
-
classmethod
load_yaml_with_base
(filename: str, allow_unsafe: bool = False) → Dict[str, Any][source]¶ - Just like yaml.load(open(filename)), but inherit attributes from its
_BASE_.
-
merge_from_file
(cfg_filename: str, allow_unsafe: bool = False) → None[source]¶ Merge configs from a given yaml file.
- Parameters
cfg_filename – the file name of the yaml config.
allow_unsafe – whether to allow loading the config file with yaml.unsafe_load.
-
merge_from_other_cfg
(cfg_other: fvcore.common.config.CfgNode) → Callable[], None][source]¶ - Parameters
cfg_other (CfgNode) – configs to merge from.
-
class
fvcore.common.history_buffer.
HistoryBuffer
(max_length: int = 1000000)[source]¶ Bases:
object
Track a series of scalar values and provide access to smoothed values over a window or the global average of the series.
-
__init__
(max_length: int = 1000000) → None[source]¶ - Parameters
max_length – maximal number of values that can be stored in the buffer. When the capacity of the buffer is exhausted, old values will be removed.
-
update
(value: float, iteration: Optional[float] = None) → None[source]¶ Add a new scalar value produced at certain iteration. If the length of the buffer exceeds self._max_length, the oldest element will be removed from the buffer.
-
median
(window_size: int) → float[source]¶ Return the median of the latest window_size values in the buffer.
-
avg
(window_size: int) → float[source]¶ Return the mean of the latest window_size values in the buffer.
-
-
class
fvcore.common.param_scheduler.
ParamScheduler
[source]¶ Bases:
object
Base class for parameter schedulers. A parameter scheduler defines a mapping from a progress value in [0, 1) to a number (e.g. learning rate).
-
WHERE_EPSILON
= 1e-06¶
-
__call__
(where: float) → float[source]¶ Get the value of the param for a given point at training.
We update params (such as learning rate) based on the percent progress of training completed. This allows a scheduler to be agnostic to the exact length of a particular run (e.g. 120 epochs vs 90 epochs), as long as the relative progress where params should be updated is the same. However, it assumes that the total length of training is known.
- Parameters
where – A float in [0,1) that represents how far training has progressed
-
-
class
fvcore.common.param_scheduler.
ConstantParamScheduler
(value: float)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Returns a constant value for a param.
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
CosineParamScheduler
(start_value: float, end_value: float)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Cosine decay or cosine warmup schedules based on start and end values. The schedule is updated based on the fraction of training progress. The schedule was proposed in ‘SGDR: Stochastic Gradient Descent with Warm Restarts’ (https://arxiv.org/abs/1608.03983). Note that this class only implements the cosine annealing part of SGDR, and not the restarts.
Example
CosineParamScheduler(start_value=0.1, end_value=0.0001)
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
ExponentialParamScheduler
(start_value: float, decay: float)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Exponetial schedule parameterized by a start value and decay. The schedule is updated based on the fraction of training progress, where, with the formula param_t = start_value * (decay ** where).
Example
Corresponds to a decreasing schedule with values in [2.0, 0.04).
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
LinearParamScheduler
(start_value: float, end_value: float)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Linearly interpolates parameter between
start_value
andend_value
. Can be used for either warmup or decay based on start and end values. The schedule is updated after every train step by default.Example
LinearParamScheduler(start_value=0.0001, end_value=0.01)
Corresponds to a linear increasing schedule with values in [0.0001, 0.01)
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
CompositeParamScheduler
(schedulers: Sequence[fvcore.common.param_scheduler.ParamScheduler], lengths: List[float], interval_scaling: Sequence[str])[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Composite parameter scheduler composed of intermediate schedulers. Takes a list of schedulers and a list of lengths corresponding to percentage of training each scheduler should run for. Schedulers are run in order. All values in lengths should sum to 1.0.
Each scheduler also has a corresponding interval scale. If interval scale is ‘fixed’, the intermediate scheduler will be run without any rescaling of the time. If interval scale is ‘rescaled’, intermediate scheduler is run such that each scheduler will start and end at the same values as it would if it were the only scheduler. Default is ‘rescaled’ for all schedulers.
Example
schedulers = [ ConstantParamScheduler(value=0.42), CosineParamScheduler(start_value=0.42, end_value=1e-4) ] CompositeParamScheduler( schedulers=schedulers, interval_scaling=['rescaled', 'rescaled'], lengths=[0.3, 0.7])
The parameter value will be 0.42 for the first [0%, 30%) of steps, and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of training.
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
MultiStepParamScheduler
(values: List[float], num_updates: Optional[int] = None, milestones: Optional[List[int]] = None)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Takes a predefined schedule for a param value, and a list of epochs or steps which stand for the upper boundary (excluded) of each range.
Example
MultiStepParamScheduler( values=[0.1, 0.01, 0.001, 0.0001], milestones=[30, 60, 80, 120] )
Then the param value will be 0.1 for epochs 0-29, 0.01 for epochs 30-59, 0.001 for epochs 60-79, 0.0001 for epochs 80-120. Note that the length of values must be equal to the length of milestones plus one.
-
__init__
(values: List[float], num_updates: Optional[int] = None, milestones: Optional[List[int]] = None) → None[source]¶ - Parameters
values – param value in each range
num_updates – the end of the last range. If None, will use
milestones[-1]
milestones – the boundary of each range. If None, will evenly split
num_updates
For example, all the following combinations define the same scheduler:
num_updates=90, milestones=[30, 60], values=[1, 0.1, 0.01]
num_updates=90, values=[1, 0.1, 0.01]
milestones=[30, 60, 90], values=[1, 0.1, 0.01]
milestones=[3, 6, 9], values=[1, 0.1, 0.01] (ParamScheduler is scale-invariant)
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
StepParamScheduler
(num_updates: Union[int, float], values: List[float])[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Takes a fixed schedule for a param value. If the length of the fixed schedule is less than the number of epochs, then the epochs are divided evenly among the param schedule. The schedule is updated after every train epoch by default.
Example
StepParamScheduler(values=[0.1, 0.01, 0.001, 0.0001], num_updates=120)
Then the param value will be 0.1 for epochs 0-29, 0.01 for epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119.
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
StepWithFixedGammaParamScheduler
(base_value: float, num_decays: int, gamma: float, num_updates: int)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Decays the param value by gamma at equal number of steps so as to have the specified total number of decays.
Example
StepWithFixedGammaParamScheduler( base_value=0.1, gamma=0.1, num_decays=3, num_updates=120)
Then the param value will be 0.1 for epochs 0-29, 0.01 for epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119.
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.param_scheduler.
PolynomialDecayParamScheduler
(base_value: float, power: float)[source]¶ Bases:
fvcore.common.param_scheduler.ParamScheduler
Decays the param value after every epoch according to a polynomial function with a fixed power. The schedule is updated after every train step by default.
Example
PolynomialDecayParamScheduler(base_value=0.1, power=0.9)
Then the param value will be 0.1 for epoch 0, 0.099 for epoch 1, and so on.
-
WHERE_EPSILON
= 1e-06¶
-
-
class
fvcore.common.registry.
Registry
(*args, **kwds)[source]¶ Bases:
collections.abc.Iterable
,typing.Generic
The registry that provides name -> object mapping, to support third-party users’ custom modules.
To create a registry (e.g. a backbone registry):
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
@BACKBONE_REGISTRY.register() class MyBackbone(): ...
Or:
BACKBONE_REGISTRY.register(MyBackbone)