This post is a demonstration of using the Segformer model in HuggingFace. We will focus on:

  • Architecture of the Segformer model
  • Traing the Segformer model by using Pytorch Lightning and HuggingFace.

Similar to the previous post, we will work with the Segmentation Problem (Nail Segmentation). In the first and second parts we will recall Problem Description and Dataset. If you have followed previous posts, you can skip those parts. In the third part, we will focus on the advantages of the Segformer model. The last part of the post will cover the training the Segformer model.

1. Problem Description and Dataset

We want to cover a nail semantic segmentation problem. For each image, we want to detect the segmentation of the nail in the image.

Images Masks

Our data is organized as

├── Images
│   ├── 1
│       ├── first_image.png
│       ├── second_image.png
│       ├── third_image.png
│   ├── 2
│   ├── 3
│   ├── 4
├── Masks
│   ├── 1
│       ├── first_image.png
│       ├── second_image.png
│       ├── third_image.png
│   ├── 2
│   ├── 3
│   ├── 4

We have two folders: Images and Masks. Images is the data folder, and Masks is the label folder, which is the segmentations of input images. Each folder has four sub-folder: 1, 2, 3, and 4, corresponding to four types of nail distribution.

We download data from link and put it in data_root, for example

data_root = "./nail-segmentation-dataset"

2. Data Preparation

Similar to the training pipeline of the previous post, we first make the data frame to store images and masks infos.

index images
1 path_first_image.png
2 path_second_image.png
3 path_third_image.png
4 path_fourth_image.png

For that we use make_csv_file function in data_processing.py file.

3. The Segformer Model for the semanctic segmentation problem

The SegFormer model was proposed in SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on image segmentation benchmarks.

The figure below illustrates the architecture of SegFormer

SegFormer has the following notable points:

  • The new Transformer encoder (backbone): Mix Transformer (MiT) that extracts coarse and fine features
  • The decoder is a MLP network to directly fuse the multi-level features of the encoder part and predicts the semantic segmentation mask

3.1 Encoder

The encoder of SegFormer is a Mix Transformer(MiT). There are six versions of encoders: MiT-B0 to MiT-B5. They have the same architecture, but different sizes. MiT-B0 is our lightweight model for fast inference, while MiT-B5 is the largest model for the best performance. The design of MiT is similar to the Vison Transformer, but it is modified to adapt with the semantic segmentation, namely,

  • Hierarchical Feature Representation: Unlike ViT that can only generate a single-resolution feature map, MiT generate multi-level features to adapt with the semantic segmentation. We can see the multi-level features idea is one of the most important ideas for the semantic segmentation, for example: HRNET, PSPNet, DeepLab, FPN, ...
  • Overlapped Patch Merging: In Vision Transformer, a image input is splitted into partition patches. With the Mix Transformer, a image input is also splitted into patches, but there are overlapping.

    Comment:With the overlapping patches, the MiT are using a CNN layer. That helps the model learn better the local feature. Is that why we call Mix Transformer?

class OverlapPatchMerging(nn.Sequential):
    def __init__(
        self, in_channels: int, out_channels: int, patch_size: int, overlap_size: int
    ):
        super().__init__(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=patch_size,
                stride=overlap_size,
                padding=patch_size // 2,
                bias=False
            ),
            LayerNorm2d(out_channels)
        )
Partition Patch Overlapped Patch
  • Efficient Self-Attention: The main computation bottleneck of the encoders is the self-attention layer. The Efficient Self-Attention is implemented as the following:
