Source code for detectron2.data.samplers.distributed_sampler

# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import math
from collections import defaultdict
from typing import Optional
import torch
from torch.utils.data.sampler import Sampler

from detectron2.utils import comm

logger = logging.getLogger(__name__)


[docs]class TrainingSampler(Sampler): """ In training, we only care about the "infinite stream" of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices. The samplers in each worker effectively produces `indices[worker_id::num_workers]` where `indices` is an infinite stream of indices consisting of `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) or `range(size) + range(size) + ...` (if shuffle is False) Note that this sampler does not shard based on pytorch DataLoader worker id. A sampler passed to pytorch DataLoader is used only with map-style dataset and will not be executed inside workers. But if this sampler is used in a way that it gets execute inside a dataloader worker, then extra work needs to be done to shard its outputs based on worker id. This is required so that workers don't produce identical data. :class:`ToIterableDataset` implements this logic. This note is true for all samplers in detectron2. """
[docs] def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ if not isinstance(size, int): raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.") if size <= 0: raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.") self._size = size self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __iter__(self): start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) while True: if self._shuffle: yield from torch.randperm(self._size, generator=g).tolist() else: yield from torch.arange(self._size).tolist()
[docs]class RandomSubsetTrainingSampler(TrainingSampler): """ Similar to TrainingSampler, but only sample a random subset of indices. This is useful when you want to estimate the accuracy vs data-number curves by training the model with different subset_ratio. """
[docs] def __init__( self, size: int, subset_ratio: float, shuffle: bool = True, seed_shuffle: Optional[int] = None, seed_subset: Optional[int] = None, ): """ Args: size (int): the total number of data of the underlying dataset to sample from subset_ratio (float): the ratio of subset data to sample from the underlying dataset shuffle (bool): whether to shuffle the indices or not seed_shuffle (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). seed_subset (int): the seed to randomize the subset to be sampled. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle) assert 0.0 < subset_ratio <= 1.0 self._size_subset = int(size * subset_ratio) assert self._size_subset > 0 if seed_subset is None: seed_subset = comm.shared_random_seed() self._seed_subset = int(seed_subset) # randomly generate the subset indexes to be sampled from g = torch.Generator() g.manual_seed(self._seed_subset) indexes_randperm = torch.randperm(self._size, generator=g) self._indexes_subset = indexes_randperm[: self._size_subset] logger.info("Using RandomSubsetTrainingSampler......") logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__() while True: if self._shuffle: # generate a random permutation to shuffle self._indexes_subset randperm = torch.randperm(self._size_subset, generator=g) yield from self._indexes_subset[randperm].tolist() else: yield from self._indexes_subset.tolist()
[docs]class RepeatFactorTrainingSampler(Sampler): """ Similar to TrainingSampler, but a sample may appear more times than others based on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS. """
[docs] def __init__(self, repeat_factors, *, shuffle=True, seed=None): """ Args: repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``. shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() # Split into whole number (_int_part) and fractional (_frac_part) parts. self._int_part = torch.trunc(repeat_factors) self._frac_part = repeat_factors - self._int_part
[docs] @staticmethod def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): """ Compute (fractional) per-image repeat factors based on category frequency. The repeat factor for an image is a function of the frequency of the rarest category labeled in that image. The "frequency of category c" in [0, 1] is defined as the fraction of images in the training set (without repeats) in which category c appears. See :paper:`lvis` (>= v2) Appendix B.2. Args: dataset_dicts (list[dict]): annotations in Detectron2 dataset format. repeat_thresh (float): frequency threshold below which data is repeated. If the frequency is half of `repeat_thresh`, the image will be repeated twice. Returns: torch.Tensor: the i-th element is the repeat factor for the dataset image at index i. """ # 1. For each category c, compute the fraction of images that contain it: f(c) category_freq = defaultdict(int) for dataset_dict in dataset_dicts: # For each image (without repeats) cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} for cat_id in cat_ids: category_freq[cat_id] += 1 num_images = len(dataset_dicts) for k, v in category_freq.items(): category_freq[k] = v / num_images # 2. For each category c, compute the category-level repeat factor: # r(c) = max(1, sqrt(t / f(c))) category_rep = { cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) for cat_id, cat_freq in category_freq.items() } # 3. For each image I, compute the image-level repeat factor: # r(I) = max_{c in I} r(c) rep_factors = [] for dataset_dict in dataset_dicts: cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) rep_factors.append(rep_factor) return torch.tensor(rep_factors, dtype=torch.float32)
def _get_epoch_indices(self, generator): """ Create a list of dataset indices (with repeats) to use for one epoch. Args: generator (torch.Generator): pseudo random number generator used for stochastic rounding. Returns: torch.Tensor: list of dataset indices to use in one epoch. Each index is repeated based on its calculated repeat factor. """ # Since repeat factors are fractional, we use stochastic rounding so # that the target repeat factor is achieved in expectation over the # course of training rands = torch.rand(len(self._frac_part), generator=generator) rep_factors = self._int_part + (rands < self._frac_part).float() # Construct a list of indices in which we repeat images as specified indices = [] for dataset_index, rep_factor in enumerate(rep_factors): indices.extend([dataset_index] * int(rep_factor.item())) return torch.tensor(indices, dtype=torch.int64) def __iter__(self): start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) while True: # Sample indices with repeats determined by stochastic rounding; each # "epoch" may have a slightly different size due to the rounding. indices = self._get_epoch_indices(g) if self._shuffle: randperm = torch.randperm(len(indices), generator=g) yield from indices[randperm].tolist() else: yield from indices.tolist()
[docs]class InferenceSampler(Sampler): """ Produce indices for inference across all workers. Inference needs to run on the __exact__ set of samples, therefore when the total number of samples is not divisible by the number of workers, this sampler produces different number of samples on different workers. """
[docs] def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[: rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices)