detectron2.data package¶
-
detectron2.data.
build_detection_test_loader
(cfg, dataset_name, mapper=None)[source]¶ Similar to build_detection_train_loader. But this function uses the given dataset_name argument (instead of the names in cfg), and uses batch size 1.
Parameters: - cfg – a detectron2 CfgNode
- dataset_name (str) – a name of the dataset that’s available in the DatasetCatalog
- mapper (callable) – a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be DatasetMapper(cfg, False).
Returns: DataLoader – a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching.
-
detectron2.data.
build_detection_train_loader
(cfg, mapper=None)[source]¶ A data loader is created by the following steps:
- Use the dataset names in config to query
DatasetCatalog
, and obtain a list of dicts. - Start workers to work on the dicts. Each worker will:
- Map each metadata dict into another format to be consumed by the model.
- Batch them by simply putting dicts into a list.
The batched
list[mapped_dict]
is what this dataloader will return.Parameters: - cfg (CfgNode) – the config
- mapper (callable) – a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be DatasetMapper(cfg, True).
Returns: a torch DataLoader object
- Use the dataset names in config to query
-
detectron2.data.
get_detection_dataset_dicts
(dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None)[source]¶ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
Parameters: - dataset_names (list[str]) – a list of dataset names
- filter_empty (bool) – whether to filter out images without instance annotations
- min_keypoints (int) – filter out images with fewer keypoints than min_keypoints. Set to 0 to do nothing.
- proposal_files (list[str]) – if given, a list of object proposal files that match each dataset in dataset_names.
-
detectron2.data.
load_proposals_into_dataset
(dataset_dicts, proposal_file)[source]¶ Load precomputed object proposals into the dataset.
The proposal file should be a pickled dict with the following keys: - “ids”: list[int] or list[str], the image ids - “boxes”: list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id - “objectness_logits”: list[np.ndarray], each is an N sized array of objectness scores
corresponding to the boxes.- “bbox_mode”: the BoxMode of the boxes array. Defaults to
BoxMode.XYXY_ABS
.
Parameters: Returns: list[dict] – the same format as dataset_dicts, but added proposal field.
- “bbox_mode”: the BoxMode of the boxes array. Defaults to
-
class
detectron2.data.
DatasetCatalog
[source]¶ Bases:
object
A catalog that stores information about the datasets and how to obtain them.
It contains a mapping from strings (which are names that identify a dataset, e.g. “coco_2014_train”) to a function which parses the dataset and returns the samples in the format of list[dict].
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details) if used with the data loader functionalities in data/build.py,data/detection_transform.py.
The purpose of having this catalog is to make it easy to choose different datasets, by just using the strings in the config.
-
static
register
(name, func)[source]¶ Parameters: - name (str) – the name that identifies a dataset, e.g. “coco_2014_train”.
- func (callable) – a callable which takes no arguments and returns a list of dicts.
-
static
-
class
detectron2.data.
MetadataCatalog
[source]¶ Bases:
object
MetadataCatalog provides access to “Metadata” of a given dataset.
The metadata associated with a certain name is a singleton: once created, the metadata will stay alive and will be returned by future calls to get(name).
It’s like global variables, so don’t abuse it. It’s meant for storing knowledge that’s constant and shared across the execution of the program, e.g.: the class names in COCO.
-
class
detectron2.data.
DatasetFromList
(lst: list, copy: bool = True)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Wrap a list to a torch Dataset. It produces elements of the list as data.
-
class
detectron2.data.
MapDataset
(dataset, map_func)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Map a function over the elements in a dataset.
Parameters: - dataset – a dataset where map function is applied.
- map_func – a callable which maps the element in dataset. map_func is responsible for error handling, when error happens, it needs to return None so the MapDataset will randomly use other elements from the dataset.
-
class
detectron2.data.
DatasetMapper
(cfg, is_train=True)[source]¶ Bases:
object
A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic.
The callable currently does the following: 1. Read the image from “file_name” 2. Applies cropping/geometric transforms to the image and annotations 3. Prepare data and annotations to Tensor and
Instances
detectron2.data.detection_utils module¶
Common data processing utilities that are used in a typical object detection data pipeline.
-
exception
detectron2.data.detection_utils.
SizeMismatchError
[source]¶ Bases:
ValueError
When loaded image has difference width/height compared with annotation.
-
detectron2.data.detection_utils.
read_image
(file_name, format=None)[source]¶ Read an image into the given format. Will apply rotation and flipping if the image has such exif information.
Parameters: Returns: image (np.ndarray) – an HWC image
-
detectron2.data.detection_utils.
check_image_size
(dataset_dict, image)[source]¶ Raise an error if the image does not match the size specified in the dict.
-
detectron2.data.detection_utils.
transform_proposals
(dataset_dict, image_shape, transforms, min_box_side_len, proposal_topk)[source]¶ Apply transformations to the proposals in dataset_dict, if any.
Parameters: - dataset_dict (dict) – a dict read from the dataset, possibly contains fields “proposal_boxes”, “proposal_objectness_logits”, “proposal_bbox_mode”
- image_shape (tuple) – height, width
- transforms (TransformList) –
- min_box_side_len (int) – keep proposals with at least this size
- proposal_topk (int) – only keep top-K scoring proposals
The input dict is modified in-place, with abovementioned keys removed. A new key “proposals” will be added. Its value is an Instances object which contains the transformed proposals in its field “proposal_boxes” and “objectness_logits”.
-
detectron2.data.detection_utils.
transform_instance_annotations
(annotation, transforms, image_size, *, keypoint_hflip_indices=None)[source]¶ Apply transforms to box, segmentation and keypoints annotations of a single instance.
It will use transforms.apply_box for the box, and transforms.apply_coords for segmentation polygons & keypoints. If you need anything more specially designed for each data structure, you’ll need to implement your own version of this function or the transforms.
Parameters: - annotation (dict) – dict of instance annotations for a single instance. It will be modified in-place.
- transforms (TransformList) –
- image_size (tuple) – the height, width of the transformed image
- keypoint_hflip_indices (ndarray[int]) – see create_keypoint_hflip_indices.
Returns: dict – the same input dict with fields “bbox”, “segmentation”, “keypoints” transformed according to transforms. The “bbox_mode” field will be set to XYXY_ABS.
-
detectron2.data.detection_utils.
transform_keypoint_annotations
(keypoints, transforms, image_size, keypoint_hflip_indices=None)[source]¶ Transform keypoint annotations of an image.
Parameters: - keypoints (list[float]) – Nx3 float in Detectron2 Dataset format.
- transforms (TransformList) –
- image_size (tuple) – the height, width of the transformed image
- keypoint_hflip_indices (ndarray[int]) – see create_keypoint_hflip_indices.
-
detectron2.data.detection_utils.
annotations_to_instances
(annos, image_size, mask_format='polygon')[source]¶ Create an
Instances
object used by the models, from instance annotations in the dataset dict.Parameters: Returns: Instances – It will contain fields “gt_boxes”, “gt_classes”, “gt_masks”, “gt_keypoints”, if they can be obtained from annos. This is the format that builtin models expect.
-
detectron2.data.detection_utils.
annotations_to_instances_rotated
(annos, image_size)[source]¶ Create an
Instances
object used by the models, from instance annotations in the dataset dict. Compared to annotations_to_instances, this function is for rotated boxes onlyParameters: Returns: Instances – Containing fields “gt_boxes”, “gt_classes”, if they can be obtained from annos. This is the format that builtin models expect.
-
detectron2.data.detection_utils.
filter_empty_instances
(instances, by_box=True, by_mask=True)[source]¶ Filter out empty instances in an Instances object.
Parameters: Returns: Instances – the filtered instances.
-
detectron2.data.detection_utils.
create_keypoint_hflip_indices
(dataset_names)[source]¶ Parameters: dataset_names (list[str]) – list of dataset names Returns: ndarray[int] – a vector of size=#keypoints, storing the horizontally-flipped keypoint indices.
-
detectron2.data.detection_utils.
gen_crop_transform_with_instance
(crop_size, image_size, instance)[source]¶ Generate a CropTransform so that the cropping region contains the center of the given instance.
Parameters:
-
detectron2.data.detection_utils.
check_metadata_consistency
(key, dataset_names)[source]¶ Check that the datasets have consistent metadata.
Parameters: Raises: AttributeError
– if the key does not exist in the metadataValueError
– if the given datasets do not have the same metadata values defined by key
detectron2.data.datasets module¶
-
detectron2.data.datasets.
load_cityscapes_instances
(image_dir, gt_dir, from_json=True, to_polygons=True)[source]¶ Parameters: - image_dir (str) – path to the raw dataset. e.g., “~/cityscapes/leftImg8bit/train”.
- gt_dir (str) – path to the raw annotations. e.g., “~/cityscapes/gtFine/train”.
- from_json (bool) – whether to read annotations from the raw json file or the png files.
- to_polygons (bool) – whether to represent the segmentation as polygons (COCO’s format) instead of masks (cityscapes’s format).
Returns: list[dict] – a list of dicts in Detectron2 standard format. (See Using Custom Datasets )
-
detectron2.data.datasets.
load_coco_json
(json_file, image_root, dataset_name=None, extra_annotation_keys=None)[source]¶ Load a json file with COCO’s instances annotation format. Currently supports instance detection, instance segmentation, and person keypoints annotations.
Parameters: - json_file (str) – full path to the json file in COCO instances annotation format.
- image_root (str) – the directory where the images in this json file exists.
- dataset_name (str) – the name of the dataset (e.g., coco_2017_train). If provided, this function will also put “thing_classes” into the metadata associated with this dataset.
- extra_annotation_keys (list[str]) – list of per-annotation keys that should also be loaded into the dataset dict (besides “iscrowd”, “bbox”, “keypoints”, “category_id”, “segmentation”). The values for these keys will be returned as-is. For example, the densepose annotations are loaded in this way.
Returns: list[dict] – a list of dicts in Detectron2 standard format. (See Using Custom Datasets )
Notes
- This function does not read the image files. The results do not have the “image” field.
-
detectron2.data.datasets.
load_sem_seg
(gt_root, image_root, gt_ext='png', image_ext='jpg')[source]¶ Load semantic segmentation datasets. All files under “gt_root” with “gt_ext” extension are treated as ground truth annotations and all files under “image_root” with “image_ext” extension as input images. Ground truth and input images are matched using file paths relative to “gt_root” and “image_root” respectively without taking into account file extensions. This works for COCO as well as some other datasets.
Parameters: - gt_root (str) – full path to ground truth semantic segmentation files. Semantic segmentation annotations are stored as images with integer values in pixels that represent corresponding semantic labels.
- image_root (str) – the directory where the input images are.
- gt_ext (str) – file extension for ground truth annotations.
- image_ext (str) – file extension for input images.
Returns: list[dict] – a list of dicts in detectron2 standard format without instance-level annotation.
Notes
- This function does not read the image and ground truth files. The results do not have the “image” and “sem_seg” fields.
-
detectron2.data.datasets.
load_lvis_json
(json_file, image_root, dataset_name=None)[source]¶ Load a json file in LVIS’s annotation format.
Parameters: - json_file (str) – full path to the LVIS json annotation file.
- image_root (str) – the directory where the images in this json file exists.
- dataset_name (str) – the name of the dataset (e.g., “lvis_v0.5_train”). If provided, this function will put “thing_classes” into the metadata associated with this dataset.
Returns: list[dict] – a list of dicts in Detectron2 standard format. (See Using Custom Datasets )
Notes
- This function does not read the image files. The results do not have the “image” field.
-
detectron2.data.datasets.
register_lvis_instances
(name, metadata, json_file, image_root)[source]¶ Register a dataset in LVIS’s json annotation format for instance detection and segmentation.
Parameters:
-
detectron2.data.datasets.
get_lvis_instances_meta
(dataset_name)[source]¶ Load LVIS metadata.
Parameters: dataset_name (str) – LVIS dataset name without the split name (e.g., “lvis_v0.5”). Returns: dict – LVIS metadata with keys: thing_classes
-
detectron2.data.datasets.
register_coco_instances
(name, metadata, json_file, image_root)[source]¶ Register a dataset in COCO’s json annotation format for instance detection, instance segmentation and keypoint detection. (i.e., Type 1 and 2 in http://cocodataset.org/#format-data. instances*.json and person_keypoints*.json in the dataset).
This is an example of how to register a new dataset. You can do something similar to this function, to register new datasets.
Parameters:
-
detectron2.data.datasets.
register_coco_panoptic_separated
(name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json)[source]¶ Register a COCO panoptic segmentation dataset named name. The annotations in this registered dataset will contain both instance annotations and semantic annotations, each with its own contiguous ids. Hence it’s called “separated”.
It follows the setting used by the PanopticFPN paper:
The instance annotations directly come from polygons in the COCO instances annotation task, rather than from the masks in the COCO panoptic annotations.
The two format have small differences: Polygons in the instance annotations may have overlaps. The mask annotations are produced by labeling the overlapped polygons with depth ordering.
The semantic annotations are converted from panoptic annotations, where all “things” are assigned a semantic id of 0. All semantic categories will therefore have ids in contiguous range [1, #stuff_categories].
This function will also register a pure semantic segmentation dataset named
name + '_stuffonly'
.Parameters: - name (str) – the name that identifies a dataset, e.g. “coco_2017_train_panoptic”
- metadata (dict) – extra metadata associated with this dataset.
- image_root (str) – directory which contains all the images
- panoptic_root (str) – directory which contains panoptic annotation images
- panoptic_json (str) – path to the json panoptic annotation file
- sem_seg_root (str) – directory which contains all the ground truth segmentation annotations.
- instances_json (str) – path to the json instance annotation file
detectron2.data.samplers module¶
-
class
detectron2.data.samplers.
GroupedBatchSampler
(sampler, group_ids, batch_size)[source]¶ Bases:
torch.utils.data.sampler.BatchSampler
Wraps another sampler to yield a mini-batch of indices. It enforces that the batch only contain elements from the same group. It also tries to provide mini-batches which follows an ordering which is as close as possible to the ordering from the original sampler.
-
class
detectron2.data.samplers.
TrainingSampler
(size: int, shuffle: bool = True, seed: Optional[int] = None)[source]¶ Bases:
torch.utils.data.sampler.Sampler
In training, we only care about the “infinite stream” of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces indices[worker_id::num_workers] where indices is an infinite stream of indices consisting of shuffle(range(size)) + shuffle(range(size)) + … (if shuffle is True) or range(size) + range(size) + … (if shuffle is False)
-
__init__
(size: int, shuffle: bool = True, seed: Optional[int] = None)[source]¶ Parameters: - size (int) – the total number of data of the underlying dataset to sample from
- shuffle (bool) – whether to shuffle the indices or not
- seed (int) – the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers).
-
-
class
detectron2.data.samplers.
InferenceSampler
(size: int)[source]¶ Bases:
torch.utils.data.sampler.Sampler
Produce indices for inference. Inference needs to run on the __exact__ set of samples, therefore when the total number of samples is not divisible by the number of workers, this sampler produces different number of samples on different workers.
-
class
detectron2.data.samplers.
RepeatFactorTrainingSampler
(dataset_dicts, repeat_thresh, shuffle=True, seed=None)[source]¶ Bases:
torch.utils.data.sampler.Sampler
Similar to TrainingSampler, but suitable for training on class imbalanced datasets like LVIS. In each epoch, an image may appear multiple times based on its “repeat factor”. The repeat factor for an image is a function of the frequency the rarest category labeled in that image. The “frequency of category c” in [0, 1] is defined as the fraction of images in the training set (without repeats) in which category c appears.
See https://arxiv.org/abs/1908.03195 (>= v2) Appendix B.2.
-
__init__
(dataset_dicts, repeat_thresh, shuffle=True, seed=None)[source]¶ Parameters: - dataset_dicts (list[dict]) – annotations in Detectron2 dataset format.
- repeat_thresh (float) – frequency threshold below which data is repeated.
- shuffle (bool) – whether to shuffle the indices or not
- seed (int) – the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers).
-
detectron2.data.transforms module¶
-
class
detectron2.data.transforms.
ExtentTransform
(src_rect, output_size, interp=2, fill=0)[source]¶ Bases:
fvcore.transforms.transform.Transform
Extracts a subregion from the source image and scales it to the output size.
The fill color is used to map pixels from the source rect that fall outside the source image.
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
-
class
detectron2.data.transforms.
ResizeTransform
(h, w, new_h, new_w, interp)[source]¶ Bases:
fvcore.transforms.transform.Transform
Resize the image to a target size.
-
__init__
(h, w, new_h, new_w, interp)[source]¶ Parameters: - w (h,) – original image size
- new_w (new_h,) – new image size
- interp – PIL interpolation methods
-
apply_rotated_box
(rotated_boxes)¶ Apply the resizing transform on rotated boxes. For details of how these (approximation) formulas are derived, please refer to
RotatedBoxes.scale()
.Parameters: rotated_boxes (ndarray) – Nx5 floating point array of (x_center, y_center, width, height, angle_degrees) format in absolute coordinates.
-
-
class
detectron2.data.transforms.
BlendTransform
(src_image: numpy.ndarray, src_weight: float, dst_weight: float)[source]¶ Bases:
fvcore.transforms.transform.Transform
Transforms pixel colors with PIL enhance functions.
-
__init__
(src_image: numpy.ndarray, src_weight: float, dst_weight: float)[source]¶ Blends the input image (dst_image) with the src_image using formula:
src_weight * src_image + dst_weight * dst_image
Parameters:
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray[source]¶ Apply blend transform on the image(s).
Parameters: - img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- interp (str) – keep this option for consistency, perform blend would not require interpolation.
Returns: ndarray – blended image(s).
-
-
class
detectron2.data.transforms.
CropTransform
(x0: int, y0: int, w: int, h: int)[source]¶ Bases:
fvcore.transforms.transform.Transform
-
__init__
(x0: int, y0: int, w: int, h: int)[source]¶ Parameters: y0, w, h (x0,) – crop the image(s) by img[y0:y0+h, x0:x0+w].
-
apply_image
(img: numpy.ndarray) → numpy.ndarray[source]¶ Crop the image(s).
Parameters: img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255]. Returns: ndarray – cropped image(s).
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray[source]¶ Apply crop transform on coordinates.
Parameters: coords (ndarray) – floating point array of shape Nx2. Each row is (x, y). Returns: ndarray – cropped coordinates.
-
apply_polygons
(polygons: list) → list[source]¶ Apply crop transform on a list of polygons, each represented by a Nx2 array. It will crop the polygon with the box, therefore the number of points in the polygon might change.
Parameters: polygon (list[ndarray]) – each is a Nx2 floating point array of (x, y) format in absolute coordinates. Returns: ndarray – cropped polygons.
-
-
class
detectron2.data.transforms.
GridSampleTransform
(grid: numpy.ndarray, interp: str)[source]¶ Bases:
fvcore.transforms.transform.Transform
-
__init__
(grid: numpy.ndarray, interp: str)[source]¶ Parameters: - grid (ndarray) – grid has x and y input pixel locations which are used to compute output. Grid has values in the range of [-1, 1], which is normalized by the input height and width. The dimension is N x H x W x 2.
- interp (str) – interpolation methods. Options include nearest and bilinear.
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray[source]¶ Apply grid sampling on the image(s).
Parameters: - img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- interp (str) – interpolation methods. Options include nearest and bilinear.
Returns: ndarray – grid sampled image(s).
-
-
class
detectron2.data.transforms.
HFlipTransform
(width: int)[source]¶ Bases:
fvcore.transforms.transform.Transform
Perform horizontal flip.
-
apply_image
(img: numpy.ndarray) → numpy.ndarray[source]¶ Flip the image(s).
Parameters: img (ndarray) – of shape HxW, HxWxC, or NxHxWxC. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255]. Returns: ndarray – the flipped image(s).
-
apply_coords
(coords: numpy.ndarray) → numpy.ndarray[source]¶ Flip the coordinates.
Parameters: coords (ndarray) – floating point array of shape Nx2. Each row is (x, y). Returns: ndarray – the flipped coordinates. Note
The inputs are floating point coordinates, not pixel indices. Therefore they are flipped by (W - x, H - y), not (W - 1 - x, H - 1 - y).
-
apply_rotated_box
(rotated_boxes)¶ Apply the horizontal flip transform on rotated boxes.
Parameters: rotated_boxes (ndarray) – Nx5 floating point array of (x_center, y_center, width, height, angle_degrees) format in absolute coordinates.
-
-
class
detectron2.data.transforms.
NoOpTransform
[source]¶ Bases:
fvcore.transforms.transform.Transform
A transform that does nothing.
-
apply_rotated_box
(x)¶
-
-
class
detectron2.data.transforms.
ScaleTransform
(h: int, w: int, new_h: int, new_w: int, interp: str)[source]¶ Bases:
fvcore.transforms.transform.Transform
Resize the image to a target size.
-
__init__
(h: int, w: int, new_h: int, new_w: int, interp: str)[source]¶ Parameters: - w (h,) – original image size.
- new_w (new_h,) – new image size.
- interp (str) – interpolation methods. Options includes nearest, linear (3D-only), bilinear, bicubic (4D-only), and area. Details can be found in: https://pytorch.org/docs/stable/nn.functional.html
-
apply_image
(img: numpy.ndarray, interp: str = None) → numpy.ndarray[source]¶ Resize the image(s).
Parameters: - img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255].
- interp (str) – interpolation methods. Options includes nearest, linear (3D-only), bilinear, bicubic (4D-only), and area. Details can be found in: https://pytorch.org/docs/stable/nn.functional.html
Returns: ndarray – resized image(s).
-
-
class
detectron2.data.transforms.
Transform
[source]¶ Bases:
object
Base class for implementations of __deterministic__ transformations for image and other data structures. “Deterministic” requires that the output of all methods of this class are deterministic w.r.t their input arguments. In training, there should be a higher-level policy that generates (likely with random variations) these transform ops. Each transform op may handle several data types, e.g.: image, coordinates, segmentation, bounding boxes. Some of them have a default implementation, but can be overwritten if the default isn’t appropriate. The implementation of each method may choose to modify its input data in-place for efficient transformation.
-
apply_image
(img: numpy.ndarray)[source]¶ Apply the transform on an image.
Parameters: img (ndarray) – of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255]. Returns: ndarray – image after apply the transformation.
-
apply_coords
(coords: numpy.ndarray)[source]¶ Apply the transform on coordinates.
Parameters: coords (ndarray) – floating point array of shape Nx2. Each row is (x, y). Returns: ndarray – coordinates after apply the transformation. Note
The coordinates are not pixel indices. Coordinates on an image of shape (H, W) are in range [0, W] or [0, H].
-
apply_segmentation
(segmentation: numpy.ndarray) → numpy.ndarray[source]¶ Apply the transform on a full-image segmentation. By default will just perform “apply_image”.
Parameters: - segmentation (ndarray) – of shape HxW. The array should have integer
- bool dtype. (or) –
Returns: ndarray – segmentation after apply the transformation.
-
apply_box
(box: numpy.ndarray) → numpy.ndarray[source]¶ Apply the transform on an axis-aligned box. By default will transform the corner points and use their minimum/maximum to create a new axis-aligned box. Note that this default may change the size of your box, e.g. in rotations.
Parameters: box (ndarray) – Nx4 floating point array of XYXY format in absolute coordinates. Returns: ndarray – box after apply the transformation. Note
The coordinates are not pixel indices. Coordinates on an image of shape (H, W) are in range [0, W] or [0, H].
-
apply_polygons
(polygons: list) → list[source]¶ Apply the transform on a list of polygons, each represented by a Nx2 array. By default will just transform all the points.
Parameters: polygon (list[ndarray]) – each is a Nx2 floating point array of (x, y) format in absolute coordinates. Returns: list[ndarray] – polygon after apply the transformation. Note
The coordinates are not pixel indices. Coordinates on an image of shape (H, W) are in range [0, W] or [0, H].
-
classmethod
register_type
(data_type: str, func: Callable)[source]¶ Register the given function as a handler that this transform will use for a specific data type.
Parameters: - data_type (str) – the name of the data type (e.g., box)
- func (callable) – takes a transform and a data, returns the transformed data.
Examples:
def func(flip_transform, voxel_data): return transformed_voxel_data HFlipTransform.register_type("voxel", func) # ... transform = HFlipTransform(...) transform.apply_voxel(voxel_data) # func will be called
-
-
class
detectron2.data.transforms.
TransformList
(transforms: list)[source]¶ Bases:
object
Maintain a list of transform operations which will be applied in sequence. .. attribute:: transforms
type: list[Transform] -
__init__
(transforms: list)[source]¶ Parameters: transforms (list[Transform]) – list of transforms to perform.
-
__add__
(other: fvcore.transforms.transform.TransformList) → fvcore.transforms.transform.TransformList[source]¶ Parameters: other (TransformList) – transformation to add. Returns: TransformList – list of transforms.
-
__iadd__
(other: fvcore.transforms.transform.TransformList) → fvcore.transforms.transform.TransformList[source]¶ Parameters: other (TransformList) – transformation to add. Returns: TransformList – list of transforms.
-
__radd__
(other: fvcore.transforms.transform.TransformList) → fvcore.transforms.transform.TransformList[source]¶ Parameters: other (TransformList) – transformation to add. Returns: TransformList – list of transforms.
-
-
class
detectron2.data.transforms.
RandomBrightness
(intensity_min, intensity_max)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Randomly transforms image brightness.
Brightness intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce brightness - intensity = 1 will preserve the input image - intensity > 1 will increase brightness
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
class
detectron2.data.transforms.
RandomContrast
(intensity_min, intensity_max)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Randomly transforms image contrast.
Contrast intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce contrast - intensity = 1 will preserve the input image - intensity > 1 will increase contrast
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
class
detectron2.data.transforms.
RandomCrop
(crop_type: str, crop_size)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Randomly crop a subimage out of an image.
-
class
detectron2.data.transforms.
RandomExtent
(scale_range, shift_range)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Outputs an image by cropping a random “subrect” of the source image.
The subrect can be parameterized to include pixels outside the source image, in which case they will be set to zeros (i.e. black). The size of the output image will vary with the size of the random subrect.
-
__init__
(scale_range, shift_range)[source]¶ Parameters: - output_size (h, w) – Dimensions of output image
- scale_range (l, h) – Range of input-to-output size scaling factor
- shift_range (x, y) – Range of shifts of the cropped subrect. The rect is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)], where (w, h) is the (width, height) of the input image. Set each component to zero to crop at the image’s center.
-
-
class
detectron2.data.transforms.
RandomFlip
(prob=0.5)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Flip the image horizontally with the given probability.
TODO Vertical flip to be implemented.
-
class
detectron2.data.transforms.
RandomSaturation
(intensity_min, intensity_max)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Randomly transforms image saturation.
Saturation intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce saturation (make the image more grayscale) - intensity = 1 will preserve the input image - intensity > 1 will increase saturation
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
-
class
detectron2.data.transforms.
RandomLighting
(scale)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Randomly transforms image color using fixed PCA over ImageNet.
The degree of color jittering is randomly sampled via a normal distribution, with standard deviation given by the scale parameter.
-
class
detectron2.data.transforms.
Resize
(shape, interp=2)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Resize image to a target size
-
class
detectron2.data.transforms.
ResizeShortestEdge
(short_edge_length, max_size=9223372036854775807, sample_style='range', interp=2)[source]¶ Bases:
detectron2.data.transforms.transform_gen.TransformGen
Scale the shorter edge to the given size, with a limit of max_size on the longer edge. If max_size is reached, then downscale so that the longer edge does not exceed max_size.
-
__init__
(short_edge_length, max_size=9223372036854775807, sample_style='range', interp=2)[source]¶ Parameters: - short_edge_length (list[int]) – If
sample_style=="range"
, a [min, max] interval from which to sample the shortest edge length. Ifsample_style=="choice"
, a list of shortest edge lengths to sample from. - max_size (int) – maximum allowed longest edge length.
- sample_style (str) – either “range” or “choice”.
- short_edge_length (list[int]) – If
-
-
class
detectron2.data.transforms.
TransformGen
[source]¶ Bases:
object
TransformGen takes an image of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255] as input.
It creates a
Transform
based on the given image, sometimes with randomness. The transform can then be used to transform images or other data (boxes, points, annotations, etc.) associated with it.The assumption made in this class is that the image itself is sufficient to instantiate a transform. When this assumption is not true, you need to create the transforms by your own.
A list of TransformGen can be applied with
apply_transform_gens()
.-
__repr__
()[source]¶ Produce something like: “MyTransformGen(field1={self.field1}, field2={self.field2})”
-
__str__
()¶ Produce something like: “MyTransformGen(field1={self.field1}, field2={self.field2})”
-
-
detectron2.data.transforms.
apply_transform_gens
(transform_gens, img)[source]¶ Apply a list of
TransformGen
on the input image, and returns the transformed image and a list of transforms.We cannot simply create and return all transforms without applying it to the image, because a subsequent transform may need the output of the previous one.
Parameters: - transform_gens (list) – list of
TransformGen
instance to be applied. - img (ndarray) – uint8 or floating point images with 1 or 3 channels.
Returns: ndarray – the transformed image TransformList: contain the transforms that’s used.
- transform_gens (list) – list of