detectron2.engine package

detectron2.engine.launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=())[source]
Parameters:
  • main_func – a function that will be called by main_func(*args)
  • num_machines (int) – the total number of machines
  • machine_rank (int) – the rank of this machine (one per machine)
  • dist_url (str) – url to connect to for distributed training, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to auto to automatically select a free port on localhost
  • args (tuple) – arguments passed to main_func
class detectron2.engine.HookBase[source]

Bases: object

Base class for hooks that can be registered with TrainerBase.

Each hook can implement 4 methods. The way they are called is demonstrated in the following snippet:

hook.before_train()
for iter in range(start_iter, max_iter):
    hook.before_step()
    trainer.run_step()
    hook.after_step()
hook.after_train()

Notes

  1. In the hook method, users can access self.trainer to access more properties about the context (e.g., current iteration).

  2. A hook that does something in before_step() can often be implemented equivalently in after_step(). If the hook takes non-trivial time, it is strongly recommended to implement the hook in after_step() instead of before_step(). The convention is that before_step() should only take negligible time.

    Following this convention will allow hooks that do care about the difference between before_step() and after_step() (e.g., timer) to function properly.

trainer

A weak reference to the trainer object. Set by the trainer when the hook is registered.

before_train()[source]

Called before the first iteration.

after_train()[source]

Called after the last iteration.

before_step()[source]

Called before each iteration.

after_step()[source]

Called after each iteration.

class detectron2.engine.TrainerBase[source]

Bases: object

Base class for iterative trainer with hooks.

The only assumption we made here is: the training runs in a loop. A subclass can implement what the loop is. We made no assumptions about the existence of dataloader, optimizer, model, etc.

iter

the current iteration.

Type:int
start_iter

The iteration to start with. By convention the minimum possible value is 0.

Type:int
max_iter

The iteration to end training.

Type:int
storage

An EventStorage that’s opened during the course of training.

Type:EventStorage
register_hooks(hooks)[source]

Register hooks to the trainer. The hooks are executed in the order they are registered.

Parameters:hooks (list[Optional[HookBase]]) – list of hooks
train(start_iter: int, max_iter: int)[source]
Parameters:max_iter (start_iter,) – See docs above
before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
run_step()[source]
class detectron2.engine.SimpleTrainer(model, data_loader, optimizer)[source]

Bases: detectron2.engine.train_loop.TrainerBase

A simple trainer for the most common type of task: single-cost single-optimizer single-data-source iterative optimization. It assumes that every step, you:

  1. Compute the loss with a data from the data_loader.
  2. Compute the gradients with the above loss.
  3. Update the model with the optimizer.

If you want to do anything fancier than this, either subclass TrainerBase and implement your own run_step, or write your own training loop.

__init__(model, data_loader, optimizer)[source]
Parameters:
  • model – a torch Module. Takes a data from data_loader and returns a dict of losses.
  • data_loader – an iterable. Contains data to be used to call model.
  • optimizer – a torch optimizer.
run_step()[source]

Implement the standard training logic described above.

detectron2.engine.defaults module

This file contains components with some default boilerplate logic user may need in training / testing. They will not work for everyone, but many users may find them useful.

The behavior of functions/classes in this file is subject to change, since they are meant to represent the “common default behavior” people need in their projects.

detectron2.engine.defaults.default_argument_parser()[source]

Create a parser with some common arguments used by detectron2 users.

Returns:argparse.ArgumentParser
detectron2.engine.defaults.default_setup(cfg, args)[source]

Perform some basic common setups at the beginning of a job, including:

  1. Set up the detectron2 logger
  2. Log basic information about environment, cmdline arguments, and config
  3. Backup the config to the output directory
Parameters:
  • cfg (CfgNode) – the full config to be used
  • args (argparse.NameSpace) – the command line arguments to be logged
class detectron2.engine.defaults.DefaultPredictor(cfg)[source]

Bases: object

Create a simple end-to-end predictor with the given config. The predictor takes an BGR image, resizes it to the specified resolution, runs the model and produces a dict of predictions.

This predictor takes care of model loading and input preprocessing for you. If you’d like to do anything more fancy, please refer to its source code as examples to build and use the model manually.

metadata

the metadata of the underlying dataset, obtained from cfg.DATASETS.TEST.

Type:Metadata

Examples:

pred = DefaultPredictor(cfg)
outputs = pred(inputs)
class detectron2.engine.defaults.DefaultTrainer(cfg)[source]

Bases: detectron2.engine.train_loop.SimpleTrainer

A trainer with default training logic. Compared to SimpleTrainer, it contains the following logic in addition:

  1. Create model, optimizer, scheduler, dataloader from the given config.
  2. Load a checkpoint or cfg.MODEL.WEIGHTS, if exists.
  3. Register a few common hooks.

It is created to simplify the standard model training workflow and reduce code boilerplate for users who only need the standard training workflow, with standard features. It means this class makes many assumptions about your training logic that may easily become invalid in a new research. In fact, any assumptions beyond those made in the SimpleTrainer are too much for research.

The code of this class has been annotated about restrictive assumptions it mades. When they do not work for you, you’re encouraged to:

  1. Overwrite methods of this class, OR:
  2. Use SimpleTrainer, which only does minimal SGD training and nothing else. You can then add your own hooks if needed. OR:
  3. Write your own training loop similar to tools/plain_train_net.py.

Also note that the behavior of this class, like other functions/classes in this file, is not stable, since it is meant to represent the “common default behavior”. It is only guaranteed to work well with the standard models and training workflow in detectron2. To obtain more stable behavior, write your own training logic with other public APIs.

scheduler
checkpointer
Type:DetectionCheckpointer
cfg
Type:CfgNode

Examples:

trainer = DefaultTrainer(cfg)
trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
trainer.train()
__init__(cfg)[source]
Parameters:cfg (CfgNode) –
resume_or_load(resume=True)[source]

If resume==True, and last checkpoint exists, resume from it.

Otherwise, load a model specified by the config.

Parameters:resume (bool) – whether to do resume or not
build_hooks()[source]

Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events.

Returns:list[HookBase]
build_writers()[source]

Build a list of writers to be used. By default it contains writers that write metrics to the screen, a json file, and a tensorboard event file respectively. If you’d like a different list of writers, you can overwrite it in your trainer.

Returns:list[EventWriter] – a list of EventWriter objects.

It is now implemented by:

