ZooScan image segmentation using a UNet with a pretrained backbone¶
%load_ext autoreload
%autoreload 2
In this lab, we would like to train a deep neural network for performing semantic segmentation of Zooscan images. This labwork will guide through :
- how to load labeled zooscan images
- how to implement a UNet for performing the semantic segmentation
- training the network
Introduction¶
We are provided with labeled zooscan which means we have both the zooscan image and a mask. The mask is an image of labels. The data are located within the data
subdirectory; Every scan is around $20000\times 15000$ pixels.
!ls data/train
rg20210310_mask.png rg20210407_mask.png rg20211103_mask.png taxa.csv rg20210310_scan.png rg20210407_scan.png rg20211103_scan.png
!for f in data/train/*_scan.png; do file $f; done
data/train/rg20210310_scan.png: PNG image data, 22737 x 14503, 8-bit grayscale, non-interlaced data/train/rg20210407_scan.png: PNG image data, 22707 x 14373, 8-bit grayscale, non-interlaced data/train/rg20211103_scan.png: PNG image data, 22747 x 14573, 8-bit grayscale, non-interlaced
An example ZooScan image and its mask are displayed below
The colorcode of the mask is dependent on the class to which the pixel has been assigned by the human labeler Note this is certainly a tough work to pixel label these huge images but the quality of your data and labels are of the uppermost importance for our deep learning algorithms to work. The semantic of the class indices are provided in the taxa.csv
file.
!cat data/train/taxa.csv
label,living,label_nb artefact,False,1 badfocus<artefact,False,2 bubble,False,3 detritus,False,4 fiber<detritus,False,5 t001,False,6 t003,False,7 Acartiidae,True,8 Actinopterygii,True,9 Aglaura,True,10 Annelida,True,11 Aulacantha,True,12 Calanidae,True,13 Calanoida,True,14 Candaciidae,True,15 Cavolinia inflexa,True,16 Centropagidae,True,17 Chaetognatha,True,18 Chelophyes appendiculata,True,19 Collodaria,True,20 Corycaeidae,True,21 Creseidae,True,22 Creseis acicula,True,23 Doliolida,True,24 Euchaetidae,True,25 Eumalacostraca,True,26 Evadne,True,27 Flaccisagitta enflata,True,28 Fritillariidae,True,29 Gammaridea,True,30 Globigerinidae,True,31 Gymnosomata,True,32 Harpacticoida,True,33 Heterorhabdidae,True,34 Hydrozoa,True,35 Hyperiidea,True,36 Insecta,True,37 Limacinidae,True,38 Metridinidae,True,39 Neoceratium,True,40 Oikopleuridae,True,41 Oithonidae,True,42 Oncaeidae,True,43 Orbulina,True,44 Ostracoda,True,45 Penilia avirostris,True,46 Podon,True,47 Rhizaria,True,48 Rhopalonema velatum,True,49 Salpida,True,50 Sapphirinidae,True,51 Temoridae,True,52 bract<Abylopsis tetragona,True,53 bract<Diphyidae,True,54 calyptopsis<Euphausiacea,True,55 colony<Phaeodaria,True,56 damaged<Aulacantha,True,57 egg<Actinopterygii,True,58 egg<Mollusca,True,59 egg<other,True,60 endostyle<Salpidae,True,61 eudoxie<Diphyidae,True,62 gonophore<Abylopsis tetragona,True,63 gonophore<Diphyidae,True,64 head<Chaetognatha,True,65 juvenile<Salpida,True,66 larvae<Porcellanidae,True,67 like<Collodaria,True,68 multiple<other,True,69 nectophore<Diphyidae,True,70 nectophore<Physonectae,True,71 nucleus<Salpidae,True,72 othertocheck,True,73 part<Cnidaria,True,74 part<Crustacea,True,75 part<Mollusca,True,76 part<Siphonophorae,True,77 part<Thaliacea,True,78 pluteus<Echinoidea,True,79 pluteus<Ophiuroidea,True,80 protozoea<Mysida,True,81 seaweed,True,82 siphonula,True,83 tail<Appendicularia,True,84 tail<Chaetognatha,True,85 trunk<Appendicularia,True,86 zoea<Brachyura,True,87 zoea<Galatheidae,True,88
We do have $88$ classes of objects plus the background (label $0$). All the living stuff have been assigned a label strictly higher than $7$. The "living" column is exactly the same as testing if the label_nb is smaller or larger than $7$.
The classes are unbalanced. For example, counting the number of occurences of all the $89$ classes (background + $88$ non living stuff and living organisms), we obtain the count below, by decreasing occurence.
The most represented class is the background with almost $900$ millions pixels. The most represented non-background class is Salpida (class label $50$) with $60$ million pixels and the less represented class is Neoceratium with only $2$K pixels. This strong unbalance can induce a lot of trouble when training a neural network as the over-represented classes can be more easily learned than the under-represented ones.
Classe name | Count |
---|---|
background | 903586894 |
Salpida | 60370854 |
detritus | 2242808 |
multiple<other | 2142290 |
Calanoida | 1914332 |
endostyle<Salpidae | 1541269 |
badfocus<artefact | 1336249 |
bubble | 1302699 |
nucleus<Salpidae | 1003932 |
Calanidae | 835997 |
Chaetognatha | 807700 |
juvenile<Salpida | 701173 |
Euchaetidae | 538480 |
part<Crustacea | 485422 |
Centropagidae | 468518 |
Candaciidae | 462508 |
Corycaeidae | 446955 |
fiber<detritus | 366257 |
Eumalacostraca | 333003 |
Rhopalonema velatum | 324429 |
Metridinidae | 313694 |
Gammaridea | 310755 |
nectophore<Diphyidae | 287985 |
Creseis acicula | 285353 |
t001 | 260026 |
Flaccisagitta enflata | 259964 |
othertocheck | 259208 |
Chelophyes appendiculata | 252368 |
Oikopleuridae | 250231 |
Heterorhabdidae | 227147 |
gonophore<Diphyidae | 211023 |
part<Cnidaria | 204616 |
nectophore<Physonectae | 200246 |
Ostracoda | 195320 |
Temoridae | 180239 |
Oithonidae | 158871 |
tail<Appendicularia | 156565 |
Doliolida | 156015 |
protozoea<Mysida | 147630 |
bract<Diphyidae | 114702 |
Acartiidae | 112833 |
damaged<Aulacantha | 109413 |
zoea<Galatheidae | 107454 |
Cavolinia inflexa | 101898 |
Hyperiidea | 99523 |
calyptopsis<Euphausiacea | 92612 |
Hydrozoa | 88566 |
Aulacantha | 72204 |
seaweed | 70972 |
Sapphirinidae | 69538 |
tail<Chaetognatha | 69175 |
Gymnosomata | 65518 |
egg<other | 62427 |
Fritillariidae | 62113 |
part<Mollusca | 61776 |
Aglaura | 53920 |
Annelida | 53318 |
Creseidae | 52541 |
like<Collodaria | 52112 |
trunk<Appendicularia | 50062 |
part<Siphonophorae | 42509 |
eudoxie<Diphyidae | 34491 |
Oncaeidae | 29734 |
Limacinidae | 28553 |
gonophore<Abylopsis tetragona | 28054 |
egg<Actinopterygii | 27244 |
zoea<Brachyura | 25307 |
head<Chaetognatha | 24456 |
bract<Abylopsis tetragona | 23398 |
part<Thaliacea | 23155 |
Insecta | 22354 |
Actinopterygii | 17853 |
colony<Phaeodaria | 14070 |
artefact | 13792 |
Penilia avirostris | 13361 |
Harpacticoida | 11465 |
t003 | 10630 |
pluteus<Ophiuroidea | 10253 |
Collodaria | 8803 |
larvae<Porcellanidae | 8463 |
Podon | 7855 |
egg<Mollusca | 6938 |
pluteus<Echinoidea | 5779 |
Globigerinidae | 5704 |
Orbulina | 5534 |
Evadne | 5205 |
siphonula | 3574 |
Rhizaria | 2171 |
Neoceratium | 2046 |
The plot below shows the number of pixels per class. Note that the y-axis is in logscale. As we count the number of pixels per class, the unbalance could be explained by some organisms larger than others or by the over-representation of these. In any case, this unbalance will cause you trouble when training a neural network.
If we group the classes by non-living vs living, we obtain the following counts where you still have $10$ times more "non living pixels".
Classe name | Count |
---|---|
Non living | 909119355 |
Living | 78495098 |
Imports¶
All the code you are going to write will be into the following files :
data.py
: script responsible for handling the data pipeline, producing the dataloaders for the training and validation datamodels/
: submodule responsible for providing the different models you want to experiment
# Standard imports
import pathlib
# External imports
import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml
# Local imports
from planktoseg import data
from planktoseg import models
from planktoseg import optim
from planktoseg import utils
from planktoseg import metrics
from planktoseg import main
use_cuda = torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device("cpu")
logdir = pathlib.Path("./logs")
if not logdir.exists():
logdir.mkdir(exist_ok=True)
/usr/users/dce-admin/fix/GIT/2024_ml4oceans/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Data loading and exploration¶
In this part, we will be exploring the dataset :
- load and visualize some samples
- investigate the balance of the classes
- investigate the application of some data augmentation techniques.
The very first step when you write a pipeline for training neural networks on your data is to prepare your data pipeline. The expected output of this first step is to provide the data loaders, i.e. the python iterable objects able to give you minibatches for training and validation.
Dataset¶
In the data.py
script, you are provided with the PlanktonDataset
object. We wrote this class to ease your work of loading the data. This class can :
- load the data (image and mask),
- split every sample into patches and you can configure both the patch size and the patch stride (overlap = patch_size - patch_stride),
- apply transformations on the input patches and mask patches
- switch between two semantic segmentation tasks :
- classify a pixel as belonging to one of the $88$ classes (
task = SegmentationTask.LIVING_CLASSES
) - classify a pixel as belonging to non living vs living classes (
task = SegmentationTask.LIVING_NONLIVING
)
- classify a pixel as belonging to one of the $88$ classes (
As an example, in the cell below, we basically invoke this dataset object which is going to :
- split the large scans into patches of size $512 \times 512$ with a stride of $128 \times 128$
- merge the labels into two classes : Non living (label < 7) and Living (label >= 8)
- every patch and mask are transformed from an Image type to pytorch Tensor
transform = A.pytorch.ToTensorV2()
dataset = data.PlanktonDataset(root="./data/train", patch_size=(512, 512), patch_stride=(128, 128), task=data.SegmentationTask.LIVING_NONLIVING)
dataset = data.WrappedDataset(dataset, transform)
0%| | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:12<00:00, 4.07s/it]
We can now access both an image and its label by indexing the dataset. Feel free to repeat the execution of the cell below as it randomly samples the dataset for a new image/mask.
idx = np.random.randint(len(dataset))
img, mask = dataset[idx]
print(f"The img is of type {type(img)} and of shape {img.shape}")
print(f"The mask is of type {type(mask)} and of shape {mask.shape}")
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Mask")
plt.axis("off")
The img is of type <class 'torch.Tensor'> and of shape torch.Size([1, 512, 512]) The mask is of type <class 'torch.Tensor'> and of shape torch.Size([512, 512])
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
Exploring the augmentation transforms¶
Data (of quality) is of paramount importance but still limited in availability. We can virtually increase its availability by applying so called "data augmentations" which are transforms of your inputs for which you know how to transform the target.
For this, we will use the albumentations python library which offers several transforms for various tasks (classification, object detection, semantic segmentation)
Your first exercise is to identify, among the available data transforms (see for example this documentation and do not forget to scroll down to the spatial level transforms) which ones can be suitable. For example :
- A.HorizontalFlip, which randomly flips horizontally your image/mask pair,
- A.VerticalFlip which randomly flips vertically your image/mask pair,
- A.MaskDropout which randomly mask some objects
You could also Blur or add GaussianNoise. I believe these transforms still make sense for our plankton segmentation task given the snow we see on the image and the possibly un-focused objects. If you wonder about "Where should I stop adding transforms ?" / "How do I know my transforms are ok ?". Remember, the purpose of these transforms is to augment your data, hopefully to improve your validation metric.
idx = np.random.randint(len(dataset))
idx = 50807
# Sample the dataset without all the transforms
# to see what the image and mask originally look like
original_transform = A.pytorch.ToTensorV2()
dataset.transform = original_transform
orig_img, orig_mask = dataset[idx]
##########################################################################################
# TODO: You have to fill this part !
# Tune the augmented transform by prepending your choosen transform before the
# conversion to pytorch tensor
# Fill free to evaluate the cell as you add transforms to see their effect
transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.MaskDropout((1, 1), image_fill_value=255, p=1),
A.Blur(),
A.pytorch.ToTensorV2()
])
##########################################################################################
# And now sample the exact same image/mask
dataset.transform = transform
aug_img, aug_mask = dataset[idx]
plt.subplot(2, 2, 1)
plt.imshow(orig_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan original")
plt.axis("off")
plt.subplot(2, 2, 2)
plt.imshow(orig_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Original Mask")
plt.axis("off")
plt.subplot(2, 2, 3)
plt.imshow(aug_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan augmented image")
plt.axis("off")
plt.subplot(2, 2, 4)
plt.imshow(aug_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Augmented mask")
plt.axis("off")
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
Adding the transforms in the data pipeline¶
The suitable augmentation transforms you identified in the previous sections needs to be given in the data pipeline.
The most important function in the data.py
script is the get_dataloaders
function where :
- the dataset is split between a training and validation folds
- both with their own transforms (with augmentations for the training, without augmentation for the validation)
- creates the dataloaders providing the mini batches
In the get_dataloaders
function of the data.py
script, fill in the augmentation_transformations
empty list with the suitabe transforms you previously identified.
Once this is done, you can execute the following cell to obtain your dataloaders.
# The configuration below requires 22GB of VRAM
data_config = {"trainpath": "./data/train",
"valid_ratio": 0.2,
"batch_size": 32,
"num_workers": 7,
"patch_size": (256, 256),
"patch_stride": (128, 128),
"task": "living_nonliving",
"normalize": True}
train_loader, valid_loader, input_size, num_classes, normalizing_stats = data.get_dataloaders(data_config, use_cuda)
# The normalizing statistics must be saved because they will be need for inference
with open(logdir / "normalizing_stats.yaml", "w") as file:
yaml.dump(normalizing_stats, file)
0%| | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:12<00:00, 4.06s/it] 100%|██████████| 1474/1474 [00:49<00:00, 29.98it/s]
print(f"The train dataloaders contains {len(train_loader)} mini batches")
print(f"The valid dataloader contains {len(valid_loader)} mini batches")
print(f"Our problem contains {num_classes} classes and our inputs have the shape {input_size}")
print(f"For normalizing the input, we will be using the following statistics {normalizing_stats}")
The train dataloaders contains 1474 mini batches The valid dataloader contains 369 mini batches Our problem contains 2 classes and our inputs have the shape (1, 256, 256) For normalizing the input, we will be using the following statistics {'mean': 0.8201782290938752, 'std': 0.09110527465244618}
Encoder/Decoder with a pretrained backbone¶
Since we now have the data to train on, the next step is to implement the model. We are going to implement a U-Net, as seen during the lecture, with a pre-trained backbone.
In our codebase, it is the responsibility of the models
module to build the models. This module provides you a builder of models which is the models.build_model
function. The UNet model is defined in the models/unet.py
script as the UNet
class. If you look at the code of that class, you will notice it is built with two components : the encoder which is a TimmEncoder
and the decoder of type Decoder
.
To better appreciate what it produces as an output, your exercice is to :
- create an encoder with
cin=1
(grayscale input images) andmodel_name=resnet18
- perform a forward pass through the network and check the dimensionality of the outputs.
As we are only interested in checking the dimensionalities of the output of the encoder, you can use a dummy torch.zeros
tensor to perform the forward propagation.
from planktoseg.models.unet import TimmEncoder
encoder = TimmEncoder(cin=1, model_name="resnet18")
fake_input = torch.zeros(1, 1, 512, 512)
f4, [f1, f2, f3] = encoder(fake_input)
print(f"With an input of shape {fake_input.shape}, \n our model outputs tensors of shape :")
print(f" - f1 : {f1.shape}")
print(f" - f2 : {f2.shape}")
print(f" - f3 : {f3.shape}")
print(f" - f4 : {f4.shape}")
With an input of shape torch.Size([1, 1, 512, 512]), our model outputs tensors of shape : - f1 : torch.Size([1, 64, 128, 128]) - f2 : torch.Size([1, 128, 64, 64]) - f3 : torch.Size([1, 256, 32, 32]) - f4 : torch.Size([1, 512, 16, 16])
Implementing the decoder¶
We now move on the decoder part. The decoder receives the outputs of the encoder and progressively upscales the representation to finally obtain an output whose spatial dimensions match the ones of the input. With an input of shape $(B, C, H, W)$, the output of the final layer is expected to be of shape $(B, K, H, W)$ where $K$ is the number of classes you want to predict for every pixel.
The specificity of the UNet is that along the upscaling path, you integrate features from intermediate layers of the encoder through the so-called shortcut connections, the connections that provide the $f_3, f_2, f_1$ features depicted on the figure above.
Your exercice is to finish the code for the decoder in the models.py
script. As the encoder, the decoder is built from the repetition of blocks, so called DecoderBlock
in the code. The Decoder
class by itself is already coded but you must finish the code of the DecoderBlock
. You have to complete the code both in the constructor and for the forward pass.
A Decoder block receives, along the upscaling path, an input tensor with cin
channels and is built from a sequence of layers :
conv1
which is a Sequential block with 2D convolution, Batch normalization layer and ReLU. This sub-block should outputcin
channels,up_conv
which is a Sequential block with an upscaling layer followed by a 2D convolution, Batch normalization and ReLU. This sub-block should outputcin//2
channelsconv2
which is a sequential block with a 2D convolution, Batch normalization layer and ReLU. This sub-block takes as inputcin
channels and outputscin//2
channels.
For the forward pass, your decoder block receives two inputs : x
and f_encoder
. The tensor x
are the features that are flowing through the upscaling path while f_encoder
are the features received from the encoder through the shortcut connections. During the forward pass you need to :
- pass the
x
features throughconv1
andup_conv
, - concatenate the output of the previous step with the encoder features
f_encoder
, - pass the output of the previous step through the final sub-block
conv2
Note how the dimensions change during these steps. At the output of up_conv
, your tensor will have cin//2
channels. The f_encoder
contains cin//2
channels as well. When both are concatenated, your tensor has cin
channels and the final conv2
transforms these into cin//2
channels.
To test your implementation, you can run the following cell which should complete without errors.
from planktoseg.models.unet import DecoderBlock, Decoder
# First test : we check the forward pass through a decoder block is working
dummy_input = torch.zeros(3, 4, 128, 128)
dummy_encoder_features = torch.zeros(3, 2, 256, 256)
block = DecoderBlock(4)
output = block(dummy_input, dummy_encoder_features)
print(f"With an input of shape {dummy_input.shape} with encoder features of shape {dummy_encoder_features.shape}, the output of the decoder block is of shape {output.shape}")
assert list(output.shape) == [3, 2, 256, 256]
# Second test : we check the forward pass through a complete decoder is working
batch_size = 3
K = 10
f1 = torch.zeros(batch_size, 64, 128, 128)
f2 = torch.zeros(batch_size, 128, 64, 64)
f3 = torch.zeros(batch_size, 256, 32, 32)
f4 = torch.zeros(batch_size, 512, 16, 16)
decoder = Decoder(num_classes = K)
output = decoder(f4, [f1, f2, f3])
print(f"The output of the decoder is of shape {output.shape}")
assert list(output.shape) == [batch_size, K, 512, 512]
With an input of shape torch.Size([3, 4, 128, 128]) with encoder features of shape torch.Size([3, 2, 256, 256]), the output of the decoder block is of shape torch.Size([3, 2, 256, 256]) The output of the decoder is of shape torch.Size([3, 10, 512, 512])
Building the complete model¶
Once the encoder and decoder are implemented, we can create the complete model and send it to the device used for the experiments.
%%capture
model = models.UNet({"encoder": {"model_name": "resnet18"}}, input_size, 1 if num_classes == 2 else num_classes)
model = model.to(device)
Loss function and metrics for an unbalanced classification¶
Our problem is a classification problem, although pixel-wise. You need to classify every single pixel of an image. A natural first guess loss function in this case in the cross entropy loss which reads :
$$ CE(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -log(f_w(x_i)_{h,w,y_i}) $$
where we denote by $f_w(x_i)$ the probability distribution your model assigns to the $H \times W$ pixels of your image, so that $f_w(x_i)_{h, w, y_i}$ is the probability assigned by your model to the pixel $(h, w)$ and to the class $y_i$. That loss induces an over-influence of the majority class in an unbalanced dataset. Other losses may be prefered such as the focal loss, dice loss, weighted cross entropy loss, ...
In this lab, you are provided with an implementation of the focal loss. The focal loss will strongly decrease the influence of the pixels that are correctly predicted and, usually, these correspond to the pixels belonging to the over represented classes. It reads :
$$ CE(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -(1-p(y_{i,h,w} | x_i))^\gamma log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -(1 - f_w(x_i)_{h,w,y_i})^\gamma log(f_w(x_i)_{h,w,y_i}) $$
with, for example, $\gamma=2$. In order to illustrate the difference between the cross entropy loss and focal loss, we display below the two losses as a function of $p(y_{i,h,w} | x_i)$. As you can see, as the probability assigned by your model tends to $1$, the influence of the loss value is lowered with respect to the BCE loss.
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
gamma = 2
Nsteps = 50
p = torch.linspace(0., 1., steps=Nsteps)
labels = torch.ones_like(p)
bce_loss_values = -torch.log(p)
focal_loss_values = -(1-p)**gamma * torch.log(p)
plt.figure()
plt.plot(p, bce_loss_values, label='BCE loss')
plt.plot(p, focal_loss_values, label='Focal loss')
plt.xlabel("Probability assigned to the class to be predicted")
plt.ylabel("Loss value")
plt.legend()
<matplotlib.legend.Legend at 0x7f1ce8d17430>
Evaluate the following cell to define the loss as the Focal loss
loss = optim.FocalLoss()
Last elements : optimizer, early stopping, metrics, loggers, ...¶
The final elements we need are the optimizer, the early stopping callback, some metrics computations and possibly loggers.
To define all these, you just need to evaluate the following cell.
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Metrics evaluated over the training and validation folds
train_fmetrics = {
"focal": main.deepcs.metrics.GenericBatchMetric(loss),
"accuracy": main.BatchAccuracy(),
"confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
train_fmetrics["F1"] = metrics.BinaryF1Metric()
test_fmetrics = {
"focal": main.deepcs.metrics.GenericBatchMetric(loss),
"accuracy": main.BatchAccuracy(),
"confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
test_fmetrics["F1"] = metrics.BinaryF1Metric()
# Define the early stopping callback
model_checkpoint = utils.ModelCheckpoint(
model, logdir, (1,) + input_size, device, min_is_best=False
)
Our first experiments¶
All right, we now have all the building blocks for running our first training. Depending on your settings and GPU, it may be more or less fast :)
On a GTX 3090, with a batch size of $16$, processing $3000$ minibatches, with images of size $512 \times 512$ takes almost 20 minutes per epoch and 22GB of VRAM.
Also, I observed, while preparing the labwork that the batch size should not be too small. Using a batch size of $7$ led to convergence to bad quality local minimum while a batch size of $16$ allowed to better minimize the focal loss. Unfortunately, this implies a lower bound on the GPU to use.
num_epochs = 2
metrics_store = {"train": [], "valid": []}
def postprocess_metrics(fmetrics, metrics_dict):
"""
This function is used only for extracting
the precision and recall we track for displaying and ploting
"""
cm = fmetrics["confusion_matrix"]
metrics_dict["precision"] = cm.get_precision()
metrics_dict["recall"] = cm.get_recall()
# Evaluate the metrics before training
train_metrics = utils.test(model, train_loader, device, test_fmetrics)
postprocess_metrics(train_fmetrics, train_metrics)
metrics_store["train"].append(train_metrics)
valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
postprocess_metrics(test_fmetrics, valid_metrics)
metrics_store["valid"].append(valid_metrics)
for e in range(num_epochs):
# Train 1 epoch
train_metrics = utils.train(
model, train_loader, loss, optimizer, device, train_fmetrics
)
postprocess_metrics(train_fmetrics, train_metrics)
metrics_store["train"].append(train_metrics)
# Test
valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
postprocess_metrics(test_fmetrics, valid_metrics)
metrics_store["valid"].append(valid_metrics)
# Save the model if it is better
checkpoint_metric_name = "F1"
checkpoint_metric = valid_metrics[checkpoint_metric_name]
updated = model_checkpoint.update(checkpoint_metric)
# Display the metrics
metrics_msg = f" Epoch {e} / {num_epochs}\n"
metrics_msg += "- Train : \n "
metrics_msg += "\n ".join(
f" {m_name}: {m_value}" for (m_name, m_value) in train_metrics.items()
)
metrics_msg += "\n"
metrics_msg += "- Valid : \n "
metrics_msg += "\n ".join(
f" {m_name}: {m_value}"
+ ("[>> BETTER <<]" if updated and m_name == checkpoint_metric_name else "")
for (m_name, m_value) in valid_metrics.items()
)
print(metrics_msg)
0%| | 0/1474 [00:00<?, ?it/s]
100%|██████████| 1474/1474 [01:21<00:00, 18.17it/s] 100%|██████████| 369/369 [00:20<00:00, 18.01it/s] 100%|██████████| 1474/1474 [04:38<00:00, 5.29it/s] 100%|██████████| 369/369 [00:20<00:00, 17.83it/s]
Epoch 0 / 2 - Train : focal: 0.05041743235933004 accuracy: 0.9605594004834199 confusion_matrix: [[0.9954146118281357, 0.004585388171864319], [0.5870692592258612, 0.41293074077413877]] F1: 0.5416884766447951 precision: 0.7871255188568016 recall: 0.41293074077413877 - Valid : focal: 0.03458496961828569 accuracy: 0.8642685112506758 confusion_matrix: [[0.9850354573537219, 0.014964542646278133], [0.526730909307094, 0.473269090692906]] F1: 0.6034457223144885[>> BETTER <<] precision: 0.8324059615164311 recall: 0.473269090692906
100%|██████████| 1474/1474 [04:33<00:00, 5.39it/s] 100%|██████████| 369/369 [00:20<00:00, 17.89it/s]
Epoch 1 / 2 - Train : focal: 0.01360388492515849 accuracy: 0.960526061672692 confusion_matrix: [[0.9962440705799295, 0.00375592942007056], [0.5190544495534446, 0.4809455504465554]] F1: 0.6117580025823954 precision: 0.8403155563173192 recall: 0.4809455504465554 - Valid : focal: 0.017102164999813124 accuracy: 0.8642685112506758 confusion_matrix: [[0.9906546216510921, 0.009345378348907864], [0.48852997030393647, 0.5114700296960635]] F1: 0.6511491356972307[>> BETTER <<] precision: 0.8957810601232823 recall: 0.5114700296960635
During training, we recorded metrics into the metrics_store
dictionnary and we can display these metrics. This is a very basic way to plot these metrics, at the very end of the lab, we propose wandb.ai
which is way more convenient.
import matplotlib.pyplot as plt
plt.figure(figsize=(15,5), dpi=150)
plt.subplot(1, 4, 1)
plt.plot([mi["focal"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["focal"] for mi in metrics_store["valid"]], 'k--')
plt.title("Focal loss")
plt.xlabel("Epoch")
plt.subplot(1, 4, 2)
plt.plot([mi["F1"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["F1"] for mi in metrics_store["valid"]], 'k--')
plt.title("F1 score")
plt.xlabel("Epoch")
plt.subplot(1, 4, 3)
plt.plot([mi["precision"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["precision"] for mi in metrics_store["valid"]], 'k--')
plt.title("Precision")
plt.xlabel("Epoch")
plt.subplot(1, 4, 4)
plt.plot([mi["recall"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["recall"] for mi in metrics_store["valid"]], 'k--')
plt.title("Recall")
plt.xlabel("Epoch")
Text(0.5, 0, 'Epoch')
Inference on new zooscan images¶
Now that we trained our first UNet, we would like to apply it on new data, to perform so called "inference". In the training loop above, the model that we consider the best is the one minimizing the F1 score on the validation data. This best model has been saved during the optimization as a ONNX file. There are actually several ways to save a model, for example as either a torch tensor of parameters or a ONNX graph and the ONNX export is certainly the most portable of the two.
ONNX graphs can be executed with any runtime in a lot of different languages.
The cell below is a standalone cell. It can be executed without the other cells being executed. This is really the complete inference code.
Also, the patch size for inference is arbitrary. Indeed, as a UNet is a fully convolutional model, it can be trained on patches of shape, say $512 \times 512$, and then seemingly be evaluated on patches of arbitrary sizes, for example $4096 \times 4096$. To test inference, my advice would be to start with small patch sizes and then increase it. You may fill your memory if you use a too large patch size. To perform inference on a large image, a naive approach is to split it into smaller non overlapping patches, perform inference on each and then stick them together.
To run the following cell, you may need to restart your kernel so that the memory gets freed on the GPU.
import pathlib
import matplotlib.pyplot as plt
import onnxruntime as ort
from PIL import Image
Image.MAX_IMAGE_PIXELS = 25000 * 15000
import numpy as np
import yaml
providers = []
use_cuda = True
patch_size = 4096
if use_cuda:
providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")
# You may adapt the following to either use
# the model you trained or the pretrained model you are provided
# logdir = pathlib.Path("pretrained_model")
logdir = pathlib.Path("logs")
inference_session = ort.InferenceSession(
str(logdir / "best_model.onnx"), providers=providers
)
# Load our normalizing statistics
stats = yaml.safe_load(open(str(logdir / "normalizing_stats.yaml"), "r"))
mean = stats["mean"]
std = stats["std"]
# Load our image
scan_path = "./data/test/rg20210421_scan.png"
scan_img = np.array(Image.open(scan_path))
mask_path = "./data/test/rg20210421_mask.png"
mask_img = np.array(Image.open(mask_path))
print(mask_img.shape)
crop_offset = (5048, 2048)
# Normalize our input
scan_img = ((scan_img - mean * 255.)/(std * 255.)).astype(np.float32)
scan_img = scan_img[np.newaxis, np.newaxis, ...]
scan_img = scan_img[:, :, crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)]
# Get the ground truth mask
# print(np.unique(mask_img))
mask_img = mask_img[crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)] >= 8
# Perform an inference
logits = inference_session.run(None, {"scan": scan_img})[0]
probs = 1.0 / (1.0 + np.exp(-logits))
pred_mask = probs >= 0.5
# Plot the results
plt.figure(dpi=300)
plt.subplot(1, 4, 1)
plt.imshow(scan_img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")
plt.subplot(1, 4, 2)
plt.imshow(probs.squeeze(), interpolation="none", clim=(0.0, 1.0))
plt.title("Probabilities")
plt.axis("off")
plt.subplot(1, 4, 3)
plt.imshow(pred_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Predicted Mask")
plt.axis("off")
plt.subplot(1, 4, 4)
plt.imshow(mask_img, interpolation="none", cmap="tab20c")
plt.title("Ground truth")
plt.axis("off")
plt.tight_layout()
(14573, 22817)
Going further¶
Using DeepLab v3+ instead of the UNet with pretrained backbone¶
All right you implemented and trained a custom U-Net with a pretrained backbone. What about experimenting with a more state of the art model such as a deeplab v3+ ? For this, you have the possibility to use the segmentation models provided by torchvision or rely on an external library such as segmentation models_pytorch.
%pip install segmentation_models_pytorch
Collecting segmentation_models_pytorch Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl (109 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 109.5/109.5 KB 1.4 MB/s eta 0:00:00a 0:00:01 Collecting efficientnet-pytorch==0.7.1 Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB) Preparing metadata (setup.py) ... done Requirement already satisfied: huggingface-hub>=0.24.6 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.25.1) Collecting pretrainedmodels==0.7.4 Downloading pretrainedmodels-0.7.4.tar.gz (58 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 KB 1.6 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Requirement already satisfied: six in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (1.16.0) Collecting timm==0.9.7 Downloading timm-0.9.7-py3-none-any.whl (2.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 5.1 MB/s eta 0:00:00:00:0100:01 Requirement already satisfied: torchvision>=0.5.0 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.19.1) Requirement already satisfied: pillow in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (10.4.0) Requirement already satisfied: tqdm in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (4.66.5) Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (from efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.4.1) Collecting munch Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB) Requirement already satisfied: safetensors in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (0.4.5) Requirement already satisfied: pyyaml in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (6.0.2) Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.16.1) Requirement already satisfied: typing-extensions>=3.7.4.3 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (4.12.2) Requirement already satisfied: packaging>=20.9 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (24.1) Requirement already satisfied: fsspec>=2023.5.0 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.9.0) Requirement already satisfied: requests in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.32.3) Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (from torchvision>=0.5.0->segmentation_models_pytorch) (2.1.1) Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.0.106) Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105) Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (10.3.2.106) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105) Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.1.4) Requirement already satisfied: triton==3.0.0 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.0.0) Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.13.3) Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.3.1) Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.20.5) Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.4.5.107) Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.3) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105) Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (9.1.0.70) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105) Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.0.2.54) Requirement already satisfied: nvidia-nvjitlink-cu12 in ./venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.6.77) Requirement already satisfied: urllib3<3,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.8.30) Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.10) Requirement already satisfied: MarkupSafe>=2.0 in ./venv/lib/python3.10/site-packages (from jinja2->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.1.5) Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./venv/lib/python3.10/site-packages (from sympy->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.3.0) Using legacy 'setup.py install' for efficientnet-pytorch, since package 'wheel' is not installed. Using legacy 'setup.py install' for pretrainedmodels, since package 'wheel' is not installed. Installing collected packages: munch, efficientnet-pytorch, timm, pretrainedmodels, segmentation_models_pytorch Running setup.py install for efficientnet-pytorch ... done Attempting uninstall: timm Found existing installation: timm 1.0.9 Uninstalling timm-1.0.9: Successfully uninstalled timm-1.0.9 Running setup.py install for pretrainedmodels ... done Successfully installed efficientnet-pytorch-0.7.1 munch-4.0.0 pretrainedmodels-0.7.4 segmentation_models_pytorch-0.3.4 timm-0.9.7 Note: you may need to restart the kernel to use updated packages.
For example, below is a snippet to construct DeepLabV3+ as proposed in Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. You can now go back to your code and experiment with one the models proposed by segmentation_models_pytorch
.
import segmentation_models_pytorch as smp
model = smp.DeepLabV3Plus(
encoder_name="resnet18",
encoder_weights="imagenet",
in_channels=1,
classes=1,
)
dummy_input = torch.zeros(3, 1, 512, 512)
output = model(dummy_input)
print(f"The output has a shape {output.shape}")
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /usr/users/dce-admin/fix/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth 100%|██████████| 44.7M/44.7M [00:04<00:00, 11.7MB/s]
The output has a shape torch.Size([3, 1, 512, 512])
Interfacing with an online dashboard : wandb.ai¶
As you may expect, deep learning is about tuning your whole pipeline, experimenting with different models and so on and you need to efficiently track all the experiments you will be running.
One part of the answer lies in using dashboards, especially online dashboard are super useful. One such dashboard is wandb.ai. To monitor your experiments with wandb, you just need to add few lines of code.
First, you need to get an account on wandb. This is free and academics have free plans. Once you have an account, you can create a new project, for which you will obtain both a project
name and an entity
name.
Then you can connect to your wandb calling
import wandb
wandb.init(project=..., entity=....)
and finally, to record your metrics during training, you simply need to call the following function after every epoch :
wandb.log({my_metric_name: my_metric_value})
In our case, we have two dictionnaries train_metrics
and test_metrics
which contain our metrics, we just need to call wandb.log
on them to get your metrics logged :
Running the code without jupyter¶
Between you and me, a jupyter notebook is great for exploring some concepts but you definitely want to move away from these if you really want to make large scale experiments because running jupyter notebooks is not the most efficient way to perform multiple experiments.
Actually, I adapted a python library I wrote for plankton segmentation so that you are able to evaluate everything from within a notebook.
The library is available on github at https://github.com/jeremyfix/2024_ml4oceans. Feel free to make use of it. You are also provided with a python script submit_slurm.py
which allows you to easily run experiments on clusters managed by SLURM (e.g. Jean Zay). Note you will certainly have to adapt the slurm sbatch directives to make it work (e.g. partition name, ...)