Segmentation Model-Part III - Training deep learning segmentation models in Pytorch Lightning
The third part of the Segmentation Tutorial Series, a step-by-step guide to developing deep learning segmentation models in Pytorch Lightning
- 1. Problem Description and Dataset
- 2. Data Preparation
- 3. Pytorch Lightnining
- 4. DataLoader
- 5. Model Module
- 6. Trainer
Continue Segmentation Model the series; in this post, we discuss how to train a segmentation model in Pytorch Lightning. PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. It is built on top of PyTorch.
We still work with the Segmentation Problem (Nail Segmentation) and discover some valuable tools for Pytorch Lightning. From this part, we will focus on the Pytorch Platform. Then for convenience, we recall some tasks of the previous post: Problem Description and Dataset, Data Preparation.
1. Problem Description and Dataset
We want to cover a nail semantic segmentation problem. For each image, we want to detect the segmentation of the nail in the image.
Images | Masks |
---|---|
Our data is organized as
├── Images
│ ├── 1
│ ├── first_image.png
│ ├── second_image.png
│ ├── third_image.png
│ ├── 2
│ ├── 3
│ ├── 4
├── Masks
│ ├── 1
│ ├── first_image.png
│ ├── second_image.png
│ ├── third_image.png
│ ├── 2
│ ├── 3
│ ├── 4
We have two folders: Images
and Masks
. Images
is the data folder, and Masks
is the label folder, which is the segmentations of input images. Each folder has four sub-folder: 1
, 2
, 3
, and 4
, corresponding to four types of nail distribution.
We download data from link and put it in data_root
, for example
data_root = "./nail-segmentation-dataset"
2. Data Preparation
We want a CSV file that stores the image and mask paths. In this project, file names of images and masks are the same, and then we only need to save the images
path and modify the data_root
of images and masks when we define a dataset.
index | images |
---|---|
1 | path_first_image.png |
2 | path_second_image.png |
3 | path_third_image.png |
4 | path_fourth_image.png |
For that we use make_csv_file
function in data_processing.py
file. More precisely,
def make_csv_file(data_root: Union[str, Path]) -> None:
list_images_train_masks = get_all_items(os.path.join(data_root, "train", "masks"))
list_images_train_images = get_all_items(os.path.join(data_root, "train", "images"))
list_images_train = [
i for i in list_images_train_images if i in list_images_train_masks
]
print(len(list_images_train))
list_images_valid = get_all_items(os.path.join(data_root, "valid", "masks"))
train_frame = pd.DataFrame(list_images_train, columns=["images"])
train_frame["train"] = 1
valid_frame = pd.DataFrame(list_images_valid, columns=["images"])
valid_frame["train"] = 0
mkdir(f"{data_root}/csv_file")
train_frame.to_csv(f"{data_root}/csv_file/train.csv", index=False)
valid_frame.to_csv(f"{data_root}/csv_file/valid.csv", index=False)
Where get_all_items
, mkdir
are two supported functions (defined in utils.py
file) that help us to find all items in a given folder and make a new folder.
Before going define the dataloader and model, let’s recall some main features of Pytorch Lightning
. For more information, you can find it at Pytorch Lightning.
3. Pytorch Lightnining
PyTorch Lightning is an open-source, lightweight Python wrapper for machine learning researchers that is built on top of PyTorch.
With this framework, you don’t have to remember all the tiny details of the PyTorch framework because Pytorch Lightnining handles it.
Three main features of Pytorch Lightning:
- LightningDataModule
- LightningModule
- Trainer
3.1 LightningDataModule
LightningDataModule
is a shareable, reusable class that encapsulates all the steps needed to process data:
- Data processing
- Load inside Dataset
- Apply transforms
- Wrap inside a DataLoader
3.2 LightningModule
A lightning module is composed of some components that fully define the system:
- The model or system of models
- The optimizer(s)
- The train loop
- The validation loop
3.3 Trainer
Once we declare LightningDataModule, LightningModule, we can train the model with Trainer
API.
A basic use of trainer:
modelmodule = LightningModule(*args_model)
datamodule = LightningDataModule(*args_data)
trainer = Trainer(*args_trainer)
trainer.fit(modelmodule, datamodule)
4. DataLoader
To define the LightningModule of our dataset, we first define the torch.utils.data.Dataset
for the nail data.
torch.utils.data.Dataset
for the Nail Data
4.1 Define class NailDataset(Dataset):
def __init__(self, data_root: str, csv_folder: str, train: str, tfms: A.Compose):
self.data_root = data_root
self.csv_folder = csv_folder
self.train = train
self.tfms = tfms
if self.train == "train":
self.ids = pd.read_csv(os.path.join(self.csv_folder, "train.csv"))["images"]
else:
self.ids = pd.read_csv(os.path.join(self.csv_folder, "valid.csv"))["images"]
def __len__(self) -> int:
return len(self.ids)
def __getitem__(self, idx: int) -> Any:
fname = self.ids[idx]
image = read_image(self.data_root + f"/{self.train}/images" + fname)
mask = read_mask(self.data_root + f"/{self.train}/masks" + fname)
mask = (mask > 0).astype(np.uint8)
if self.tfms is not None:
augmented = self.tfms(image=image, mask=mask)
image, mask = augmented["image"], augmented["mask"]
return {
"image": img2tensor(image),
"label": img2tensor(mask),
}
LightningDataModule
for the Nail Data
4.2 Define We then use LightningDataModule to wrap our NailDataset into the data module of Pytorch Lightning.
class NailSegmentation(LightningDataModule):
def __init__(self, data_root: str, csv_path: str, test_path: str, batch_size: int = 16, num_workers: int = 4):
super().__init__()
assert os.path.isdir(csv_path), f"missing folder: {csv_path}"
assert os.path.isdir(data_root), f"missing folder: {data_root}"
self.data_root = str(data_root)
self.csv_path = str(csv_path)
self.test_path = str(test_path)
self.valid_transform = valid_transform()
self.train_transform = train_transform()
# other configs
self.batch_size = batch_size
self.num_workers = num_workers if num_workers is not None else mproc.cpu_count()
def prepare_data(self) -> None:
pass
def setup(self, *_, **__) -> None:
self.train_dataset = NailDataset(
self.data_root,
self.csv_path,
train="train",
tfms=self.train_transform,
)
self.valid_dataset = NailDataset(
self.data_root,
self.csv_path,
train="valid",
tfms=self.valid_transform,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.valid_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
Here we need to define 3 main functions
- set_up: data loading and preprocessing
- train_dataloader(): define
train_loader
object in Pytorch- val_dataloader(): define
val_loader
object in Pytorch
5. Model Module
In this part we define:
- A segmentation model
- Wrap the model module by using LightningModule, for that we will define some main functions:
- def training_step : calculate {loss, metric}, logging in each train step
- def validation_step: calculate {loss, metric}, logging in each valid step
- def validation_epoch_end: calculate {loss, metric}, logging in each epoch by using infos of validation_step
- def configure_optimizers: which optimization and learning rate scheduler do we use for the training?
segmentation_models_pytorch
5.1 Define the model by using For convenience, we use segmentation_models_pytorch to define our model. Segmentation_models_pytorch
is a high-level API, it helps us build a semantic segmentation model with only some lines of code.
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="timm-efficientnet-b4", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB,
classes=1, # model output channels (number of classes in your dataset)
)
Here we use:
- Unet architecture for the segmentation model. Unet has two components: encoder and decoder
- encoder: EfficientNet B4 which is written by timm library
- input channels: 3 for RGB images, 1 for gray-scale images
- classes: 1 for binary segmentation, 2 for multi-class segmentation.
5.2 Define LightningModule
We next use LightningModule
to wrap the model into the model module of Pytorch Lightnining.
class LitNailSegmentation(LightningModule):
def __init__(self, model: nn.Module, learning_rate: float = 1e-4):
super().__init__()
self.model = model
self.loss_function = symmetric_lovasz
self.dice_soft = binary_dice_coefficient
self.learning_rate = learning_rate
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
imgs, masks = batch["image"], batch["label"]
imgs, masks = imgs.float(), masks.float()
logits = self(imgs)
train_loss = self.loss_function(logits, masks)
train_dice_soft = self.dice_soft(logits, masks)
self.log("train_loss", train_loss, prog_bar=True)
self.log("train_dice_soft", train_dice_soft, prog_bar=True)
return {"loss": train_loss, "train_dice_soft": train_dice_soft}
def validation_step(self, batch, batch_idx):
imgs, masks = batch["image"], batch["label"]
imgs, masks = imgs.float(), masks.float()
logits = self(imgs)
valid_loss = self.loss_function(logits, masks)
valid_dice_soft = self.dice_soft(logits, masks)
valid_iou = binary_mean_iou(logits, masks)
self.log("valid_loss", valid_loss, prog_bar=True)
self.log("valid_dice", valid_dice_soft, prog_bar=True)
self.log("valid_iou", valid_iou, prog_bar=True)
return {
"valid_loss": valid_loss,
"valid_dice": valid_dice_soft,
"valid_iou": valid_iou,
}
def validation_epoch_end(self, outputs):
logs = {"epoch": self.trainer.current_epoch}
valid_losses = torch.stack([x["valid_loss"] for x in outputs]).mean()
valid_dices = torch.stack([x["valid_dice"] for x in outputs]).mean()
valid_ious = torch.stack([x["valid_iou"] for x in outputs]).mean()
logs["valid_losses"] = valid_losses
logs["valid_dices"] = valid_dices
logs["valid_ious"] = valid_ious
return {
"valid_losses": valid_losses,
"valid_dices": valid_dices,
"valid_ious": valid_ious,
"log": logs,
}
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0)
self.optimizer = [optimizer]
return self.optimizer, [scheduler]
Here we use:
- AdamW as the optimizers
- symmetric_lovasz as the loss function, which is defined in the Loss.py file. symmetric_lovasz is defined by
def symmetric_lovasz(outputs, targets):
return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))
where lovasz_hinge is Lovasz loss for the binary segmentation.
- Metrics: Dice, IOU
6. Trainer
Once we have the data module, and model module, we can train the model with Trainer
API,
datamodule = NailSegmentation(
data_root=data_root,
csv_path=csv_path,
test_path="",
batch_size=batch_size,
num_workers=4,
)
model_lighning = LitNailSegmentation(model=model, learning_rate=config.training.learning_rate)
trainer = Trainer(*args_trainer)
trainer.fit(
model=model_lighning,
datamodule=datamodule,
ckpt_path=ckpt_path,
)
Here args_trainer
is the argument of the trainer
. More precisely, it has
{ gpus: [0] # gpu device to train
max_epochs: 300 # number of epochs
precision: 16 # using mix precision to train
auto_lr_find: True # auto find the good initial learning rate
limit_train_batches: 1.0 # percent of train dataset use to train, here 100%
...
}
Lightning implements various techniques to help during training that can help make the training smoother.
For more details, we can find the source code at github