""" Module containing the data module for brats data """
import os
import random
from typing import Dict, List, Optional, Tuple
import numpy as np
from torch.utils.data import DataLoader, Dataset
from .data_module import ActiveLearningDataModule
from .doubly_shuffled_nifti_dataset import DoublyShuffledNIfTIDataset
[docs]class BraTSDataModule(ActiveLearningDataModule):
"""
Initializes the BraTS data module.
Args:
data_dir (string): Path of the directory that contains the data.
batch_size (int): Batch size.
num_workers (int): Number of workers for DataLoader.
active_learning_mode (bool, optional): Whether the datamodule should be configured for active learning or for
conventional model training (default = False).
batch_size_unlabeled_set (int, optional): Batch size for the unlabeled set. Defaults to :attr:`batch_size`.
cache_size (int, optional): Number of images to keep in memory between epochs to speed-up data loading
(default = 0).
initial_training_set_size (int, optional): Initial size of the training set if the active learning mode is
activated.
pin_memory (bool, optional): `pin_memory` parameter as defined by the PyTorch `DataLoader` class.
shuffle (boolean): Flag if the data should be shuffled.
dim (int): 2 or 3 to define if the datsets should return 2d slices of whole 3d images.
combine_foreground_classes (bool, optional): Flag if the non zero values of the annotations should be merged.
(default = False)
mask_filter_values (Tuple[int], optional): Values from the annotations which should be used. Defaults to using
all values.
random_state (int, optional): Random state for splitting the data into an initial training set and an unlabeled
set and for shuffling the data. Pass an int for reproducibility across runs.
only_return_true_labels (bool, optional): Whether only true labels or also pseudo-labels are to be returned.
Defaults to `False`.
**kwargs: Further, dataset specific parameters.
"""
# pylint: disable=unused-argument,no-self-use
[docs] @staticmethod
def discover_paths(
dir_path: str,
modality: str = "flair",
random_samples: Optional[int] = None,
) -> Tuple[List[str], List[str]]:
"""
Discover the ``.nii.gz`` file paths with a given modality
Args:
dir_path: directory to discover paths in
modality (string, optional): modality of scan
random_samples: the amount of random samples from the data sets
Returns:
list of files as tuple of image paths, annotation paths
"""
cases = sorted(os.listdir(dir_path))
cases = [
case
for case in cases
if not case.startswith(".") and os.path.isdir(os.path.join(dir_path, case))
]
if random_samples is not None and random_samples < len(cases):
random.seed(42)
cases = random.sample(cases, random_samples)
image_paths = [
os.path.join(dir_path, case, f"{os.path.basename(case)}_{modality}.nii.gz")
for case in cases
]
annotation_paths = [
os.path.join(dir_path, case, f"{os.path.basename(case)}_seg.nii.gz")
for case in cases
]
return image_paths, annotation_paths
# pylint: disable=too-many-arguments
def __init__(
self,
data_dir: str,
batch_size: int,
num_workers: int,
active_learning_mode: bool = False,
batch_size_unlabeled_set: Optional[int] = None,
cache_size: int = 0,
initial_training_set_size: int = 1,
pin_memory: bool = True,
shuffle: bool = True,
dim: int = 2,
combine_foreground_classes: bool = False,
mask_filter_values: Optional[Tuple[int]] = None,
random_state: int = None,
only_return_true_labels: bool = False,
**kwargs,
):
super().__init__(
data_dir,
batch_size,
num_workers,
active_learning_mode=active_learning_mode,
batch_size_unlabeled_set=batch_size_unlabeled_set,
initial_training_set_size=initial_training_set_size,
pin_memory=pin_memory,
shuffle=shuffle,
**kwargs,
)
self.data_folder = self.data_dir
self.dim = dim
self.cache_size = cache_size
self.combine_foreground_classes = combine_foreground_classes
self.mask_filter_values = mask_filter_values
self.random_state = random_state
self.only_return_true_labels = only_return_true_labels
if self.active_learning_mode:
(
self.initial_training_samples,
self.initial_unlabeled_samples,
) = DoublyShuffledNIfTIDataset.generate_active_learning_split(
BraTSDataModule.discover_paths(os.path.join(self.data_folder, "train"))[
0
],
dim,
initial_training_set_size,
random_state,
)
else:
self.initial_training_samples = None
self.initial_unlabeled_samples = None
[docs] def multi_label(self) -> bool:
"""
Returns:
bool: Whether the dataset is a multi-label or a single-label dataset.
"""
return False
[docs] def id_to_class_names(self) -> Dict[int, str]:
"""
Returns:
Dict[int, str]: A mapping of class indices to descriptive class names.
"""
if self.combine_foreground_classes:
return {0: "background", 1: "tumor"}
labels = {
0: "background",
1: "non-enhancing tumor core",
2: "peritumoral edema",
4: "GD-enhancing tumor",
}
if self.mask_filter_values is not None:
return {
class_id: class_name
for class_id, class_name in labels.items()
if class_id in self.mask_filter_values or class_id == 0
}
return labels
[docs] def label_items(
self, ids: List[str], pseudo_labels: Optional[Dict[str, np.array]] = None
) -> None:
"""Moves the given samples from the unlabeled dataset to the labeled dataset."""
# create list of files as tuple of image id and slice index
image_slice_ids = [case_id.split("-") for case_id in ids]
image_slice_ids = [
(split_id[0], int(split_id[1]) if len(split_id) > 1 else None)
for split_id in image_slice_ids
]
if self.training_set is not None and self.unlabeled_set is not None:
for case_id, (image_id, slice_id) in zip(ids, image_slice_ids):
if self.dim == 3 and slice_id is None:
slice_id = 0
if pseudo_labels is not None and case_id in pseudo_labels:
self.training_set.add_image(
image_id, slice_id, pseudo_labels[case_id]
)
else:
self.training_set.add_image(image_id, slice_id)
self.unlabeled_set.remove_image(image_id, slice_id)
def _create_training_set(self) -> Optional[Dataset]:
"""
Creates a training dataset.
"""
train_image_paths, train_annotation_paths = BraTSDataModule.discover_paths(
os.path.join(self.data_folder, "train")
)
return DoublyShuffledNIfTIDataset(
image_paths=train_image_paths,
annotation_paths=train_annotation_paths,
dim=self.dim,
cache_size=self.cache_size,
shuffle=self.shuffle,
combine_foreground_classes=self.combine_foreground_classes,
mask_filter_values=self.mask_filter_values,
slice_indices=self.initial_training_samples,
random_state=self.random_state,
only_return_true_labels=self.only_return_true_labels,
)
[docs] def train_dataloader(self) -> Optional[DataLoader]:
"""
Returns:
Pytorch dataloader or Keras sequence representing the training set.
"""
# disable shuffling in the dataloader since the dataset is a subclass of
# IterableDataset and implements it's own shuffling
if self.training_set:
return DataLoader(
self.training_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self._get_collate_fn(),
)
return None
def _create_validation_set(self) -> Optional[Dataset]:
"""Creates a validation dataset."""
val_image_paths, val_annotation_paths = BraTSDataModule.discover_paths(
os.path.join(self.data_folder, "val")
)
return DoublyShuffledNIfTIDataset(
image_paths=val_image_paths,
annotation_paths=val_annotation_paths,
dim=self.dim,
cache_size=self.cache_size,
combine_foreground_classes=self.combine_foreground_classes,
mask_filter_values=self.mask_filter_values,
case_id_prefix="val",
random_state=self.random_state,
only_return_true_labels=self.only_return_true_labels,
)
def _create_test_set(self) -> Optional[Dataset]:
# faked test set
# ToDo: implement test set
return self._create_validation_set()
def _create_unlabeled_set(self) -> Optional[Dataset]:
"""Creates an unlabeled dataset."""
if self.active_learning_mode:
train_image_paths, train_annotation_paths = BraTSDataModule.discover_paths(
os.path.join(self.data_folder, "train")
)
return DoublyShuffledNIfTIDataset(
image_paths=train_image_paths,
annotation_paths=train_annotation_paths,
dim=self.dim,
cache_size=self.cache_size,
is_unlabeled=True,
shuffle=self.shuffle,
combine_foreground_classes=self.combine_foreground_classes,
mask_filter_values=self.mask_filter_values,
slice_indices=self.initial_unlabeled_samples,
random_state=self.random_state,
only_return_true_labels=self.only_return_true_labels,
)
# unlabeled set is empty
unlabeled_set = self._create_training_set()
unlabeled_set.is_unlabeled = True
return unlabeled_set