Source code for datasets.data_module
""" Module containing abstract classes for the data modules"""
from abc import ABC, abstractmethod
import warnings
from typing import Any, Callable, Dict, List, Optional
from pytorch_lightning.core.datamodule import LightningDataModule
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings(
"ignore",
".*DataModule.setup has already been called, so it will not be called again.*",
)
[docs]class ActiveLearningDataModule(LightningDataModule, ABC):
"""
Abstract base class to structure the dataset creation for active learning
Args:
data_dir (str): Path of the directory that contains the data.
batch_size (int): Batch size.
num_workers (int): Number of workers for DataLoader.
batch_size_unlabeled_set (int, optional): Batch size for the unlabeled set. Defaults to `batch_size`.
active_learning_mode (bool, optional): Whether the datamodule should be configured for active learning or for
conventional model training (default = False).
initial_training_set_size (int, optional): Initial size of the training set if the active learning mode is
activated.
pin_memory (bool, optional): `pin_memory` parameter as defined by the PyTorch `DataLoader` class.
shuffle (bool, optional): Flag if the data should be shuffled.
**kwargs: Further, dataset specific parameters.
"""
# pylint: disable=too-many-instance-attributes, too-many-arguments
def __init__(
self,
data_dir: str,
batch_size: int,
num_workers: int,
active_learning_mode: bool = False,
batch_size_unlabeled_set: Optional[int] = None,
initial_training_set_size: int = 1,
pin_memory: bool = True,
shuffle: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.data_dir = data_dir
self.batch_size = batch_size
self.batch_size_unlabeled_set = (
batch_size_unlabeled_set
if batch_size_unlabeled_set is not None
else batch_size
)
self.num_workers = num_workers
self.active_learning_mode = active_learning_mode
self.initial_training_set_size = initial_training_set_size
self.pin_memory = pin_memory
self.shuffle = shuffle
self.training_set = None
self.validation_set = None
self.test_set = None
self.unlabeled_set = None
[docs] def setup(self, stage: Optional[str] = None) -> None:
"""
Creates the datasets managed by this data module.
Args:
stage: Current training stage.
"""
self.training_set = self._create_training_set()
self.validation_set = self._create_validation_set()
self.test_set = self._create_test_set()
self.unlabeled_set = self._create_unlabeled_set()
[docs] @staticmethod
def data_channels() -> int:
"""
Can be overwritten by subclasses if the data has multiple channels.
Returns:
The amount of data channels. Defaults to 1.
"""
return 1
@staticmethod
def _get_collate_fn() -> Optional[Callable[[List[Any]], Any]]:
"""
Can be overwritten by subclasses to pass a custom collate function to the dataloaders.
Returns:
Callable[[List[torch.Tensor]], Any] that combines batches. Defaults to None.
"""
return None
[docs] @abstractmethod
def multi_label(self) -> bool:
"""
Returns:
bool: Whether the dataset is a multi-label or a single-label dataset.
"""
[docs] @abstractmethod
def id_to_class_names(self) -> Dict[int, str]:
"""
Returns:
Dict[int, str]: A mapping of class indices to descriptive class names.
"""
@abstractmethod
def _create_training_set(self) -> Optional[Dataset]:
"""
Returns:
Pytorch data_module or Keras sequence representing the training set.
"""
@abstractmethod
def _create_validation_set(self) -> Optional[Dataset]:
"""
Returns:
Pytorch data_module or Keras sequence representing the validation set.
"""
@abstractmethod
def _create_test_set(self) -> Optional[Dataset]:
"""
Returns:
Pytorch data_module or Keras sequence representing the test set.
"""
@abstractmethod
def _create_unlabeled_set(self) -> Optional[Dataset]:
"""
Returns:
Pytorch data_module or Keras sequence representing the unlabeled set.
"""
[docs] @abstractmethod
def label_items(
self, ids: List[str], pseudo_labels: Optional[Dict[str, Any]] = None
) -> None:
"""
Moves data items from the unlabeled set to one of the labeled sets (training, validation or test set).
Args:
ids (List[str]): IDs of the items to be labeled.
pseudo_labels (Dict[str, Any], optional): Optional pseudo labels for (some of the) the selected data items.
Returns:
None.
"""
[docs] def train_dataloader(self) -> Optional[DataLoader]:
"""
Returns:
Pytorch dataloader or Keras sequence representing the training set.
"""
if self.training_set:
return DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self._get_collate_fn(),
)
return None
[docs] def val_dataloader(self) -> Optional[DataLoader]:
"""
Returns:
Pytorch dataloader or Keras sequence representing the validation set.
"""
if self.validation_set:
return DataLoader(
self.validation_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self._get_collate_fn(),
)
return None
[docs] def test_dataloader(self) -> Optional[DataLoader]:
"""
Returns:
Pytorch dataloader or Keras sequence representing the test set.
"""
if self.test_set:
return DataLoader(
self.test_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self._get_collate_fn(),
)
return None
[docs] def unlabeled_dataloader(self) -> Optional[DataLoader]:
"""
Returns:
Pytorch dataloader or Keras sequence representing the unlabeled set.
"""
if self.unlabeled_set:
return DataLoader(
self.unlabeled_set,
batch_size=self.batch_size_unlabeled_set,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self._get_collate_fn(),
)
return None
[docs] def training_set_size(self) -> int:
"""
Returns:
Size of training set.
"""
if self.training_set:
return self.training_set.size()
return 0
[docs] def training_set_num_pseudo_labels(self) -> int:
"""
Returns:
Number of pseudo-labels in training set.
"""
if self.training_set:
return self.training_set.num_pseudo_labels()
return 0
[docs] def validation_set_size(self) -> int:
"""
Returns:
Size of validation set.
"""
if self.validation_set:
return self.validation_set.size()
return 0
[docs] def test_set_size(self) -> int:
"""
Returns:
Size of test set.
"""
if self.test_set:
return self.test_set.size()
return 0
[docs] def unlabeled_set_size(self) -> int:
"""
Returns:
Number of unlabeled items.
"""
if self.unlabeled_set:
return self.unlabeled_set.size()
return 0
[docs] def num_classes(self) -> int:
"""
Returns:
Number of classes.
"""
return len(self.id_to_class_names())