detectron2.export

Related tutorial: Deployment.

detectron2.export.add_export_config(cfg)[source]

Add options needed by caffe2 export.

Parameters

cfg (CfgNode) – a detectron2 config

Returns

CfgNode – an updated config with new options that will be used by Caffe2Tracer.

class detectron2.export.Caffe2Model(predict_net, init_net)[source]

Bases: torch.nn.Module

A wrapper around the traced model in Caffe2’s protobuf format. The exported graph has different inputs/outputs from the original Pytorch model, as explained in Caffe2Tracer. This class wraps around the exported graph to simulate the same interface as the original Pytorch model. It also provides functions to save/load models in Caffe2’s format.’

Examples:

c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
inputs = [{"image": img_tensor_CHW}]
outputs = c2_model(inputs)
orig_outputs = torch_model(inputs)
property predict_net

the underlying caffe2 predict net

Type

caffe2.core.Net

property init_net

the underlying caffe2 init net

Type

caffe2.core.Net

save_protobuf(output_dir)[source]

Save the model as caffe2’s protobuf format. It saves the following files:

  • “model.pb”: definition of the graph. Can be visualized with tools like netron.

  • “model_init.pb”: model parameters

  • “model.pbtxt”: human-readable definition of the graph. Not needed for deployment.

Parameters

output_dir (str) – the output directory to save protobuf files.

save_graph(output_file, inputs=None)[source]

Save the graph as SVG format.

Parameters
  • output_file (str) – a SVG file

  • inputs – optional inputs given to the model. If given, the inputs will be used to run the graph to record shape of every tensor. The shape information will be saved together with the graph.

static load_protobuf(dir)[source]
Parameters

dir (str) – a directory used to save Caffe2Model with save_protobuf(). The files “model.pb” and “model_init.pb” are needed.

Returns

Caffe2Model – the caffe2 model loaded from this directory.

__call__(inputs)[source]

An interface that wraps around a Caffe2 model and mimics detectron2’s models’ input/output format. See details about the format at Use Models. This is used to compare the outputs of caffe2 model with its original torch model.

Due to the extra conversion between Pytorch/Caffe2, this method is not meant for benchmark. Because of the conversion, this method also has dependency on detectron2 in order to convert to detectron2’s output format.

training: bool
class detectron2.export.Caffe2Tracer(cfg: detectron2.config.CfgNode, model: torch.nn.Module, inputs)[source]

Bases: object

Make a detectron2 model traceable with Caffe2 operators. This class creates a traceable version of a detectron2 model which:

  1. Rewrite parts of the model using ops in Caffe2. Note that some ops do not have GPU implementation in Caffe2.

  2. Remove post-processing and only produce raw layer outputs

After making a traceable model, the class provide methods to export such a model to different deployment formats. Exported graph produced by this class take two input tensors:

  1. (1, C, H, W) float “data” which is an image (usually in [0, 255]). (H, W) often has to be padded to multiple of 32 (depend on the model architecture).

  2. 1x3 float “im_info”, each row of which is (height, width, 1.0). Height and width are true image shapes before padding.

The class currently only supports models using builtin meta architectures. Batch inference is not supported, and contributions are welcome.

__init__(cfg: detectron2.config.CfgNode, model: torch.nn.Module, inputs)[source]
Parameters
  • cfg (CfgNode) – a detectron2 config, with extra export-related options added by add_export_config(). It’s used to construct caffe2-compatible model.

  • model (nn.Module) – An original pytorch model. Must be among a few official models in detectron2 that can be converted to become caffe2-compatible automatically. Weights have to be already loaded to this model.

  • inputs – sample inputs that the given model takes for inference. Will be used to trace the model. For most models, random inputs with no detected objects will not work as they lead to wrong traces.

export_caffe2()[source]

Export the model to Caffe2’s protobuf format. The returned object can be saved with its save_protobuf() method. The result can be loaded and executed using Caffe2 runtime.

Returns

Caffe2Model

export_onnx()[source]

