Shortcuts

Source code for functional.metrics

"""
Module containing model evaluation metrics.

The metric implementations are based on the TorchMetrics framework. For instructions on how to implement custom metrics
with this framework, see https://torchmetrics.readthedocs.io/en/latest/pages/implement.html.
"""
import abc
from typing import List, Literal, Optional, Tuple, Union

import psutil
import torch
import torchmetrics

from .utils import (
    flatten_tensors,
    is_binary,
    reduce_metric,
    remove_padding,
    preprocess_metric_inputs,
)

# pylint: disable=too-many-lines


[docs]def dice_score( prediction: torch.Tensor, target: torch.Tensor, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "min", "max", "none"] = "none", ) -> torch.Tensor: r""" Computes the Dice similarity coefficient (DSC) between a predicted segmentation mask and the target mask: :math:`DSC = \frac{2 \cdot TP}{2 \cdot TP + FP + FN}` Using the `epsilon` parameter, Laplacian smoothing can be applied: :math:`DSC = \frac{2 \cdot TP + \lambda}{2 \cdot TP + FP + FN + \lambda}` This metric supports both single-label and multi-label segmentation tasks. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing factor (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum Returns: Tensor: Dice similarity coefficient. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label encoding | and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or multi-hot encoding | (`C = number of classes`). - Target: Same shape and type as prediction. - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ prediction, target = preprocess_metric_inputs( prediction, target, num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, ) flattened_prediction, flattened_target = flatten_tensors( prediction, target, include_background=include_background ) intersection = (flattened_prediction * flattened_target).sum(dim=1) per_class_dice_score = (2.0 * intersection + epsilon) / ( flattened_prediction.sum(dim=1) + flattened_target.sum(dim=1) + epsilon ) return reduce_metric(per_class_dice_score, reduction)
[docs]def sensitivity( prediction: torch.Tensor, target: torch.Tensor, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "min", "max", "none"] = "none", ) -> torch.Tensor: r""" Computes the sensitivity from a predicted segmentation mask and the target mask: :math:`Sensitivity = \frac{TP}{TP + FN}` Using the `epsilon` parameter, Laplacian smoothing can be applied: :math:`Sensitivity = \frac{TP + \lambda}{TP + FN + \lambda}` This metric supports both single-label and multi-label segmentation tasks. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing factor (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum Returns: Tensor: Sensitivity. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label encoding | and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or multi-hot encoding | (`C = number of classes`). - Target: Same shape and type as prediction. - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ prediction, target = preprocess_metric_inputs( prediction, target, num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, ) flattened_prediction, flattened_target = flatten_tensors( prediction, target, include_background=include_background ) true_positives = (flattened_prediction * flattened_target).sum(dim=1) true_positives_false_negatives = flattened_target.sum(dim=1) per_class_sensitivity = (true_positives + epsilon) / ( true_positives_false_negatives + epsilon ) return reduce_metric(per_class_sensitivity, reduction)
[docs]def specificity( prediction: torch.Tensor, target: torch.Tensor, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[int] = None, include_background: bool = True, reduction: Literal["mean", "min", "max", "none"] = "none", ) -> torch.Tensor: r""" Computes the specificity from a predicted segmentation mask and the target mask: :math:`Specificity = \frac{TN}{TN + FP}` Using the `epsilon` parameter, Laplacian smoothing can be applied: :math:`Specificity = \frac{TN + \lambda}{TN + FP + \lambda}` This metric supports both single-label and multi-label segmentation tasks. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing factor (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum Returns: Tensor: Specificity. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label encoding | and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or multi-hot encoding | (`C = number of classes`). - Target: Same shape and type as prediction. - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ prediction, target = preprocess_metric_inputs( prediction, target, num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, ignore_value=1, ) flattened_prediction, flattened_target = flatten_tensors( prediction, target, include_background=include_background ) ones = torch.ones(flattened_prediction.shape, device=prediction.device) true_negatives = ((ones - flattened_prediction) * (ones - flattened_target)).sum( dim=1 ) true_negatives_false_positives = (ones - flattened_target).sum(dim=1) per_class_specificity = (true_negatives + epsilon) / ( true_negatives_false_positives + epsilon ) return reduce_metric(per_class_specificity, reduction)
def _distance_matrix( first_point_set: torch.Tensor, second_point_set: torch.Tensor ) -> torch.Tensor: r""" Computes shortest Euclidean distance from all points in the first tensor to any point in the second tensor. Args: first_point_set (Tensor): A tensor representing a list of points in an Euclidean space (typically 2d or 3d) second_point_set (Tensor): A tensor representing a list of points in an Euclidean space (typically 2d or 3d). Returns: Tensor: Matrix containing the shortest Euclidean distances between all points in the first tensor and any point in the second tensor. Shape: first_point_set: :math:`(N, D)` where :math:`N` is the number of points in the first set and :math:`D` is the dimensionality of the Euclidean space. second_point_set: :math:`(M, D)` where :math:`M` is the number of points in the second set and :math:`D` is the dimensionality of the Euclidean space. Output: :math:`(N)` where :math:`N` is the number of points in the first set. """ if len(first_point_set) == 0: return torch.as_tensor(float("nan"), device=first_point_set.device) if torch.cuda.is_available(): free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0) else: free_memory = psutil.virtual_memory().available second_point_set_memory = ( second_point_set.element_size() * second_point_set.nelement() ) split_size = int(max(free_memory / second_point_set_memory, 1)) - 5 minimum_distances = torch.zeros(len(first_point_set), device=first_point_set.device) first_point_sets = torch.split(first_point_set, split_size) for idx, first_point_set_split in enumerate(first_point_sets): distances = torch.cdist(first_point_set_split, second_point_set.float()) minimum_distances_for_current_split, _ = distances.min(axis=-1) start_index = idx * split_size end_index = (idx + 1) * split_size minimum_distances[start_index:end_index] = minimum_distances_for_current_split return minimum_distances def _compute_all_image_locations(shape: Tuple[int], device: Optional[str] = None): r""" Computes a tensor of points corresponding to all pixel locations of a 2d or a 3d image in Euclidean space. Args: shape (Tuple[int]): The shape of the image for which the pixel locations are to be computed. device (str, optional): Device as defined by PyTorch. Returns: A tensor of points corresponding to all pixel locations of the image. Shape: Output: :math:`(width \cdot height, 2)` for 2d images, :math:`(S \cdot width \cdot height, 3)`, where `S = number of slices`, for 3d images. """ return torch.cartesian_prod( *[torch.arange(dim, device=device).float() for dim in shape] ) def _binary_erosion(input_image: torch.Tensor) -> torch.Tensor: r""" Applies an erosion filter to a tensor representing a binary 2d or 3d image. Args: input_image (Tensor): Binary image to be eroded. Returns: Tensor: Eroded imaged. Shape: Input Image: :math:`(height, width)` or :math:`(S, height, width)` where :math:`S = number of slices`. Output: Has the same dimensions as the input. """ assert input_image.dim() in [ 2, 3, ], "Input must be a two- or three-dimensional tensor" assert is_binary(input_image), "Input must be binary." if input_image.dim() == 2: max_pooling = torch.nn.MaxPool2d( 3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False ) elif input_image.dim() == 3: max_pooling = torch.nn.MaxPool3d( 3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False ) # to implement erosion filtering, max pooling with a 3x3 kernel is applied to the inverted image inverted_input = ( torch.ones(input_image.shape, device=input_image.device) - input_image ) # pad image with ones to maintain the input's dimensions if input_image.dim() == 2: inverted_input_padded = torch.nn.functional.pad( inverted_input, (1, 1, 1, 1), "constant", 1 ) else: inverted_input_padded = torch.nn.functional.pad( inverted_input, (1, 1, 1, 1, 1, 1), "constant", 1 ) # create class and batch dimension as this is required by the max pooling module inverted_input_padded = inverted_input_padded.unsqueeze(dim=0).unsqueeze(dim=0) # apply the max pooling and invert the result return torch.ones(input_image.shape, device=input_image.device) - max_pooling( inverted_input_padded ).squeeze(dim=0).squeeze(dim=0) # pylint: disable=too-many-locals
[docs]def single_class_hausdorff_distance( prediction: torch.Tensor, target: torch.Tensor, normalize: bool = False, percentile: float = 0.95, all_image_locations: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Computes the Hausdorff distance between a predicted segmentation mask and the target mask. Note: In this method, the `prediction` tensor is considered as a segmentation mask of a single 2D or 3D image for a given class and thus the distances are calculated over all channels and dimensions. As this method is implemented using the scipy package, the returned value is not differentiable in PyTorch and can therefore not be used in loss functions. Args: prediction (Tensor): The predicted segmentation mask, where each value is in :math:`\{0, 1\}`. target (Tensor): The target segmentation mask, where each value is in :math:`\{0, 1\}`. normalize (bool, optional): Whether the Hausdorff distance should be normalized by dividing it by the diagonal distance. percentile (float, optional): Percentile for which the Hausdorff distance is to be calculated, must be in :math:`\[0, 1\]`. all_image_locations (Tensor, optional): A pre-computed tensor containing one point for each pixel in the prediction representing the pixel's location in 2d or 3d Euclidean space. Passing a pre-computed tensor to this parameter can be used to speed up computation. Returns: Tensor: Hausdorff distance. Shape: - | Prediction: Can have arbitrary dimensions. Typically :math:`(S, height, width)`, where | `S = number of slices`, or `(height, width)` for single image segmentation tasks. - Target: Must have the same dimensions as the prediction. - | All_image_locations: Must contain one element per-pixel in the prediction. Each element represents an | n-dimensional point. Typically :math:`(S \cdot height \cdot width, 3)`, where `S = number of slices`, or | :math:`(height \cdot width, 2)` - Output: Scalar. """ # adapted code from # https://github.com/PiechaczekMyller/brats/blob/eb9f7eade1066dd12c90f6cef101b74c5e974bfa/brats/functional.py#L135 assert prediction.device == target.device assert ( prediction.shape == target.shape ), "Prediction and target must have the same dimensions." assert ( prediction.dim() == 2 or prediction.dim() == 3 ), "Prediction and target must have either two or three dimensions." assert is_binary(prediction), "Predictions must be binary." assert is_binary(target), "Target must be binary." if torch.count_nonzero(prediction) == 0 or torch.count_nonzero(target) == 0: return torch.as_tensor(float("nan"), device=prediction.device) if all_image_locations is None: all_image_locations = _compute_all_image_locations( prediction.shape, device=prediction.device ) # to reduce computational effort, the distance calculations are only done for the border points of the predicted and # the target shape # to identify border points, binary erosion is used prediction_boundaries = torch.logical_xor( prediction, _binary_erosion(prediction) ).int() target_boundaries = torch.logical_xor(target, _binary_erosion(target)).int() flattened_prediction = prediction_boundaries.view(-1).float() flattened_target = target_boundaries.view(-1).float() # select those points that belong to the target segmentation mask target_mask = flattened_target.eq(1) target_locations = all_image_locations[target_mask, :] # select those points that belong to the predicted segmentation mask prediction_mask = flattened_prediction.eq(1) prediction_locations = all_image_locations[prediction_mask, :] # for each point in the predicted segmentation mask, compute the Euclidean distance to all points of the target # segmentation mask distances_to_target = _distance_matrix(prediction_locations, target_locations) # for each point in the target segmentation mask, compute the Euclidean distance to all points of the predicted # segmentation mask distances_to_prediction = _distance_matrix(target_locations, prediction_locations) distances = torch.cat((distances_to_target, distances_to_prediction), 0) if normalize: maximum_distance = torch.norm( torch.tensor(list(prediction.shape), dtype=torch.float) - 1 ) distances = distances / maximum_distance return torch.quantile(distances, q=percentile, keepdim=False).float()
# pylint: disable=too-many-arguments
[docs]def hausdorff_distance( prediction: torch.Tensor, target: torch.Tensor, num_classes: int, all_image_locations: Optional[torch.Tensor] = None, convert_to_one_hot: bool = True, ignore_index: Optional[int] = None, include_background: bool = True, normalize: bool = False, percentile: float = 0.95, reduction: Literal["mean", "min", "max", "none"] = "none", ) -> torch.Tensor: r""" Computes the Hausdorff distance between a predicted segmentation mask and the target mask. This metric supports both single-label and multi-label segmentation tasks. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. num_classes (int): Number of classes (for single-label segmentation tasks including the background class). all_image_locations (Tensor, optional): A pre-computed tensor containing one point for each pixel in the prediction representing the pixel's location in 2d or 3d Euclidean space. Passing a pre-computed tensor to this parameter can be used to speed up computation. convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). normalize (bool, optional): Whether the Hausdorff distance should be normalized by dividing it by the diagonal distance. percentile (float, optional): Percentile for which the Hausdorff distance is to be calculated, must be in :math:`\[0, 1\]`. reduction (string): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum Returns: Tensor: Hausdorff distance. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label encoding | and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or multi-hot encoding | (`C = number of classes`). - Target: Same shape and type as prediction. - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ # adapted code from # https://github.com/PiechaczekMyller/brats/blob/eb9f7eade1066dd12c90f6cef101b74c5e974bfa/brats/functional.py#L135 prediction, target = remove_padding( prediction, target, ignore_index, convert_to_one_hot ) prediction, target = preprocess_metric_inputs( prediction, target, num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, ) if not include_background: num_classes = num_classes - 1 # drop the channel of the background class prediction = prediction[1:] target = target[1:] per_class_hausdorff_distances = torch.ones(num_classes, device=prediction.device) for i in range(num_classes): per_class_hausdorff_distances[i] = single_class_hausdorff_distance( prediction[i], target[i], normalize=normalize, percentile=percentile, all_image_locations=all_image_locations, ) return reduce_metric(per_class_hausdorff_distances, reduction)
[docs]class SegmentationMetric(torchmetrics.Metric, abc.ABC): """ Base class for segmentation metrics. Args: num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). ignore_value (float, optional): Value that should be inserted at the positions where the target is equal to `ignore_index`. Defaults to 0. reduction (string, optional): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum """ def __init__( self, num_classes: int, convert_to_one_hot: bool = True, ignore_index: Optional[bool] = None, ignore_value: float = 0, include_background: bool = True, reduction: Literal["mean", "min", "max", "none"] = "none", ): super().__init__() self.num_classes = num_classes self.convert_to_one_hot = convert_to_one_hot self.ignore_index = ignore_index self.ignore_value = ignore_value self.include_background = include_background if reduction not in ["mean", "min", "max", "none"]: raise ValueError("Invalid reduction method.") self.reduction = reduction def _flatten_tensors( self, prediction: torch.Tensor, target: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Reshapes and flattens prediction and target tensors except for the first dimension (class dimension). Args: prediction (Tensor): The prediction tensor (either label-encoded, one-hot encoded or multi-hot encoded). target (Tensor): The target tensor (either label-encoded, one-hot encoded or multi-hot encoded). Returns: Tuple[Tensor]: Flattened prediction and target tensors (one-hot or multi-hot encoded). Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding (`C = number of classes`). - Target: Same shape and type as prediction. - | Output: :math:`(C, X * Y * ...)` where each element is in :math:`\{0, 1\}` indicating the absence / | presence of the respective class (one-hot / multi-hot encoding). """ return flatten_tensors( prediction, target, include_background=self.include_background, ) def _preprocess_inputs( self, prediction: torch.Tensor, target: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ This method implements three preprocessing steps that are needed for most segmentation metrics: 1. Validation of input shape and type 2. Conversion from label encoding to one-hot encoding if necessary 3. Mapping of pixels/voxels labeled with the :attr:`ignore_index` to true negatives or true positives Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. Returns: Tuple[Tensor, Tensor]: The preprocessed prediction and target tensors. """ return preprocess_metric_inputs( prediction, target, self.num_classes, convert_to_one_hot=self.convert_to_one_hot, ignore_index=self.ignore_index, ignore_value=self.ignore_value, )
[docs]class DiceScore(SegmentationMetric): """ Computes the Dice similarity coefficient (DSC). Can be used for 3D images whose slices are scattered over multiple batches. Args: num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string, optional): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum """ def __init__( self, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[bool] = None, include_background: bool = True, reduction: Literal["none", "mean", "min", "max"] = "none", ): super().__init__( num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, include_background=include_background, reduction=reduction, ) self.epsilon = epsilon num_classes = num_classes if self.include_background else num_classes - 1 self.numerator = torch.zeros(num_classes) self.denominator = torch.zeros(num_classes) self.add_state("numerator", torch.zeros(num_classes)) self.add_state("denominator", torch.zeros(num_classes)) # pylint: disable=arguments-differ
[docs] def update(self, prediction: torch.Tensor, target: torch.Tensor) -> None: r""" Updates metric using the provided prediction. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding (`C = number of classes`). - Target: Same shape and type as prediction. """ prediction, target = self._preprocess_inputs(prediction, target) flattened_prediction, flattened_target = self._flatten_tensors( prediction, target ) self.numerator += (flattened_prediction * flattened_target).sum(dim=1) self.denominator += flattened_prediction.sum(dim=1) + flattened_target.sum( dim=1 )
[docs] def compute(self) -> torch.Tensor: """ Computes the DSC over all slices that were registered using the `update` method. Returns: Tensor: Dice similarity coefficient. Shape: - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ per_class_dice_score = (2.0 * self.numerator + self.epsilon) / ( self.denominator + self.epsilon ) return reduce_metric(per_class_dice_score, self.reduction)
[docs]class Sensitivity(SegmentationMetric): """ Computes the sensitivity. Can be used for 3D images whose slices are scattered over multiple batches. Args: num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string, optional): A method to reduce metric scores of multiple classes. - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum """ def __init__( self, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[bool] = None, include_background: bool = True, reduction: Literal["none", "mean", "min", "max"] = "none", ): super().__init__( num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, include_background=include_background, reduction=reduction, ) self.epsilon = epsilon num_classes = num_classes if self.include_background else num_classes - 1 self.true_positives = torch.zeros(num_classes) self.true_positives_false_negatives = torch.zeros(num_classes) self.add_state("true_positives", torch.zeros(num_classes)) self.add_state("true_positives_false_negatives", torch.zeros(num_classes)) # pylint: disable=arguments-differ
[docs] def update(self, prediction: torch.Tensor, target: torch.Tensor) -> None: r""" Updates metric using the provided prediction. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding (`C = number of classes`). - Target: Same shape and type as prediction. """ prediction, target = self._preprocess_inputs(prediction, target) flattened_prediction, flattened_target = self._flatten_tensors( prediction, target ) self.true_positives += (flattened_prediction * flattened_target).sum(dim=1) self.true_positives_false_negatives += flattened_target.sum(dim=1)
[docs] def compute(self) -> torch.Tensor: """ Computes the sensitivity over all slices that were registered using the `update` method. Returns: Tensor: Sensitivity. Shape: - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ per_class_sensitivity = (self.true_positives + self.epsilon) / ( self.true_positives_false_negatives + self.epsilon ) return reduce_metric(per_class_sensitivity, self.reduction)
[docs]class Specificity(SegmentationMetric): """ Computes the specificity. Can be used for 3D images whose slices are scattered over multiple batches. Args: num_classes (int): Number of classes (for single-label segmentation tasks including the background class). convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). epsilon (float, optional): Laplacian smoothing term to avoid divisions by zero (default = 0). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). reduction (string, optional): A method to reduce metric scores of multiple classes (default = `"none"`). - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum """ def __init__( self, num_classes: int, convert_to_one_hot: bool = True, epsilon: float = 0, ignore_index: Optional[bool] = None, include_background: bool = True, reduction: Literal["none", "mean", "min", "max"] = "none", ): super().__init__( num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, ignore_value=1, include_background=include_background, reduction=reduction, ) self.epsilon = epsilon num_classes = num_classes if self.include_background else num_classes - 1 self.true_negatives = torch.zeros(num_classes) self.true_negatives_false_positives = torch.zeros(num_classes) self.add_state("true_negatives", torch.zeros(num_classes)) self.add_state("true_negatives_false_positives", torch.zeros(num_classes)) # pylint: disable=arguments-differ
[docs] def update(self, prediction: torch.Tensor, target: torch.Tensor) -> None: r""" Updates metric using the provided prediction. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding (`C = number of classes`). - Target: Same shape and type as prediction. """ prediction, target = self._preprocess_inputs(prediction, target) flattened_prediction, flattened_target = self._flatten_tensors( prediction, target ) ones = torch.ones(flattened_prediction.shape, device=prediction.device) self.true_negatives += ( (ones - flattened_prediction) * (ones - flattened_target) ).sum(dim=1) self.true_negatives_false_positives += (ones - flattened_target).sum(dim=1)
[docs] def compute(self) -> torch.Tensor: """ Computes the sensitivity over all slices that were registered using the `update` method. Returns: Tensor: Sensitivity. Shape: - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ per_class_sensitivity = (self.true_negatives + self.epsilon) / ( self.true_negatives_false_positives + self.epsilon ) return reduce_metric(per_class_sensitivity, self.reduction)
[docs]class HausdorffDistance(SegmentationMetric): r""" Computes the Hausdorff distance. Can be used for 3D images whose slices are scattered over multiple batches. Args: num_classes (int): Number of classes (for single-label segmentation tasks including the background class). slices_per_image (int): Number of slices per 3d image. convert_to_one_hot (bool, optional): Determines if data is label encoded and needs to be converted to one-hot encoding or not (default = `True`). ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric. Defaults to `None`. include_background (bool, optional): if `False`, class channel index 0 (background class) is excluded from the calculation (default = `True`). normalize (bool, optional): Whether the Hausdorff distance should be normalized by dividing it by the diagonal distance. percentile (float, optional): Percentile for which the Hausdorff distance is to be calculated, must be in :math:`\[0, 1\]`. reduction (string, optional): A method to reduce metric scores of multiple classes (default = `"none"`). - ``"none"``: no reduction will be applied (default) - ``"mean"``: takes the mean - ``"min"``: takes the minimum - ``"max"``: takes the maximum """ def __init__( self, num_classes: int, slices_per_image: int, convert_to_one_hot: bool = True, ignore_index: Optional[bool] = None, include_background: bool = True, normalize: bool = False, percentile: float = 0.95, reduction: Literal["none", "mean", "min", "max"] = "none", ): super().__init__( num_classes, convert_to_one_hot=convert_to_one_hot, ignore_index=ignore_index, include_background=include_background, reduction=reduction, ) num_classes = num_classes if self.include_background else num_classes - 1 self.normalize = normalize self.percentile = percentile self.predictions = torch.empty(slices_per_image) self.targets = torch.empty(slices_per_image) self.add_state("predictions", torch.empty(slices_per_image)) self.add_state("targets", torch.empty(slices_per_image)) self.hausdorff_distance = torch.zeros(num_classes) self.add_state("hausdorff_distance", torch.zeros(num_classes)) self.all_image_locations = None self.hausdorff_distance_cached = False self.slices_per_image = slices_per_image self.number_of_slices = 0 self.collected_slice_ids = set()
[docs] def reset(self) -> None: """ Resets the metric state. """ super().reset() self.all_image_locations = None self.hausdorff_distance_cached = False self.number_of_slices = 0 self.collected_slice_ids = set()
# pylint: disable=arguments-differ
[docs] def update( self, prediction: torch.Tensor, target: torch.Tensor, slice_ids: Union[int, List[int]] = None, ) -> None: r""" Updates metric using the provided prediction. Args: prediction (Tensor): The prediction tensor. target (Tensor): The target tensor. slice_ids (Union[int, List[int]]): Indices of the image slices represented by `prediction` and `target`. Must be provided if `target` is a single slice or a subset of slices taken from a 3d image. Shape: - | Prediction: :math:`(X, Y, ...)` where each value is in :math:`\{0, ..., C - 1\}` in case of label | encoding and :math:`(C, X, Y, ...)`, where each value is in :math:`\{0, 1\}` in case of one-hot or | multi-hot encoding (`C = number of classes`). - Target: Same shape and type as prediction. """ prediction, target = remove_padding( prediction, target, self.ignore_index, self.convert_to_one_hot ) prediction, target = self._preprocess_inputs(prediction, target) # the prediction and target tensor include an additional class dimension dimensionality = 2 if prediction.dim() == 3 else 3 added_slices = 1 if dimensionality == 2 else prediction.shape[1] if slice_ids is None: assert self.slices_per_image == 1 or ( dimensionality == 3 and self.slices_per_image == prediction.shape[1] ), ( "The `slice_ids` parameter must be provided when `prediction` and `target` are a single slice or a " "subset of slices taken from a 3d image." ) self.predictions = prediction self.targets = target else: if isinstance(slice_ids, int): slice_ids = [slice_ids] assert ( len(slice_ids) == added_slices ), "The length of `slice_ids` must be equal to the number of slices in `prediction` and `target`." assert ( len(self.collected_slice_ids.intersection(set(slice_ids))) == 0 ), "Tried to update the Hausdorff distance with a slice that was already added." # in this case we just collect all slices of the image since the Hausdorff distance needs to be computed # over all slices # note that this will be memory-intensive if this is done for multiple 3d images in parallel if self.number_of_slices == 0: # create tensors that have the shape of the original 3d images shape_of_full_image = ( prediction.shape[0], self.slices_per_image, prediction.shape[-2], prediction.shape[-1], ) self.predictions = torch.zeros(*shape_of_full_image, dtype=torch.int) self.targets = torch.zeros(*shape_of_full_image, dtype=torch.int) self.collected_slice_ids.update(slice_ids) for idx, slice_id in enumerate(slice_ids): self.predictions[:, slice_id] = ( prediction[:, idx] if dimensionality == 3 else prediction ) self.targets[:, slice_id] = ( target[:, idx] if dimensionality == 3 else target ) self.hausdorff_distance_cached = False self.number_of_slices += added_slices if self.number_of_slices == self.slices_per_image: self.compute() assert self.number_of_slices <= self.slices_per_image
[docs] def compute(self) -> torch.Tensor: """ Computes the Hausdorff distance over all slices that were registered using the `update` method as the maximum per slice-distance. Returns: Tensor: Hausdorff distance. Shape: - Output: If :attr:`reduction` is `"none"`, shape :math:`(C)`. Otherwise, scalar. """ if self.hausdorff_distance_cached: assert ( self.number_of_slices == 0 ), "A cached value for the Hausdorff distance is used even though there were slices added afterwards." return self.hausdorff_distance assert self.number_of_slices == self.slices_per_image, ( "The compute method was called on the Hausdorff distance module before all slices of the image were " "provided." ) hausdorff_dist = hausdorff_distance( self.predictions, self.targets, self.num_classes, convert_to_one_hot=False, ignore_index=None, include_background=self.include_background, normalize=self.normalize, percentile=self.percentile, ) self.hausdorff_distance = hausdorff_dist self.hausdorff_distance_cached = True # free memory del self.predictions del self.targets self.predictions = torch.empty(self.slices_per_image) self.targets = torch.empty(self.slices_per_image) self.number_of_slices = 0 return reduce_metric(hausdorff_dist, self.reduction)

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs