detectron2.structures

class detectron2.structures.Boxes(tensor: torch.Tensor)

Bases: object

This structure stores a list of boxes as a Nx4 torch.Tensor. It supports some common methods about boxes (area, clip, nonempty, etc), and also behaves like a Tensor (support indexing, to(device), .device, and iteration over all boxes)

tensor

float matrix of Nx4. Each row is (x1, y1, x2, y2).

Type

torch.Tensor

__getitem__(item)detectron2.structures.Boxes
Parameters

item – int, slice, or a BoolTensor

Returns

Boxes – Create a new Boxes by indexing.

The following usage are allowed:

  1. new_boxes = boxes[3]: return a Boxes which contains only one box.

  2. new_boxes = boxes[2:10]: return a slice of boxes.

  3. new_boxes = boxes[vector], where vector is a torch.BoolTensor with length = len(boxes). Nonzero elements in the vector will be selected.

Note that the returned Boxes might share storage with this Boxes, subject to Pytorch’s indexing semantics.

__init__(tensor: torch.Tensor)
Parameters

tensor (Tensor[float]) – a Nx4 matrix. Each row is (x1, y1, x2, y2).

__iter__()

Yield a box as a Tensor of shape (4,) at a time.

area()torch.Tensor

Computes the area of all the boxes.

Returns

torch.Tensor – a vector with areas of each box.

classmethod cat(boxes_list: List[Boxes])detectron2.structures.Boxes[source]

Concatenates a list of Boxes into a single Boxes

Parameters

boxes_list (list[Boxes]) –

Returns

Boxes – the concatenated Boxes

clip(box_size: Tuple[int, int])None

Clip (in place) the boxes by limiting x coordinates to the range [0, width] and y coordinates to the range [0, height].

Parameters

box_size (height, width) – The clipping box’s size.

clone()detectron2.structures.Boxes

Clone the Boxes.

Returns

Boxes

property device
get_centers()torch.Tensor
Returns

The box centers in a Nx2 array of (x, y).

inside_box(box_size: Tuple[int, int], boundary_threshold: int = 0)torch.Tensor
Parameters
  • box_size (height, width) – Size of the reference box.

  • boundary_threshold (int) – Boxes that extend beyond the reference box boundary by more than boundary_threshold are considered “outside”.

Returns

a binary vector, indicating whether each box is inside the reference box.

nonempty(threshold: float = 0.0)torch.Tensor

Find boxes that are non-empty. A box is considered empty, if either of its side is no larger than threshold.

Returns

Tensor – a binary vector which represents whether each box is empty (False) or non-empty (True).

scale(scale_x: float, scale_y: float)None

Scale the box with horizontal and vertical scaling factors

to(device: torch.device)
class detectron2.structures.BoxMode(value)

Bases: enum.IntEnum

Enum of different ways to represent a box.

XYXY_ABS = 0
XYWH_ABS = 1
XYXY_REL = 2
XYWH_REL = 3
XYWHA_ABS = 4
static convert(box: Union[List[float], Tuple[float, ], torch.Tensor, numpy.ndarray], from_mode: detectron2.structures.BoxMode, to_mode: detectron2.structures.BoxMode) → Union[List[float], Tuple[float, ], torch.Tensor, numpy.ndarray][source]
Parameters
  • box – can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5

  • from_mode (BoxMode) –

  • to_mode (BoxMode) –

Returns

The converted box of the same type.

detectron2.structures.pairwise_iou(boxes1: detectron2.structures.Boxes, boxes2: detectron2.structures.Boxes)torch.Tensor

Given two lists of boxes of size N and M, compute the IoU (intersection over union) between all N x M pairs of boxes. The box order must be (xmin, ymin, xmax, ymax).

Parameters
  • boxes1 (Boxes) – two Boxes. Contains N & M boxes, respectively.

  • boxes2 (Boxes) – two Boxes. Contains N & M boxes, respectively.

Returns

Tensor – IoU, sized [N,M].

detectron2.structures.pairwise_ioa(boxes1: detectron2.structures.Boxes, boxes2: detectron2.structures.Boxes)torch.Tensor

Similar to pariwise_iou() but compute the IoA (intersection over boxes2 area).

Parameters
  • boxes1 (Boxes) – two Boxes. Contains N & M boxes, respectively.

  • boxes2 (Boxes) – two Boxes. Contains N & M boxes, respectively.

Returns

Tensor – IoA, sized [N,M].

detectron2.structures.pairwise_point_box_distance(points: torch.Tensor, boxes: detectron2.structures.Boxes)

Pairwise distance between N points and M boxes. The distance between a point and a box is represented by the distance from the point to 4 edges of the box. Distances are all positive when the point is inside the box.

Parameters
  • points – Nx2 coordinates. Each row is (x, y)

  • boxes – M boxes

Returns

Tensor

distances of size (N, M, 4). The 4 values are distances from

the point to the left, top, right, bottom of the box.

class detectron2.structures.ImageList(tensor: torch.Tensor, image_sizes: List[Tuple[int, int]])

Bases: object

Structure that holds a list of images (of possibly varying sizes) as a single tensor. This works by padding the images to the same size. The original sizes of each image is stored in image_sizes.

image_sizes

each tuple is (h, w). During tracing, it becomes list[Tensor] instead.

Type

list[tuple[int, int]]

__getitem__(idx)torch.Tensor

Access the individual image in its original size.

Parameters

idx – int or slice

Returns

Tensor – an image of shape (H, W) or (C_1, …, C_K, H, W) where K >= 1

__init__(tensor: torch.Tensor, image_sizes: List[Tuple[int, int]])
Parameters
  • tensor (Tensor) – of shape (N, H, W) or (N, C_1, …, C_K, H, W) where K >= 1

  • image_sizes (list[tuple[int, int]]) – Each tuple is (h, w). It can be smaller than (H, W) due to padding.

property device
static from_tensors(tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0, padding_constraints: Optional[Dict[str, int]] = None)detectron2.structures.ImageList[source]
Parameters
  • tensors – a tuple or list of torch.Tensor, each of shape (Hi, Wi) or (C_1, …, C_K, Hi, Wi) where K >= 1. The Tensors will be padded to the same shape with pad_value.

  • size_divisibility (int) – If size_divisibility > 0, add padding to ensure the common height and width is divisible by size_divisibility. This depends on the model and many models need a divisibility of 32.

  • pad_value (float) – value to pad.

  • padding_constraints (optional[Dict]) – If given, it would follow the format as {“size_divisibility”: int, “square_size”: int}, where size_divisibility will overwrite the above one if presented and square_size indicates the square padding size if square_size > 0.

Returns

an ImageList.

to(*args: Any, **kwargs: Any)detectron2.structures.ImageList
class detectron2.structures.Instances(image_size: Tuple[int, int], **kwargs: Any)

Bases: object

This class represents a list of instances in an image. It stores the attributes of instances (e.g., boxes, masks, labels, scores) as “fields”. All fields must have the same __len__ which is the number of instances.

All other (non-field) attributes of this class are considered private: they must start with ‘_’ and are not modifiable by a user.

Some basic usage:

  1. Set/get/check a field:

    instances.gt_boxes = Boxes(...)
    print(instances.pred_masks)  # a tensor of shape (N, H, W)
    print('gt_masks' in instances)
    
  2. len(instances) returns the number of instances

  3. Indexing: instances[indices] will apply the indexing on all the fields and returns a new Instances. Typically, indices is a integer vector of indices, or a binary mask of length num_instances

    category_3_detections = instances[instances.pred_classes == 3]
    confident_detections = instances[instances.scores > 0.9]
    
__getitem__(item: Union[int, slice, torch.BoolTensor])detectron2.structures.Instances
Parameters

item – an index-like object and will be used to index all the fields.

Returns

If item is a string, return the data in the corresponding field. Otherwise, returns an Instances where all fields are indexed by item.

__init__(image_size: Tuple[int, int], **kwargs: Any)
Parameters
  • image_size (height, width) – the spatial size of the image.

  • kwargs – fields to add to this Instances.

static cat(instance_lists: List[Instances])detectron2.structures.Instances[source]
Parameters

instance_lists (list[Instances]) –

Returns

Instances

get(name: str) → Any

Returns the field called name.

get_fields() → Dict[str, Any]
Returns

dict – a dict which maps names (str) to data of the fields

Modifying the returned dict will modify this instance.

has(name: str)bool
Returns

bool – whether the field called name exists.

property image_size

Returns: tuple: height, width

remove(name: str)None

Remove the field called name.

set(name: str, value: Any)None

Set the field named name to value. The length of value must be the number of instances, and must agree with other existing fields in this object.

to(*args: Any, **kwargs: Any)detectron2.structures.Instances
Returns

Instances – all fields are called with a to(device), if the field has this method.

class detectron2.structures.Keypoints(keypoints: Union[torch.Tensor, numpy.ndarray, List[List[float]]])

Bases: object

Stores keypoint annotation data. GT Instances have a gt_keypoints property containing the x,y location and visibility flag of each keypoint. This tensor has shape (N, K, 3) where N is the number of instances and K is the number of keypoints per instance.

The visibility flag follows the COCO format and must be one of three integers:

  • v=0: not labeled (in which case x=y=0)

  • v=1: labeled but not visible

  • v=2: labeled and visible

__getitem__(item: Union[int, slice, torch.BoolTensor])detectron2.structures.Keypoints

Create a new Keypoints by indexing on this Keypoints.

The following usage are allowed:

  1. new_kpts = kpts[3]: return a Keypoints which contains only one instance.

  2. new_kpts = kpts[2:10]: return a slice of key points.

  3. new_kpts = kpts[vector], where vector is a torch.ByteTensor with length = len(kpts). Nonzero elements in the vector will be selected.

Note that the returned Keypoints might share storage with this Keypoints, subject to Pytorch’s indexing semantics.

__init__(keypoints: Union[torch.Tensor, numpy.ndarray, List[List[float]]])
Parameters

keypoints – A Tensor, numpy array, or list of the x, y, and visibility of each keypoint. The shape should be (N, K, 3) where N is the number of instances, and K is the number of keypoints per instance.

static cat(keypoints_list: List[Keypoints])detectron2.structures.Keypoints[source]

Concatenates a list of Keypoints into a single Keypoints

Parameters

keypoints_list (list[Keypoints]) –

Returns

Keypoints – the concatenated Keypoints

property device
to(*args: Any, **kwargs: Any)detectron2.structures.Keypoints
to_heatmap(boxes: torch.Tensor, heatmap_size: int)torch.Tensor

Convert keypoint annotations to a heatmap of one-hot labels for training, as described in Mask R-CNN.

Parameters

boxes – Nx4 tensor, the boxes to draw the keypoints to

Returns

heatmaps – A tensor of shape (N, K), each element is integer spatial label

in the range [0, heatmap_size**2 - 1] for each keypoint in the input.

valid:

A tensor of shape (N, K) containing whether each keypoint is in the roi or not.

detectron2.structures.heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor)torch.Tensor

Extract predicted keypoint locations from heatmaps.

Parameters
  • maps (Tensor) – (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for each ROI and each keypoint.

  • rois (Tensor) – (#ROIs, 4). The box of each ROI.

Returns

Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to (x, y, logit, score) for each keypoint.

When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate, we maintain consistency with Keypoints.to_heatmap() by using the conversion from Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.

class detectron2.structures.BitMasks(tensor: Union[torch.Tensor, numpy.ndarray])

Bases: object

This class stores the segmentation masks for all objects in one image, in the form of bitmaps.

tensor

bool Tensor of N,H,W, representing N instances in the image.

__getitem__(item: Union[int, slice, torch.BoolTensor])detectron2.structures.BitMasks
Returns

BitMasks – Create a new BitMasks by indexing.

The following usage are allowed:

  1. new_masks = masks[3]: return a BitMasks which contains only one mask.

  2. new_masks = masks[2:10]: return a slice of masks.

  3. new_masks = masks[vector], where vector is a torch.BoolTensor with length = len(masks). Nonzero elements in the vector will be selected.

Note that the returned object might share storage with this object, subject to Pytorch’s indexing semantics.

__init__(tensor: Union[torch.Tensor, numpy.ndarray])
Parameters

tensor – bool Tensor of N,H,W, representing N instances in the image.

static cat(bitmasks_list: List[BitMasks])detectron2.structures.BitMasks[source]

Concatenates a list of BitMasks into a single BitMasks

Parameters

bitmasks_list (list[BitMasks]) –

Returns

BitMasks – the concatenated BitMasks

crop_and_resize(boxes: torch.Tensor, mask_size: int)torch.Tensor

Crop each bitmask by the given box, and resize results to (mask_size, mask_size). This can be used to prepare training targets for Mask R-CNN. It has less reconstruction error compared to rasterization with polygons. However we observe no difference in accuracy, but BitMasks requires more memory to store all the masks.

Parameters
  • boxes (Tensor) – Nx4 tensor storing the boxes for each mask

  • mask_size (int) – the size of the rasterized mask.

Returns

Tensor – A bool tensor of shape (N, mask_size, mask_size), where N is the number of predicted boxes for this image.

property device
static from_polygon_masks(polygon_masks: Union[PolygonMasks, List[List[numpy.ndarray]]], height: int, width: int)detectron2.structures.BitMasks[source]
Parameters
static from_roi_masks(roi_masks: detectron2.structures.ROIMasks, height: int, width: int)detectron2.structures.BitMasks[source]
Parameters
  • roi_masks

  • height (int) –

  • width (int) –

get_bounding_boxes()detectron2.structures.Boxes
Returns

Boxes – tight bounding boxes around bitmasks. If a mask is empty, it’s bounding box will be all zero.

nonempty()torch.Tensor

Find masks that are non-empty.

Returns

Tensor

a BoolTensor which represents

whether each mask is empty (False) or non-empty (True).

to(*args: Any, **kwargs: Any)detectron2.structures.BitMasks
class detectron2.structures.PolygonMasks(polygons: List[List[Union[torch.Tensor, numpy.ndarray]]])

Bases: object

This class stores the segmentation masks for all objects in one image, in the form of polygons.

polygons

list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.

__getitem__(item: Union[int, slice, List[int], torch.BoolTensor])detectron2.structures.PolygonMasks

Support indexing over the instances and return a PolygonMasks object. item can be:

  1. An integer. It will return an object with only one instance.

  2. A slice. It will return an object with the selected instances.

  3. A list[int]. It will return an object with the selected instances, correpsonding to the indices in the list.

  4. A vector mask of type BoolTensor, whose length is num_instances. It will return an object with the instances whose mask is nonzero.

__init__(polygons: List[List[Union[torch.Tensor, numpy.ndarray]]])
Parameters

polygons (list[list[np.ndarray]]) – The first level of the list correspond to individual instances, the second level to all the polygons that compose the instance, and the third level to the polygon coordinates. The third level array should have the format of [x0, y0, x1, y1, …, xn, yn] (n >= 3).

__iter__() → Iterator[List[numpy.ndarray]]
Yields

list[ndarray] – the polygons for one instance. Each Tensor is a float64 vector representing a polygon.

area()

Computes area of the mask. Only works with Polygons, using the shoelace formula: https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates

Returns

Tensor – a vector, area for each instance

static cat(polymasks_list: List[PolygonMasks])detectron2.structures.PolygonMasks[source]

Concatenates a list of PolygonMasks into a single PolygonMasks

Parameters

polymasks_list (list[PolygonMasks]) –

Returns

PolygonMasks – the concatenated PolygonMasks

crop_and_resize(boxes: torch.Tensor, mask_size: int)torch.Tensor

Crop each mask by the given box, and resize results to (mask_size, mask_size). This can be used to prepare training targets for Mask R-CNN.

Parameters
  • boxes (Tensor) – Nx4 tensor storing the boxes for each mask

  • mask_size (int) – the size of the rasterized mask.

Returns

Tensor – A bool tensor of shape (N, mask_size, mask_size), where N is the number of predicted boxes for this image.

property device
get_bounding_boxes()detectron2.structures.Boxes
Returns

Boxes – tight bounding boxes around polygon masks.

nonempty()torch.Tensor

Find masks that are non-empty.

Returns

Tensor – a BoolTensor which represents whether each mask is empty (False) or not (True).

to(*args: Any, **kwargs: Any)detectron2.structures.PolygonMasks
detectron2.structures.polygons_to_bitmask(polygons: List[numpy.ndarray], height: int, width: int)numpy.ndarray
Parameters
  • polygons (list[ndarray]) – each array has shape (Nx2,)

  • height (int) –

  • width (int) –

Returns

ndarray – a bool mask of shape (height, width)

class detectron2.structures.ROIMasks(tensor: torch.Tensor)

Bases: object

Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given, full-image bitmask can be obtained by “pasting” the mask on the region defined by the corresponding ROI box.

__getitem__(item)detectron2.structures.ROIMasks
Returns

ROIMasks – Create a new ROIMasks by indexing.

The following usage are allowed:

  1. new_masks = masks[2:10]: return a slice of masks.

  2. new_masks = masks[vector], where vector is a torch.BoolTensor with length = len(masks). Nonzero elements in the vector will be selected.

Note that the returned object might share storage with this object, subject to Pytorch’s indexing semantics.

__init__(tensor: torch.Tensor)
Parameters

tensor – (N, M, M) mask tensor that defines the mask within each ROI.

property device
to(device: torch.device)detectron2.structures.ROIMasks
to_bitmasks(boxes: torch.Tensor, height, width, threshold=0.5)

Args: see documentation of paste_masks_in_image().

class detectron2.structures.RotatedBoxes(tensor: torch.Tensor)

Bases: detectron2.structures.Boxes

This structure stores a list of rotated boxes as a Nx5 torch.Tensor. It supports some common methods about boxes (area, clip, nonempty, etc), and also behaves like a Tensor (support indexing, to(device), .device, and iteration over all boxes)

__getitem__(item)detectron2.structures.RotatedBoxes
Returns

RotatedBoxes – Create a new RotatedBoxes by indexing.

The following usage are allowed:

  1. new_boxes = boxes[3]: return a RotatedBoxes which contains only one box.

  2. new_boxes = boxes[2:10]: return a slice of boxes.

  3. new_boxes = boxes[vector], where vector is a torch.ByteTensor with length = len(boxes). Nonzero elements in the vector will be selected.

Note that the returned RotatedBoxes might share storage with this RotatedBoxes, subject to Pytorch’s indexing semantics.

__init__(tensor: torch.Tensor)
Parameters

tensor (Tensor[float]) – a Nx5 matrix. Each row is (x_center, y_center, width, height, angle), in which angle is represented in degrees. While there’s no strict range restriction for it, the recommended principal range is between [-180, 180) degrees.

Assume we have a horizontal box B = (x_center, y_center, width, height), where width is along the x-axis and height is along the y-axis. The rotated box B_rot (x_center, y_center, width, height, angle) can be seen as:

  1. When angle == 0: B_rot == B

  2. When angle > 0: B_rot is obtained by rotating B w.r.t its center by \(|angle|\) degrees CCW;

  3. When angle < 0: B_rot is obtained by rotating B w.r.t its center by \(|angle|\) degrees CW.

Mathematically, since the right-handed coordinate system for image space is (y, x), where y is top->down and x is left->right, the 4 vertices of the rotated rectangle \((yr_i, xr_i)\) (i = 1, 2, 3, 4) can be obtained from the vertices of the horizontal rectangle \((y_i, x_i)\) (i = 1, 2, 3, 4) in the following way (\(\theta = angle*\pi/180\) is the angle in radians, \((y_c, x_c)\) is the center of the rectangle):

\[ \begin{align}\begin{aligned}yr_i = \cos(\theta) (y_i - y_c) - \sin(\theta) (x_i - x_c) + y_c,\\xr_i = \sin(\theta) (y_i - y_c) + \cos(\theta) (x_i - x_c) + x_c,\end{aligned}\end{align} \]

which is the standard rigid-body rotation transformation.

Intuitively, the angle is (1) the rotation angle from y-axis in image space to the height vector (top->down in the box’s local coordinate system) of the box in CCW, and (2) the rotation angle from x-axis in image space to the width vector (left->right in the box’s local coordinate system) of the box in CCW.

More intuitively, consider the following horizontal box ABCD represented in (x1, y1, x2, y2): (3, 2, 7, 4), covering the [3, 7] x [2, 4] region of the continuous coordinate system which looks like this:

O--------> x
|
|  A---B
|  |   |
|  D---C
|
v y

Note that each capital letter represents one 0-dimensional geometric point instead of a ‘square pixel’ here.

In the example above, using (x, y) to represent a point we have:

\[O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)\]

We name vector AB = vector DC as the width vector in box’s local coordinate system, and vector AD = vector BC as the height vector in box’s local coordinate system. Initially, when angle = 0 degree, they’re aligned with the positive directions of x-axis and y-axis in the image space, respectively.

For better illustration, we denote the center of the box as E,

O--------> x
|
|  A---B
|  | E |
|  D---C
|
v y

where the center E = ((3+7)/2, (2+4)/2) = (5, 3).

Also,

\[width = |AB| = |CD| = 7 - 3 = 4, height = |AD| = |BC| = 4 - 2 = 2.\]

Therefore, the corresponding representation for the same shape in rotated box in (x_center, y_center, width, height, angle) format is:

(5, 3, 4, 2, 0),

Now, let’s consider (5, 3, 4, 2, 90), which is rotated by 90 degrees CCW (counter-clockwise) by definition. It looks like this:

O--------> x
|   B-C
|   | |
|   |E|
|   | |
|   A-D
v y

The center E is still located at the same point (5, 3), while the vertices ABCD are rotated by 90 degrees CCW with regard to E: A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)

Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to vector AD or vector BC (the top->down height vector in box’s local coordinate system), or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right width vector in box’s local coordinate system).