class EfficientMultiHeadAttention(nn.Module):
    def __init__(self, channels: int, reduction_ratio: int = 1, num_heads: int = 8):
        super().__init__()
        self.reducer = nn.Sequential(
            nn.Conv2d(
                channels, channels, kernel_size=reduction_ratio, stride=reduction_ratio
            ),
            LayerNorm2d(channels),
        )
        self.att = nn.MultiheadAttention(
            channels, num_heads=num_heads, batch_first=True
        )

    def forward(self, x):
        _, _, h, w = x.shape
        reduced_x = self.reducer(x)
        # attention needs tensor of shape (batch, sequence_length, channels)
        reduced_x = rearrange(reduced_x, "b c h w -> b (h w) c")
        x = rearrange(x, "b c h w -> b (h w) c")
        out = self.att(x, reduced_x, reduced_x)[0]
        # reshape it back to (batch, channels, height, width)
        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
        return out
  • Mix-FFN: Authors don't use the positional encoding (PE) to introduce the location information as in the ViT. That is from the argument that positional encoding is actually not necessary for semantic segmentation. One intorduces the Mix-FFN is defined as:

$$x_{out} = MLP(GELU(CONV_{3 \times 3}(MLP(x_{in})))) + x_{in}$$

More precisely,

class MixMLP(nn.Sequential):
    def __init__(self, channels: int, expansion: int = 4):
        super().__init__(
            # dense layer
            nn.Conv2d(channels, channels, kernel_size=1),
            # depth wise conv
            nn.Conv2d(
                channels,
                channels * expansion,
                kernel_size=3,
                groups=channels,
                padding=1,
            ),
            nn.GELU(),
            # dense layer
            nn.Conv2d(channels * expansion, channels, kernel_size=1),
        )

3.2 Decoder

The Mix Transformer do well for the encoder part, then for the decoder part, we use All-MLP to fuse the multi-level features of the encoder part.

Each Block of MLP-ll has the following form:

class SegFormerDecoderBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
        super().__init__(
            nn.UpsamplingBilinear2d(scale_factor=scale_factor),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
        )

For the MLP-All. Now we can jump to the next part.

4. Traing the Segformer model with Pytorch Lightning and HuggingFace.

In this part we will discover how to train the Segformer model. In the part III, we have used the segmentation_models_pytorch to build a Unet model to deal with the nail the segmentation problem. Unfortunately, the segmentation_models_pytorch don't yet implement SegFormer model. There are some open sources that implement the SegFormer model:

  • MMSegmentation
  • Transformers - HuggingFace
  • Implementing SegFormer in PyTorch The first one is the officinal source code, but the model sticks with the MMSegmentation platform. It will be difficulty for unfamiliar people of the MMSegmentation platform. The third one is reimplemented from scratch, but the model is not trained for any data. So we cannot profit the pretrained weights. We choose the second one that is implemented and trained by the HuggingFace team.

We will reuse the datapipeline and modelpipeline of the third part of the tutorial series except that we will use the transformer library to build the Segformer model.

from transformers import SegformerForSemanticSegmentation

class SegFormer(nn.Module):
    def __init__(
        self, pretrained: str = "nvidia/segformer-b4-finetuned-ade-512-512", size: int = 512, num_labels: int = 9
    ):
        super().__init__()
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(
            pretrained, ignore_mismatched_sizes=True, num_labels=num_labels
        )
        self.size = size

    def forward(self, x):
        outputs = self.segformer(x)

        upsampled_logits = torch.nn.functional.interpolate(
            outputs.logits, size=(self.size, self.size), mode="bilinear", align_corners=False
        )
        return upsampled_logits

Here we use the pretrained of "nvidia/segformer-b4-finetuned-ade-512-512". It means that:

  • MiT-B4 Mix-Transformer is used to build the encoder part.
  • Weight is trained on the ADE 20K dataset.
  • Size of image = 512

    Note that the output of the SegFormer model is (128,128). We the use the resize function torch.nn.functional.interpolate. We can totally replace the resize function with any other weighted function:nn.ConvTranspose2d.

We then can define the model module and data module as the same in the part III:

model = SegFormer(config.model.encoder_name, config.model.size, config.model.classes)

datamodule = NailSegmentation(
    data_root=data_root,
    csv_path=csv_path,
    test_path="",
    batch_size=batch_size,
    num_workers=4,
    )

model_lighning = LitNailSegmentation(model=model, learning_rate=config.training.learning_rate)

And run the Trainer API

trainer = Trainer(args)
trainer.fit(
    model=model_lighning,
    datamodule=datamodule
)

We can find the full source code at github