Shortcuts

Source code for datasets.collate

""" Module to collate batches """

import math
from typing import List, Union

import torch
import torch.nn.functional as F


[docs]def batch_padding_collate_fn( batch: List[Union[tuple, str, torch.Tensor]], pad_value: int = 0 ) -> torch.Tensor: """ Collates a batch and padds tensors to the same size before stacking them. Args: batch (List[Union[tuple, str, torch.Tensor]]): The batch in List form. Returns: The batch collated. """ elem = batch[0] if isinstance(elem, tuple): return tuple( ( # Always pad labels with -1. batch_padding_collate_fn(samples, -1 if idx == 1 else pad_value) for idx, samples in enumerate(zip(*batch)) ) ) if isinstance(elem, (bool, str)): return batch max_size = [max([item.size(i) for item in batch]) for i in range(batch[0].dim())] total_paddings = [ [max_size[i] - item.size(i) for i in range(len(max_size))] for item in batch ] split_paddings = [ [ p for dim in reversed(padding) for p in [math.floor(dim / 2), math.ceil(dim / 2)] ] for padding in total_paddings ] batch = [ F.pad(item, tuple(pad), "constant", pad_value) for item, pad in zip(batch, split_paddings) ] return torch.stack(batch)

Docs

Access comprehensive developer documentation for Active Segmentation

View Docs