Source code for fvcore.nn.focal_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torch
from torch.nn import functional as F


[docs]def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none", ) -> torch.Tensor: """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. 'sum': The output will be summed. Returns: Loss tensor with the reduction option applied. """ p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss
sigmoid_focal_loss_jit = torch.jit.script( sigmoid_focal_loss ) # type: torch.jit.ScriptModule
[docs]def sigmoid_focal_loss_star( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 1, reduction: str = "none", ) -> torch.Tensor: """ FL* described in RetinaNet paper Appendix: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Gamma parameter described in FL*. Default = 1 (no weighting). reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. 'sum': The output will be summed. Returns: Loss tensor with the reduction option applied. """ shifted_inputs = gamma * (inputs * (2 * targets - 1)) loss = -(F.logsigmoid(shifted_inputs)) / gamma if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss *= alpha_t if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss
sigmoid_focal_loss_star_jit = torch.jit.script( sigmoid_focal_loss_star ) # type: torch.jit.ScriptModule