""" Base classes to implement models with pytorch """
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
import numpy
import torch
import torchmetrics
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import wandb
import functional
from metric_tracking import CombinedPerEpochMetric
from .loss_weight_scheduler import LossWeightScheduler
Optimizer = Union[Adam, SGD]
LRScheduler = Union[ReduceLROnPlateau, CosineAnnealingLR]
LRSchedulerDict = Dict[str, Union[str, LRScheduler]]
# pylint: disable=too-many-instance-attributes
[docs]class PytorchModel(LightningModule, ABC):
"""
Base class to implement Pytorch models.
Args:
learning_rate (float, optional): The step size at each iteration while moving towards a minimum of the loss.
Defaults to `0.0001`.
lr_scheduler (string, optional): Algorithm used for dynamically updating the learning rate during training:
``"reduceLROnPlateau"`` | ``"cosineAnnealingLR"``. Defaults to using no scheduler.
loss_config (Dict[str, Any], optional): Dictionary with loss parameters.
optimizer (string, optional): Algorithm used to calculate the loss and update the weights: ``"adam"`` |
``"sgd"``. Defaults to ``"adam"``.
train_metrics (Iterable[str], optional): A list with the names of the metrics that should be computed and logged
in each training and validation epoch of the training loop. Available options: ``"dice_score"`` |
``"sensitivity"`` | ``"specificity"`` | ``"hausdorff95"``. Defaults to `["dice_score"]`.
train_metric_confidence_levels (Iterable[float], optional): A list of confidence levels for which the metrics
specified in the `train_metrics` parameter should be computed in the training loop (`trainer.fit()`). This
parameter is used only for multi-label classification tasks. Defaults to `[0.5]`.
test_metrics (Iterable[str], optional): A list with the names of the metrics that should be computed and logged
in the model validation or testing loop (`trainer.validate()`, `trainer.test()`). Available options:
``"dice_score"`` | ``"sensitivity"`` | ``"specificity"`` | ``"hausdorff95"`` Defaults to
`["dice_score", "sensitivity", "specificity", "hausdorff95"]`.
test_metric_confidence_levels (Iterable[float], optional): A list of confidence levels for which the metrics
specified in the `test_metrics` parameter should be computed in the validation or testing loop. This
parameter is used only for multi-label classification tasks. Defaults to `[0.5]`.
**kwargs: Further, dataset specific parameters.
"""
# pylint: disable=too-many-ancestors,arguments-differ,too-many-arguments
def __init__(
self,
learning_rate: float = 0.0001,
lr_scheduler: Optional[
Literal["reduceLROnPlateau", "cosineAnnealingLR"]
] = None,
loss_config: Dict[str, Any] = None,
optimizer: Literal["adam", "sgd"] = "adam",
train_metrics: Optional[Iterable[str]] = None,
train_metric_confidence_levels: Optional[Iterable[float]] = None,
test_metrics: Optional[Iterable[str]] = None,
test_metric_confidence_levels: Optional[Iterable[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.learning_rate = learning_rate
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.loss_config = (
loss_config if loss_config is not None else {"type": "cross_entropy"}
)
self.loss = self.loss_config["type"]
del self.loss_config["type"]
if self.loss_config.get("weight_pseudo_labels_start") is not None:
self.loss_weight_pseudo_labels_scheduler = LossWeightScheduler(
self.loss_config.get("weight_pseudo_labels_scheduler", "fixed"),
self.loss_config.get("weight_pseudo_labels_start"),
self.loss_config.get("weight_pseudo_labels_end", None),
0,
self.loss_config.get("weight_pseudo_labels_decay_steps", None),
)
else:
self.loss_weight_pseudo_labels_scheduler = None
self.loss_config.pop("weight_pseudo_labels_scheduler", None)
self.loss_config.pop("weight_pseudo_labels_start", None)
self.loss_config.pop("weight_pseudo_labels_end", None)
self.loss_config.pop("weight_pseudo_labels_decay_steps", None)
self.loss_module = None
self.train_metric_confidence_levels = train_metric_confidence_levels
self.test_metric_confidence_levels = test_metric_confidence_levels
if train_metrics is None:
train_metrics = ["dice_score"]
self.train_metric_names = train_metrics
if test_metrics is None:
test_metrics = ["dice_score", "sensitivity", "specificity", "hausdorff95"]
self.test_metric_names = test_metrics
self.train_metrics = torch.nn.ModuleList([])
self.val_metrics = torch.nn.ModuleList([])
self.test_metrics = torch.nn.ModuleList([])
self.stage = None
self.start_epoch = 0
self.iteration = 0
[docs] def setup(self, stage: Optional[str] = None) -> None:
"""
Setup hook as defined by PyTorch Lightning. Called at the beginning of fit (train + validate), validate, test,
or predict.
Args:
stage(string, optional): Either 'fit', 'validate', 'test', or 'predict'.
"""
self.stage = stage
self.configure_metrics()
if stage in ["fit", "validate", "test"]:
self.loss_module = self.configure_loss(
self.loss, self.trainer.datamodule.multi_label(), **self.loss_config
)
@property
def current_epoch(self) -> int:
return self.start_epoch + super().current_epoch
[docs] @abstractmethod
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float:
"""
Trains the model on a given batch of model inputs.
This method should match the requirements of the pytorch lightning framework. See the
`pytorch lighting documentation
<https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html>`_ for more details.
Args:
batch: A batch of model inputs.
batch_idx: Index of the current batch.
Returns:
Training loss.
"""
[docs] @abstractmethod
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> float:
"""
Validates the model on a given batch of model inputs.
This method should match the requirements of the pytorch lightning framework. See the
`pytorch lightning documentation
<https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html>`_ for more details.
Args:
batch: A batch of model inputs.
batch_idx: Index of the current batch.
Returns:
Validation loss.
"""
[docs] def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
) -> Any:
"""
Compute the model's predictions on a given batch of model inputs.
This method should match the requirements of the pytorch lightning framework. See the
`pytorch lightning documentation
<https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html>`_ for more details.
Args:
batch: A batch of model inputs.
batch_idx: Index of the current batch.
dataloader_idx: The index of the dataloader that produced this batch.
(only if multiple val dataloaders used)
"""
return self.predict(batch)
[docs] @abstractmethod
def test_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None
) -> None:
"""
Compute the model's predictions on a given batch of model inputs from the test set.
Args:
batch: The output of your :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_id: The index of the dataloader that produced this batch.
(only if multiple test dataloaders used).
"""
# pylint: disable=too-many-return-statements
[docs] def predict(self, batch: torch.Tensor) -> numpy.ndarray:
"""
Computes predictions for a given batch of model inputs.
Args:
batch: A batch of model inputs.
Returns:
Predictions for the given inputs.
"""
self.eval()
with torch.no_grad():
return self(batch)
[docs] def get_train_metrics(self) -> Iterable[torchmetrics.Metric]:
"""
Returns:
A list of metrics to be updated in each training step.
"""
return self.train_metrics
[docs] def get_val_metrics(self) -> Iterable[torchmetrics.Metric]:
"""
Returns:
A list of metrics to be updated in each validation step.
"""
return self.val_metrics
[docs] def get_test_metrics(self) -> Iterable[torchmetrics.Metric]:
"""
Returns:
A list of metrics to be updated in each testing step.
"""
return self.test_metrics
[docs] def training_epoch_end(
self,
outputs: Union[torch.Tensor, List[torch.Tensor], List[Dict[str, torch.Tensor]]],
) -> None:
"""
This method is called by the Pytorch Lightning framework at the end of each training epoch.
Args:
outputs: List of return values of all training steps of the current training epoch.
"""
train_metrics = {
"trainer/epoch": self.current_epoch,
"trainer/iteration": self.iteration,
"trainer/training_set_size": self.trainer.datamodule.training_set_size(),
"trainer/training_set_num_pseudo_labels": self.trainer.datamodule.training_set_num_pseudo_labels(),
"trainer/training_set_n_cases": len(
set(self.trainer.datamodule.training_set.image_ids())
),
"trainer/unlabeled_set_size": self.trainer.datamodule.unlabeled_set_size(),
"trainer/unlabeled_set_n_cases": len(
set(self.trainer.datamodule.unlabeled_set.image_ids())
),
}
# collect loss values
if isinstance(outputs, torch.Tensor):
losses = outputs
else:
losses = torch.Tensor(
[item["loss"] if isinstance(item, dict) else item for item in outputs]
)
train_metrics["train/mean_loss"] = losses.mean()
for train_metric in self.train_metrics:
train_metrics = {**train_metrics, **train_metric.compute()}
train_metric.reset()
self.logger.log_metrics(train_metrics)
[docs] def validation_epoch_end(
self,
outputs: Union[torch.Tensor, List[torch.Tensor], List[Dict[str, torch.Tensor]]],
):
"""
This method is called by the Pytorch Lightning framework at the end of each validation epoch.
Args:
outputs: List of return values of all validation steps of the current validation epoch.
"""
val_metrics = {
"trainer/epoch": self.current_epoch,
"trainer/iteration": self.iteration,
"trainer/training_set_size": self.trainer.datamodule.training_set_size(),
"trainer/training_set_num_pseudo_labels": self.trainer.datamodule.training_set_num_pseudo_labels(),
"trainer/training_set_n_cases": len(
set(self.trainer.datamodule.training_set.image_ids())
),
"trainer/unlabeled_set_size": self.trainer.datamodule.unlabeled_set_size(),
"trainer/unlabeled_set_n_cases": len(
set(self.trainer.datamodule.unlabeled_set.image_ids())
),
}
for val_metric in self.val_metrics:
val_metrics = {**val_metrics, **val_metric.compute()}
val_metric.reset()
# collect loss values
if isinstance(outputs, torch.Tensor):
losses = outputs
else:
losses = torch.Tensor(
[item["loss"] if isinstance(item, dict) else item for item in outputs]
)
val_metrics = {**val_metrics, "val/mean_loss": losses.mean()}
if self.stage == "fit":
# log to trainer to allow model selection
self.log_dict(val_metrics, logger=False)
if not self.trainer.sanity_checking:
# log to Weights and Biases
self.logger.experiment.log(val_metrics, commit=False)
else:
self.logger.log_metrics(val_metrics)
[docs] def test_epoch_end(self, outputs: Any) -> None:
"""
This method is called by the Pytorch Lightning framework at the end of each testing epoch.
Args:
outputs: List of return values of all validation steps of the current testing epoch.
"""
test_metrics = {
"trainer/epoch": self.current_epoch,
"trainer/iteration": self.iteration,
"train/training_set_size": self.trainer.datamodule.training_set_size(),
"trainer/training_set_num_pseudo_labels": self.trainer.datamodule.training_set_num_pseudo_labels(),
"train/training_set_n_cases": len(
set(self.trainer.datamodule.training_set.image_ids())
),
"train/unlabeled_set_size": self.trainer.datamodule.unlabeled_set_size(),
"train/unlabeled_set_n_cases": len(
set(self.trainer.datamodule.unlabeled_set.image_ids())
),
}
for test_metric in self.test_metrics:
test_metrics = {**test_metrics, **test_metric.compute()}
test_metrics.reset()
self.logger.log_metrics(test_metrics)
[docs] @abstractmethod
def reset_parameters(self) -> None:
"""
This method is called when resetting the weights is activated for the active learing loop
"""
[docs] def step_loss_weight_pseudo_labels_scheduler(self) -> None:
"""
Increases step of pseudo-label loss weight scheduler if :attr:`loss_weight_pseudo_labels_scheduler` is not
`None`.
"""
if self.loss_weight_pseudo_labels_scheduler is not None:
self.loss_weight_pseudo_labels_scheduler.step()
@property
def loss_weight_pseudo_labels(self) -> Union[float, None]:
"""
Returns:
Union[float, None]: Pseudo-label loss weight if :attr:`loss_weight_pseudo_labels_scheduler` is not `None`,
else `None`.
"""
if self.loss_weight_pseudo_labels_scheduler is not None:
return self.loss_weight_pseudo_labels_scheduler.current_weight()
return None