Lazy Configs

The traditional yacs-based config system provides basic, standard functionalities. However, it does not offer enough flexibility for many new projects. We develop an alternative, non-intrusive config system that can be used with detectron2 or potentially any other complex projects.

Python Syntax

Our config objects are still dictionaries. Instead of using Yaml to define dictionaries, we create dictionaries in Python directly. This gives users the following power that doesn’t exist in Yaml:

  • Easily manipulate the dictionary (addition & deletion) using Python.

  • Write simple arithmetics or call simple functions.

  • Use more data types / objects.

  • Import / compose other config files, using the familiar Python import syntax.

A Python config file can be loaded like this:

# config.py:
a = dict(x=1, y=2, z=dict(xx=1))
b = dict(x=3, y=4)

# my_code.py:
from detectron2.config import LazyConfig
cfg = LazyConfig.load("path/to/config.py")  # an omegaconf dictionary
assert cfg.a.z.xx == 1

After LazyConfig.load, cfg will be a dictionary that contains all dictionaries defined in the global scope of the config file. Note that:

  • All dictionaries are turned to an omegaconf config object during loading. This enables access to omegaconf features, such as its access syntax and interpolation.

  • Absolute imports in config.py works the same as in regular Python.

  • Relative imports can only import dictionaries from config files. They are simply a syntax sugar for LazyConfig.load_rel. They can load Python files at relative path without requiring __init__.py.

LazyConfig.save can save a config object to yaml. Note that this is not always successful if non-serializable objects appear in the config file (e.g. lambdas). It is up to users whether to sacrifice the ability to save in exchange for flexibility.

Recursive Instantiation

The LazyConfig system heavily uses recursive instantiation, which is a pattern that uses a dictionary to describe a call to a function/class. The dictionary consists of:

  1. A “_target_” key which contains path to the callable, such as “module.submodule.class_name”.

  2. Other keys that represent arguments to pass to the callable. Arguments themselves can be defined using recursive instantiation.

We provide a helper function LazyCall that helps create such dictionaries. The following code using LazyCall

from detectron2.config import LazyCall as L
from my_app import Trainer, Optimizer
cfg = L(Trainer)(
  optimizer=L(Optimizer)(
    lr=0.01,
    algo="SGD"
  )
)

creates a dictionary like this:

cfg = {
  "_target_": "my_app.Trainer",
  "optimizer": {
    "_target_": "my_app.Optimizer",
    "lr": 0.01, "algo": "SGD"
  }
}

By representing objects using such dictionaries, a general instantiate function can turn them into actual objects, i.e.:

from detectron2.config import instantiate
trainer = instantiate(cfg)
# equivalent to:
# from my_app import Trainer, Optimizer
# trainer = Trainer(optimizer=Optimizer(lr=0.01, algo="SGD"))

This pattern is powerful enough to describe very complex objects, e.g.:

A Full Mask R-CNN described in recursive instantiation (click to expand)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from detectron2.config import LazyCall as L
from detectron2.layers import ShapeSpec
from detectron2.modeling.meta_arch import GeneralizedRCNN
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
from detectron2.modeling.backbone import BasicStem, FPN, ResNet
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.proposal_generator import RPN, StandardRPNHead
from detectron2.modeling.roi_heads import (
    StandardROIHeads,
    FastRCNNOutputLayers,
    MaskRCNNConvUpsampleHead,
    FastRCNNConvFCHead,
)

