detectron2.solver

detectron2.solver.build_lr_scheduler(cfg: detectron2.config.CfgNode, optimizer: torch.optim.optimizer.Optimizer) → torch.optim.lr_scheduler._LRScheduler[source]

Build a LR scheduler from config.

detectron2.solver.build_optimizer(cfg: detectron2.config.CfgNode, model: torch.nn.Module) → torch.optim.optimizer.Optimizer[source]

Build an optimizer from config.

detectron2.solver.get_default_optimizer_params(model: torch.nn.Module, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None) → List[Dict[str, Any]][source]

Get default param list for optimizer, with support for a few types of overrides. If no overrides needed, this is equivalent to model.parameters().

Parameters
  • base_lr – lr for every group by default. Can be omitted to use the one in optimizer.

  • weight_decay – weight decay for every group by default. Can be omitted to use the one in optimizer.

  • weight_decay_norm – override weight decay for params in normalization layers

  • bias_lr_factor – multiplier of lr for bias parameters.

  • weight_decay_bias – override weight decay for bias parameters

  • overrides – if not None, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g. {"embedding": {"lr": 0.01, "weight_decay": 0.1}} will set the LR and weight decay values for all module parameters named embedding.

For common detection models, weight_decay_norm is the only option needed to be set. bias_lr_factor,weight_decay_bias are legacy settings from Detectron1 that are not found useful.

Example:

torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
               lr=0.01, weight_decay=1e-4, momentum=0.9)
class detectron2.solver.LRMultiplier(optimizer: torch.optim.optimizer.Optimizer, multiplier: fvcore.common.param_scheduler.ParamScheduler, max_iter: int, last_iter: int = - 1)[source]

Bases: torch.optim.lr_scheduler._LRScheduler

A LRScheduler which uses fvcore ParamScheduler to multiply the learning rate of each param in the optimizer. Every step, the learning rate of each parameter becomes its initial value multiplied by the output of the given ParamScheduler.

The absolute learning rate value of each parameter can be different. This scheduler can be used as long as the relative scale among them do not change during training.

Examples:

LRMultiplier(
    opt,
    WarmupParamScheduler(
        MultiStepParamScheduler(
            [1, 0.1, 0.01],
            milestones=[60000, 80000],
            num_updates=90000,
        ), 0.001, 100 / 90000
    ),
    max_iter=90000
)
__init__(optimizer: torch.optim.optimizer.Optimizer, multiplier: fvcore.common.param_scheduler.ParamScheduler, max_iter: int, last_iter: int = - 1)[source]
Parameters
  • optimizer – See torch.optim.lr_scheduler._LRScheduler. last_iter is the same as last_epoch.

  • last_iter – See torch.optim.lr_scheduler._LRScheduler. last_iter is the same as last_epoch.

  • multiplier – a fvcore ParamScheduler that defines the multiplier on every LR of the optimizer

  • max_iter – the total number of training iterations

state_dict()[source]
get_lr() → List[float][source]
class detectron2.solver.WarmupParamScheduler(scheduler: fvcore.common.param_scheduler.ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear')[source]

Bases: fvcore.common.param_scheduler.CompositeParamScheduler

Add an initial warmup stage to another scheduler.

__init__(scheduler: fvcore.common.param_scheduler.ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear')[source]
Parameters
  • scheduler – warmup will be added at the beginning of this scheduler

  • warmup_factor – the factor w.r.t the initial value of scheduler, e.g. 0.001

  • warmup_length – the relative length (in [0, 1]) of warmup steps w.r.t the entire training, e.g. 0.01

  • warmup_method – one of “linear” or “constant”