Shortcuts

Source code for models.u_net

"""U-Net architecture."""

from collections import OrderedDict
from typing import List

import torch
from torch import nn
import torch.nn.functional as F


[docs]class UNet(nn.Module): """ This U-Net implementation was originally taken from `this implementation <https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py>`_ and adapted to a flexible number of levels and for optional 3d mode. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. Should be equal to the number of classes (for multi-label segmentation tasks excluding the background class). multi_label (bool, optional): Whether the model should produce single-label or multi-label outputs. If set to `False`, the model's predictions are computed using a Softmax activation layer. to If set to `True`, sigmoid activation layers are used to compute the model's predicitions. Defaults to False. init_features (int, optional): Number of feature channels of the first U-Net block, in each down-sampling block, the number of feature channels is doubled. Defaults to 32. num_levels (int, optional): Number levels (encoder and decoder blocks) in the U-Net. Defaults to 4. dim (int, optional): The dimensionality of the U-Net. Defaults to 2. p_dropout (float, optional): Probability of applying dropout to the outputs of the encoder layers. Defaults to 0. """ # pylint: disable-msg=too-many-instance-attributes def __init__( self, in_channels: int, out_channels: int, multi_label: bool = False, init_features: int = 32, num_levels: int = 4, dim: int = 2, p_dropout: float = 0, ): super().__init__() self.num_levels = num_levels MaxPool = nn.MaxPool2d if dim == 2 else nn.MaxPool3d ConvTranspose = nn.ConvTranspose2d if dim == 2 else nn.ConvTranspose3d Conv = nn.Conv2d if dim == 2 else nn.Conv3d features = init_features self.encoders = nn.ModuleList( [ UNet._block( in_channels if level == 0 else features * (2 ** (level - 1)), features * (2**level), name=f"enc{level + 1}", dim=dim, p_dropout=p_dropout, ) for level in range(num_levels) ] ) self.pools = nn.ModuleList( [MaxPool(kernel_size=2, stride=2) for _ in range(num_levels)] ) self.bottleneck = UNet._block( features * (2 ** (num_levels - 1)), features * (2**num_levels), name="bottleneck", dim=dim, ) self.upconvs = nn.ModuleList( [ ConvTranspose( features * (2 ** (level + 1)), features * (2**level), kernel_size=2, stride=2, ) for level in range(num_levels) ] ) self.decoders = nn.ModuleList( [ UNet._block( features * (2 ** (level + 1)), features * (2**level), name=f"dec{level + 1}", dim=dim, ) for level in range(num_levels) ] ) self.conv = Conv(in_channels=features, out_channels=out_channels, kernel_size=1) self.prediction_layer = ( torch.nn.Sigmoid() if multi_label else torch.nn.Softmax(dim=-1 * (dim + 1)) )
[docs] def forward(self, x): """ Args: x (Tensor): Batch of input images. Returns: Tensor: Segmentation masks. """ x = x.float() # individually store encoding results for skip connections encodings: List[torch.Tensor] = [] for level in range(self.num_levels): encodings.append( self.encoders[level]( x if level == 0 else self.pools[level - 1](encodings[level - 1]) ) ) bottleneck = self.bottleneck(self.pools[-1](encodings[-1])) dec = bottleneck for level in reversed(range(self.num_levels)): dec = self.upconvs[level](dec) # for the relevant dimensions [0, 0] for even and [0, 1] for odd in reversed order flattened pad = [ p for dim in reversed(encodings[level].size()[2:]) for p in [0, dim % 2] ] dec = F.pad(dec, tuple(pad), "constant", 0) dec = torch.cat((dec, encodings[level]), dim=1) dec = self.decoders[level](dec) decoded = self.conv(dec) return self.prediction_layer(decoded)
@staticmethod def _block( in_channels: int, features: int, name: str, dim: int, p_dropout: float = 0 ): Conv = nn.Conv2d if dim == 2 else nn.Conv3d BatchNorm = nn.BatchNorm2d if dim == 2 else nn.BatchNorm3d return nn.Sequential( OrderedDict( [ ( name + "conv1", Conv( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm1", BatchNorm(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", Conv( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm2", BatchNorm(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), (name + "dropout", nn.Dropout(p=p_dropout)), ] ) )

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs