detectron2.evaluation

class detectron2.evaluation.CityscapesInstanceEvaluator(dataset_name)[source]

Bases: detectron2.evaluation.cityscapes_evaluation.CityscapesEvaluator

Evaluate instance segmentation results on cityscapes dataset using cityscapes API.

Note

  • It does not work in multi-machine distributed training.

  • It contains a synchronization, therefore has to be used on all ranks.

  • Only the main process runs evaluation.

process(inputs, outputs)[source]
evaluate()[source]
Returns

dict – has a key “segm”, whose value is a dict of “AP” and “AP50”.

class detectron2.evaluation.CityscapesSemSegEvaluator(dataset_name)[source]

Bases: detectron2.evaluation.cityscapes_evaluation.CityscapesEvaluator

Evaluate semantic segmentation results on cityscapes dataset using cityscapes API.

Note

  • It does not work in multi-machine distributed training.

  • It contains a synchronization, therefore has to be used on all ranks.

  • Only the main process runs evaluation.

process(inputs, outputs)[source]
evaluate()[source]
class detectron2.evaluation.COCOEvaluator(dataset_name, tasks=None, distributed=True, output_dir=None, *, max_dets_per_image=None, use_fast_impl=True, kpt_oks_sigmas=())[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Evaluate AR for object proposals, AP for instance detection/segmentation, AP for keypoint detection outputs using COCO’s metrics. See http://cocodataset.org/#detection-eval and http://cocodataset.org/#keypoints-eval to understand its metrics. The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means the metric cannot be computed (e.g. due to no predictions made).

In addition to COCO, this evaluator is able to support any bounding box detection, instance segmentation, or keypoint detection dataset.

__init__(dataset_name, tasks=None, distributed=True, output_dir=None, *, max_dets_per_image=None, use_fast_impl=True, kpt_oks_sigmas=())[source]
Parameters
  • dataset_name (str) –

    name of the dataset to be evaluated. It must have either the following corresponding metadata:

    ”json_file”: the path to the COCO format annotation

    Or it must be in detectron2’s standard dataset format so it can be converted to COCO format automatically.

  • tasks (tuple[str]) – tasks that can be evaluated under the given configuration. A task is one of “bbox”, “segm”, “keypoints”. By default, will infer this automatically from predictions.

  • distributed (True) – if True, will collect results from all ranks and run evaluation in the main process. Otherwise, will only evaluate the results in the current process.

  • output_dir (str) –

    optional, an output directory to dump all results predicted on the dataset. The dump contains two files:

    1. ”instances_predictions.pth” a file that can be loaded with torch.load and contains all the results in the format they are produced by the model.

    2. ”coco_instances_results.json” a json file in COCO’s result format.

  • max_dets_per_image (int) – limit on the maximum number of detections per image. By default in COCO, this limit is to 100, but this can be customized to be greater, as is needed in evaluation metrics AP fixed and AP pool (see https://arxiv.org/pdf/2102.01066.pdf) This doesn’t affect keypoint evaluation.

  • use_fast_impl (bool) – use a fast but unofficial implementation to compute AP. Although the results should be very close to the official implementation in COCO API, it is still recommended to compute results with the official API for use in papers. The faster implementation also uses more RAM.

  • kpt_oks_sigmas (list[float]) – The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval When empty, it will use the defaults in COCO. Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.

reset()[source]
process(inputs, outputs)[source]
Parameters
  • inputs – the inputs to a COCO model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.

  • outputs – the outputs of a COCO model. It is a list of dicts with key “instances” that contains Instances.

evaluate(img_ids=None)[source]
Parameters

img_ids – a list of image IDs to evaluate on. Default to None for the whole dataset

class detectron2.evaluation.RotatedCOCOEvaluator(dataset_name, tasks=None, distributed=True, output_dir=None, *, max_dets_per_image=None, use_fast_impl=True, kpt_oks_sigmas=())[source]

Bases: detectron2.evaluation.coco_evaluation.COCOEvaluator

Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs, with rotated boxes support. Note: this uses IOU only and does not consider angle differences.

process(inputs, outputs)[source]
Parameters
  • inputs – the inputs to a COCO model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.

  • outputs – the outputs of a COCO model. It is a list of dicts with key “instances” that contains Instances.

instances_to_json(instances, img_id)[source]
class detectron2.evaluation.DatasetEvaluator[source]

Bases: object

Base class for a dataset evaluator.

The function inference_on_dataset() runs the model over all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.

This class will accumulate information of the inputs/outputs (by process()), and produce evaluation results in the end (by evaluate()).

reset()[source]

Preparation for a new round of evaluation. Should be called before starting a round of evaluation.

process(inputs, outputs)[source]

Process the pair of inputs and outputs. If they contain batches, the pairs can be consumed one-by-one using zip:

for input_, output in zip(inputs, outputs):
    # do evaluation on single input/output pair
    ...
Parameters
  • inputs (list) – the inputs that’s used to call the model.

  • outputs (list) – the return value of model(inputs)

evaluate()[source]

Evaluate/summarize the performance, after processing all input/output pairs.

Returns

dict – A new evaluator class can return a dict of arbitrary format as long as the user can process the results. In our train_net.py, we expect the following format:

  • key: the name of the task (e.g., bbox)

  • value: a dict of {metric name: score}, e.g.: {“AP50”: 80}

class detectron2.evaluation.DatasetEvaluators(evaluators)[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Wrapper class to combine multiple DatasetEvaluator instances.

This class dispatches every evaluation call to all of its DatasetEvaluator.

__init__(evaluators)[source]
Parameters

evaluators (list) – the evaluators to combine.

reset()[source]
process(inputs, outputs)[source]
evaluate()[source]
detectron2.evaluation.inference_context(model)[source]

A context where the model is temporarily changed to eval mode, and restored to previous mode afterwards.

Parameters

model – a torch Module

detectron2.evaluation.inference_on_dataset(model, data_loader, evaluator: Optional[Union[detectron2.evaluation.evaluator.DatasetEvaluator, List[detectron2.evaluation.evaluator.DatasetEvaluator]]])[source]

Run model on the data_loader and evaluate the metrics with evaluator. Also benchmark the inference speed of model.__call__ accurately. The model will be used in eval mode.

Parameters
  • model (callable) –

    a callable which takes an object from data_loader and returns some outputs.

    If it’s an nn.Module, it will be temporarily set to eval mode. If you wish to evaluate a model in training mode instead, you can wrap the given model and override its behavior of .eval() and .train().

  • data_loader – an iterable object with a length. The elements it generates will be the inputs to the model.

  • evaluator – the evaluator(s) to run. Use None if you only want to benchmark, but don’t want to do any evaluation.

Returns

The return value of evaluator.evaluate()

class detectron2.evaluation.LVISEvaluator(dataset_name, tasks=None, distributed=True, output_dir=None, *, max_dets_per_image=None)[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Evaluate object proposal and instance detection/segmentation outputs using LVIS’s metrics and evaluation API.

__init__(dataset_name, tasks=None, distributed=True, output_dir=None, *, max_dets_per_image=None)[source]
Parameters
  • dataset_name (str) – name of the dataset to be evaluated. It must have the following corresponding metadata: “json_file”: the path to the LVIS format annotation

  • tasks (tuple[str]) – tasks that can be evaluated under the given configuration. A task is one of “bbox”, “segm”. By default, will infer this automatically from predictions.

  • distributed (True) – if True, will collect results from all ranks for evaluation. Otherwise, will evaluate the results in the current process.

  • output_dir (str) – optional, an output directory to dump results.

  • max_dets_per_image (None or int) – limit on maximum detections per image in evaluating AP This limit, by default of the LVIS dataset, is 300.

reset()[source]
process(inputs, outputs)[source]
Parameters
  • inputs – the inputs to a LVIS model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”, “image_id”.

  • outputs – the outputs of a LVIS model. It is a list of dicts with key “instances” that contains Instances.

evaluate()[source]
class detectron2.evaluation.COCOPanopticEvaluator(dataset_name: str, output_dir: Optional[str] = None)[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Evaluate Panoptic Quality metrics on COCO using PanopticAPI. It saves panoptic segmentation prediction in output_dir

It contains a synchronize call and has to be called from all workers.

__init__(dataset_name: str, output_dir: Optional[str] = None)[source]
Parameters
  • dataset_name – name of the dataset

  • output_dir – output directory to save results for evaluation.

reset()[source]
process(inputs, outputs)[source]
evaluate()[source]
class detectron2.evaluation.PascalVOCDetectionEvaluator(dataset_name)[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Evaluate Pascal VOC style AP for Pascal VOC dataset. It contains a synchronization, therefore has to be called from all ranks.

Note that the concept of AP can be implemented in different ways and may not produce identical results. This class mimics the implementation of the official Pascal VOC Matlab API, and should produce similar but not identical results to the official API.

__init__(dataset_name)[source]
Parameters

dataset_name (str) – name of the dataset, e.g., “voc_2007_test”

reset()[source]
process(inputs, outputs)[source]
evaluate()[source]
Returns

dict – has a key “segm”, whose value is a dict of “AP”, “AP50”, and “AP75”.

class detectron2.evaluation.SemSegEvaluator(dataset_name, distributed=True, output_dir=None, *, num_classes=None, ignore_label=None)[source]

Bases: detectron2.evaluation.evaluator.DatasetEvaluator

Evaluate semantic segmentation metrics.

__init__(dataset_name, distributed=True, output_dir=None, *, num_classes=None, ignore_label=None)[source]
Parameters
  • dataset_name (str) – name of the dataset to be evaluated.

  • distributed (bool) – if True, will collect results from all ranks for evaluation. Otherwise, will evaluate the results in the current process.

  • output_dir (str) – an output directory to dump results.

  • num_classes – deprecated argument

  • ignore_label – deprecated argument

reset()[source]
process(inputs, outputs)[source]
Parameters
  • inputs – the inputs to a model. It is a list of dicts. Each dict corresponds to an image and contains keys like “height”, “width”, “file_name”.

  • outputs – the outputs of a model. It is either list of semantic segmentation predictions (Tensor [H, W]) or list of dicts with key “sem_seg” that contains semantic segmentation prediction in the same format.

evaluate()[source]

Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):

  • Mean intersection-over-union averaged across classes (mIoU)

  • Frequency Weighted IoU (fwIoU)

  • Mean pixel accuracy averaged across classes (mACC)

  • Pixel Accuracy (pACC)

encode_json_sem_seg(sem_seg, input_file_name)[source]

Convert semantic segmentation to COCO stuff format with segments encoded as RLEs. See http://cocodataset.org/#format-results

detectron2.evaluation.print_csv_format(results)[source]

Print main metrics in a format similar to Detectron, so that they are easy to copypaste into a spreadsheet.

Parameters

results (OrderedDict[dict]) – task_name -> {metric -> score} unordered dict can also be printed, but in arbitrary order

detectron2.evaluation.verify_results(cfg, results)[source]
Parameters

results (OrderedDict[dict]) – task_name -> {metric -> score}

Returns

bool – whether the verification succeeds or not