model = L(GeneralizedRCNN)(
    backbone=L(FPN)(
        bottom_up=L(ResNet)(
            stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
            stages=L(ResNet.make_default_stages)(
                depth=50,
                stride_in_1x1=True,
                norm="FrozenBN",
            ),
            out_features=["res2", "res3", "res4", "res5"],
        ),
        in_features="${.bottom_up.out_features}",
        out_channels=256,
        top_block=L(LastLevelMaxPool)(),
    ),
    proposal_generator=L(RPN)(
        in_features=["p2", "p3", "p4", "p5", "p6"],
        head=L(StandardRPNHead)(in_channels=256, num_anchors=3),
        anchor_generator=L(DefaultAnchorGenerator)(
            sizes=[[32], [64], [128], [256], [512]],
            aspect_ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64],
            offset=0.0,
        ),
        anchor_matcher=L(Matcher)(
            thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True
        ),
        box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]),
        batch_size_per_image=256,
        positive_fraction=0.5,
        pre_nms_topk=(2000, 1000),
        post_nms_topk=(1000, 1000),
        nms_thresh=0.7,
    ),
    roi_heads=L(StandardROIHeads)(
        num_classes=80,
        batch_size_per_image=512,
        positive_fraction=0.25,
        proposal_matcher=L(Matcher)(
            thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False
        ),
        box_in_features=["p2", "p3", "p4", "p5"],
        box_pooler=L(ROIPooler)(
            output_size=7,
            scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
            sampling_ratio=0,
            pooler_type="ROIAlignV2",
        ),
        box_head=L(FastRCNNConvFCHead)(
            input_shape=ShapeSpec(channels=256, height=7, width=7),
            conv_dims=[],
            fc_dims=[1024, 1024],
        ),
        box_predictor=L(FastRCNNOutputLayers)(
            input_shape=ShapeSpec(channels=1024),
            test_score_thresh=0.05,
            box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)),
            num_classes="${..num_classes}",
        ),
        mask_in_features=["p2", "p3", "p4", "p5"],
        mask_pooler=L(ROIPooler)(
            output_size=14,
            scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
            sampling_ratio=0,
            pooler_type="ROIAlignV2",
        ),
        mask_head=L(MaskRCNNConvUpsampleHead)(
            input_shape=ShapeSpec(channels=256, width=14, height=14),
            num_classes="${..num_classes}",
            conv_dims=[256, 256, 256, 256, 256],
        ),
    ),
    pixel_mean=[103.530, 116.280, 123.675],
    pixel_std=[1.0, 1.0, 1.0],
    input_format="BGR",
)

There are also objects or logic that cannot be described simply by a dictionary, such as reused objects or method calls. They may require some refactoring to work with recursive instantiation.

Using Model Zoo LazyConfigs

We provide some configs in the model zoo using the LazyConfig system, for example:

After installing detectron2, they can be loaded by the model zoo API model_zoo.get_config.

Using these as references, you’re free to define custom config structure / fields for your own project, as long as your training script can understand them. Despite of this, our model zoo configs still follow some simple conventions for consistency, e.g. cfg.model defines a model object, cfg.dataloader.{train,test} defines dataloader objects, and cfg.train contains training options in key-value form. In addition to print(), a better way to view the structure of a config is like this:

from detectron2.model_zoo import get_config
from detectron2.config import LazyConfig
print(LazyConfig.to_py(get_config("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py")))

From the output it’s easier to find relevant options to change, e.g. dataloader.train.total_batch_size for the batch size, or optimizer.lr for base learning rate.

We provide a reference training script tools/lazyconfig_train_net.py, that can train/eval our model zoo configs. It also shows how to support command line value overrides.

To demonstrate the power and flexibility of the new system, we show that a simple config file can let detectron2 train an ImageNet classification model from torchvision, even though detectron2 contains no features about ImageNet classification. This can serve as a reference for using detectron2 in other deep learning tasks.

Summary

By using recursive instantiation to create objects, we avoid passing a giant config to many places, because cfg is only passed to instantiate. This has the following benefits:

  • It’s non-intrusive: objects to be constructed are config-agnostic, regular Python functions/classes. They can even live in other libraries. For example, {"_target_": "torch.nn.Conv2d", "in_channels": 10, "out_channels": 10, "kernel_size": 1} defines a conv layer.

  • Clarity of what function/classes will be called, and what arguments they use.

  • cfg doesn’t need pre-defined keys and structures. It’s valid as long as it translates to valid code. This gives a lot more flexibility.

  • You can still pass huge dictionaries as arguments, just like the old way.

Recursive instantiation and Python syntax are orthogonal: you can use one without the other. But by putting them together, the config file looks a lot like the code that will be executed:

img

However, the config file just defines dictionaries, which can be easily manipulated further by composition or overrides. The corresponding code will only be executed later when instantiate is called. In some way, in config files we’re writing “editable code” that will be “lazily executed” later when needed. That’s why we call this system “LazyConfig”.