Shortcuts

Source code for functional.losses

""" Module containing segmentation losses """
import abc
from typing import Literal, Optional, Tuple

from monai.losses.dice import (
    DiceLoss as MonaiDiceLoss,
    GeneralizedDiceLoss as MonaiGeneralizedDiceLoss,
)
import torch

from .utils import mask_tensor, one_hot_encode


[docs]class SegmentationLoss(torch.nn.Module, abc.ABC): r""" Base class for implementation of segmentation losses. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__() self.ignore_index = ignore_index self.include_background = include_background if reduction and reduction not in ["mean", "sum", "none"]: raise ValueError("Invalid reduction method.") self.reduction = reduction self.epsilon = epsilon def _reduce_loss( self, loss: torch.Tensor, weight: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Aggregates the loss values of the different classes of an image as well as the different images of a batch. Args: loss (Tensor): Loss to be aggregated. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Aggregated loss value. Shape: - Loss: :math:`(N, C)`, where `N = batch size`, and `C = number of classes`, or `(N)` for binary segmentation tasks. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ if weight is not None: if loss.ndim > 1: for _ in range(1, loss.ndim): weight = weight.unsqueeze(axis=-1) loss = weight * loss # aggregate loss values for all class channels and the entire batch if self.reduction == "mean": return loss.mean() if self.reduction == "sum": return loss.sum() return loss @staticmethod def _flatten_tensor(tensor: torch.Tensor) -> torch.Tensor: r""" Flattens a tensor except for its first two dimensions (batch dimension and class dimension). Args: tensor (Tensor): The tensor to be flattened. Returns: Tensor: Flattened view of the input tensor. Shape: - Tensor: :math:`(N, C, X, Y, ...)` where `N = batch size`, and `C = number of classes`. - Output: :math:`(N, C, X*Y*...)` """ return tensor.contiguous().view(*tensor.shape[0:2], -1).float() def _preprocess_inputs( self, prediction: torch.Tensor, target: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r""" This method implements preprocessing steps that are needed for most segmentation losses: 1. Conversion from label encoding to one-hot encoding if necessary 2. Mapping of pixels/voxels labeled with the :attr:`ignore_index` to true negatives Args: prediction (Tensor): The prediction tensor to be preprocessed. target (Tensor): The target tensor to be preprocessed. Returns: Tuple[Tensor, Tensor]: The preprocessed prediction and target tensors. Shape: - Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is in :math:`[0, 1]` - Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label encoding or :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or multi-hot encoding. - Output: :math:`(N, C, X, Y, ...)` for both prediction and target. """ # during one-hot encoding the ignore index is removed, therefore the original target including the ignore index # is copied target_including_ignore_index = target target = target.clone() if prediction.dim() != target.dim(): assert prediction.dim() == target.dim() + 1 target = one_hot_encode( target, target.dim() - 1, prediction.shape[1], ignore_index=self.ignore_index, ) target_including_ignore_index = target_including_ignore_index.unsqueeze( dim=1 ) prediction = mask_tensor( prediction, target_including_ignore_index, self.ignore_index ) target = mask_tensor(target, target_including_ignore_index, self.ignore_index) return prediction, target
[docs]class AbstractDiceLoss(SegmentationLoss, abc.ABC): r""" Base class for implementation of Dice loss and Generalized Dice loss. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( epsilon=epsilon, ignore_index=ignore_index, include_background=include_background, reduction=reduction, )
[docs] @abc.abstractmethod def get_dice_loss_module(self) -> torch.nn.Module: """ Returns: Module: Dice loss module. """
[docs] def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. Returns: Tensor: Dice loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding or :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ assert prediction.shape == target.shape or prediction.dim() == target.dim() + 1 prediction, target = self._preprocess_inputs(prediction, target) dice_loss_module = self.get_dice_loss_module() dice_loss = dice_loss_module(prediction, target) return dice_loss
[docs]class DiceLoss(AbstractDiceLoss): r""" Implementation of the Dice loss for segmentation tasks. The Dice loss for binary segmentation tasks originally was formulated in: `Fausto Milletari, Nassir Navab, and Seyed-Ahmad Ahmadi. V-net: Fully convolutional neural networks for volumetric medical image segmentation, 2016 <https://arxiv.org/pdf/1606.04797.pdf>`_. In this implementation an adapted version of the Dice loss is used that a includes an epsilon term :math:`epsilon`to avoid divisions by zero and does not square the terms in the denominator. Additionally, the loss formulation is generalized to multi-class classification tasks by averaging dice losses over the class and batch dimension: .. math:: DL = 1 - \frac{1}{N \cdot L} \cdot \sum_{n=1}^N \sum_{l=1}^L 2 \cdot \frac{\sum_{i} r_{nli} p_{nli} + \epsilon}{\sum_{n=1}^N \sum_{l=1}^L \sum_{i} (r_{nli} + p_{nli}) + \epsilon} where :math:`N` is the batch size, :math:`L` is the number of classes, :math:`r_{nli}` are the ground-truth labels for class :math:`l` in the :math:`i`-th voxel of the :math:`n`-th image. Analogously, :math:`p` is the predicted probability for class :math:`l` in the :math:`i`-th voxel of the :math:`n`-th image. This implementation is a wrapper of the Dice loss implementation from the `MONAI package <https://docs.monai.io/en/stable/losses.html#diceloss>`_ that adapts the reduction behaviour. It supports both single-label and multi-label segmentation tasks. For a discussion on different dice loss implementations, see https://github.com/pytorch/pytorch/issues/1249. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( epsilon=epsilon, ignore_index=ignore_index, reduction=reduction ) self.dice_loss = MonaiDiceLoss( include_background=include_background, reduction="none", smooth_nr=epsilon, smooth_dr=epsilon, )
[docs] def get_dice_loss_module(self) -> torch.nn.Module: """ Returns: Module: Dice loss module. """ return self.dice_loss
# pylint: disable=arguments-differ
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: Dice loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ dice_loss = super().forward(prediction, target) # the MONAI Dice loss implementation returns a loss tensor of shape `(N, C, X, Y, ...)` when reduction is # set to "none" # since the spatial dimensions only contain a single element, they are squeezed here dice_loss = dice_loss.reshape((dice_loss.shape[0], dice_loss.shape[1])) dice_loss = self._reduce_loss(dice_loss, weight=weight) return dice_loss
[docs]class GeneralizedDiceLoss(AbstractDiceLoss): r""" Implementation of Generalized Dice loss for segmentation tasks. The Generalized Dice loss was formulated in: `Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M. Jorge Cardoso: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations, 2017. <https://arxiv.org/pdf/1707.03237.pdf>`_ It is formulated as: .. math:: GDL = \frac{1}{N} \cdot \sum_{n=1}^N (1 - 2 \frac{\sum_{l=1}^L w_l \cdot \sum_{i} r_{nli} p_{nli} + \epsilon}{\sum_{l=1}^L w_l \cdot \sum_{i} (r_{nli} + p_{nli}) + \epsilon}) where :math:`N` is the batch size, :math:`L` is the number of classes, :math:`w_l` is a class weight, :math:`r_{nli}` are the ground-truth labels for class :math:`l` in the :math:`i`-th voxel of the :math:`n`-th image. Analogously, :math:`p` is the predicted probability for class :math:`l` in the :math:`i`-th voxel of the :math:`n`-th image. This implementation is a wrapper of the Generalized Dice loss implementation from the `MONAI package <https://docs.monai.io/en/stable/losses.html#generalizeddiceloss>`_ that adapts the reduction behaviour. It supports both single-label and multi-label segmentation tasks. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). weight_type (string, optional): Type of function to transform ground truth volume to a weight factor: `"square"` | `"simple"` | `"uniform"`. Defaults to `"square"`. reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, weight_type: Literal["square", "simple", "uniform"] = "square", reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( epsilon=epsilon, ignore_index=ignore_index, reduction=reduction ) self.generalized_dice_loss = MonaiGeneralizedDiceLoss( include_background=include_background, w_type=weight_type, reduction="none", smooth_nr=epsilon, smooth_dr=epsilon, )
[docs] def get_dice_loss_module(self) -> torch.nn.Module: """ Returns: Module: Dice loss module. """ return self.generalized_dice_loss
# pylint: disable=arguments-differ
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: Generalized dice loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ dice_loss = super().forward(prediction, target) # the MONAI Dice loss implementation returns a loss tensor of shape `(N, C, X, Y, ...)` when reduction is # set to "none" # since the class dimension and the spatial dimensions only contain a single element, they are squeezed here dice_loss = dice_loss.reshape(dice_loss.shape[0]) dice_loss = self._reduce_loss(dice_loss, weight=weight) return dice_loss
[docs]class FalsePositiveLoss(SegmentationLoss): """ Implementation of false positive loss. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( ignore_index=ignore_index, include_background=include_background, reduction=reduction, epsilon=epsilon, )
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: False positive loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ assert prediction.shape == target.shape or prediction.dim() == target.dim() + 1 assert prediction.isfinite().all() and target.isfinite().all() prediction, target = self._preprocess_inputs(prediction, target) if not self.include_background: prediction = prediction[:, 1:] target = target[:, 1:] # flatten spatial dimensions flattened_prediction = self._flatten_tensor(prediction) flattened_target = self._flatten_tensor(target) false_positives = ((1 - flattened_target) * flattened_prediction).sum(-1) positives = flattened_prediction.sum(-1) # in contrast to the Dice loss, the epsilon term is only added to the denominator # if there are no positives at all, this will yield an optimal loss value of zero fp_loss = false_positives / (self.epsilon + positives) return self._reduce_loss(fp_loss, weight=weight)
[docs]class FalsePositiveDiceLoss(SegmentationLoss): """ Implements a loss function that combines the Dice loss with the false positive loss. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__(reduction=reduction, epsilon=epsilon) self.fp_loss = FalsePositiveLoss( ignore_index=ignore_index, include_background=include_background, reduction=reduction, epsilon=epsilon, ) self.dice_loss = DiceLoss( ignore_index=ignore_index, include_background=include_background, reduction=reduction, epsilon=epsilon, )
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: Combined loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ return self.fp_loss(prediction, target, weight=weight) + self.dice_loss( prediction, target, weight=weight )
[docs]class CrossEntropyLoss(SegmentationLoss): """ Wrapper for the PyTorch implementation of BCE loss / NLLLoss to ensure uniform reduction behaviour for all losses. Args: multi_label (bool, optional): Determines if data is multilabel or not (default = `False`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, multi_label: bool = False, ignore_index: Optional[int] = None, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( ignore_index=ignore_index, include_background=True, reduction=reduction, epsilon=epsilon, ) self.multi_label = multi_label if self.multi_label: self.cross_entropy_loss = torch.nn.BCELoss(reduction="none") else: self.cross_entropy_loss = torch.nn.NLLLoss( ignore_index=ignore_index if ignore_index is not None else -100, reduction="none", ) def _compute_loss( self, prediction: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: return self.cross_entropy_loss(prediction, target)
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: Cross-entropy loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ assert prediction.shape == target.shape or prediction.dim() == target.dim() + 1 assert prediction.isfinite().all() and target.isfinite().all() if self.multi_label: target = target.float() else: # the Pytorch NLLLoss expect the inputs to be the output of a LogSoftmax layer # since this loss expects the output of a Softmax layer as input, the log is taken here prediction = torch.log(prediction + self.epsilon) target = target.long() loss = self._compute_loss(prediction, target) if self.multi_label and self.ignore_index is not None: # the BCELoss from Pytorch does not provide an `ignore_index` argument # therefore the loss values for the voxels to be ignored have to be set to zero here loss = mask_tensor(loss, target, self.ignore_index) if self.reduction == "mean": # the images in one batch can have different sizes and thus the padding size can differ # in order to weight the loss term of each image equally regardless of its size, the loss tensor is averaged # over the spatial dimension (and the class dimension in case of multi-label segmentation tasks) before # passing it to the reduction function axis_to_reduce = tuple(range(1, loss.dim())) loss = loss.mean(dim=axis_to_reduce) return self._reduce_loss(loss, weight=weight)
[docs]class FocalLoss(CrossEntropyLoss): """ Wrapper for the `CrossEntropyLoss` to perform some additional computations to turn it into focal loss. Args: multi_label (bool, optional): Determines if data is multilabel or not (default = `False`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). gamma (float, optional): Specifies how far the loss of well-classified examples is down-weighed (default = `5`) """ def __init__( self, multi_label: bool = False, ignore_index: Optional[int] = None, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, gamma: float = 5, ): super().__init__(multi_label, ignore_index, reduction, epsilon) self.gamma = gamma def _compute_loss( self, prediction: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: cross_entropy_loss = self.cross_entropy_loss(prediction, target) probability = torch.exp(-cross_entropy_loss) focal_loss = (1 - probability) ** self.gamma * cross_entropy_loss return focal_loss
[docs]class CrossEntropyDiceLoss(SegmentationLoss): """ Implements a loss function that combines the Dice loss with the binary cross-entropy (negative log-likelihood) loss. Args: multi_label (bool, optional): Determines if data is multilabel or not (default = `False`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the Dice loss calculation, but not from the Cross-entropy loss calculation (default = `True`). reduction (str, optional): Specifies the reduction to aggregate the loss values over the images of a batch and multiple classes: `"none"` | `"mean"` | `"sum"`. `"none"`: no reduction will be applied, `"mean"`: the mean of the output is taken, `"sum"`: the output will be summed (default = `"mean"`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = `1e-10`). """ def __init__( self, multi_label: bool = False, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "sum", "none"] = "mean", epsilon: float = 1e-10, ): super().__init__( include_background=include_background, reduction=reduction, epsilon=epsilon ) self.cross_entropy_loss = CrossEntropyLoss( multi_label=multi_label, ignore_index=ignore_index, reduction=reduction, epsilon=epsilon, ) self.dice_loss = DiceLoss( ignore_index=ignore_index, include_background=include_background, reduction=reduction, epsilon=epsilon, )
[docs] def forward( self, prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: prediction (Tensor): Predicted segmentation mask which is either the output of a softmax or a sigmoid layer. target (Tensor): Target segmentation mask which is either label encoded, one-hot encoded or multi-hot encoded. weight (Tensor, optional): Manual weight given to the loss of each image / slice. Defaults to `None`, which means that all images are weighted equally. Returns: Tensor: Combined loss. Shape: - | Prediction: :math:`(N, C, X, Y, ...)`, where `N = batch size`, `C = number of classes` and each value is | in :math:`[0, 1]` - | Target: :math:`(N, X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(N, C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding. - Weight: :math:`(N)` where `N = batch size`. - Output: If :attr:`reduction` is `"none"`, shape :math:`(N, C)`. Otherwise, scalar. """ return self.cross_entropy_loss( prediction, target, weight=weight ) + self.dice_loss(prediction, target, weight=weight)

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs