""" Module to load and batch nifti datasets """
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from multiprocessing import Manager
from sklearn.model_selection import train_test_split
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import IterableDataset
from datasets.dataset_hooks import DatasetHooks
# pylint: disable=too-many-instance-attributes,abstract-method
[docs]class DoublyShuffledNIfTIDataset(IterableDataset, DatasetHooks):
"""
This dataset can be used with NIfTI images. It is iterable and can return both 2D and 3D images.
Args:
image_paths (List[str]): List with the paths to the images. Has to contain paths of all images which can ever
become part of the dataset.
annotation_paths (List[str]): List with the paths to the annotations. Has to contain paths of all images which
can ever become part of the dataset.
cache_size (int, optional): Number of images to keep in memory to speed-up data loading in subsequent epochs.
Defaults to zero.
combine_foreground_classes (bool, optional): Flag if the non-zero values of the annotations should be merged.
Defaults to False.
mask_filter_values (Tuple[int], optional): Values from the annotations which should be used. Defaults to using
all values.
shuffle (bool, optional): Whether the data should be shuffled.
transform (Callable[[Any], Tensor], optional): Function to transform the images.
target_transform (Callable[[Any], Tensor], optional): Function to transform the annotations.
dim (int, optional): 2 or 3 to define if the dataset should return 2d slices of whole 3d images.
Defaults to 2.
slice_indices (List[np.array], optional): Array of indices per image which should be part of the dataset.
Uses all slices if None. Defaults to None.
random_state (int, optional): Controls the data shuffling. Pass an int for reproducible output across multiple
runs.
only_return_true_labels (bool, optional): Whether only true labels or also pseudo-labels are to be returned.
Defaults to `False`.
"""
[docs] @staticmethod
def normalize(img: np.ndarray) -> np.ndarray:
"""
Normalizes an image by
#. Dividing by the maximum value
#. Subtracting the mean, zeros will be ignored while calculating the mean
#. Dividing by the negative minimum value
Args:
img: The input image that should be normalized.
Returns:
Normalized image with background values normalized to -1
"""
tmp = img
tmp /= np.max(tmp)
# ignore zero values for mean calculation because background dominates
tmp -= np.mean(tmp[tmp > 0])
tmp /= -np.min(tmp)
return tmp
@staticmethod
def __read_image(filepath: str) -> Any:
"""
Reads image or annotation.
Args:
filepath (str): Path of the image file.
Returns:
The image. See https://nipy.org/nibabel/reference/nibabel.spatialimages.html#module-nibabel.spatialimages
"""
return nib.load(filepath)
@staticmethod
def __read_slice_count(filepath: str, dim: int = 2) -> int:
"""
Reads image or annotation.
Args:
filepath (str): Path of the image file.
dim (int, optional): The dimensionality of the dataset. Defaults to 2.
Returns:
The slice count of the image at the filepath or 1 if dim is not 2.
"""
return (
DoublyShuffledNIfTIDataset.__read_image(filepath).shape[2]
if dim == 2
else 1
)
[docs] @staticmethod
def generate_active_learning_split(
filepaths: List[str],
dim: int,
initial_training_set_size: int,
random_state: Optional[int] = None,
) -> Tuple[List[np.array]]:
"""
Generates a split between initial training set and initially unlabeled set for active learning.
Args:
filepaths (List[str]): The file paths to the Nifti files.
dim (int): The dimensionality of the dataset. (2 or 3.)
initial_training_set_size (int): The number of samples in the initial training set.
random_state (int, optional): The random state used to generate the split. Pass an int for reproducibility
across runs.
Returns:
A tuple of two lists of np.arrays. The lists contain one array per filepath which contains the
slice indices of the slices which should be part of the training and unlabeled sets respectively.
The lists can be passed as `slice_indices` for initialization of a DoublyShuffledNIfTIDataset.
"""
if dim == 3:
image_indices = range(len(filepaths))
(initial_training_samples, initial_unlabeled_samples,) = train_test_split(
image_indices,
train_size=initial_training_set_size,
random_state=random_state,
)
return (
[
np.array([0] if image_index in initial_training_samples else [])
for image_index in range(len(filepaths))
],
[
np.array([0] if image_index in initial_unlabeled_samples else [])
for image_index in range(len(filepaths))
],
)
all_samples = [
[image_index, slice_index]
for image_index, filepath in enumerate(filepaths)
for slice_index in range(
DoublyShuffledNIfTIDataset.__read_slice_count(filepath, dim=dim)
)
]
(initial_training_samples, initial_unlabeled_samples) = train_test_split(
all_samples,
train_size=initial_training_set_size,
random_state=random_state,
)
return (
[
np.array(
[
slice_index
for image_index, slice_index in initial_training_samples
if image_index == image_id
]
)
for image_id in range(len(filepaths))
],
[
np.array(
[
slice_index
for image_index, slice_index in initial_unlabeled_samples
if image_index == image_id
]
)
for image_id in range(len(filepaths))
],
)
@staticmethod
def __align_axes(img: np.ndarray) -> np.ndarray:
"""
Aligns the axes to (slice, x, y) or (channel, slice, x, y), depending on if there is a channel dimension.
Args:
img (np.ndarray): The image
Returns:
The images with realigned axes.
"""
img = np.moveaxis(img, 2, 0) # slice dimension to front
if len(img.shape) == 4:
img = np.moveaxis(img, 3, 0) # channel dimension to front
return img
@staticmethod
def __read_image_as_array(
filepath: str,
norm: bool,
join_non_zero: bool = False,
filter_values: Optional[Tuple[int]] = None,
) -> np.ndarray:
"""
Reads image or annotation as numpy array.
Args:
filepath: Path of the image file.
norm: Whether the image should be normalized.
join_non_zero: Whether the non-zero values of the image should be joined. Will set all non-zero values to 1.
filter_values: Values to be filtered from the images. All other values will be set to zero.
Can be used together with join_non_zero. Filtering will be applied before joining.
Returns:
The array representation of an image.
"""
img = DoublyShuffledNIfTIDataset.__read_image(filepath).get_fdata()
if filter_values is not None:
map_to_filtered_value = np.vectorize(
lambda value: value if value in filter_values else 0
)
img = map_to_filtered_value(img)
if join_non_zero:
img = np.clip(img, 0, 1)
if norm:
img = DoublyShuffledNIfTIDataset.normalize(img)
return DoublyShuffledNIfTIDataset.__align_axes(img)
@staticmethod
def __ensure_channel_dim(img: torch.Tensor, dim: int) -> torch.Tensor:
return img if len(img.shape) == dim + 1 else torch.unsqueeze(img, 0)
def __arange_image_slice_indices(
self,
filepaths: List[str],
dim: int = 2,
shuffle: bool = False,
random_state: Optional[int] = None,
slice_indices: Optional[List[np.array]] = None,
) -> Dict[int, Dict[int, Optional[np.array]]]:
"""
Reads the slice indices for the images at the provided slice paths and pairs them with their image index.
Implements efficient shuffling for 2D image datasets like the DoublyShuffledNIfTIDataset whose elements
represent the slices of multiple 3D images. To allow for efficient image pre-fetching, first the order of all 3D
images is shuffled and then the order of slices within each 3D image is shuffled. This way the 3D images can
still be loaded as a whole.
Args:
filepaths (List[str]): The paths of the images.
dim (int, optional): The dimensionality of the dataset. Defaults to 2.
shuffle (boolean, optional): Flag indicating whether to shuffle the slices. Defaults to False.
random_state (int, optional): Random seed for shuffling.
slice_indices (List[np.array], optional): Array of indices per image which should be part of the dataset.
Uses all slices if None. Defaults to None.
Returns:
A dictionary of per-image dictionaries which contain slice indices as keys. If a slice index is not part of
the per-image dictionary, it is not part of the dataset. If its value is None, it does not have a pseudo
label. If its value is a np.array, this is the pseudo label for that slice.
"""
if slice_indices is None:
slice_indices = [
np.arange(
DoublyShuffledNIfTIDataset.__read_slice_count(filepath, dim=dim)
)
for filepath in filepaths
]
if shuffle:
rng = np.random.default_rng(random_state)
# Shuffle the slices within the images
for slices in slice_indices:
rng.shuffle(slices)
# Shuffle the images
enumerated_slice_indices = list(enumerate(slice_indices))
rng.shuffle(enumerated_slice_indices)
else:
enumerated_slice_indices = enumerate(slice_indices)
# Pair up the slices indices with their image index and concatenate for all images
# (e.g. [5,1,9,0,...] for image index 3 becomes [(3,5),(3,1),(3,9),(3,0),...])
image_slice_indices = self.manager.dict()
for image_index, slices in enumerated_slice_indices:
if len(slices) > 0:
image_slice_indices[image_index] = self.manager.dict()
for slice_index in slices:
image_slice_indices[image_index][slice_index] = None
# Concatenate the [image_index, slice_index] pairs for all images
return image_slice_indices
# pylint: disable=too-many-arguments
def __init__(
self,
image_paths: List[str],
annotation_paths: List[str],
cache_size: int = 0,
combine_foreground_classes: bool = False,
mask_filter_values: Optional[Tuple[int]] = None,
is_unlabeled: bool = False,
shuffle: bool = False,
transform: Optional[Callable[[Any], torch.Tensor]] = None,
target_transform: Optional[Callable[[Any], torch.Tensor]] = None,
dim: int = 2,
slice_indices: Optional[List[np.array]] = None,
case_id_prefix: str = "train",
random_state: Optional[int] = None,
only_return_true_labels: bool = False,
):
self.manager = Manager()
self.image_paths = self.manager.list(image_paths)
self.annotation_paths = self.manager.list(annotation_paths)
self.combine_foreground_classes = combine_foreground_classes
self.mask_filter_values = mask_filter_values
self.only_return_true_labels = only_return_true_labels
assert len(image_paths) == len(annotation_paths)
self.is_unlabeled = is_unlabeled
self._current_image = None
self._current_mask = None
self._currently_loaded_image_index = None
self.cache_size = cache_size
# since the PyTorch dataloader uses multiple processes for data loading (if num_workers > 0),
# a shared dict is used to share the cache between all processes have to use
# see https://github.com/ptrblck/pytorch_misc/blob/master/shared_dict.py and
# https://discuss.pytorch.org/t/reuse-of-dataloader-worker-process-and-caching-in-dataloader/30620/14
# for more information
self.image_cache = self.manager.dict()
self.mask_cache = self.manager.dict()
self.shuffle = shuffle
self.transform = transform
self.target_transform = target_transform
self.dim = dim
self.case_id_prefix = case_id_prefix
self.image_slice_indices = self.__arange_image_slice_indices(
filepaths=self.image_paths,
dim=self.dim,
shuffle=self.shuffle,
random_state=random_state,
slice_indices=slice_indices,
)
self.num_workers = 1
self.current_image_key_index = 0
self.current_slice_key_index = 0
def __get_case_id(self, image_index: int):
return f"{self.case_id_prefix}_{image_index}"
def __get_image_index(self, case_id: str):
return int(case_id.replace(f"{self.case_id_prefix}_", ""))
def __iter__(self):
"""
Returns:
Iterator: Iterator that yields the whole dataset if a single process is used for data loading
or a subset of the dataset if the dataloading is split across multiple worker processes.
"""
worker_info = torch.utils.data.get_worker_info()
# check whether data loading is split across multiple workers
if worker_info is not None:
self.num_workers = worker_info.num_workers
self.current_image_key_index = worker_info.id
self.current_slice_key_index = 0
else:
self.num_workers = 1
self.current_image_key_index = 0
self.current_slice_key_index = 0
return self
[docs] def read_mask_for_image(self, image_index: int) -> np.array:
"""
Reads the mask for the image from file. Uses correct mask specific parameters.
Args:
image_index (int): Index of the image to load.
"""
return self.__read_image_as_array(
self.annotation_paths[image_index],
norm=False,
join_non_zero=self.combine_foreground_classes,
filter_values=self.mask_filter_values,
)
def __load_image_and_mask(self, image_index: int) -> None:
"""
Loads image with the given index either from cache or from disk.
Args:
image_index (int): Index of the image to load.
"""
self._currently_loaded_image_index = image_index
# check if image and mask are in cache
if image_index in self.image_cache and image_index in self.mask_cache:
self._current_image = self.image_cache[image_index]
self._current_mask = self.mask_cache[image_index]
# read image and mask from disk otherwise
else:
self._current_image = self.__read_image_as_array(
self.image_paths[image_index], norm=True
)
self._current_mask = self.read_mask_for_image(image_index)
# cache image and mask if there is still space in cache
if len(self.image_cache.keys()) < self.cache_size:
self.image_cache[image_index] = self._current_image
self.mask_cache[image_index] = self._current_mask
# pylint: disable=too-many-branches
def __next__(
self,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, str], Tuple[torch.Tensor, torch.Tensor]
]:
if self.current_image_key_index >= len(self.image_slice_indices):
raise StopIteration
image_index = list(self.image_slice_indices.keys())[
self.current_image_key_index
]
if image_index != self._currently_loaded_image_index:
self.__load_image_and_mask(image_index)
case_id = self.__get_case_id(image_index)
if self.dim == 2:
slice_index = list(self.image_slice_indices[image_index].keys())[
self.current_slice_key_index
]
case_id = f"{case_id}-{slice_index}"
if len(self._current_image.shape) == 4:
x = torch.from_numpy(self._current_image[:, slice_index, :, :])
else:
x = torch.from_numpy(self._current_image[slice_index, :, :])
pseudo_label = self.image_slice_indices[image_index][slice_index]
is_pseudo_label = pseudo_label is not None
y = (
torch.from_numpy(pseudo_label).int()
if is_pseudo_label
else torch.from_numpy(self._current_mask[slice_index, :, :]).int()
)
self.current_slice_key_index += 1
if self.current_slice_key_index >= len(
self.image_slice_indices[image_index]
):
self.current_image_key_index += self.num_workers
self.current_slice_key_index = 0
if is_pseudo_label and self.only_return_true_labels:
return self.__next__()
else:
x = torch.from_numpy(self._current_image)
y = torch.from_numpy(self._current_mask).int()
for slice_id in range(len(y)):
if slice_id not in self.image_slice_indices[image_index]:
y[slice_id, :, :] = -1
elif self.image_slice_indices[image_index][slice_id] is not None:
y[slice_id, :, :] = torch.from_numpy(
self.image_slice_indices[image_index][slice_id]
).int()
self.current_image_key_index += self.num_workers
if self.transform:
x = self.transform(x)
if self.target_transform:
y = self.target_transform(y)
x = DoublyShuffledNIfTIDataset.__ensure_channel_dim(x, self.dim)
if self.is_unlabeled:
return (x, case_id)
return (x, y, is_pseudo_label, case_id)
[docs] def add_image(
self,
image_id: str,
slice_index: int = 0,
pseudo_label: Optional[np.array] = None,
) -> None:
"""
Adds an image to this dataset.
Args:
image_id (str): The id of the image.
slice_index (int): Index of the slice to be added.
pseudo_label (np.array, optional): An optional pseudo label for the slice. If no pseudo label is provided,
the actual label from the corresponding file is used.
"""
image_index = self.__get_image_index(image_id)
if (
image_index in self.image_slice_indices
and slice_index in self.image_slice_indices[image_index]
and self.image_slice_indices[image_index][slice_index] is None
):
if pseudo_label is not None:
# If a pseudo label is added even though the real label already exists it should be ignored
return
raise ValueError("Slice of image already belongs to this dataset.")
if image_index not in self.image_slice_indices:
self.image_slice_indices[image_index] = self.manager.dict()
self.image_slice_indices[image_index][slice_index] = pseudo_label
[docs] def remove_image(self, image_id: str, slice_index: int = 0) -> None:
"""
Removes an image from this dataset.
Args:
image_id (str): The id of the image.
slice_index (int): Index of the slice to be removed.
"""
image_index = self.__get_image_index(image_id)
if (
image_index in self.image_slice_indices
and slice_index in self.image_slice_indices[image_index]
):
del self.image_slice_indices[image_index][slice_index]
if len(self.image_slice_indices[image_index]) == 0:
self.image_slice_indices.pop(image_index)
[docs] def get_images_by_id(
self,
case_ids: List[str],
) -> List[Tuple[np.ndarray, str]]:
"""
Retrieves the last n images and corresponding case ids from the images that were last added to the dataset.
Args:
case_ids (List[str]): List with case_ids to get.
Returns:
A list of all the images with provided case ids.
"""
# create list of files as tuple of image id and slice index
image_slice_ids = [case_id.split("-") for case_id in case_ids]
image_slice_ids = [
(split_id[0], int(split_id[1]) if len(split_id) > 1 else None)
for split_id in image_slice_ids
]
all_images = []
for case_id, (image_id, slice_index) in zip(case_ids, image_slice_ids):
image_index = self.__get_image_index(image_id)
# check if image and mask are in cache
if image_index in self.image_cache:
current_image = self.image_cache[image_index]
# read image and mask from disk otherwise
else:
current_image = self.__read_image_as_array(
self.image_paths[image_index], norm=True
)
all_images.append((current_image[slice_index, :, :], case_id))
return all_images
[docs] def get_items_for_logging(
self, case_ids: List[str]
) -> List[Tuple[str, str, Optional[int], str]]:
"""
Creates a list of files as tuple of image id and slice index.
Args:
case_ids (List[str]): List with case_ids to get.
"""
image_slice_ids = [case_id.split("-") for case_id in case_ids]
image_slice_ids = [
(split_id[0], int(split_id[1]) if len(split_id) > 1 else None)
for split_id in image_slice_ids
]
items = []
for case_id, (image_id, slice_index) in zip(case_ids, image_slice_ids):
image_index = self.__get_image_index(image_id)
image_path = self.image_paths[image_index]
items.append((case_id, image_path, image_id, slice_index))
return items
def __image_indices(self) -> Iterable[str]:
return self.image_slice_indices.keys()
[docs] def image_ids(self) -> Iterable[str]:
return [self.__get_case_id(image_idx) for image_idx in self.__image_indices()]
[docs] def slices_per_image(self, **kwargs) -> List[int]:
return [
DoublyShuffledNIfTIDataset.__read_slice_count(self.image_paths[image_idx])
for image_idx in self.__image_indices()
]
[docs] def size(self) -> int:
"""
Returns:
int: Size of the dataset.
"""
if self.dim == 2:
size = 0
for inner_dict in self.image_slice_indices.values():
if self.only_return_true_labels:
for value in inner_dict.values():
if value is None:
size += 1
else:
size += len(inner_dict)
return size
return len(self.image_ids())
[docs] def num_pseudo_labels(self) -> int:
"""
Returns:
int: Number of items with pseudo-labels in the dataset.
"""
if self.only_return_true_labels:
return 0
if self.dim == 2:
num_pseudo_labels = 0
for inner_dict in self.image_slice_indices.values():
for value in inner_dict.values():
if value is not None:
num_pseudo_labels += 1
return num_pseudo_labels
return 0