Export the model to ONNX format. Note that the exported model contains custom ops only available in caffe2, therefore it cannot be directly executed by other runtime (such as onnxruntime or TensorRT). Post-processing or transformation passes may be applied on the model to accommodate different runtimes, but we currently do not provide support for them.

Returns

onnx.ModelProto – an onnx model.

export_torchscript()[source]

Export the model to a torch.jit.TracedModule by tracing. The returned object can be saved to a file by .save().

Returns

torch.jit.TracedModule – a torch TracedModule

class detectron2.export.TracingAdapter(model: torch.nn.Module, inputs, inference_func: Optional[Callable] = None, allow_non_tensor: bool = False)[source]

Bases: torch.nn.Module

A model may take rich input/output format (e.g. dict or custom classes), but torch.jit.trace requires tuple of tensors as input/output. This adapter flattens input/output format of a model so it becomes traceable.

It also records the necessary schema to rebuild model’s inputs/outputs from flattened inputs/outputs.

Example:

outputs = model(inputs)   # inputs/outputs may be rich structure
adapter = TracingAdapter(model, inputs)

# can now trace the model, with adapter.flattened_inputs, or another
# tuple of tensors with the same length and meaning
traced = torch.jit.trace(adapter, adapter.flattened_inputs)

# traced model can only produce flattened outputs (tuple of tensors)
flattened_outputs = traced(*adapter.flattened_inputs)
# adapter knows the schema to convert it back (new_outputs == outputs)
new_outputs = adapter.outputs_schema(flattened_outputs)
outputs_schema: detectron2.export.flatten.Schema = None

Schema of the output produced by calling the given model with inputs.

__init__(model: torch.nn.Module, inputs, inference_func: Optional[Callable] = None, allow_non_tensor: bool = False)[source]
Parameters
  • model – an nn.Module

  • inputs – An input argument or a tuple of input arguments used to call model. After flattening, it has to only consist of tensors.

  • inference_func – a callable that takes (model, *inputs), calls the model with inputs, and return outputs. By default it is lambda model, *inputs: model(*inputs). Can be override if you need to call the model differently.

  • allow_non_tensor – allow inputs/outputs to contain non-tensor objects. This option will filter out non-tensor objects to make the model traceable, but inputs_schema/outputs_schema cannot be used anymore because inputs/outputs cannot be rebuilt from pure tensors. This is useful when you’re only interested in the single trace of execution (e.g. for flop count), but not interested in generalizing the traced graph to new inputs.

flattened_inputs: Tuple[torch.Tensor] = None

Flattened version of inputs given to this class’s constructor.

inputs_schema: detectron2.export.flatten.Schema = None

Schema of the inputs given to this class’s constructor.

forward(*args: torch.Tensor)[source]
detectron2.export.scripting_with_instances(model, fields)[source]

Run torch.jit.script() on a model that uses the Instances class. Since attributes of Instances are “dynamically” added in eager mode,it is difficult for scripting to support it out of the box. This function is made to support scripting a model that uses Instances. It does the following:

  1. Create a scriptable new_Instances class which behaves similarly to Instances, but with all attributes been “static”. The attributes need to be statically declared in the fields argument.

  2. Register new_Instances, and force scripting compiler to use it when trying to compile Instances.

After this function, the process will be reverted. User should be able to script another model using different fields.

Example

Assume that Instances in the model consist of two attributes named proposal_boxes and objectness_logits with type Boxes and Tensor respectively during inference. You can call this function like:

fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
torchscipt_model =  scripting_with_instances(model, fields)

Note

It only support models in evaluation mode.

Parameters
  • model (nn.Module) – The input model to be exported by scripting.

  • fields (Dict[str, type]) – Attribute names and corresponding type that Instances will use in the model. Note that all attributes used in Instances need to be added, regardless of whether they are inputs/outputs of the model. Data type not defined in detectron2 is not supported for now.

Returns

torch.jit.ScriptModule – the model in torchscript format

detectron2.export.dump_torchscript_IR(model, dir)[source]

Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, inlined graph). Useful for debugging.

Parameters
  • model (TracedModule/ScriptModule/ScriptFUnction) – traced or scripted module

  • dir (str) – output directory to dump files.