detectron2.engine

Related tutorial: Training.

detectron2.engine.launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=(), timeout=datetime.timedelta(seconds=1800))[source]

Launch multi-gpu or distributed training. This function must be called on all machines involved in the training. It will spawn child processes (defined by num_gpus_per_machine) on each machine.

Parameters
  • main_func – a function that will be called by main_func(*args)

  • num_gpus_per_machine (int) – number of GPUs per machine

  • num_machines (int) – the total number of machines

  • machine_rank (int) – the rank of this machine

  • dist_url (str) – url to connect to for distributed jobs, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to “auto” to automatically select a free port on localhost

  • timeout (timedelta) – timeout of the distributed workers

  • args (tuple) – arguments passed to main_func

class detectron2.engine.HookBase[source]

Bases: object

Base class for hooks that can be registered with TrainerBase.

Each hook can implement 4 methods. The way they are called is demonstrated in the following snippet:

hook.before_train()
for iter in range(start_iter, max_iter):
    hook.before_step()
    trainer.run_step()
    hook.after_step()
iter += 1
hook.after_train()

Notes

  1. In the hook method, users can access self.trainer to access more properties about the context (e.g., model, current iteration, or config if using DefaultTrainer).

  2. A hook that does something in before_step() can often be implemented equivalently in after_step(). If the hook takes non-trivial time, it is strongly recommended to implement the hook in after_step() instead of before_step(). The convention is that before_step() should only take negligible time.

    Following this convention will allow hooks that do care about the difference between before_step() and after_step() (e.g., timer) to function properly.

trainer: detectron2.engine.train_loop.TrainerBase = None

A weak reference to the trainer object. Set by the trainer when the hook is registered.

before_train()[source]

Called before the first iteration.

after_train()[source]

Called after the last iteration.

before_step()[source]

Called before each iteration.

after_step()[source]

Called after each iteration.

state_dict()[source]

Hooks are stateless by default, but can be made checkpointable by implementing state_dict and load_state_dict.

class detectron2.engine.TrainerBase[source]

Bases: object

Base class for iterative trainer with hooks.

The only assumption we made here is: the training runs in a loop. A subclass can implement what the loop is. We made no assumptions about the existence of dataloader, optimizer, model, etc.

iter

the current iteration.

Type

int

start_iter

The iteration to start with. By convention the minimum possible value is 0.

Type

int

max_iter

The iteration to end training.

Type

int

storage

An EventStorage that’s opened during the course of training.

Type

EventStorage

register_hooks(hooks: List[Optional[detectron2.engine.train_loop.HookBase]])None[source]

Register hooks to the trainer. The hooks are executed in the order they are registered.

Parameters

hooks (list[Optional[HookBase]]) – list of hooks

train(start_iter: int, max_iter: int)[source]
Parameters
  • start_iter (int) – See docs above

  • max_iter (int) – See docs above

before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
run_step()[source]
state_dict()[source]
load_state_dict(state_dict)[source]
class detectron2.engine.SimpleTrainer(model, data_loader, optimizer)[source]

Bases: detectron2.engine.train_loop.TrainerBase

A simple trainer for the most common type of task: single-cost single-optimizer single-data-source iterative optimization, optionally using data-parallelism. It assumes that every step, you:

  1. Compute the loss with a data from the data_loader.

  2. Compute the gradients with the above loss.

  3. Update the model with the optimizer.

All other tasks during training (checkpointing, logging, evaluation, LR schedule) are maintained by hooks, which can be registered by TrainerBase.register_hooks().

If you want to do anything fancier than this, either subclass TrainerBase and implement your own run_step, or write your own training loop.

__init__(model, data_loader, optimizer)[source]
Parameters
  • model – a torch Module. Takes a data from data_loader and returns a dict of losses.

  • data_loader – an iterable. Contains data to be used to call model.

  • optimizer – a torch optimizer.

run_step()[source]

Implement the standard training logic described above.

static write_metrics(loss_dict: Mapping[str, torch.Tensor], data_time: float, prefix: str = '')None[source]
Parameters
  • loss_dict (dict) – dict of scalar losses

  • data_time (float) – time taken by the dataloader iteration

  • prefix (str) – prefix for logging keys

state_dict()[source]
load_state_dict(state_dict)[source]
class detectron2.engine.AMPTrainer(model, data_loader, optimizer, grad_scaler=None)[source]

Bases: detectron2.engine.train_loop.SimpleTrainer

Like SimpleTrainer, but uses PyTorch’s native automatic mixed precision in the training loop.

__init__(model, data_loader, optimizer, grad_scaler=None)[source]
Parameters
run_step()[source]

Implement the AMP training logic.

state_dict()[source]
load_state_dict(state_dict)[source]

detectron2.engine.defaults module

This file contains components with some default boilerplate logic user may need in training / testing. They will not work for everyone, but many users may find them useful.

The behavior of functions/classes in this file is subject to change, since they are meant to represent the “common default behavior” people need in their projects.

detectron2.engine.defaults.create_ddp_model(model, *, fp16_compression=False, **kwargs)[source]

Create a DistributedDataParallel model if there are >1 processes.

Parameters
detectron2.engine.defaults.default_argument_parser(epilog=None)[source]

Create a parser with some common arguments used by detectron2 users.

Parameters

epilog (str) – epilog passed to ArgumentParser describing the usage.

Returns

argparse.ArgumentParser

detectron2.engine.defaults.default_setup(cfg, args)[source]

Perform some basic common setups at the beginning of a job, including:

  1. Set up the detectron2 logger

  2. Log basic information about environment, cmdline arguments, and config

  3. Backup the config to the output directory

Parameters
  • cfg (CfgNode or omegaconf.DictConfig) – the full config to be used

  • args (argparse.NameSpace) – the command line arguments to be logged

detectron2.engine.defaults.default_writers(output_dir: str, max_iter: Optional[int] = None)[source]

Build a list of EventWriter to be used. It now consists of a CommonMetricPrinter, TensorboardXWriter and JSONWriter.

Parameters
  • output_dir – directory to store JSON metrics and tensorboard events

  • max_iter – the total number of iterations

Returns

list[EventWriter] – a list of EventWriter objects.

class detectron2.engine.defaults.DefaultPredictor(cfg)[source]

Bases: object

Create a simple end-to-end predictor with the given config that runs on single device for a single input image.

Compared to using the model directly, this class does the following additions:

  1. Load checkpoint from cfg.MODEL.WEIGHTS.

  2. Always take BGR image as the input and apply conversion defined by cfg.INPUT.FORMAT.

  3. Apply resizing defined by cfg.INPUT.{MIN,MAX}_SIZE_TEST.

  4. Take one input image and produce a single output, instead of a batch.

This is meant for simple demo purposes, so it does the above steps automatically. This is not meant for benchmarks or running complicated inference logic. If you’d like to do anything more complicated, please refer to its source code as examples to build and use the model manually.

metadata

the metadata of the underlying dataset, obtained from cfg.DATASETS.TEST.

Type

Metadata

Examples:

pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
__call__(original_image)[source]
Parameters

original_image (np.ndarray) – an image of shape (H, W, C) (in BGR order).

Returns

predictions (dict) – the output of the model for one image only. See Use Models for details about the format.

class detectron2.engine.defaults.DefaultTrainer(cfg)[source]

Bases: detectron2.engine.train_loop.TrainerBase

A trainer with default training logic. It does the following:

  1. Create a SimpleTrainer using model, optimizer, dataloader defined by the given config. Create a LR scheduler defined by the config.

  2. Load the last checkpoint or cfg.MODEL.WEIGHTS, if exists, when resume_or_load is called.

  3. Register a few common hooks defined by the config.

It is created to simplify the standard model training workflow and reduce code boilerplate for users who only need the standard training workflow, with standard features. It means this class makes many assumptions about your training logic that may easily become invalid in a new research. In fact, any assumptions beyond those made in the SimpleTrainer are too much for research.

The code of this class has been annotated about restrictive assumptions it makes. When they do not work for you, you’re encouraged to:

  1. Overwrite methods of this class, OR:

  2. Use SimpleTrainer, which only does minimal SGD training and nothing else. You can then add your own hooks if needed. OR:

  3. Write your own training loop similar to tools/plain_train_net.py.

See the Training tutorials for more details.

Note that the behavior of this class, like other functions/classes in this file, is not stable, since it is meant to represent the “common default behavior”. It is only guaranteed to work well with the standard models and training workflow in detectron2. To obtain more stable behavior, write your own training logic with other public APIs.

Examples:

trainer = DefaultTrainer(cfg)
trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
trainer.train()
scheduler
checkpointer
Type

DetectionCheckpointer

cfg
Type

CfgNode

__init__(cfg)[source]
Parameters

cfg (CfgNode) –

resume_or_load(resume=True)[source]

If resume==True and cfg.OUTPUT_DIR contains the last checkpoint (defined by a last_checkpoint file), resume from the file. Resuming means loading all available states (eg. optimizer and scheduler) and update iteration counter from the checkpoint. cfg.MODEL.WEIGHTS will not be used.

