Shortcuts

Source code for models.pytorch_u_net

"""U-Net architecture wrapped as PytorchModel"""

from typing import Iterable

import torch
import numpy as np

from .pytorch_model import PytorchModel
from .u_net import UNet

# pylint: disable-msg=too-many-ancestors, abstract-method
[docs]class PytorchUNet(PytorchModel): """ U-Net architecture wrapped as PytorchModel. Details about the architecture can be found in `this paper <https://arxiv.org/pdf/1505.04597.pdf>`_. Args: num_levels (int, optional): Number levels (encoder and decoder blocks) in the U-Net. Defaults to 4. dim (int, optional): The dimensionality of the U-Net. Defaults to 2. in_channels (int, optional): Number of input channels. Defaults to 1 out_channels (int): Number of output channels. Should be equal to the number of classes (for multi-label segmentation tasks excluding the background class). Defaults to 2. multi_label (bool, optional): Whether the model should produce single-label or multi-label outputs. If set to `False`, the model's predictions are computed using a Softmax activation layer. to If set to `True`, sigmoid activation layers are used to compute the model's predicitions. Defaults to False. p_dropout (float, optional): Probability of applying dropout to the outputs of the encoder layers. Defaults to 0. **kwargs: Further, dataset specific parameters. """ def __init__( self, num_levels: int = 4, dim: int = 2, in_channels=1, out_channels=2, multi_label=False, p_dropout: float = 0, **kwargs ): super().__init__(**kwargs) self.save_hyperparameters() self.num_levels = num_levels self.in_channels = in_channels self.dim = dim self.in_channels = in_channels self.out_channels = out_channels self.multi_label = multi_label self.p_dropout = p_dropout self.model = UNet( in_channels=self.in_channels, out_channels=out_channels, multi_label=multi_label, init_features=32, num_levels=self.num_levels, dim=self.dim, p_dropout=self.p_dropout, )
[docs] def input_dimensionality(self) -> int: return self.dim
# wrap model interface
[docs] def eval(self) -> None: """ Sets model to evaluation mode. """ return self.model.eval()
[docs] def train(self, mode: bool = True): """ Sets model to training mode. """ # pylint: disable-msg=unused-argument return self.model.train(mode=mode)
[docs] def parameters(self, recurse: bool = True) -> Iterable: """ Returns: Iterable: Model parameters. """ # pylint: disable-msg=unused-argument return self.model.parameters(recurse=recurse)
[docs] def forward(self, x: torch.Tensor): """ Args: x (Tensor): Batch of input images. Returns: Tensor: Segmentation masks. """ # pylint: disable-msg=arguments-differ return self.model.forward(x)
[docs] def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: """ Trains the model on a given batch of input images. Args: batch (Tensor): Batch of training images. batch_idx: Index of the training batch. Returns: Loss on the training batch. """ x, y, is_pseudo_label, case_ids = batch weight_pseudo_labels = self.loss_weight_pseudo_labels if weight_pseudo_labels is not None: weight = torch.ones(len(is_pseudo_label), device=self.device) weight[is_pseudo_label] = weight_pseudo_labels else: weight = None probabilities = self(x) loss = self.loss_module(probabilities, y, weight=weight) for train_metric in self.get_train_metrics(): train_metric.update(probabilities.detach(), y, case_ids) self.logger.log_metrics( { "train/loss": loss.detach(), "trainer/iteration": self.iteration, "trainer/epoch": self.current_epoch, } ) return loss
[docs] def validation_step(self, batch, batch_idx) -> float: """ Validates the model on a given batch of input images. Args: batch (Tensor): Batch of validation images. batch_idx: Index of the validation batch. Returns: Loss on the validation batch. """ x, y, _, case_ids = batch probabilities = self(x) loss = self.loss_module(probabilities, y) if self.stage == "fit": # log to trainer for model selection self.log( "val/loss", loss.detach(), logger=False, on_epoch=True, on_step=False, ) for val_metric in self.get_val_metrics(): val_metric.update(probabilities.detach(), y, case_ids) return loss
[docs] def predict_step( self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0 ) -> np.ndarray: """ Uses the model to predict a given batch of input images. Args: batch (Tensor): Batch of prediction images. batch_idx: Index of the prediction batch. dataloader_idx: Index of the dataloader. """ return self.predict(batch)
[docs] def test_step( self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0 ) -> None: """ Tests the model on a given batch of input images. Args: batch (Tensor): Batch of prediction images. batch_idx: Index of the prediction batch. dataloader_idx: Index of the dataloader. """ x, y, _, case_ids = batch probabilities = self(x) loss = self.loss_module(probabilities, y) self.logger.log_metrics( { "test/loss": loss.detach(), "trainer/iteration": self.iteration, "trainer/epoch": self.current_epoch, } ) for test_metric in self.get_test_metrics(): test_metric.update(probabilities.detach(), y, case_ids)
[docs] def reset_parameters(self) -> None: """ This method is called when resetting the weights is activated for the active learing loop """ self.model = UNet( in_channels=self.in_channels, out_channels=self.out_channels, multi_label=self.multi_label, init_features=32, num_levels=self.num_levels, dim=self.dim, p_dropout=self.p_dropout, )

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs