Source code for datasets.bcss_dataset
""" Module to load and batch the BCSS dataset """
from typing import List, Optional, Tuple, Union
from pathlib import Path
import math
from multiprocessing import Manager
from PIL import Image, ImageOps
import numpy as np
import torch
from torch.utils.data import IterableDataset
from .dataset_hooks import DatasetHooks
[docs]class BCSSDataset(IterableDataset, DatasetHooks):
"""
The BCSS dataset contains over 20,000 segmentation annotations of tissue region from breast cancer images from TCGA.
Detailed description can be found either `at the challenge website <https://bcsegmentation.grand-challenge.org>`_
or on `github <https://github.com/PathologyDataScience/BCSS>`_ .
Args:
image_paths (List[Path]): List with all images to load, can be obtained by
:py:meth:`datasets.bcss_data_module.BCSSDataModule.discover_paths` .
annotation_paths (List[Path]): List with all annotations to load, can be obtained by
:py:meth:`datasets.bcss_data_module.BCSSDataModule.discover_paths` .
target_label (int, optional): The label to use for learning. Following labels are in the annotations:
* outside_roi 0
* tumor 1
* stroma 2
* lymphocytic_infiltrate 3
* necrosis_or_debris 4
* glandular_secretions 5
* blood 6
* exclude 7
* metaplasia_NOS 8
* fat 9
* plasma_cells 10
* other_immune_infiltrate 11
* mucoid_material 12
* normal_acinus_or_duct 13
* lymphatics 14
* undetermined 15
* nerve 16
* skin_adnexa 17
* blood_vessel 18
* angioinvasion 19
* dcis 20
* other 21
is_unlabeled (bool, optional): Whether the dataset is used as "unlabeled" for the active learning loop.
shuffle (bool, optional): Whether the data should be shuffled.
channels (int, optional): Number of channels of the images. 3 means RGB, 2 means greyscale.
image_shape (tuple, optional): Shape of the image.
random_state (int, optional): Controls the data shuffling. Pass an int for reproducible output across
multiple runs.
"""
# pylint: disable=too-many-instance-attributes,abstract-method
[docs] @staticmethod
def normalize(img: np.ndarray) -> np.ndarray:
"""
Normalizes an image by:
#. Dividing by the mean value
#. Subtracting the std
Args:
img: The input image that should be normalized.
Returns:
Normalized image with background values normalized to -1
"""
return (img - np.mean(img)) / np.std(img)
@staticmethod
def __align_axis(img: np.ndarray) -> np.ndarray:
"""Align the axes of the image based on the dimension"""
if len(img.shape) == 2:
img = np.expand_dims(img, axis=0)
if img.shape[2] == 3:
img = np.moveaxis(img, 2, 0)
return img
[docs] @staticmethod
def get_case_id(filepath: Union[str, Path]) -> str:
"""Gets the case ID for a given filepath."""
return Path(filepath).name.split("_")[0]
[docs] @staticmethod
def get_institute_name(filepath: Union[str, Path]) -> str:
"""Gets the name of the institute which donated the image."""
return Path(filepath).name.split("-")[1]
def __init__(
self,
image_paths: List[Path],
annotation_paths: List[Path],
cache_size: int = 0,
target_label: int = 1,
is_unlabeled: bool = False,
shuffle: bool = True,
channels: int = 3,
image_shape: tuple = (300, 300),
random_state: Optional[int] = None,
) -> None:
super().__init__()
self.image_paths = image_paths
self.annotation_paths = annotation_paths
self.target_label = target_label
self.channels = channels
self.image_shape = tuple(image_shape)
self.cache_size = cache_size
manager = Manager()
# since the PyTorch dataloader uses multiple processes for data loading (if num_workers > 0),
# a shared dict is used to share the cache between all processes have to use
# see https://github.com/ptrblck/pytorch_misc/blob/master/shared_dict.py and
# https://discuss.pytorch.org/t/reuse-of-dataloader-worker-process-and-caching-in-dataloader/30620/14
# for more information
self.image_cache = manager.dict()
self.mask_cache = manager.dict()
self.num_images = len(self.image_paths)
self.num_masks = len(self.annotation_paths)
assert self.num_images == self.num_masks
self.is_unlabeled = is_unlabeled
self._current_image = None
self._current_mask = None
self._current_image_index = None
self.indices = list(np.arange(self.num_images))
if shuffle:
rng = np.random.default_rng(random_state)
rng.shuffle(self.indices)
self.start_index = 0
self.end_index = self.__len__()
self.current_index = 0
def __load_image_and_mask(self, index: int) -> None:
"""Loads images and annotations into _current_image and _current_mask variables."""
self._current_image_index = index
# check if image and mask are in cache
if index in self.image_cache and index in self.mask_cache:
self._current_image = self.image_cache[index]
self._current_mask = self.mask_cache[index]
# read image and mask from disk otherwise
else:
self._current_image = self.__load_image_as_array(
filepath=self.image_paths[self._current_image_index].as_posix(),
norm=True,
is_mask=False,
)
self._current_mask = self.__load_image_as_array(
filepath=self.annotation_paths[self._current_image_index].as_posix(),
norm=False,
is_mask=True,
)
# cache image and mask if there is still space in cache
if len(self.image_cache.keys()) < self.cache_size:
self.image_cache[index] = self._current_image
self.mask_cache[index] = self._current_mask
def __load_image_as_array(
self, filepath: str, norm: bool = True, is_mask: bool = False
) -> np.ndarray:
"""Loads one image in memory."""
img = Image.open(filepath).resize((self.image_shape[0], self.image_shape[1]))
if self.channels == 2:
img = ImageOps.grayscale(img)
img = np.asarray(img)
if norm:
img = BCSSDataset.normalize(img=img)
if is_mask:
img = self.__restrict_on_target_class(img=img)
img = self.__align_axis(img)
return img
def __restrict_on_target_class(self, img: np.ndarray) -> np.ndarray:
"""Only keeps set target class and sets rest of the image to background with value 0."""
img[np.where(img != self.target_label)] = 0
return img
def __iter__(self):
"""
Returns:
Iterator: Iterator that yields the whole dataset if a single process is used for data loading
or a subset of the dataset if the dataloading is split across multiple worker processes.
"""
worker_info = torch.utils.data.get_worker_info()
# check whether data loading is split across multiple workers
if worker_info is not None:
# code adapted from https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
per_worker = int(
math.ceil(
(self.end_index - self.start_index) / float(worker_info.num_workers)
)
)
worker_id = worker_info.id
self.start_index = self.start_index + worker_id * per_worker
self.current_index = self.start_index
self.end_index = min(self.start_index + per_worker, self.end_index)
return self
def __next__(
self,
) -> Union[Tuple[torch.Tensor, torch.Tensor, str], Tuple[torch.Tensor, str]]:
"""One iteration yields a tuple of image, annotation, case id"""
if self.current_index >= self.__len__():
raise StopIteration
index = self.indices[self.current_index]
case_id = self.get_case_id(filepath=self.image_paths[index].as_posix())
self.__load_image_and_mask(index=index)
x = torch.from_numpy(self._current_image)
y = torch.from_numpy(self._current_mask)
self.current_index += 1
if self.is_unlabeled:
return x, case_id
return x, y, False, case_id
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.indices)
[docs] def add_image(self, image_path: Path, annotation_path: Path) -> None:
"""
Adds an image to this dataset.
Args:
image_path: Path of the image to be added.
annotation_path: Path of the annotation of the image to be added.
Returns:
None. Raises ValueError if image already exists.
"""
if image_path not in self.image_paths:
self.image_paths.append(image_path)
if annotation_path not in self.annotation_paths:
self.annotation_paths.append(annotation_path)
image_index = self.image_paths.index(image_path)
if image_index not in self.indices:
# add new image index to existing ones
self.indices.append(image_index)
else:
raise ValueError("The image already belongs to this dataset.")
[docs] def remove_image(self, image_path: Path, annotation_path: Path) -> None:
"""
Removes an image from this dataset.
Args:
image_path: Path of the image to be removed.
annotation_path: Path of the annotation of the image to be removed.
Returns:
None. Raises ValueError if image already exists.
"""
if image_path in self.image_paths and annotation_path in self.annotation_paths:
self.image_paths.remove(image_path)
self.annotation_paths.remove(annotation_path)
self.num_images -= 1
self.indices = list(np.arange(self.num_images))
else:
raise ValueError("Image does not belong to this dataset.")
[docs] def slices_per_image(self, **kwargs) -> List[int]:
"""For each image returns the number of slices"""
return [1] * len(self.indices)
[docs] def image_ids(self) -> List[str]:
"""For each image returns the case ID's"""
return [self.get_case_id(filepath=path) for path in self.image_paths]
[docs] def size(self) -> int:
"""
Returns:
int: Size of the dataset.
"""
return self.__len__()
[docs] def num_pseudo_labels(self) -> int:
"""
Returns:
int: Number of items with pseudo-labels in the dataset.
"""
return 0