Shortcuts

Source code for main

""" Main module to execute active learning pipeline from CLI """
import copy
import json
import os.path
from typing import Any, Dict, Iterable, Optional

import torch
import fire
import pytorch_lightning
from pytorch_lightning.loggers import WandbLogger
import wandb

from active_learning import ActiveLearningPipeline
from inferencing import Inferencer
from models import PytorchUNet
from datasets import (
    BraTSDataModule,
    DecathlonDataModule,
    BCSSDataModule,
)
from query_strategies import (
    RandomSamplingStrategy,
    UncertaintySamplingStrategy,
    InterpolationSamplingStrategy,
    DistanceBasedRepresentativenessSamplingStrategy,
    ClusteringBasedRepresentativenessSamplingStrategy,
    UncertaintyRepresentativenessSamplingStrategy,
)


[docs]def create_data_module( dataset: str, data_dir: str, batch_size: int, num_workers: int, random_state: int, active_learning_config: Dict[str, Any], dataset_config: Dict[str, Any], ): """ Creates the correct data module. Args: dataset (string): Name of the dataset. E.g. 'brats' data_dir (string, optional): Main directory with the dataset. E.g. './data' batch_size (int, optional): Size of training examples passed in one training step. num_workers (int, optional): Number of workers. random_state (int): Random constant for shuffling the data active_learning_config (Dict[str, Any): Dictionary with active learning specific parameters. dataset_config (Dict[str, Any]): Dictionary with dataset specific parameters. Returns: The data module. """ if dataset == "brats": data_module = BraTSDataModule( data_dir, batch_size, num_workers, active_learning_mode=active_learning_config.get( "active_learning_mode", False ), batch_size_unlabeled_set=min( active_learning_config.get("batch_size_unlabeled_set", batch_size), active_learning_config.get("items_to_label", 1), ), initial_training_set_size=active_learning_config.get( "initial_training_set_size", 10 ), random_state=random_state, **dataset_config, ) elif dataset == "decathlon": data_module = DecathlonDataModule( data_dir, batch_size, num_workers, active_learning_mode=active_learning_config.get( "active_learning_mode", False ), batch_size_unlabeled_set=min( active_learning_config.get("batch_size_unlabeled_set", batch_size), active_learning_config.get("items_to_label", 1), ), initial_training_set_size=active_learning_config.get( "initial_training_set_size", 10 ), random_state=random_state, **dataset_config, ) elif dataset == "bcss": dataset_config.pop("dim") data_module = BCSSDataModule( data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, active_learning_mode=active_learning_config.get( "active_learning_mode", False ), initial_training_set_size=active_learning_config.get( "initial_training_set_size", 10 ), random_state=random_state, **dataset_config, ) else: raise ValueError("Invalid data_module name.") return data_module
# pylint: disable=too-many-arguments,too-many-locals
[docs]def run_active_learning_pipeline( architecture: str, dataset: str, strategy_config: dict, experiment_name: str, batch_size: int = 16, checkpoint_dir: Optional[str] = None, data_dir: str = "./data", dataset_config: Optional[Dict[str, Any]] = None, model_config: Optional[Dict[str, Any]] = None, model_selection_criterion: Optional[str] = "mean_dice_score_0.5", active_learning_config: Optional[Dict[str, Any]] = None, epochs: int = 50, experiment_tags: Optional[Iterable[str]] = None, gpus: int = 1, num_workers: int = 4, learning_rate: float = 0.0001, lr_scheduler: Optional[str] = None, num_levels: int = 4, prediction_count: Optional[int] = None, prediction_dir: str = "./predictions", wandb_project_name: str = "active-segmentation", early_stopping: bool = False, random_state: int = 42, deterministic_mode: bool = True, save_model_every_epoch: bool = False, clear_wandb_cache: bool = False, ) -> None: """ Main function to execute an active learning pipeline run, or start an active learning simulation. Args: architecture (string): Name of the desired model architecture. E.g. 'u_net'. dataset (string): Name of the dataset. E.g. 'brats' strategy_config (dict): Configuration of the query strategy. experiment_name (string): Name of the experiment. batch_size (int, optional): Size of training examples passed in one training step. checkpoint_dir (str, optional): Directory where the model checkpoints are to be saved. data_dir (string, optional): Main directory with the dataset. E.g. './data' dataset_config (Dict[str, Any], optional): Dictionary with dataset specific parameters. model_config (Dict[str, Any], optional): Dictionary with model specific parameters. active_learning_config (Dict[str, Any], optional): Dictionary with active learning specific parameters. epochs (int, optional): Number of iterations with the full dataset. experiment_tags (Iterable[string], optional): Tags with which to label the experiment. gpus (int): Number of GPUS to use for model training. num_workers (int, optional): Number of workers. learning_rate (float): The step size at each iteration while moving towards a minimum of the loss. lr_scheduler (string, optional): Algorithm used for dynamically updating the learning rate during training. E.g. 'reduceLROnPlateau' or 'cosineAnnealingLR' num_levels (int, optional): Number levels (encoder and decoder blocks) in the U-Net. Defaults to 4. early_stopping (bool, optional): Enable/Disable Early stopping when model is not learning anymore (default = False). random_state (int): Random constant for shuffling the data wandb_project_name (string, optional): Name of the project that the W&B runs are stored in. deterministic_mode (bool, optional): Whether only deterministic CUDA operations should be used. Defaults to `True`. save_model_every_epoch (bool, optional): Whether the model files of all epochs are to be saved or only the model file of the best epoch. Defaults to `False`. clear_wandb_cache (bool, optional): Whether the whole Weights and Biases cache should be deleted when the run is finished. Should only be used when no other runs are running in parallel. Defaults to False. Returns: None. """ # set global seeds for reproducibility pytorch_lightning.utilities.seed.seed_everything(random_state, workers=True) torch.cuda.manual_seed(random_state) torch.cuda.manual_seed_all(random_state) # for multi-GPU runs wandb_logger = WandbLogger( project=wandb_project_name, entity="active-segmentation", name=experiment_name if random_state is None else f"{experiment_name}-{random_state}", tags=experiment_tags, log_model="all", config=copy.deepcopy(locals()), group=strategy_config.get("type", None), job_type=strategy_config.get("description", None), ) if dataset_config is None: dataset_config = {} if active_learning_config is None: active_learning_config = {} data_module = create_data_module( dataset, data_dir, batch_size, num_workers, random_state, active_learning_config, dataset_config, ) model = create_model( data_module, architecture, learning_rate, lr_scheduler, num_levels, model_config, loss_weight_scheduler_max_steps=active_learning_config.get("iterations", None), ) strategy = create_query_strategy(strategy_config=strategy_config) if checkpoint_dir is not None: checkpoint_dir = os.path.join(checkpoint_dir, f"{wandb_logger.experiment.id}") prediction_dir = os.path.join(prediction_dir, f"{wandb_logger.experiment.id}") pipeline = ActiveLearningPipeline( data_module, model, strategy, epochs, gpus, checkpoint_dir, active_learning_mode=active_learning_config.get("active_learning_mode", False), initial_epochs=active_learning_config.get("initial_epochs", epochs), items_to_label=active_learning_config.get("items_to_label", 1), iterations=active_learning_config.get("iterations", None), reset_weights=active_learning_config.get("reset_weights", False), epochs_increase_per_query=active_learning_config.get( "epochs_increase_per_query", 0 ), heatmaps_per_iteration=active_learning_config.get("heatmaps_per_iteration", 0), logger=wandb_logger, early_stopping=early_stopping, lr_scheduler=lr_scheduler, model_selection_criterion=model_selection_criterion, deterministic_mode=deterministic_mode, save_model_every_epoch=save_model_every_epoch, clear_wandb_cache=clear_wandb_cache, **active_learning_config.get("strategy_config", {}), ) pipeline.run() if prediction_count is None: return inferencer = Inferencer( model, dataset, data_dir, prediction_dir, prediction_count, dataset_config=dataset_config, ) inferencer.inference()
[docs]def create_model( data_module, architecture, learning_rate, lr_scheduler, num_levels, model_config, loss_weight_scheduler_max_steps: Optional[int] = None, ): """ Creates the specified model. Args: data_module (ActiveLearningDataModule): A data module object providing data. architecture (string): Name of the desired model architecture. E.g. 'u_net'. learning_rate (float): The step size at each iteration while moving towards a minimum of the loss. lr_scheduler (string, optional): Algorithm used for dynamically updating the learning rate during training. E.g. 'reduceLROnPlateau' or 'cosineAnnealingLR' num_levels (int, optional): Number levels (encoder and decoder blocks) in the U-Net. Defaults to 4. model_config (Dict[str, Any], optional): Dictionary with model specific parameters. loss_weight_scheduler_max_steps (int, optional): Number of steps for pseudo-label loss weight scheduler. Returns: The model. """ if ( "loss_config" in model_config and "weight_pseudo_labels_scheduler" in model_config["loss_config"] ): model_config["loss_config"][ "weight_pseudo_labels_decay_steps" ] = loss_weight_scheduler_max_steps if architecture == "u_net": model = PytorchUNet( learning_rate=learning_rate, lr_scheduler=lr_scheduler, num_levels=num_levels, in_channels=data_module.data_channels(), out_channels=data_module.num_classes(), multi_label=data_module.multi_label(), **model_config, ) else: raise ValueError("Invalid model architecture.") return model
[docs]def create_query_strategy(strategy_config: dict): """ Initialises the chosen query strategy. Args: strategy_config (dict): Configuration of the query strategy """ strategy_type = strategy_config.get("type") if strategy_type == "random": return RandomSamplingStrategy(**strategy_config) if strategy_type == "interpolation": return InterpolationSamplingStrategy(**strategy_config) if strategy_type == "uncertainty": return UncertaintySamplingStrategy(**strategy_config) if strategy_type == "representativeness_distance": return DistanceBasedRepresentativenessSamplingStrategy(**strategy_config) if strategy_type == "representativeness_clustering": return ClusteringBasedRepresentativenessSamplingStrategy(**strategy_config) if strategy_type == "representativeness_uncertainty": return UncertaintyRepresentativenessSamplingStrategy(**strategy_config) raise ValueError("Invalid query strategy.")
[docs]def run_active_learning_pipeline_from_config( config_file_name: str, hp_optimisation: bool = False ) -> None: """ Runs the active learning pipeline based on a config file. Args: config_file_name: Name of or path to the config file. hp_optimisation: If this flag is set, run the pipeline with different hyperparameters based on the configured sweep file """ if not os.path.isfile(config_file_name): print("Config file could not be found.") raise FileNotFoundError(f"{config_file_name} is not a valid filename.") with open(config_file_name, encoding="utf-8") as config_file: hyperparameter_defaults = json.load(config_file) config = hyperparameter_defaults if "dataset_config" in config and "dataset" in config["dataset_config"]: config["dataset"] = config["dataset_config"]["dataset"] del config["dataset_config"]["dataset"] if "dataset_config" in config and "data_dir" in config["dataset_config"]: config["data_dir"] = config["dataset_config"]["data_dir"] del config["dataset_config"]["data_dir"] if ( "dataset_config" in config and "mask_filter_values" in config["dataset_config"] ): config["dataset_config"]["mask_filter_values"] = tuple( config["dataset_config"]["mask_filter_values"] ) if ( "model_config" in config and "dim" in config["model_config"] and "dataset_config" in config ): config["dataset_config"]["dim"] = config["model_config"]["dim"] if "model_config" in config and "architecture" in config["model_config"]: config["architecture"] = config["model_config"]["architecture"] del config["model_config"]["architecture"] if "model_config" in config and "learning_rate" in config["model_config"]: config["learning_rate"] = config["model_config"]["learning_rate"] del config["model_config"]["learning_rate"] if "model_config" in config and "lr_scheduler" in config["model_config"]: config["lr_scheduler"] = config["model_config"]["lr_scheduler"] del config["model_config"]["lr_scheduler"] if "model_config" in config and "num_levels" in config["model_config"]: config["num_levels"] = config["model_config"]["num_levels"] del config["model_config"]["num_levels"] if ( "model_config" in config and "model_selection_criterion" in config["model_config"] ): config["model_selection_criterion"] = config["model_config"][ "model_selection_criterion" ] del config["model_config"]["model_selection_criterion"] if hp_optimisation: print("Start Hyperparameter Optimisation using sweep.yaml file") wandb.init( config=hyperparameter_defaults, project=config["wandb_project_name"], entity="active-segmentation", ) # Config parameters are automatically set by W&B sweep agent config = wandb.config run_active_learning_pipeline(**config)
if __name__ == "__main__": fire.Fire(run_active_learning_pipeline_from_config)

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs