Segmentation Model-Part VI - Training the Segformer model by using Pytorch Lightning and HuggingFace
The sixth part of the Segmentation Tutorial Series, a guide to developing the SerFormer model for segmentation problem.
- 1. Problem Description and Dataset
- 2. Data Preparation
- 3. The Segformer Model for the semanctic segmentation problem
- 4. Traing the Segformer model with Pytorch Lightning and HuggingFace.
This post is a demonstration of using the Segformer model in HuggingFace. We will focus on:
- Architecture of the Segformer model
- Traing the Segformer model by using Pytorch Lightning and HuggingFace.
Similar to the previous post, we will work with the Segmentation Problem (Nail Segmentation). In the first and second parts we will recall Problem Description and Dataset. If you have followed previous posts, you can skip those parts. In the third part, we will focus on the advantages of the Segformer model. The last part of the post will cover the training the Segformer model.
Our data is organized as
├── Images
│ ├── 1
│ ├── first_image.png
│ ├── second_image.png
│ ├── third_image.png
│ ├── 2
│ ├── 3
│ ├── 4
├── Masks
│ ├── 1
│ ├── first_image.png
│ ├── second_image.png
│ ├── third_image.png
│ ├── 2
│ ├── 3
│ ├── 4
We have two folders: Images
and Masks
. Images
is the data folder, and Masks
is the label folder, which is the segmentations of input images. Each folder has four sub-folder: 1
, 2
, 3
, and 4
, corresponding to four types of nail distribution.
We download data from link and put it in data_root
, for example
data_root = "./nail-segmentation-dataset"
2. Data Preparation
Similar to the training pipeline of the previous post, we first make the data frame to store images and masks infos.
index | images |
---|---|
1 | path_first_image.png |
2 | path_second_image.png |
3 | path_third_image.png |
4 | path_fourth_image.png |
For that we use make_csv_file
function in data_processing.py
file.
3. The Segformer Model for the semanctic segmentation problem
The SegFormer model was proposed in SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on image segmentation benchmarks.
The figure below illustrates the architecture of SegFormer
SegFormer has the following notable points:
- The new Transformer encoder (backbone): Mix Transformer (MiT) that extracts coarse and fine features
- The decoder is a MLP network to directly fuse the multi-level features of the encoder part and predicts the semantic segmentation mask
3.1 Encoder
The encoder of SegFormer is a Mix Transformer(MiT). There are six versions of encoders: MiT-B0 to MiT-B5. They have the same architecture, but different sizes. MiT-B0 is our lightweight model for fast inference, while MiT-B5 is the largest model for the best performance. The design of MiT is similar to the Vison Transformer, but it is modified to adapt with the semantic segmentation, namely,
- Hierarchical Feature Representation: Unlike ViT that can only generate a single-resolution feature map, MiT generate multi-level features to adapt with the semantic segmentation. We can see the multi-level features idea is one of the most important ideas for the semantic segmentation, for example: HRNET, PSPNet, DeepLab, FPN, ...
-
Overlapped Patch Merging: In Vision Transformer, a image input is splitted into partition patches. With the Mix Transformer, a image input is also splitted into patches, but there are overlapping.
Comment:With the overlapping patches, the MiT are using a CNN layer. That helps the model learn better the local feature. Is that why we call Mix Transformer?
class OverlapPatchMerging(nn.Sequential):
def __init__(
self, in_channels: int, out_channels: int, patch_size: int, overlap_size: int
):
super().__init__(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=patch_size,
stride=overlap_size,
padding=patch_size // 2,
bias=False
),
LayerNorm2d(out_channels)
)
Partition Patch | Overlapped Patch |
---|---|
-
Efficient Self-Attention: The main computation bottleneck of the encoders is the self-attention layer. The
Efficient Self-Attention
is implemented as the following:
class EfficientMultiHeadAttention(nn.Module):
def __init__(self, channels: int, reduction_ratio: int = 1, num_heads: int = 8):
super().__init__()
self.reducer = nn.Sequential(
nn.Conv2d(
channels, channels, kernel_size=reduction_ratio, stride=reduction_ratio
),
LayerNorm2d(channels),
)
self.att = nn.MultiheadAttention(
channels, num_heads=num_heads, batch_first=True
)
def forward(self, x):
_, _, h, w = x.shape
reduced_x = self.reducer(x)
# attention needs tensor of shape (batch, sequence_length, channels)
reduced_x = rearrange(reduced_x, "b c h w -> b (h w) c")
x = rearrange(x, "b c h w -> b (h w) c")
out = self.att(x, reduced_x, reduced_x)[0]
# reshape it back to (batch, channels, height, width)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
return out
-
Mix-FFN: Authors don't use the positional encoding (PE) to introduce the location information as in the
ViT. That is from the argument that positional encoding is actually not necessary for semantic segmentation. One intorduces the
Mix-FFN
is defined as:
$$x_{out} = MLP(GELU(CONV_{3 \times 3}(MLP(x_{in})))) + x_{in}$$
More precisely,
class MixMLP(nn.Sequential):
def __init__(self, channels: int, expansion: int = 4):
super().__init__(
# dense layer
nn.Conv2d(channels, channels, kernel_size=1),
# depth wise conv
nn.Conv2d(
channels,
channels * expansion,
kernel_size=3,
groups=channels,
padding=1,
),
nn.GELU(),
# dense layer
nn.Conv2d(channels * expansion, channels, kernel_size=1),
)
3.2 Decoder
The Mix Transformer do well for the encoder part, then for the decoder part, we use All-MLP to fuse the multi-level features of the encoder part.
Each Block of MLP-ll has the following form:
class SegFormerDecoderBlock(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
super().__init__(
nn.UpsamplingBilinear2d(scale_factor=scale_factor),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
For the MLP-All. Now we can jump to the next part.
4. Traing the Segformer model with Pytorch Lightning and HuggingFace.
In this part we will discover how to train the Segformer model. In the part III, we have used the segmentation_models_pytorch to build a Unet model to deal with the nail the segmentation problem. Unfortunately, the segmentation_models_pytorch don't yet implement SegFormer model. There are some open sources that implement the SegFormer model:
- MMSegmentation
- Transformers - HuggingFace
- Implementing SegFormer in PyTorch The first one is the officinal source code, but the model sticks with the MMSegmentation platform. It will be difficulty for unfamiliar people of the MMSegmentation platform. The third one is reimplemented from scratch, but the model is not trained for any data. So we cannot profit the pretrained weights. We choose the second one that is implemented and trained by the HuggingFace team.
We will reuse the datapipeline and modelpipeline of the third part of the tutorial series except that we will use the transformer
library to build the Segformer model.
from transformers import SegformerForSemanticSegmentation
class SegFormer(nn.Module):
def __init__(
self, pretrained: str = "nvidia/segformer-b4-finetuned-ade-512-512", size: int = 512, num_labels: int = 9
):
super().__init__()
self.segformer = SegformerForSemanticSegmentation.from_pretrained(
pretrained, ignore_mismatched_sizes=True, num_labels=num_labels
)
self.size = size
def forward(self, x):
outputs = self.segformer(x)
upsampled_logits = torch.nn.functional.interpolate(
outputs.logits, size=(self.size, self.size), mode="bilinear", align_corners=False
)
return upsampled_logits
Here we use the pretrained of "nvidia/segformer-b4-finetuned-ade-512-512"
. It means that:
- MiT-B4 Mix-Transformer is used to build the encoder part.
- Weight is trained on the ADE 20K dataset.
- Size of image = 512
Note that the output of the SegFormer model is (128,128). We the use the resize function
torch.nn.functional.interpolate
. We can totally replace the resize function with any other weighted function:nn.ConvTranspose2d
.
We then can define the model module
and data module
as the same in the part III:
model = SegFormer(config.model.encoder_name, config.model.size, config.model.classes)
datamodule = NailSegmentation(
data_root=data_root,
csv_path=csv_path,
test_path="",
batch_size=batch_size,
num_workers=4,
)
model_lighning = LitNailSegmentation(model=model, learning_rate=config.training.learning_rate)
And run the Trainer
API
trainer = Trainer(args)
trainer.fit(
model=model_lighning,
datamodule=datamodule
)
We can find the full source code at github