""" Module containing the active learning pipeline """
import math
import os
import shutil
from typing import Iterable, Optional, Union, Tuple, List
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import numpy as np
import wandb
from query_strategies import QueryStrategy
from datasets import ActiveLearningDataModule
from models import PytorchModel
from functional.interpretation import HeatMaps
[docs]class ActiveLearningPipeline:
"""
The pipeline or simulation environment to run active learning experiments.
Args:
data_module (ActiveLearningDataModule): A data module object providing data.
model (PytorchModel): A model object with architecture able to be fitted with the data.
strategy (QueryStrategy): An active learning strategy to query for new labels.
epochs (int): The number of epochs the model should be trained.
gpus (int): Number of GPUS to use for model training.
checkpoint_dir (str, optional): Directory where the model checkpoints are to be saved.
early_stopping (bool, optional): Enable/Disable Early stopping when model
is not learning anymore. Defaults to False.
logger: A logger object as defined by Pytorch Lightning.
lr_scheduler (string, optional): Algorithm used for dynamically updating the
learning rate during training. E.g. 'reduceLROnPlateau' or 'cosineAnnealingLR'
active_learning_mode (bool, optional): Enable/Disabled Active Learning Pipeline. Defaults to False.
initial_epochs (int, optional): Number of epochs the initial model should be trained. Defaults to `epochs`.
items_to_label (int, optional): Number of items that should be selected for labeling in the active learning run.
Defaults to 1.
iterations (int, optional): iteration times how often the active learning pipeline should be
executed. If None, the active learning pipeline is run until the whole dataset is labeled. Defaults to None.
reset_weights (bool, optional): Enable/Disable resetting of weights after every active learning run
epochs_increase_per_query (int, optional): Increase number of epochs for every query to compensate for
the increased training dataset size. Defaults to 0.
heatmaps_per_iteration (int, optional): Number of heatmaps that should be generated per iteration. Defaults to
0.
deterministic_mode (bool, optional): Whether only deterministic CUDA operations should be used. Defaults to
`True`.
save_model_every_epoch (bool, optional): Whether the model files of all epochs are to be saved or only the
model file of the best epoch. Defaults to `False`.
clear_wandb_cache (bool, optional): Whether the whole Weights and Biases cache should be deleted when the run
is finished. Should only be used when no other runs are running in parallel. Defaults to False.
**kwargs: Additional, strategy-specific parameters.
"""
# pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes,too-many-locals
# pylint: disable=protected-access
def __init__(
self,
data_module: ActiveLearningDataModule,
model: PytorchModel,
strategy: QueryStrategy,
epochs: int,
gpus: int,
checkpoint_dir: Optional[str] = None,
active_learning_mode: bool = False,
initial_epochs: Optional[int] = None,
items_to_label: int = 1,
iterations: Optional[int] = None,
reset_weights: bool = False,
epochs_increase_per_query: int = 0,
heatmaps_per_iteration: int = 0,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
early_stopping: bool = False,
lr_scheduler: str = None,
model_selection_criterion="loss",
deterministic_mode: bool = True,
save_model_every_epoch: bool = False,
clear_wandb_cache: bool = False,
**kwargs,
) -> None:
self.data_module = data_module
self.model = model
self.model_trainer = None
# log gradients, parameter histogram and model topology
logger.watch(self.model, log="all")
self.selected_items_table = None
self.strategy = strategy
self.epochs = epochs
self.logger = logger
self.gpus = gpus
self.active_learning_mode = active_learning_mode
self.checkpoint_dir = checkpoint_dir
self.early_stopping = early_stopping
self.initial_epochs = initial_epochs if initial_epochs is not None else epochs
self.items_to_label = items_to_label
self.iterations = iterations
self.heatmaps_per_iteration = heatmaps_per_iteration
self.lr_scheduler = lr_scheduler
self.model_selection_criterion = model_selection_criterion
self.reset_weights = reset_weights
self.epochs_increase_per_query = epochs_increase_per_query
self.deterministic_mode = deterministic_mode
self.save_model_every_epoch = save_model_every_epoch
self.clear_wandb_cache = clear_wandb_cache
self.kwargs = kwargs
[docs] def run(self) -> None:
"""Run the pipeline"""
self.data_module.setup()
# pylint: disable=too-many-nested-blocks
if self.active_learning_mode:
self.model_trainer = self.setup_trainer(self.initial_epochs, iteration=0)
if self.iterations is None:
self.iterations = math.ceil(
self.data_module.unlabeled_set_size() / self.items_to_label
)
# run pipeline
for iteration in range(0, self.iterations + 1):
self.selected_items_table = wandb.Table(
columns=[
"iteration",
"case_id",
"image_path",
"image_id",
"slice_index",
"pseudo_label",
]
)
# skip labeling in the first iteration because the model hasn't trained yet
if iteration != 0:
# query batch selection
if self.data_module.unlabeled_set_size() > 0:
(
items_to_label,
pseudo_labels,
) = self.strategy.select_items_to_label(
self.model,
self.data_module,
self.items_to_label,
**self.kwargs,
)
# label batch
self.data_module.label_items(items_to_label, pseudo_labels)
# Log selected items to wandb table
self.__log_selected_items(iteration, items_to_label, pseudo_labels)
if self.heatmaps_per_iteration > 0:
# Get latest added items from dataset
items_to_inspect = (
self.data_module.training_set.get_images_by_id(
case_ids=items_to_label[: self.heatmaps_per_iteration],
)
)
# Generate heatmaps using final predictions and heatmaps
if len(items_to_inspect) > 0:
self.__generate_and_log_heatmaps(
items_to_inspect=items_to_inspect, iteration=iteration
)
self.model_trainer = self.setup_trainer(
self.epochs, iteration=iteration
)
# optionally reset weights after fitting on new data
if self.reset_weights and iteration != 0:
self.model.reset_parameters()
self.model.start_epoch = self.model.current_epoch + 1
self.model.iteration = iteration
# train model on labeled batch
self.model_trainer.fit(self.model, self.data_module)
# compute metrics for the best model on the validation set
self.model_trainer.validate(
ckpt_path="best", dataloaders=self.data_module
)
self.model.step_loss_weight_pseudo_labels_scheduler()
else:
self.model_trainer = self.setup_trainer(self.epochs, iteration=0)
# run regular fit run with all the data if no active learning mode
self.model_trainer.fit(self.model, self.data_module)
# compute metrics for the best model on the validation set
self.model_trainer.validate(ckpt_path="best", dataloaders=self.data_module)
wandb.run.finish()
if self.clear_wandb_cache:
self.remove_wandb_cache()
[docs] def setup_trainer(self, epochs: int, iteration: Optional[int] = None) -> Trainer:
"""
Initializes a new Pytorch Lightning trainer object.
Args:
epochs (int): Number of training epochs.
iteration (Optional[int], optional): Current active learning iteration. Defaults to None.
Returns:
pytorch_lightning.Trainer: A trainer object.
"""
callbacks = []
if self.lr_scheduler is not None:
callbacks.append(LearningRateMonitor(logging_interval="step"))
if self.early_stopping:
callbacks.append(EarlyStopping("validation/loss"))
monitoring_mode = "min" if "loss" in self.model_selection_criterion else "max"
if self.checkpoint_dir is not None and iteration is not None:
checkpoint_dir = os.path.join(self.checkpoint_dir, str(iteration))
else:
checkpoint_dir = self.checkpoint_dir
num_sanity_val_steps = 2 if iteration is None or iteration == 0 else 0
best_model_checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="best_model_epoch_{epoch}",
auto_insert_metric_name=False,
monitor=f"val/{self.model_selection_criterion}",
mode=monitoring_mode,
save_last=True,
every_n_epochs=1,
every_n_train_steps=0,
save_on_train_epoch_end=False,
)
callbacks.append(best_model_checkpoint_callback)
if self.save_model_every_epoch:
all_models_checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(checkpoint_dir, "all_models"),
filename="epoch_{epoch}",
auto_insert_metric_name=False,
save_top_k=-1,
every_n_epochs=1,
every_n_train_steps=0,
save_on_train_epoch_end=False,
)
callbacks.append(all_models_checkpoint_callback)
# Pytorch lightning currently does not support deterministic 3d max pooling
# therefore this option is only enabled for the 2d case
# see https://pytorch.org/docs/stable/notes/randomness.html
deterministic_mode = (
self.deterministic_mode if self.model.input_dimensionality() == 2 else False
)
return Trainer(
deterministic=deterministic_mode,
benchmark=not self.deterministic_mode,
profiler="simple",
max_epochs=epochs + iteration * self.epochs_increase_per_query
if iteration is not None
else epochs,
logger=self.logger,
log_every_n_steps=20,
gpus=self.gpus,
callbacks=callbacks,
num_sanity_val_steps=num_sanity_val_steps,
)
def __log_selected_items(
self,
iteration: int,
selected_items: List[str],
pseudo_labels: Optional[List[str]] = None,
) -> None:
"""
Log the iteration, case_id, image_path, image_id, slice_index and pseudo_label for all slices
selected in an iteration.
Args:
iteration (int): The current active learning iteration.
selected_items (List[str]): A list of all case_ids selected by the strategy in this iteration.
pseudo_labels (List[str]): A list of all case_ids selected as pseudo labels in this iteration.
"""
if self.selected_items_table is not None:
if pseudo_labels is not None:
selected_items_with_true_labels = [
item for item in selected_items if item not in pseudo_labels
]
else:
selected_items_with_true_labels = selected_items
items = self.data_module.training_set.get_items_for_logging(
selected_items_with_true_labels
)
items = [[iteration, *i, False] for i in items]
if (
pseudo_labels is not None
and not self.data_module.training_set.only_return_true_labels
):
pseudo_items = self.data_module.training_set.get_items_for_logging(
list(pseudo_labels.keys())
)
pseudo_items = [[iteration, *i, True] for i in pseudo_items]
items.extend(pseudo_items)
for row in items:
self.selected_items_table.add_data(*row)
wandb.log({"selected_items": self.selected_items_table})
def __generate_and_log_heatmaps(
self, items_to_inspect: List[Tuple[np.ndarray, str]], iteration: int
) -> None:
"""
Generates heatmaps using gradient based method and the prediction of the last layer of the model.
Args:
items_to_inspect (List[Tuple[np.ndarray, str]]): A list with the items to generate heatmaps for.
iteration (int): The iteration of the active learning loop.
"""
# Generate heatmaps using final predictions and gradient based method
gcam_images, logit_images = [], []
for img, case_id in items_to_inspect:
gcam_heatmap, logit_heatmap = self.__generate_heatmaps(
img=img, case_id=case_id
)
gcam_images.append(
wandb.Image(
gcam_heatmap,
caption=f"AL Iteration: {iteration}, Case ID: {case_id}",
)
)
logit_images.append(
wandb.Image(
logit_heatmap,
caption=f"AL Iteration: {iteration}, Case ID: {case_id}",
)
)
wandb.log({"GradCam heatmaps": gcam_images})
wandb.log({"Logit heatmaps": logit_images})
def __generate_heatmaps(
self,
img: np.ndarray,
case_id: str,
target_category: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Generates two heatmaps: One based on the GradCam method and one based on the predictions of the last layer.
Args:
img (np.ndarray): The image as numpy array.
case_id (str): The id of the current image.
target_category (int, optional): The label of the target class to analyze.
Returns:
A tuple of both heatmaps. GradCamp heatmap, prediction heatmap.
"""
input_tensor = torch.from_numpy(img)
heatmap = HeatMaps(model=self.model)
target_layers = [self.model.model.conv]
gcam_gray = heatmap.generate_grayscale_cam(
input_tensor=input_tensor,
target_category=target_category,
target_layers=target_layers,
)
logits_gray = heatmap.generate_grayscale_logits(
input_tensor=input_tensor, target_category=target_category
)
gcam_img = heatmap.show_grayscale_heatmap_on_image(
image=img, grayscale_heatmap=gcam_gray
)
logits_img = heatmap.show_grayscale_heatmap_on_image(
image=img, grayscale_heatmap=logits_gray
)
print(f"Generated heatmaps for case {case_id}")
return gcam_img, logits_img
[docs] @staticmethod
def remove_wandb_cache() -> None:
"""
Deletes Weights and Biases cache directory. This is necessary since the Weights and Biases client currently does
not implement proper cache cleanup itself. See
`this github issue <https://github.com/wandb/client/issues/1193>`_ for more details.
"""
wandb_cache_dir = wandb.env.get_cache_dir()
if wandb_cache_dir is not None:
shutil.rmtree(wandb_cache_dir)