\[width = |AB| = |CD| = 5 - 1 = 4, height = |AD| = |BC| = 6 - 4 = 2.\]

Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise) by definition? It looks like this:

O--------> x
|   D-A
|   | |
|   |E|
|   | |
|   C-B
v y

The center E is still located at the same point (5, 3), while the vertices ABCD are rotated by 90 degrees CW with regard to E: A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)

\[width = |AB| = |CD| = 5 - 1 = 4, height = |AD| = |BC| = 6 - 4 = 2.\]

This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU will be 1. However, these two will generate different RoI Pooling results and should not be treated as an identical box.

On the other hand, it’s easy to see that (X, Y, W, H, A) is identical to (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is equivalent to rotating the same shape 90 degrees CW.

We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):

O--------> x
|
|  C---D
|  | E |
|  B---A
|
v y
\[ \begin{align}\begin{aligned}A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),\\width = |AB| = |CD| = 7 - 3 = 4, height = |AD| = |BC| = 4 - 2 = 2.\end{aligned}\end{align} \]

Finally, this is a very inaccurate (heavily quantized) illustration of how (5, 3, 4, 2, 60) looks like in case anyone wonders:

O--------> x
|     B            |    /  C
|   /E /
|  A  /
|   `D
v y

It’s still a rectangle with center of (5, 3), width of 4 and height of 2, but its angle (and thus orientation) is somewhere between (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).

__iter__()

Yield a box as a Tensor of shape (5,) at a time.

area()torch.Tensor

Computes the area of all the boxes.

Returns

torch.Tensor – a vector with areas of each box.

classmethod cat(boxes_list: List[RotatedBoxes])detectron2.structures.RotatedBoxes[source]

Concatenates a list of RotatedBoxes into a single RotatedBoxes

Parameters

boxes_list (list[RotatedBoxes]) –

Returns

RotatedBoxes – the concatenated RotatedBoxes

clip(box_size: Tuple[int, int], clip_angle_threshold: float = 1.0)None

Clip (in place) the boxes by limiting x coordinates to the range [0, width] and y coordinates to the range [0, height].

For RRPN: Only clip boxes that are almost horizontal with a tolerance of clip_angle_threshold to maintain backward compatibility.

Rotated boxes beyond this threshold are not clipped for two reasons:

  1. There are potentially multiple ways to clip a rotated box to make it fit within the image.

  2. It’s tricky to make the entire rectangular box fit within the image and still be able to not leave out pixels of interest.

Therefore we rely on ops like RoIAlignRotated to safely handle this.

Parameters
  • box_size (height, width) – The clipping box’s size.

  • clip_angle_threshold – Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees), we do the clipping as horizontal boxes.

clone()detectron2.structures.RotatedBoxes

Clone the RotatedBoxes.

Returns

RotatedBoxes

property device
get_centers()torch.Tensor
Returns

The box centers in a Nx2 array of (x, y).

inside_box(box_size: Tuple[int, int], boundary_threshold: int = 0)torch.Tensor
Parameters
  • box_size (height, width) – Size of the reference box covering [0, width] x [0, height]

  • boundary_threshold (int) – Boxes that extend beyond the reference box boundary by more than boundary_threshold are considered “outside”.

For RRPN, it might not be necessary to call this function since it’s common for rotated box to extend to outside of the image boundaries (the clip function only clips the near-horizontal boxes)

Returns

a binary vector, indicating whether each box is inside the reference box.

nonempty(threshold: float = 0.0)torch.Tensor

Find boxes that are non-empty. A box is considered empty, if either of its side is no larger than threshold.

Returns

Tensor – a binary vector which represents whether each box is empty (False) or non-empty (True).

normalize_angles()None

Restrict angles to the range of [-180, 180) degrees

scale(scale_x: float, scale_y: float)None

Scale the rotated box with horizontal and vertical scaling factors Note: when scale_factor_x != scale_factor_y, the rotated box does not preserve the rectangular shape when the angle is not a multiple of 90 degrees under resize transformation. Instead, the shape is a parallelogram (that has skew) Here we make an approximation by fitting a rotated rectangle to the parallelogram.

to(device: torch.device)
detectron2.structures.pairwise_iou_rotated(boxes1: detectron2.structures.RotatedBoxes, boxes2: detectron2.structures.RotatedBoxes)None

Given two lists of rotated boxes of size N and M, compute the IoU (intersection over union) between all N x M pairs of boxes. The box order must be (x_center, y_center, width, height, angle).

Parameters
  • boxes1 (RotatedBoxes) – two RotatedBoxes. Contains N & M rotated boxes, respectively.

  • boxes2 (RotatedBoxes) – two RotatedBoxes. Contains N & M rotated boxes, respectively.

Returns

Tensor – IoU, sized [N,M].