Otherwise, this is considered as an independent training. The method will load model weights from the file cfg.MODEL.WEIGHTS (but will not load other states) and start from iteration 0.

Parameters

resume (bool) – whether to do resume or not

build_hooks()[source]

Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events.

Returns

list[HookBase]

build_writers()[source]

Build a list of writers to be used using default_writers(). If you’d like a different list of writers, you can overwrite it in your trainer.

Returns

list[EventWriter] – a list of EventWriter objects.

train()[source]

Run training.

Returns

OrderedDict of results, if evaluation is enabled. Otherwise None.

run_step()[source]
classmethod build_model(cfg)[source]
Returns

torch.nn.Module

It now calls detectron2.modeling.build_model(). Overwrite it if you’d like a different model.

classmethod build_optimizer(cfg, model)[source]
Returns

torch.optim.Optimizer

It now calls detectron2.solver.build_optimizer(). Overwrite it if you’d like a different optimizer.

classmethod build_lr_scheduler(cfg, optimizer)[source]

It now calls detectron2.solver.build_lr_scheduler(). Overwrite it if you’d like a different scheduler.

classmethod build_train_loader(cfg)[source]
Returns

iterable

It now calls detectron2.data.build_detection_train_loader(). Overwrite it if you’d like a different data loader.

classmethod build_test_loader(cfg, dataset_name)[source]
Returns

iterable

It now calls detectron2.data.build_detection_test_loader(). Overwrite it if you’d like a different data loader.

classmethod build_evaluator(cfg, dataset_name)[source]
Returns

DatasetEvaluator or None

It is not implemented by default.

classmethod test(cfg, model, evaluators=None)[source]
Parameters
Returns

dict – a dict of result metrics

static auto_scale_workers(cfg, num_workers: int)[source]

When the config is defined for certain number of workers (according to cfg.SOLVER.REFERENCE_WORLD_SIZE) that’s different from the number of workers currently in use, returns a new cfg where the total batch size is scaled so that the per-GPU batch size stays the same as the original IMS_PER_BATCH // REFERENCE_WORLD_SIZE.

Other config options are also scaled accordingly: * training steps and warmup steps are scaled inverse proportionally. * learning rate are scaled proportionally, following Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.

For example, with the original config like the following:

IMS_PER_BATCH: 16
BASE_LR: 0.1
REFERENCE_WORLD_SIZE: 8
MAX_ITER: 5000
STEPS: (4000,)
CHECKPOINT_PERIOD: 1000

When this config is used on 16 GPUs instead of the reference number 8, calling this method will return a new config with:

IMS_PER_BATCH: 32
BASE_LR: 0.2
REFERENCE_WORLD_SIZE: 16
MAX_ITER: 2500
STEPS: (2000,)
CHECKPOINT_PERIOD: 500

Note that both the original config and this new config can be trained on 16 GPUs. It’s up to user whether to enable this feature (by setting REFERENCE_WORLD_SIZE).

Returns

CfgNode – a new config. Same as original if cfg.SOLVER.REFERENCE_WORLD_SIZE==0.

property data_loader
property model
property optimizer

detectron2.engine.hooks module

class detectron2.engine.hooks.CallbackHook(*, before_train=None, after_train=None, before_step=None, after_step=None)[source]

Bases: detectron2.engine.train_loop.HookBase

Create a hook using callback functions provided by the user.

__init__(*, before_train=None, after_train=None, before_step=None, after_step=None)[source]

Each argument is a function that takes one argument: the trainer.

before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
class detectron2.engine.hooks.IterationTimer(warmup_iter=3)[source]

Bases: detectron2.engine.train_loop.HookBase

Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training.

This hook uses the time between the call to its before_step() and after_step() methods. Under the convention that before_step() of all hooks should only take negligible amount of time, the IterationTimer hook should be placed at the beginning of the list of hooks to obtain accurate timing.

__init__(warmup_iter=3)[source]
Parameters

warmup_iter (int) – the number of iterations at the beginning to exclude from timing.

before_train()[source]
after_train()[source]
before_step()[source]
after_step()[source]
class detectron2.engine.hooks.PeriodicWriter(writers, period=20)[source]

Bases: detectron2.engine.train_loop.HookBase

Write events to EventStorage (by calling writer.write()) periodically.

It is executed every period iterations and after the last iteration. Note that period does not affect how data is smoothed by each writer.

__init__(writers, period=20)[source]
Parameters
  • writers (list[EventWriter]) – a list of EventWriter objects

  • period (int) –

after_step()[source]
after_train()[source]
class detectron2.engine.hooks.PeriodicCheckpointer(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]

Bases: fvcore.common.checkpoint.PeriodicCheckpointer, detectron2.engine.train_loop.HookBase

Same as detectron2.checkpoint.PeriodicCheckpointer, but as a hook.

Note that when used as a hook, it is unable to save additional data other than what’s defined by the given checkpointer.

It is executed every period iterations and after the last iteration.

before_train()[source]
after_step()[source]
class detectron2.engine.hooks.LRScheduler(optimizer=None, scheduler=None)[source]

Bases: detectron2.engine.train_loop.HookBase

A hook which executes a torch builtin LR scheduler and summarizes the LR. It is executed after every iteration.

__init__(optimizer=None, scheduler=None)[source]
Parameters

If any argument is not given, will try to obtain it from the trainer.

before_train()[source]
static get_best_param_group_id(optimizer)[source]
after_step()[source]
property scheduler
state_dict()[source]
load_state_dict(state_dict)[source]
class detectron2.engine.hooks.AutogradProfiler(enable_predicate, output_dir, *, use_cuda=True)[source]

Bases: detectron2.engine.hooks.TorchProfiler

A hook which runs torch.autograd.profiler.profile.

Examples:

hooks.AutogradProfiler(
     lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
)

The above example will run the profiler for iteration 10~20 and dump results to OUTPUT_DIR. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in the chrome://tracing page in chrome browser.

Note

When used together with NCCL on older version of GPUs, autograd profiler may cause deadlock because it unnecessarily allocates memory on every device it sees. The memory management calls, if interleaved with NCCL calls, lead to deadlock on GPUs that do not support cudaLaunchCooperativeKernelMultiDevice.

__init__(enable_predicate, output_dir, *, use_cuda=True)[source]
Parameters
  • enable_predicate (callable[trainer -> bool]) – a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile.

  • output_dir (str) – the output directory to dump tracing files.

  • use_cuda (bool) – same as in torch.autograd.profiler.profile.

before_step()[source]
class detectron2.engine.hooks.EvalHook(eval_period, eval_function)[source]

Bases: detectron2.engine.train_loop.HookBase

Run an evaluation function periodically, and at the end of training.

It is executed every eval_period iterations and after the last iteration.

__init__(eval_period, eval_function)[source]
Parameters
  • eval_period (int) – the period to run eval_function. Set to 0 to not evaluate periodically (but still after the last iteration).

  • eval_function (callable) – a function which takes no arguments, and returns a nested dict of evaluation metrics.

Note

This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (eval_function=lambda: None).

after_step()[source]
after_train()[source]
class detectron2.engine.hooks.PreciseBN(period, model, data_loader, num_iter)[source]

Bases: detectron2.engine.train_loop.HookBase

The standard implementation of BatchNorm uses EMA in inference, which is sometimes suboptimal. This class computes the true average of statistics rather than the moving average, and put true averages to every BN layer in the given model.

It is executed every period iterations and after the last iteration.

__init__(period, model, data_loader, num_iter)[source]
Parameters
  • period (int) – the period this hook is run, or 0 to not run during training. The hook will always run in the end of training.

  • model (nn.Module) – a module whose all BN layers in training mode will be updated by precise BN. Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered.

  • data_loader (iterable) – it will produce data to be run by model(data).

  • num_iter (int) – number of iterations used to compute the precise statistics.

after_step()[source]
update_stats()[source]

Update the model with precise statistics. Users can manually call this method.

class detectron2.engine.hooks.TorchProfiler(enable_predicate, output_dir, *, activities=None, save_tensorboard=True)[source]

Bases: detectron2.engine.train_loop.HookBase

A hook which runs torch.profiler.profile.

Examples:

hooks.TorchProfiler(
     lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
)

The above example will run the profiler for iteration 10~20 and dump results to OUTPUT_DIR. We did not profile the first few iterations because they are typically slower than the rest. The result files can be loaded in the chrome://tracing page in chrome browser, and the tensorboard visualizations can be visualized using tensorboard --logdir OUTPUT_DIR/log

__init__(enable_predicate, output_dir, *, activities=None, save_tensorboard=True)[source]
Parameters
  • enable_predicate (callable[trainer -> bool]) – a function which takes a trainer, and returns whether to enable the profiler. It will be called once every step, and can be used to select which steps to profile.

  • output_dir (str) – the output directory to dump tracing files.

  • activities (iterable) – same as in torch.profiler.profile.

  • save_tensorboard (bool) – whether to save tensorboard visualizations at (output_dir)/log/

before_step()[source]
after_step()[source]