return [
    CommonMetricPrinter(self.max_iter),
    JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
    TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
train()[source]

Run training.

Returns:OrderedDict of results, if evaluation is enabled. Otherwise None.
classmethod build_model(cfg)[source]
Returns:torch.nn.Module

It now calls detectron2.modeling.build_model(). Overwrite it if you’d like a different model.

classmethod build_optimizer(cfg, model)[source]
Returns:torch.optim.Optimizer

It now calls detectron2.solver.build_optimizer(). Overwrite it if you’d like a different optimizer.

classmethod build_lr_scheduler(cfg, optimizer)[source]

It now calls detectron2.solver.build_lr_scheduler(). Overwrite it if you’d like a different scheduler.

classmethod build_train_loader(cfg)[source]
Returns:iterable

It now calls detectron2.data.build_detection_train_loader(). Overwrite it if you’d like a different data loader.

classmethod build_test_loader(cfg, dataset_name)[source]
Returns:iterable

It now calls detectron2.data.build_detection_test_loader(). Overwrite it if you’d like a different data loader.

classmethod build_evaluator(cfg, dataset_name)[source]
Returns:DatasetEvaluator

It is not implemented by default.

classmethod test(cfg, model, evaluators=None)[source]
Parameters:
Returns:

dict – a dict of result metrics

detectron2.engine.hooks module

class detectron2.engine.hooks.CallbackHook(*, before_train=None, after_train=None, before_step=None, after_step=None)[source]

Bases: detectron2.engine.train_loop.HookBase

Create a hook using callback functions provided by the user.

__init__(*, before_train=None, after_train=None, before_step=None, after_step=None)[source]

Each argument is a function that takes one argument: the trainer.

before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
class detectron2.engine.hooks.IterationTimer(warmup_iter=3)[source]

Bases: detectron2.engine.train_loop.HookBase

Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training.

This hook uses the time between the call to its before_step() and after_step() methods. Under the convention that before_step() of all hooks should only take negligible amount of time, the IterationTimer hook should be placed at the beginning of the list of hooks to obtain accurate timing.

__init__(warmup_iter=3)[source]
Parameters:warmup_iter (int) – the number of iterations at the beginning to exclude from timing.
before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
class detectron2.engine.hooks.PeriodicWriter(writers, period=20)[source]

Bases: detectron2.engine.train_loop.HookBase

Write events to EventStorage periodically.

It is executed every period iterations and after the last iteration.

__init__(writers, period=20)[source]
Parameters:
after_step()[source]
after_train()[source]
class detectron2.engine.hooks.PeriodicCheckpointer(checkpointer: Any, period: int, max_iter: int = None)[source]

Bases: fvcore.common.checkpoint.PeriodicCheckpointer, detectron2.engine.train_loop.HookBase

Same as detectron2.checkpoint.PeriodicCheckpointer, but as a hook.

Note that when used as a hook, it is unable to save additional data other than what’s defined by the given checkpointer.

It is executed every period iterations and after the last iteration.

before_train()[source]
after_step()[source]
class detectron2.engine.hooks.LRScheduler(optimizer, scheduler)[source]

Bases: detectron2.engine.train_loop.HookBase

A hook which executes a torch builtin LR scheduler and summarizes the LR. It is executed after every iteration.

__init__(optimizer, scheduler)[source]
Parameters:
after_step()[source]
class detectron2.engine.hooks.AutogradProfiler(enable_predicate, output_dir, *, use_cuda=True)[source]

Bases: detectron2.engine.train_loop.HookBase

A hook which runs torch.autograd.profiler.profile.

Examples:

hooks.AutogradProfiler(
     lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
)

The above example will run the profiler for iteration 10~20 and dump results to OUTPUT_DIR. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in the chrome://tracing page in chrome browser.

Note

When used together with NCCL on older version of GPUs, autograd profiler may cause deadlock because it unnecessarily allocates memory on every device it sees. The memory management calls, if interleaved with NCCL calls, lead to deadlock on GPUs that do not support cudaLaunchCooperativeKernelMultiDevice.

__init__(enable_predicate, output_dir, *, use_cuda=True)[source]
Parameters:
  • enable_predicate (callable[trainer -> bool]) – a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile.
  • output_dir (str) – the output directory to dump tracing files.
  • use_cuda (bool) – same as in torch.autograd.profiler.profile.
before_step()[source]
after_step()[source]
class detectron2.engine.hooks.EvalHook(eval_period, eval_function)[source]

Bases: detectron2.engine.train_loop.HookBase

Run an evaluation function periodically, and at the end of training.

It is executed every eval_period iterations and after the last iteration.

__init__(eval_period, eval_function)[source]
Parameters:
  • eval_period (int) – the period to run eval_function.
  • eval_function (callable) – a function which takes no arguments, and returns a nested dict of evaluation metrics.

Note

This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (eval_function=lambda: None).

after_step()[source]
after_train()[source]
class detectron2.engine.hooks.PreciseBN(period, model, data_loader, num_iter)[source]

Bases: detectron2.engine.train_loop.HookBase

The standard implementation of BatchNorm uses EMA in inference, which is sometimes suboptimal. This class computes the true average of statistics rather than the moving average, and put true averages to every BN layer in the given model.

It is executed every period iterations and after the last iteration.

__init__(period, model, data_loader, num_iter)[source]
Parameters:
  • period (int) – the period this hook is run, or 0 to not run during training. The hook will always run in the end of training.
  • model (nn.Module) – a module whose all BN layers in training mode will be updated by precise BN. Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered.
  • data_loader (iterable) – it will produce data to be run by model(data).
  • num_iter (int) – number of iterations used to compute the precise statistics.
after_step()[source]
update_stats()[source]

Update the model with precise statistics. Users can manually call this method.