Source code for detectron2.utils.analysis

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# -*- coding: utf-8 -*-

import logging
import typing
import torch
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn

from detectron2.structures import BitMasks, Boxes, ImageList, Instances

from .logger import log_first_n

__all__ = [
    "activation_count_operators",
    "flop_count_operators",
    "parameter_count_table",
    "parameter_count",
]

FLOPS_MODE = "flops"
ACTIVATIONS_MODE = "activations"


# some extra ops to ignore from counting.
_IGNORED_OPS = [
    "aten::add",
    "aten::add_",
    "aten::batch_norm",
    "aten::constant_pad_nd",
    "aten::div",
    "aten::div_",
    "aten::exp",
    "aten::log2",
    "aten::max_pool2d",
    "aten::meshgrid",
    "aten::mul",
    "aten::mul_",
    "aten::nonzero_numpy",
    "aten::relu",
    "aten::relu_",
    "aten::rsub",
    "aten::sigmoid",
    "aten::sigmoid_",
    "aten::softmax",
    "aten::sort",
    "aten::sqrt",
    "aten::sub",
    "aten::upsample_nearest2d",
    "prim::PythonOp",
    "torchvision::nms",
]


[docs]def flop_count_operators( model: nn.Module, inputs: list, **kwargs ) -> typing.DefaultDict[str, float]: """ Implement operator-level flops counting using jit. This is a wrapper of fvcore.nn.flop_count, that supports standard detection models in detectron2. Note: The function runs the input through the model to compute flops. The flops of a detection model is often input-dependent, for example, the flops of box & mask head depends on the number of proposals & the number of detected objects. Therefore, the flops counting using a single input may not accurately reflect the computation cost of a model. Args: model: a detectron2 model that takes `list[dict]` as input. inputs (list[dict]): inputs to model, in detectron2's standard format. """ return _wrapper_count_operators(model=model, inputs=inputs, mode=FLOPS_MODE, **kwargs)
[docs]def activation_count_operators( model: nn.Module, inputs: list, **kwargs ) -> typing.DefaultDict[str, float]: """ Implement operator-level activations counting using jit. This is a wrapper of fvcore.nn.activation_count, that supports standard detection models in detectron2. Note: The function runs the input through the model to compute activations. The activations of a detection model is often input-dependent, for example, the activations of box & mask head depends on the number of proposals & the number of detected objects. Args: model: a detectron2 model that takes `list[dict]` as input. inputs (list[dict]): inputs to model, in detectron2's standard format. """ return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
def _flatten_to_tuple(outputs): result = [] if isinstance(outputs, torch.Tensor): result.append(outputs) elif isinstance(outputs, (list, tuple)): for v in outputs: result.extend(_flatten_to_tuple(v)) elif isinstance(outputs, dict): for _, v in outputs.items(): result.extend(_flatten_to_tuple(v)) elif isinstance(outputs, Instances): result.extend(_flatten_to_tuple(outputs.get_fields())) elif isinstance(outputs, (Boxes, BitMasks, ImageList)): result.append(outputs.tensor) else: log_first_n( logging.WARN, f"Output of type {type(outputs)} not included in flops/activations count.", n=10, ) return tuple(result) def _wrapper_count_operators( model: nn.Module, inputs: list, mode: str, **kwargs ) -> typing.DefaultDict[str, float]: # ignore some ops supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS} supported_ops.update(kwargs.pop("supported_ops", {})) kwargs["supported_ops"] = supported_ops assert len(inputs) == 1, "Please use batch size=1" tensor_input = inputs[0]["image"] class WrapModel(nn.Module): def __init__(self, model): super().__init__() if isinstance( model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel) ): self.model = model.module else: self.model = model def forward(self, image): # jit requires the input/output to be Tensors inputs = [{"image": image}] outputs = self.model.forward(inputs) # Only the subgraph that computes the returned tuple of tensor will be # counted. So we flatten everything we found to tuple of tensors. return _flatten_to_tuple(outputs) old_train = model.training with torch.no_grad(): if mode == FLOPS_MODE: ret = flop_count(WrapModel(model).train(False), (tensor_input,), **kwargs) elif mode == ACTIVATIONS_MODE: ret = activation_count(WrapModel(model).train(False), (tensor_input,), **kwargs) else: raise NotImplementedError("Count for mode {} is not supported yet.".format(mode)) # compatible with change in fvcore if isinstance(ret, tuple): ret = ret[0] model.train(old_train) return ret