Benchmarks

Here we benchmark the training speed of a Mask R-CNN in detectron2, with some other popular open source Mask R-CNN implementations.

Settings

  • Hardware: 8 NVIDIA V100s with NVLink.

  • Software: Python 3.7, CUDA 10.1, cuDNN 7.6.5, PyTorch 1.5, TensorFlow 1.15.0rc2, Keras 2.2.5, MxNet 1.6.0b20190820.

  • Model: an end-to-end R-50-FPN Mask-RCNN model, using the same hyperparameter as the Detectron baseline config (it does not have scale augmentation).

  • Metrics: We use the average throughput in iterations 100-500 to skip GPU warmup time. Note that for R-CNN-style models, the throughput of a model typically changes during training, because it depends on the predictions of the model. Therefore this metric is not directly comparable with “train speed” in model zoo, which is the average speed of the entire training run.

Main Results

Implementation

Throughput (img/s)

D2 PT

62

mmdetection PT

53

maskrcnn-benchmark PT

53

tensorpack TF

50

simpledet mxnet

39

Detectron C2

19

matterport/Mask_RCNN TF

14

Details for each implementation:

  • Detectron2: with release v0.1.2, run:

    python tools/train_net.py  --config-file configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml --num-gpus 8
    
  • mmdetection: at commit b0d845f, run

    ./tools/dist_train.sh configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py 8
    
  • maskrcnn-benchmark: use commit 0ce8f6f with sed -i 's/torch.uint8/torch.bool/g' **/*.py; sed -i 's/AT_CHECK/TORCH_CHECK/g' **/*.cu to make it compatible with PyTorch 1.5. Then, run training with

    python -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/e2e_mask_rcnn_R_50_FPN_1x.yaml
    

    The speed we observed is faster than its model zoo, likely due to different software versions.

  • tensorpack: at commit caafda, export TF_CUDNN_USE_AUTOTUNE=0, then run

    mpirun -np 8 ./train.py --config DATA.BASEDIR=/data/coco TRAINER=horovod BACKBONE.STRIDE_1X1=True TRAIN.STEPS_PER_EPOCH=50 --load ImageNet-R50-AlignPadding.npz
    
  • SimpleDet: at commit 9187a1, run

    python detection_train.py --config config/mask_r50v1_fpn_1x.py
    
  • Detectron: run

    python tools/train_net.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml
    

    Note that many of its ops run on CPUs, therefore the performance is limited.

  • matterport/Mask_RCNN: at commit 3deaec, apply the following diff, export TF_CUDNN_USE_AUTOTUNE=0, then run

    python coco.py train --dataset=/data/coco/ --model=imagenet
    

    Note that many small details in this implementation might be different from Detectron’s standards.

    (diff to make it use the same hyperparameters - click to expand)
    diff --git i/mrcnn/model.py w/mrcnn/model.py
    index 62cb2b0..61d7779 100644
    --- i/mrcnn/model.py
    +++ w/mrcnn/model.py
    @@ -2367,8 +2367,8 @@ class MaskRCNN():
          epochs=epochs,
          steps_per_epoch=self.config.STEPS_PER_EPOCH,
          callbacks=callbacks,
    -            validation_data=val_generator,
    -            validation_steps=self.config.VALIDATION_STEPS,
    +            #validation_data=val_generator,
    +            #validation_steps=self.config.VALIDATION_STEPS,
          max_queue_size=100,
          workers=workers,
          use_multiprocessing=True,
    diff --git i/mrcnn/parallel_model.py w/mrcnn/parallel_model.py
    index d2bf53b..060172a 100644
    --- i/mrcnn/parallel_model.py
    +++ w/mrcnn/parallel_model.py
    @@ -32,6 +32,7 @@ class ParallelModel(KM.Model):
        keras_model: The Keras model to parallelize
        gpu_count: Number of GPUs. Must be > 1
        """
    +        super().__init__()
        self.inner_model = keras_model
        self.gpu_count = gpu_count
        merged_outputs = self.make_parallel()
    diff --git i/samples/coco/coco.py w/samples/coco/coco.py
    index 5d172b5..239ed75 100644
    --- i/samples/coco/coco.py
    +++ w/samples/coco/coco.py
    @@ -81,7 +81,10 @@ class CocoConfig(Config):
      IMAGES_PER_GPU = 2
    
      # Uncomment to train on 8 GPUs (default is 1)
    -    # GPU_COUNT = 8
    +    GPU_COUNT = 8
    +    BACKBONE = "resnet50"
    +    STEPS_PER_EPOCH = 50
    +    TRAIN_ROIS_PER_IMAGE = 512
    
      # Number of classes (including background)
      NUM_CLASSES = 1 + 80  # COCO has 80 classes
    @@ -496,29 +499,10 @@ if __name__ == '__main__':
        # *** This training schedule is an example. Update to your needs ***
    
        # Training - Stage 1
    -        print("Training network heads")
        model.train(dataset_train, dataset_val,
              learning_rate=config.LEARNING_RATE,
              epochs=40,
    -                    layers='heads',
    -                    augmentation=augmentation)
    -
    -        # Training - Stage 2
    -        # Finetune layers from ResNet stage 4 and up
    -        print("Fine tune Resnet stage 4 and up")
    -        model.train(dataset_train, dataset_val,
    -                    learning_rate=config.LEARNING_RATE,
    -                    epochs=120,
    -                    layers='4+',
    -                    augmentation=augmentation)
    -
    -        # Training - Stage 3
    -        # Fine tune all layers
    -        print("Fine tune all layers")
    -        model.train(dataset_train, dataset_val,
    -                    learning_rate=config.LEARNING_RATE / 10,
    -                    epochs=160,
    -                    layers='all',
    +                    layers='3+',
              augmentation=augmentation)
    
      elif args.command == "evaluate":