In this post, we discover how to use Kornia modules in order to perform the data augmentation on the GPU in batch mode. Kornia is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia consists a lot of components. One of them is kornia.augmentation - a module to perform data augmentation in the GPU.

We will work with the Segmentation Problem (Nail Segmentation). For that, we use Pytorch Lightninig to train model and use Kornia to build the data augmentation on the GPU.

1. Problem Description and Dataset

We want to cover a nail semantic segmentation problem. For each image, we want to detect the segmentation of the nail in the image.

Images Masks

Our data is organized as

├── Images
│   ├── 1
│       ├── first_image.png
│       ├── second_image.png
│       ├── third_image.png
│   ├── 2
│   ├── 3
│   ├── 4
├── Masks
│   ├── 1
│       ├── first_image.png
│       ├── second_image.png
│       ├── third_image.png
│   ├── 2
│   ├── 3
│   ├── 4

We have two folders: Images and Masks. Images is the data folder, and Masks is the label folder, which is the segmentations of input images. Each folder has four sub-folder: 1, 2, 3, and 4, corresponding to four distribution patterns of nail .

We download data from link and put it in data_root, for example

data_root = "./nail-segmentation-dataset"

2. Data Preparation

Similar to the training pipeline of the previous post, we first make the data frame to store images and masks infos.

index images
1 path_first_image.png
2 path_second_image.png
3 path_third_image.png
4 path_fourth_image.png

For that we use make_csv_file function in file.

3. The CPU bottleneck

The fact is that today these transforms are applied one input at a time on CPUs. This means that they are super slow.

3.1 A naive approach model training

The naive training pipeline includes:

  • The pre-processing of the data occurs on the CPU
  • The model will be typically trained on GPU/TPU.

3.2 Data Augmentation using GPU

To improve the training speed we can shift the data augmentation task in to GPU

To do that we can use Kornia.augmentation, Dali libraries.

  • Kornia.augmentation is the module of Kornia which permit to do augmentation in GPU. It will boost the speed of traininig in almost cases.
  • DALI is the library for data loading and pre-processing to accelerate deep learning applications. Data processing pipelines implemented using DALI can easily be retargeted to TensorFlow, PyTorch, MXNet and PaddlePaddle.

This post we will focus on how to use Kornia. A guide of using DALI will be introduced in next post.

4. Data Augmentation using Kornia

In this part, we will cover how to use Kornia for data augmentation. To augumentate data on GPU, we use transforms (augumentations) as a transform_module ( in Pytorch platform, it is a nn.Module object) whose input is a tensor of size $C\times H \times W$ and output is also tensor of size $C\times H \times W$.

That transform_module is put between the processing task (includes read images, make images of batch having same size, convert images in to the tensor format) and the training model. More precisely,

class ModelWithAugumentation(nn.Module):
    """Module to perform data augmentation on torch tensors."""

    def __init__(self, transform_module: nn.Module, model : nn.Module) -> None:

        self.transform_module = transform_module
        self.model = model

    def forward(self, x: Tensor) -> Tensor:
        augmented_x = self.transform_module(x)  # BxCxHxW
        x_out = self.model(augmented_x)
        return x_out

where transform_module is defined by using Kornia or torchvision. For example

transform_module = K.augmentation.AugmentationSequential(
    K.augmentation.Normalize(Tensor((0.485, 0.456, 0.406)), Tensor((0.229, 0.224, 0.225)), p=1)

Note: we can also use torchvsiion to define the transform module.

transforms = torch.nn.Sequential(
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),

Each transform module is a nn.Module object. We can use nn.Sequential to define a sequence of transforms.

We now apply that strategy to our problem. Comparing with the previous pipeline in the last post (Training deep learning segmentation models in Pytorch Lightning), here we have some modifications:

  • Only use Resize or Padding in the data augmentation on CPUs, in the last part we define the whole augmentation by using albumentations and use it as the transform before going to the model.
import albumentations as A

def resize(p: float = 1):
    return A.Resize(384, 384, always_apply=True)

self.valid_transform = resize()
self.train_transform = resize()
  • Using Kornia to define the augmentation (nn.Module object), here we have train_transform_K and valid_transform_K
import kornia as K

valid_transform_K = K.augmentation.AugmentationSequential(
    K.augmentation.Normalize(Tensor((0.485, 0.456, 0.406)), Tensor((0.229, 0.224, 0.225)), p=1),
    data_keys=["input", "mask"],

train_transform_K = K.augmentation.AugmentationSequential(
    K.augmentation.container.ImageSequential(  # OneOf
        random_apply_weights=[0.5, 0.5],
    K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.5),
    # K.augmentation.RandomAffine( degrees = (-15.0,15.0), p= 0.3),
    K.augmentation.Normalize(Tensor((0.485, 0.456, 0.406)), Tensor((0.229, 0.224, 0.225)), p=1),
    data_keys=["input", "mask"],

  • In the LightningModule, we define two new functions (or two nn.Module objects)
self.train_transform = train_transform_K
self.valid_transform = valid_transform_K

Note: Add transform into the training loop and the valid loop (training_step and validation_step)

def training_step(self, batch, batch_idx):
    imgs, masks = batch["image"], batch["label"]
    if self.train_transform is not None:
        imgs, masks = self.train_transform(imgs, masks) # add the transform before going to the model
        imgs, masks = imgs.float(), masks.float()
    logits = self(imgs)

    train_loss = self.loss_function(logits, masks)
    train_dice_soft = self.dice_soft(logits, masks)

    self.log("train_loss", train_loss, prog_bar=True)
    self.log("train_dice_soft", train_dice_soft, prog_bar=True)
    return {"loss": train_loss, "train_dice_soft": train_dice_soft}

def validation_step(self, batch, batch_idx):
    imgs, masks = batch["image"], batch["label"]
    if self.valid_transform:
        imgs, masks = self.valid_transform(imgs, masks) # add the transform before going to the model
        imgs, masks = imgs.float(), masks.float()
    logits = self(imgs)

    valid_loss = self.loss_function(logits, masks)
    valid_dice_soft = self.dice_soft(logits, masks)
    valid_iou = binary_mean_iou(logits, masks)

    self.log("valid_loss", valid_loss, prog_bar=True)
    self.log("valid_dice", valid_dice_soft, prog_bar=True)
    self.log("valid_iou", valid_iou, prog_bar=True)

    return {
        "valid_loss": valid_loss,
        "valid_dice": valid_dice_soft,
        "valid_iou": valid_iou,

We keep all of rest parts of the pipeline (LightningDataModule, Trainer).

For more details, we can find the source code at github
