Shortcuts

Source code for active_learning

""" Module containing the active learning pipeline """

import math
import os
import shutil
from typing import Iterable, Optional, Union, Tuple, List

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import numpy as np
import wandb

from query_strategies import QueryStrategy
from datasets import ActiveLearningDataModule
from models import PytorchModel
from functional.interpretation import HeatMaps


[docs]class ActiveLearningPipeline: """ The pipeline or simulation environment to run active learning experiments. Args: data_module (ActiveLearningDataModule): A data module object providing data. model (PytorchModel): A model object with architecture able to be fitted with the data. strategy (QueryStrategy): An active learning strategy to query for new labels. epochs (int): The number of epochs the model should be trained. gpus (int): Number of GPUS to use for model training. checkpoint_dir (str, optional): Directory where the model checkpoints are to be saved. early_stopping (bool, optional): Enable/Disable Early stopping when model is not learning anymore. Defaults to False. logger: A logger object as defined by Pytorch Lightning. lr_scheduler (string, optional): Algorithm used for dynamically updating the learning rate during training. E.g. 'reduceLROnPlateau' or 'cosineAnnealingLR' active_learning_mode (bool, optional): Enable/Disabled Active Learning Pipeline. Defaults to False. initial_epochs (int, optional): Number of epochs the initial model should be trained. Defaults to `epochs`. items_to_label (int, optional): Number of items that should be selected for labeling in the active learning run. Defaults to 1. iterations (int, optional): iteration times how often the active learning pipeline should be executed. If None, the active learning pipeline is run until the whole dataset is labeled. Defaults to None. reset_weights (bool, optional): Enable/Disable resetting of weights after every active learning run epochs_increase_per_query (int, optional): Increase number of epochs for every query to compensate for the increased training dataset size. Defaults to 0. heatmaps_per_iteration (int, optional): Number of heatmaps that should be generated per iteration. Defaults to 0. deterministic_mode (bool, optional): Whether only deterministic CUDA operations should be used. Defaults to `True`. save_model_every_epoch (bool, optional): Whether the model files of all epochs are to be saved or only the model file of the best epoch. Defaults to `False`. clear_wandb_cache (bool, optional): Whether the whole Weights and Biases cache should be deleted when the run is finished. Should only be used when no other runs are running in parallel. Defaults to False. **kwargs: Additional, strategy-specific parameters. """ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes,too-many-locals # pylint: disable=protected-access def __init__( self, data_module: ActiveLearningDataModule, model: PytorchModel, strategy: QueryStrategy, epochs: int, gpus: int, checkpoint_dir: Optional[str] = None, active_learning_mode: bool = False, initial_epochs: Optional[int] = None, items_to_label: int = 1, iterations: Optional[int] = None, reset_weights: bool = False, epochs_increase_per_query: int = 0, heatmaps_per_iteration: int = 0, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, early_stopping: bool = False, lr_scheduler: str = None, model_selection_criterion="loss", deterministic_mode: bool = True, save_model_every_epoch: bool = False, clear_wandb_cache: bool = False, **kwargs, ) -> None: self.data_module = data_module self.model = model self.model_trainer = None # log gradients, parameter histogram and model topology logger.watch(self.model, log="all") self.selected_items_table = None self.strategy = strategy self.epochs = epochs self.logger = logger self.gpus = gpus self.active_learning_mode = active_learning_mode self.checkpoint_dir = checkpoint_dir self.early_stopping = early_stopping self.initial_epochs = initial_epochs if initial_epochs is not None else epochs self.items_to_label = items_to_label self.iterations = iterations self.heatmaps_per_iteration = heatmaps_per_iteration self.lr_scheduler = lr_scheduler self.model_selection_criterion = model_selection_criterion self.reset_weights = reset_weights self.epochs_increase_per_query = epochs_increase_per_query self.deterministic_mode = deterministic_mode self.save_model_every_epoch = save_model_every_epoch self.clear_wandb_cache = clear_wandb_cache self.kwargs = kwargs
[docs] def run(self) -> None: """Run the pipeline""" self.data_module.setup() # pylint: disable=too-many-nested-blocks if self.active_learning_mode: self.model_trainer = self.setup_trainer(self.initial_epochs, iteration=0) if self.iterations is None: self.iterations = math.ceil( self.data_module.unlabeled_set_size() / self.items_to_label ) # run pipeline for iteration in range(0, self.iterations + 1): self.selected_items_table = wandb.Table( columns=[ "iteration", "case_id", "image_path", "image_id", "slice_index", "pseudo_label", ] ) # skip labeling in the first iteration because the model hasn't trained yet if iteration != 0: # query batch selection if self.data_module.unlabeled_set_size() > 0: ( items_to_label, pseudo_labels, ) = self.strategy.select_items_to_label( self.model, self.data_module, self.items_to_label, **self.kwargs, ) # label batch self.data_module.label_items(items_to_label, pseudo_labels) # Log selected items to wandb table self.__log_selected_items(iteration, items_to_label, pseudo_labels) if self.heatmaps_per_iteration > 0: # Get latest added items from dataset items_to_inspect = ( self.data_module.training_set.get_images_by_id( case_ids=items_to_label[: self.heatmaps_per_iteration], ) ) # Generate heatmaps using final predictions and heatmaps if len(items_to_inspect) > 0: self.__generate_and_log_heatmaps( items_to_inspect=items_to_inspect, iteration=iteration ) self.model_trainer = self.setup_trainer( self.epochs, iteration=iteration ) # optionally reset weights after fitting on new data if self.reset_weights and iteration != 0: self.model.reset_parameters() self.model.start_epoch = self.model.current_epoch + 1 self.model.iteration = iteration # train model on labeled batch self.model_trainer.fit(self.model, self.data_module) # compute metrics for the best model on the validation set self.model_trainer.validate( ckpt_path="best", dataloaders=self.data_module ) self.model.step_loss_weight_pseudo_labels_scheduler() else: self.model_trainer = self.setup_trainer(self.epochs, iteration=0) # run regular fit run with all the data if no active learning mode self.model_trainer.fit(self.model, self.data_module) # compute metrics for the best model on the validation set self.model_trainer.validate(ckpt_path="best", dataloaders=self.data_module) wandb.run.finish() if self.clear_wandb_cache: self.remove_wandb_cache()
[docs] def setup_trainer(self, epochs: int, iteration: Optional[int] = None) -> Trainer: """ Initializes a new Pytorch Lightning trainer object. Args: epochs (int): Number of training epochs. iteration (Optional[int], optional): Current active learning iteration. Defaults to None. Returns: pytorch_lightning.Trainer: A trainer object. """ callbacks = [] if self.lr_scheduler is not None: callbacks.append(LearningRateMonitor(logging_interval="step")) if self.early_stopping: callbacks.append(EarlyStopping("validation/loss")) monitoring_mode = "min" if "loss" in self.model_selection_criterion else "max" if self.checkpoint_dir is not None and iteration is not None: checkpoint_dir = os.path.join(self.checkpoint_dir, str(iteration)) else: checkpoint_dir = self.checkpoint_dir num_sanity_val_steps = 2 if iteration is None or iteration == 0 else 0 best_model_checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dir, filename="best_model_epoch_{epoch}", auto_insert_metric_name=False, monitor=f"val/{self.model_selection_criterion}", mode=monitoring_mode, save_last=True, every_n_epochs=1, every_n_train_steps=0, save_on_train_epoch_end=False, ) callbacks.append(best_model_checkpoint_callback) if self.save_model_every_epoch: all_models_checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(checkpoint_dir, "all_models"), filename="epoch_{epoch}", auto_insert_metric_name=False, save_top_k=-1, every_n_epochs=1, every_n_train_steps=0, save_on_train_epoch_end=False, ) callbacks.append(all_models_checkpoint_callback) # Pytorch lightning currently does not support deterministic 3d max pooling # therefore this option is only enabled for the 2d case # see https://pytorch.org/docs/stable/notes/randomness.html deterministic_mode = ( self.deterministic_mode if self.model.input_dimensionality() == 2 else False ) return Trainer( deterministic=deterministic_mode, benchmark=not self.deterministic_mode, profiler="simple", max_epochs=epochs + iteration * self.epochs_increase_per_query if iteration is not None else epochs, logger=self.logger, log_every_n_steps=20, gpus=self.gpus, callbacks=callbacks, num_sanity_val_steps=num_sanity_val_steps, )
def __log_selected_items( self, iteration: int, selected_items: List[str], pseudo_labels: Optional[List[str]] = None, ) -> None: """ Log the iteration, case_id, image_path, image_id, slice_index and pseudo_label for all slices selected in an iteration. Args: iteration (int): The current active learning iteration. selected_items (List[str]): A list of all case_ids selected by the strategy in this iteration. pseudo_labels (List[str]): A list of all case_ids selected as pseudo labels in this iteration. """ if self.selected_items_table is not None: if pseudo_labels is not None: selected_items_with_true_labels = [ item for item in selected_items if item not in pseudo_labels ] else: selected_items_with_true_labels = selected_items items = self.data_module.training_set.get_items_for_logging( selected_items_with_true_labels ) items = [[iteration, *i, False] for i in items] if ( pseudo_labels is not None and not self.data_module.training_set.only_return_true_labels ): pseudo_items = self.data_module.training_set.get_items_for_logging( list(pseudo_labels.keys()) ) pseudo_items = [[iteration, *i, True] for i in pseudo_items] items.extend(pseudo_items) for row in items: self.selected_items_table.add_data(*row) wandb.log({"selected_items": self.selected_items_table}) def __generate_and_log_heatmaps( self, items_to_inspect: List[Tuple[np.ndarray, str]], iteration: int ) -> None: """ Generates heatmaps using gradient based method and the prediction of the last layer of the model. Args: items_to_inspect (List[Tuple[np.ndarray, str]]): A list with the items to generate heatmaps for. iteration (int): The iteration of the active learning loop. """ # Generate heatmaps using final predictions and gradient based method gcam_images, logit_images = [], [] for img, case_id in items_to_inspect: gcam_heatmap, logit_heatmap = self.__generate_heatmaps( img=img, case_id=case_id ) gcam_images.append( wandb.Image( gcam_heatmap, caption=f"AL Iteration: {iteration}, Case ID: {case_id}", ) ) logit_images.append( wandb.Image( logit_heatmap, caption=f"AL Iteration: {iteration}, Case ID: {case_id}", ) ) wandb.log({"GradCam heatmaps": gcam_images}) wandb.log({"Logit heatmaps": logit_images}) def __generate_heatmaps( self, img: np.ndarray, case_id: str, target_category: int = 1, ) -> Tuple[np.ndarray, np.ndarray]: """ Generates two heatmaps: One based on the GradCam method and one based on the predictions of the last layer. Args: img (np.ndarray): The image as numpy array. case_id (str): The id of the current image. target_category (int, optional): The label of the target class to analyze. Returns: A tuple of both heatmaps. GradCamp heatmap, prediction heatmap. """ input_tensor = torch.from_numpy(img) heatmap = HeatMaps(model=self.model) target_layers = [self.model.model.conv] gcam_gray = heatmap.generate_grayscale_cam( input_tensor=input_tensor, target_category=target_category, target_layers=target_layers, ) logits_gray = heatmap.generate_grayscale_logits( input_tensor=input_tensor, target_category=target_category ) gcam_img = heatmap.show_grayscale_heatmap_on_image( image=img, grayscale_heatmap=gcam_gray ) logits_img = heatmap.show_grayscale_heatmap_on_image( image=img, grayscale_heatmap=logits_gray ) print(f"Generated heatmaps for case {case_id}") return gcam_img, logits_img
[docs] @staticmethod def remove_wandb_cache() -> None: """ Deletes Weights and Biases cache directory. This is necessary since the Weights and Biases client currently does not implement proper cache cleanup itself. See `this github issue <https://github.com/wandb/client/issues/1193>`_ for more details. """ wandb_cache_dir = wandb.env.get_cache_dir() if wandb_cache_dir is not None: shutil.rmtree(wandb_cache_dir)

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs