Source code for detectron2.export.torchscript

# Copyright (c) Facebook, Inc. and its affiliates.

import os
import torch

from detectron2.utils.file_io import PathManager

from .torchscript_patch import freeze_training_mode, patch_instances

__all__ = ["scripting_with_instances", "dump_torchscript_IR"]


[docs]def scripting_with_instances(model, fields): """ Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult for scripting to support it out of the box. This function is made to support scripting a model that uses :class:`Instances`. It does the following: 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``, but with all attributes been "static". The attributes need to be statically declared in the ``fields`` argument. 2. Register ``new_Instances``, and force scripting compiler to use it when trying to compile ``Instances``. After this function, the process will be reverted. User should be able to script another model using different fields. Example: Assume that ``Instances`` in the model consist of two attributes named ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and :class:`Tensor` respectively during inference. You can call this function like: :: fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} torchscipt_model = scripting_with_instances(model, fields) Note: It only support models in evaluation mode. Args: model (nn.Module): The input model to be exported by scripting. fields (Dict[str, type]): Attribute names and corresponding type that ``Instances`` will use in the model. Note that all attributes used in ``Instances`` need to be added, regardless of whether they are inputs/outputs of the model. Data type not defined in detectron2 is not supported for now. Returns: torch.jit.ScriptModule: the model in torchscript format """ assert ( not model.training ), "Currently we only support exporting models in evaluation mode to torchscript" with freeze_training_mode(model), patch_instances(fields): scripted_model = torch.jit.script(model) return scripted_model
# alias for old name export_torchscript_with_instances = scripting_with_instances
[docs]def dump_torchscript_IR(model, dir): """ Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, inlined graph). Useful for debugging. Args: model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module dir (str): output directory to dump files. """ dir = os.path.expanduser(dir) PathManager.mkdirs(dir) def _get_script_mod(mod): if isinstance(mod, torch.jit.TracedModule): return mod._actual_script_module return mod # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f: def get_code(mod): # Try a few ways to get code using private attributes. try: # This contains more information than just `mod.code` return _get_script_mod(mod)._c.code except AttributeError: pass try: return mod.code except AttributeError: return None def dump_code(prefix, mod): code = get_code(mod) name = prefix or "root model" if code is None: f.write(f"Could not found code for {name} (type={mod.original_name})\n") f.write("\n") else: f.write(f"\nCode for {name}, type={mod.original_name}:\n") f.write(code) f.write("\n") f.write("-" * 80) for name, m in mod.named_children(): dump_code(prefix + "." + name, m) if isinstance(model, torch.jit.ScriptFunction): f.write(get_code(model)) else: dump_code("", model) def _get_graph(model): try: # Recursively dump IR of all modules return _get_script_mod(model)._c.dump_to_str(True, False, False) except AttributeError: return model.graph.str() with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f: f.write(_get_graph(model)) # Dump IR of the entire graph (all submodules inlined) with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f: f.write(str(model.inlined_graph)) if not isinstance(model, torch.jit.ScriptFunction): # Dump the model structure in pytorch style with PathManager.open(os.path.join(dir, "model.txt"), "w") as f: f.write(str(model))