ZooScan image segmentation using a UNet with a pretrained backbone¶

In [ ]:
%load_ext autoreload
%autoreload 2

In this lab, we would like to train a deep neural network for performing semantic segmentation of Zooscan images. This labwork will guide through :

  • how to load labeled zooscan images
  • how to implement a UNet for performing the semantic segmentation
  • training the network

Introduction¶

We are provided with labeled zooscan which means we have both the zooscan image and a mask. The mask is an image of labels. The data are located within the data subdirectory; Every scan is around $20000\times 15000$ pixels.

In [ ]:
!ls data/train
rg20210310_mask.png  rg20210407_mask.png  rg20211103_mask.png  taxa.csv
rg20210310_scan.png  rg20210407_scan.png  rg20211103_scan.png
In [ ]:
!for f in data/train/*_scan.png; do file $f; done
data/train/rg20210310_scan.png: PNG image data, 22737 x 14503, 8-bit grayscale, non-interlaced
data/train/rg20210407_scan.png: PNG image data, 22707 x 14373, 8-bit grayscale, non-interlaced
data/train/rg20211103_scan.png: PNG image data, 22747 x 14573, 8-bit grayscale, non-interlaced

An example ZooScan image and its mask are displayed below

Zooscan image and mask example

The colorcode of the mask is dependent on the class to which the pixel has been assigned by the human labeler Note this is certainly a tough work to pixel label these huge images but the quality of your data and labels are of the uppermost importance for our deep learning algorithms to work. The semantic of the class indices are provided in the taxa.csv file.

In [ ]:
!cat data/train/taxa.csv
label,living,label_nb
artefact,False,1
badfocus<artefact,False,2
bubble,False,3
detritus,False,4
fiber<detritus,False,5
t001,False,6
t003,False,7
Acartiidae,True,8
Actinopterygii,True,9
Aglaura,True,10
Annelida,True,11
Aulacantha,True,12
Calanidae,True,13
Calanoida,True,14
Candaciidae,True,15
Cavolinia inflexa,True,16
Centropagidae,True,17
Chaetognatha,True,18
Chelophyes appendiculata,True,19
Collodaria,True,20
Corycaeidae,True,21
Creseidae,True,22
Creseis acicula,True,23
Doliolida,True,24
Euchaetidae,True,25
Eumalacostraca,True,26
Evadne,True,27
Flaccisagitta enflata,True,28
Fritillariidae,True,29
Gammaridea,True,30
Globigerinidae,True,31
Gymnosomata,True,32
Harpacticoida,True,33
Heterorhabdidae,True,34
Hydrozoa,True,35
Hyperiidea,True,36
Insecta,True,37
Limacinidae,True,38
Metridinidae,True,39
Neoceratium,True,40
Oikopleuridae,True,41
Oithonidae,True,42
Oncaeidae,True,43
Orbulina,True,44
Ostracoda,True,45
Penilia avirostris,True,46
Podon,True,47
Rhizaria,True,48
Rhopalonema velatum,True,49
Salpida,True,50
Sapphirinidae,True,51
Temoridae,True,52
bract<Abylopsis tetragona,True,53
bract<Diphyidae,True,54
calyptopsis<Euphausiacea,True,55
colony<Phaeodaria,True,56
damaged<Aulacantha,True,57
egg<Actinopterygii,True,58
egg<Mollusca,True,59
egg<other,True,60
endostyle<Salpidae,True,61
eudoxie<Diphyidae,True,62
gonophore<Abylopsis tetragona,True,63
gonophore<Diphyidae,True,64
head<Chaetognatha,True,65
juvenile<Salpida,True,66
larvae<Porcellanidae,True,67
like<Collodaria,True,68
multiple<other,True,69
nectophore<Diphyidae,True,70
nectophore<Physonectae,True,71
nucleus<Salpidae,True,72
othertocheck,True,73
part<Cnidaria,True,74
part<Crustacea,True,75
part<Mollusca,True,76
part<Siphonophorae,True,77
part<Thaliacea,True,78
pluteus<Echinoidea,True,79
pluteus<Ophiuroidea,True,80
protozoea<Mysida,True,81
seaweed,True,82
siphonula,True,83
tail<Appendicularia,True,84
tail<Chaetognatha,True,85
trunk<Appendicularia,True,86
zoea<Brachyura,True,87
zoea<Galatheidae,True,88

We do have $88$ classes of objects plus the background (label $0$). All the living stuff have been assigned a label strictly higher than $7$. The "living" column is exactly the same as testing if the label_nb is smaller or larger than $7$.

The classes are unbalanced. For example, counting the number of occurences of all the $89$ classes (background + $88$ non living stuff and living organisms), we obtain the count below, by decreasing occurence.

The most represented class is the background with almost $900$ millions pixels. The most represented non-background class is Salpida (class label $50$) with $60$ million pixels and the less represented class is Neoceratium with only $2$K pixels. This strong unbalance can induce a lot of trouble when training a neural network as the over-represented classes can be more easily learned than the under-represented ones.

Classe name Count
background 903586894
Salpida 60370854
detritus 2242808
multiple<other 2142290
Calanoida 1914332
endostyle<Salpidae 1541269
badfocus<artefact 1336249
bubble 1302699
nucleus<Salpidae 1003932
Calanidae 835997
Chaetognatha 807700
juvenile<Salpida 701173
Euchaetidae 538480
part<Crustacea 485422
Centropagidae 468518
Candaciidae 462508
Corycaeidae 446955
fiber<detritus 366257
Eumalacostraca 333003
Rhopalonema velatum 324429
Metridinidae 313694
Gammaridea 310755
nectophore<Diphyidae 287985
Creseis acicula 285353
t001 260026
Flaccisagitta enflata 259964
othertocheck 259208
Chelophyes appendiculata 252368
Oikopleuridae 250231
Heterorhabdidae 227147
gonophore<Diphyidae 211023
part<Cnidaria 204616
nectophore<Physonectae 200246
Ostracoda 195320
Temoridae 180239
Oithonidae 158871
tail<Appendicularia 156565
Doliolida 156015
protozoea<Mysida 147630
bract<Diphyidae 114702
Acartiidae 112833
damaged<Aulacantha 109413
zoea<Galatheidae 107454
Cavolinia inflexa 101898
Hyperiidea 99523
calyptopsis<Euphausiacea 92612
Hydrozoa 88566
Aulacantha 72204
seaweed 70972
Sapphirinidae 69538
tail<Chaetognatha 69175
Gymnosomata 65518
egg<other 62427
Fritillariidae 62113
part<Mollusca 61776
Aglaura 53920
Annelida 53318
Creseidae 52541
like<Collodaria 52112
trunk<Appendicularia 50062
part<Siphonophorae 42509
eudoxie<Diphyidae 34491
Oncaeidae 29734
Limacinidae 28553
gonophore<Abylopsis tetragona 28054
egg<Actinopterygii 27244
zoea<Brachyura 25307
head<Chaetognatha 24456
bract<Abylopsis tetragona 23398
part<Thaliacea 23155
Insecta 22354
Actinopterygii 17853
colony<Phaeodaria 14070
artefact 13792
Penilia avirostris 13361
Harpacticoida 11465
t003 10630
pluteus<Ophiuroidea 10253
Collodaria 8803
larvae<Porcellanidae 8463
Podon 7855
egg<Mollusca 6938
pluteus<Echinoidea 5779
Globigerinidae 5704
Orbulina 5534
Evadne 5205
siphonula 3574
Rhizaria 2171
Neoceratium 2046

The plot below shows the number of pixels per class. Note that the y-axis is in logscale. As we count the number of pixels per class, the unbalance could be explained by some organisms larger than others or by the over-representation of these. In any case, this unbalance will cause you trouble when training a neural network.

Number of pixels per class

If we group the classes by non-living vs living, we obtain the following counts where you still have $10$ times more "non living pixels".

Classe name Count
Non living 909119355
Living 78495098

Imports¶

All the code you are going to write will be into the following files :

  • data.py : script responsible for handling the data pipeline, producing the dataloaders for the training and validation data
  • models/ : submodule responsible for providing the different models you want to experiment
In [ ]:
# Standard imports
import pathlib

# External imports
import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml

# Local imports
from planktoseg import data
from planktoseg import models
from planktoseg import optim
from planktoseg import utils
from planktoseg import metrics
from planktoseg import main

use_cuda = torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device("cpu")

logdir = pathlib.Path("./logs")
if not logdir.exists():
    logdir.mkdir(exist_ok=True)
/usr/users/dce-admin/fix/GIT/2024_ml4oceans/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Data loading and exploration¶

In this part, we will be exploring the dataset :

  • load and visualize some samples
  • investigate the balance of the classes
  • investigate the application of some data augmentation techniques.

The very first step when you write a pipeline for training neural networks on your data is to prepare your data pipeline. The expected output of this first step is to provide the data loaders, i.e. the python iterable objects able to give you minibatches for training and validation.

Dataset¶

In the data.py script, you are provided with the PlanktonDataset object. We wrote this class to ease your work of loading the data. This class can :

  • load the data (image and mask),
  • split every sample into patches and you can configure both the patch size and the patch stride (overlap = patch_size - patch_stride),
  • apply transformations on the input patches and mask patches
  • switch between two semantic segmentation tasks :
    1. classify a pixel as belonging to one of the $88$ classes (task = SegmentationTask.LIVING_CLASSES)
    2. classify a pixel as belonging to non living vs living classes (task = SegmentationTask.LIVING_NONLIVING)

As an example, in the cell below, we basically invoke this dataset object which is going to :

  • split the large scans into patches of size $512 \times 512$ with a stride of $128 \times 128$
  • merge the labels into two classes : Non living (label < 7) and Living (label >= 8)
  • every patch and mask are transformed from an Image type to pytorch Tensor
In [ ]:
transform = A.pytorch.ToTensorV2()

dataset = data.PlanktonDataset(root="./data/train", patch_size=(512, 512), patch_stride=(128, 128), task=data.SegmentationTask.LIVING_NONLIVING)
dataset = data.WrappedDataset(dataset, transform)
  0%|          | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:12<00:00,  4.07s/it]

We can now access both an image and its label by indexing the dataset. Feel free to repeat the execution of the cell below as it randomly samples the dataset for a new image/mask.

In [ ]:
idx = np.random.randint(len(dataset)) 

img, mask = dataset[idx]

print(f"The img is of type {type(img)} and of shape {img.shape}")
print(f"The mask is of type {type(mask)} and of shape {mask.shape}")

plt.subplot(1, 2, 1)
plt.imshow(img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Mask")
plt.axis("off")
The img is of type <class 'torch.Tensor'> and of shape torch.Size([1, 512, 512])
The mask is of type <class 'torch.Tensor'> and of shape torch.Size([512, 512])
Out[ ]:
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
No description has been provided for this image

Exploring the augmentation transforms¶

Data (of quality) is of paramount importance but still limited in availability. We can virtually increase its availability by applying so called "data augmentations" which are transforms of your inputs for which you know how to transform the target.

For this, we will use the albumentations python library which offers several transforms for various tasks (classification, object detection, semantic segmentation)

Exercice

Your first exercise is to identify, among the available data transforms (see for example this documentation and do not forget to scroll down to the spatial level transforms) which ones can be suitable. For example :

  • A.HorizontalFlip, which randomly flips horizontally your image/mask pair,
  • A.VerticalFlip which randomly flips vertically your image/mask pair,
  • A.MaskDropout which randomly mask some objects

You could also Blur or add GaussianNoise. I believe these transforms still make sense for our plankton segmentation task given the snow we see on the image and the possibly un-focused objects. If you wonder about "Where should I stop adding transforms ?" / "How do I know my transforms are ok ?". Remember, the purpose of these transforms is to augment your data, hopefully to improve your validation metric.

In [ ]:
idx = np.random.randint(len(dataset)) 
idx = 50807

# Sample the dataset without all the transforms
# to see what the image and mask originally look like
original_transform = A.pytorch.ToTensorV2()
dataset.transform = original_transform
orig_img, orig_mask = dataset[idx]

##########################################################################################
# TODO: You have to fill this part !
# Tune the augmented transform by prepending your choosen transform before the 
# conversion to pytorch tensor
# Fill free to evaluate the cell as you add transforms to see their effect
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.MaskDropout((1, 1), image_fill_value=255, p=1), 
    A.Blur(),
    A.pytorch.ToTensorV2()
])
##########################################################################################

# And now sample the exact same image/mask 
dataset.transform = transform
aug_img, aug_mask = dataset[idx]

plt.subplot(2, 2, 1)
plt.imshow(orig_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan original")
plt.axis("off")

plt.subplot(2, 2, 2)
plt.imshow(orig_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Original Mask")
plt.axis("off")

plt.subplot(2, 2, 3)
plt.imshow(aug_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan augmented image")
plt.axis("off")

plt.subplot(2, 2, 4)
plt.imshow(aug_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Augmented mask")
plt.axis("off")
Out[ ]:
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
No description has been provided for this image

Adding the transforms in the data pipeline¶

The suitable augmentation transforms you identified in the previous sections needs to be given in the data pipeline.

The most important function in the data.py script is the get_dataloaders function where :

  • the dataset is split between a training and validation folds
  • both with their own transforms (with augmentations for the training, without augmentation for the validation)
  • creates the dataloaders providing the mini batches
Exercice

In the get_dataloaders function of the data.py script, fill in the augmentation_transformations empty list with the suitabe transforms you previously identified.

Once this is done, you can execute the following cell to obtain your dataloaders.

In [ ]:
# The configuration below requires 22GB of VRAM
data_config = {"trainpath": "./data/train", 
               "valid_ratio": 0.2,
               "batch_size": 32,
               "num_workers": 7,
               "patch_size": (256, 256),
               "patch_stride": (128, 128),
               "task": "living_nonliving",
               "normalize": True}

train_loader, valid_loader, input_size, num_classes, normalizing_stats = data.get_dataloaders(data_config, use_cuda)

# The normalizing statistics must be saved because they will be need for inference
with open(logdir / "normalizing_stats.yaml", "w") as file:
    yaml.dump(normalizing_stats, file)
  0%|          | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:12<00:00,  4.06s/it]
100%|██████████| 1474/1474 [00:49<00:00, 29.98it/s]
In [ ]:
print(f"The train dataloaders contains {len(train_loader)} mini batches")
print(f"The valid dataloader contains {len(valid_loader)} mini batches")
print(f"Our problem contains {num_classes} classes and our inputs have the shape {input_size}")
print(f"For normalizing the input, we will be using the following statistics {normalizing_stats}")
The train dataloaders contains 1474 mini batches
The valid dataloader contains 369 mini batches
Our problem contains 2 classes and our inputs have the shape (1, 256, 256)
For normalizing the input, we will be using the following statistics {'mean': 0.8201782290938752, 'std': 0.09110527465244618}

Encoder/Decoder with a pretrained backbone¶

Since we now have the data to train on, the next step is to implement the model. We are going to implement a U-Net, as seen during the lecture, with a pre-trained backbone.

UNet architecture

In our codebase, it is the responsibility of the models module to build the models. This module provides you a builder of models which is the models.build_model function. The UNet model is defined in the models/unet.py script as the UNet class. If you look at the code of that class, you will notice it is built with two components : the encoder which is a TimmEncoder and the decoder of type Decoder.

Experimenting with the encoder¶

The code for the encoder is already given to you. This code will load a pretrained model provided by the timm library.

Exercice

To better appreciate what it produces as an output, your exercice is to :

  • create an encoder with cin=1 (grayscale input images) and model_name=resnet18
  • perform a forward pass through the network and check the dimensionality of the outputs.

As we are only interested in checking the dimensionalities of the output of the encoder, you can use a dummy torch.zeros tensor to perform the forward propagation.

In [ ]:
from planktoseg.models.unet import TimmEncoder

encoder = TimmEncoder(cin=1, model_name="resnet18")
fake_input = torch.zeros(1, 1, 512, 512)
f4, [f1, f2, f3] = encoder(fake_input)

print(f"With an input of shape {fake_input.shape}, \n our model outputs tensors of shape :")
print(f" - f1 : {f1.shape}")
print(f" - f2 : {f2.shape}")
print(f" - f3 : {f3.shape}")
print(f" - f4 : {f4.shape}")
With an input of shape torch.Size([1, 1, 512, 512]), 
 our model outputs tensors of shape :
 - f1 : torch.Size([1, 64, 128, 128])
 - f2 : torch.Size([1, 128, 64, 64])
 - f3 : torch.Size([1, 256, 32, 32])
 - f4 : torch.Size([1, 512, 16, 16])

Implementing the decoder¶

We now move on the decoder part. The decoder receives the outputs of the encoder and progressively upscales the representation to finally obtain an output whose spatial dimensions match the ones of the input. With an input of shape $(B, C, H, W)$, the output of the final layer is expected to be of shape $(B, K, H, W)$ where $K$ is the number of classes you want to predict for every pixel.

The specificity of the UNet is that along the upscaling path, you integrate features from intermediate layers of the encoder through the so-called shortcut connections, the connections that provide the $f_3, f_2, f_1$ features depicted on the figure above.

Exercice

Your exercice is to finish the code for the decoder in the models.py script. As the encoder, the decoder is built from the repetition of blocks, so called DecoderBlock in the code. The Decoder class by itself is already coded but you must finish the code of the DecoderBlock. You have to complete the code both in the constructor and for the forward pass.

A Decoder block receives, along the upscaling path, an input tensor with cin channels and is built from a sequence of layers :

  • conv1 which is a Sequential block with 2D convolution, Batch normalization layer and ReLU. This sub-block should output cin channels,
  • up_conv which is a Sequential block with an upscaling layer followed by a 2D convolution, Batch normalization and ReLU. This sub-block should output cin//2 channels
  • conv2 which is a sequential block with a 2D convolution, Batch normalization layer and ReLU. This sub-block takes as input cin channels and outputs cin//2 channels.

For the forward pass, your decoder block receives two inputs : x and f_encoder. The tensor x are the features that are flowing through the upscaling path while f_encoder are the features received from the encoder through the shortcut connections. During the forward pass you need to :

  • pass the x features through conv1 and up_conv,
  • concatenate the output of the previous step with the encoder features f_encoder,
  • pass the output of the previous step through the final sub-block conv2

Note how the dimensions change during these steps. At the output of up_conv, your tensor will have cin//2 channels. The f_encoder contains cin//2 channels as well. When both are concatenated, your tensor has cin channels and the final conv2 transforms these into cin//2 channels.

To test your implementation, you can run the following cell which should complete without errors.

In [ ]:
from planktoseg.models.unet import DecoderBlock, Decoder

# First test : we check the forward pass through a decoder block is working
dummy_input = torch.zeros(3, 4, 128, 128)
dummy_encoder_features = torch.zeros(3, 2, 256, 256)
block = DecoderBlock(4)
output = block(dummy_input, dummy_encoder_features)

print(f"With an input of shape {dummy_input.shape} with encoder features of shape {dummy_encoder_features.shape}, the output of the decoder block is of shape {output.shape}")

assert list(output.shape) == [3, 2, 256, 256]

# Second test : we check the forward pass through a complete decoder is working
batch_size = 3
K = 10
f1 = torch.zeros(batch_size, 64, 128, 128)
f2 = torch.zeros(batch_size, 128, 64, 64)
f3 = torch.zeros(batch_size, 256, 32, 32)
f4 = torch.zeros(batch_size, 512, 16, 16)

decoder = Decoder(num_classes = K)
output = decoder(f4, [f1, f2, f3])

print(f"The output of the decoder is of shape {output.shape}")

assert list(output.shape) == [batch_size, K, 512, 512]
With an input of shape torch.Size([3, 4, 128, 128]) with encoder features of shape torch.Size([3, 2, 256, 256]), the output of the decoder block is of shape torch.Size([3, 2, 256, 256])
The output of the decoder is of shape torch.Size([3, 10, 512, 512])

Building the complete model¶

Once the encoder and decoder are implemented, we can create the complete model and send it to the device used for the experiments.

In [ ]:
%%capture

model = models.UNet({"encoder": {"model_name": "resnet18"}}, input_size, 1 if num_classes == 2 else num_classes)
model = model.to(device)

Loss function and metrics for an unbalanced classification¶

Our problem is a classification problem, although pixel-wise. You need to classify every single pixel of an image. A natural first guess loss function in this case in the cross entropy loss which reads :

$$ CE(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -log(f_w(x_i)_{h,w,y_i}) $$

where we denote by $f_w(x_i)$ the probability distribution your model assigns to the $H \times W$ pixels of your image, so that $f_w(x_i)_{h, w, y_i}$ is the probability assigned by your model to the pixel $(h, w)$ and to the class $y_i$. That loss induces an over-influence of the majority class in an unbalanced dataset. Other losses may be prefered such as the focal loss, dice loss, weighted cross entropy loss, ...

In this lab, you are provided with an implementation of the focal loss. The focal loss will strongly decrease the influence of the pixels that are correctly predicted and, usually, these correspond to the pixels belonging to the over represented classes. It reads :

$$ CE(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -(1-p(y_{i,h,w} | x_i))^\gamma log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -(1 - f_w(x_i)_{h,w,y_i})^\gamma log(f_w(x_i)_{h,w,y_i}) $$

with, for example, $\gamma=2$. In order to illustrate the difference between the cross entropy loss and focal loss, we display below the two losses as a function of $p(y_{i,h,w} | x_i)$. As you can see, as the probability assigned by your model tends to $1$, the influence of the loss value is lowered with respect to the BCE loss.

In [ ]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

gamma = 2
Nsteps = 50
p = torch.linspace(0., 1., steps=Nsteps)
labels = torch.ones_like(p)

bce_loss_values = -torch.log(p)
focal_loss_values = -(1-p)**gamma * torch.log(p)

plt.figure()
plt.plot(p, bce_loss_values, label='BCE loss')
plt.plot(p, focal_loss_values, label='Focal loss')
plt.xlabel("Probability assigned to the class to be predicted")
plt.ylabel("Loss value")
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f1ce8d17430>
No description has been provided for this image
Exercice

Evaluate the following cell to define the loss as the Focal loss

In [ ]:
loss = optim.FocalLoss()

Last elements : optimizer, early stopping, metrics, loggers, ...¶

The final elements we need are the optimizer, the early stopping callback, some metrics computations and possibly loggers.

To define all these, you just need to evaluate the following cell.

In [ ]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Metrics evaluated over the training and validation folds
train_fmetrics = {
    "focal": main.deepcs.metrics.GenericBatchMetric(loss),
    "accuracy": main.BatchAccuracy(),
    "confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
train_fmetrics["F1"] = metrics.BinaryF1Metric()
test_fmetrics = {
    "focal": main.deepcs.metrics.GenericBatchMetric(loss),
    "accuracy": main.BatchAccuracy(),
    "confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
test_fmetrics["F1"] = metrics.BinaryF1Metric()

# Define the early stopping callback
model_checkpoint = utils.ModelCheckpoint(
    model, logdir, (1,) + input_size, device, min_is_best=False
)

Our first experiments¶

All right, we now have all the building blocks for running our first training. Depending on your settings and GPU, it may be more or less fast :)

On a GTX 3090, with a batch size of $16$, processing $3000$ minibatches, with images of size $512 \times 512$ takes almost 20 minutes per epoch and 22GB of VRAM.

Also, I observed, while preparing the labwork that the batch size should not be too small. Using a batch size of $7$ led to convergence to bad quality local minimum while a batch size of $16$ allowed to better minimize the focal loss. Unfortunately, this implies a lower bound on the GPU to use.

In [ ]:
num_epochs = 2

metrics_store = {"train": [], "valid": []}

def postprocess_metrics(fmetrics, metrics_dict):
    """
    This function is used only for extracting
    the precision and recall we track for displaying and ploting
    """
    cm = fmetrics["confusion_matrix"]
    metrics_dict["precision"] = cm.get_precision()
    metrics_dict["recall"] = cm.get_recall()

# Evaluate the metrics before training
train_metrics = utils.test(model, train_loader, device, test_fmetrics)
postprocess_metrics(train_fmetrics, train_metrics)
metrics_store["train"].append(train_metrics)

valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
postprocess_metrics(test_fmetrics, valid_metrics)
metrics_store["valid"].append(valid_metrics)

for e in range(num_epochs):
        # Train 1 epoch
        train_metrics = utils.train(
            model, train_loader, loss, optimizer, device, train_fmetrics
        )
        postprocess_metrics(train_fmetrics, train_metrics)
        metrics_store["train"].append(train_metrics)

        # Test
        valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
        postprocess_metrics(test_fmetrics, valid_metrics)
        metrics_store["valid"].append(valid_metrics)

        # Save the model if it is better
        checkpoint_metric_name = "F1"
        checkpoint_metric = valid_metrics[checkpoint_metric_name]
        updated = model_checkpoint.update(checkpoint_metric)

        # Display the metrics
        metrics_msg = f" Epoch {e} / {num_epochs}\n"
        metrics_msg += "- Train : \n  "
        metrics_msg += "\n  ".join(
            f" {m_name}: {m_value}" for (m_name, m_value) in train_metrics.items()
        )
        metrics_msg += "\n"
        metrics_msg += "- Valid : \n  "
        metrics_msg += "\n  ".join(
            f" {m_name}: {m_value}"
            + ("[>> BETTER <<]" if updated and m_name == checkpoint_metric_name else "")
            for (m_name, m_value) in valid_metrics.items()
        )
        print(metrics_msg)
        
    
  0%|          | 0/1474 [00:00<?, ?it/s]
100%|██████████| 1474/1474 [01:21<00:00, 18.17it/s]
100%|██████████| 369/369 [00:20<00:00, 18.01it/s]
100%|██████████| 1474/1474 [04:38<00:00,  5.29it/s]
100%|██████████| 369/369 [00:20<00:00, 17.83it/s]
 Epoch 0 / 2
- Train : 
   focal: 0.05041743235933004
   accuracy: 0.9605594004834199
   confusion_matrix: [[0.9954146118281357, 0.004585388171864319], [0.5870692592258612, 0.41293074077413877]]
   F1: 0.5416884766447951
   precision: 0.7871255188568016
   recall: 0.41293074077413877
- Valid : 
   focal: 0.03458496961828569
   accuracy: 0.8642685112506758
   confusion_matrix: [[0.9850354573537219, 0.014964542646278133], [0.526730909307094, 0.473269090692906]]
   F1: 0.6034457223144885[>> BETTER <<]
   precision: 0.8324059615164311
   recall: 0.473269090692906
100%|██████████| 1474/1474 [04:33<00:00,  5.39it/s]
100%|██████████| 369/369 [00:20<00:00, 17.89it/s]
 Epoch 1 / 2
- Train : 
   focal: 0.01360388492515849
   accuracy: 0.960526061672692
   confusion_matrix: [[0.9962440705799295, 0.00375592942007056], [0.5190544495534446, 0.4809455504465554]]
   F1: 0.6117580025823954
   precision: 0.8403155563173192
   recall: 0.4809455504465554
- Valid : 
   focal: 0.017102164999813124
   accuracy: 0.8642685112506758
   confusion_matrix: [[0.9906546216510921, 0.009345378348907864], [0.48852997030393647, 0.5114700296960635]]
   F1: 0.6511491356972307[>> BETTER <<]
   precision: 0.8957810601232823
   recall: 0.5114700296960635

During training, we recorded metrics into the metrics_store dictionnary and we can display these metrics. This is a very basic way to plot these metrics, at the very end of the lab, we propose wandb.ai which is way more convenient.

In [ ]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15,5), dpi=150)
plt.subplot(1, 4, 1)
plt.plot([mi["focal"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["focal"] for mi in metrics_store["valid"]], 'k--')
plt.title("Focal loss")
plt.xlabel("Epoch")

plt.subplot(1, 4, 2)
plt.plot([mi["F1"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["F1"] for mi in metrics_store["valid"]], 'k--')
plt.title("F1 score")
plt.xlabel("Epoch")

plt.subplot(1, 4, 3)
plt.plot([mi["precision"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["precision"] for mi in metrics_store["valid"]], 'k--')
plt.title("Precision")
plt.xlabel("Epoch")

plt.subplot(1, 4, 4)
plt.plot([mi["recall"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["recall"] for mi in metrics_store["valid"]], 'k--')
plt.title("Recall")
plt.xlabel("Epoch")
Out[ ]:
Text(0.5, 0, 'Epoch')
No description has been provided for this image

Inference on new zooscan images¶

Now that we trained our first UNet, we would like to apply it on new data, to perform so called "inference". In the training loop above, the model that we consider the best is the one minimizing the F1 score on the validation data. This best model has been saved during the optimization as a ONNX file. There are actually several ways to save a model, for example as either a torch tensor of parameters or a ONNX graph and the ONNX export is certainly the most portable of the two.

ONNX graphs can be executed with any runtime in a lot of different languages.

The cell below is a standalone cell. It can be executed without the other cells being executed. This is really the complete inference code.

Also, the patch size for inference is arbitrary. Indeed, as a UNet is a fully convolutional model, it can be trained on patches of shape, say $512 \times 512$, and then seemingly be evaluated on patches of arbitrary sizes, for example $4096 \times 4096$. To test inference, my advice would be to start with small patch sizes and then increase it. You may fill your memory if you use a too large patch size. To perform inference on a large image, a naive approach is to split it into smaller non overlapping patches, perform inference on each and then stick them together.

To run the following cell, you may need to restart your kernel so that the memory gets freed on the GPU.

In [ ]:
import pathlib
import matplotlib.pyplot as plt
import onnxruntime as ort
from PIL import Image
Image.MAX_IMAGE_PIXELS = 25000 * 15000
import numpy as np
import yaml

providers = []
use_cuda = True
patch_size = 4096

if use_cuda:
    providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")

# You may adapt the following to either use 
# the model you trained  or the pretrained model you are provided
# logdir = pathlib.Path("pretrained_model")
logdir = pathlib.Path("logs")

inference_session = ort.InferenceSession(
    str(logdir / "best_model.onnx"), providers=providers
)

# Load our normalizing statistics
stats = yaml.safe_load(open(str(logdir / "normalizing_stats.yaml"), "r"))
mean = stats["mean"]
std = stats["std"]

# Load our image
scan_path = "./data/test/rg20210421_scan.png"
scan_img = np.array(Image.open(scan_path))

mask_path = "./data/test/rg20210421_mask.png"
mask_img = np.array(Image.open(mask_path))
print(mask_img.shape)

crop_offset = (5048, 2048)
# Normalize our input
scan_img = ((scan_img - mean * 255.)/(std * 255.)).astype(np.float32)
scan_img = scan_img[np.newaxis, np.newaxis, ...]
scan_img = scan_img[:, :, crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)]

# Get the ground truth mask
# print(np.unique(mask_img))
mask_img = mask_img[crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)] >= 8

# Perform an inference
logits = inference_session.run(None, {"scan": scan_img})[0]
probs = 1.0 / (1.0 + np.exp(-logits))

pred_mask = probs >= 0.5

# Plot the results
plt.figure(dpi=300)
plt.subplot(1, 4, 1)
plt.imshow(scan_img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")

plt.subplot(1, 4, 2)
plt.imshow(probs.squeeze(), interpolation="none", clim=(0.0, 1.0))
plt.title("Probabilities")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(pred_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Predicted Mask")
plt.axis("off")

plt.subplot(1, 4, 4)
plt.imshow(mask_img, interpolation="none", cmap="tab20c")
plt.title("Ground truth")
plt.axis("off")

plt.tight_layout()
(14573, 22817)
No description has been provided for this image

Going further¶

Using DeepLab v3+ instead of the UNet with pretrained backbone¶

All right you implemented and trained a custom U-Net with a pretrained backbone. What about experimenting with a more state of the art model such as a deeplab v3+ ? For this, you have the possibility to use the segmentation models provided by torchvision or rely on an external library such as segmentation models_pytorch.

In [ ]:
%pip install segmentation_models_pytorch
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl (109 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 109.5/109.5 KB 1.4 MB/s eta 0:00:00a 0:00:01
Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: huggingface-hub>=0.24.6 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.25.1)
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 KB 1.6 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: six in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (1.16.0)
Collecting timm==0.9.7
  Downloading timm-0.9.7-py3-none-any.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 5.1 MB/s eta 0:00:00:00:0100:01
Requirement already satisfied: torchvision>=0.5.0 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.19.1)
Requirement already satisfied: pillow in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (10.4.0)
Requirement already satisfied: tqdm in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (4.66.5)
Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (from efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.4.1)
Collecting munch
  Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Requirement already satisfied: safetensors in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (0.4.5)
Requirement already satisfied: pyyaml in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (6.0.2)
Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.16.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (4.12.2)
Requirement already satisfied: packaging>=20.9 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (24.1)
Requirement already satisfied: fsspec>=2023.5.0 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.9.0)
Requirement already satisfied: requests in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.32.3)
Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (from torchvision>=0.5.0->segmentation_models_pytorch) (2.1.1)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.0.106)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (10.3.2.106)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.1.4)
Requirement already satisfied: triton==3.0.0 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.0.0)
Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.13.3)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.3.1)
Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.20.5)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.4.5.107)
Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.3)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (9.1.0.70)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.0.2.54)
Requirement already satisfied: nvidia-nvjitlink-cu12 in ./venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.6.77)
Requirement already satisfied: urllib3<3,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.8.30)
Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.10)
Requirement already satisfied: MarkupSafe>=2.0 in ./venv/lib/python3.10/site-packages (from jinja2->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./venv/lib/python3.10/site-packages (from sympy->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.3.0)
Using legacy 'setup.py install' for efficientnet-pytorch, since package 'wheel' is not installed.
Using legacy 'setup.py install' for pretrainedmodels, since package 'wheel' is not installed.
Installing collected packages: munch, efficientnet-pytorch, timm, pretrainedmodels, segmentation_models_pytorch
  Running setup.py install for efficientnet-pytorch ... done
  Attempting uninstall: timm
    Found existing installation: timm 1.0.9
    Uninstalling timm-1.0.9:
      Successfully uninstalled timm-1.0.9
  Running setup.py install for pretrainedmodels ... done
Successfully installed efficientnet-pytorch-0.7.1 munch-4.0.0 pretrainedmodels-0.7.4 segmentation_models_pytorch-0.3.4 timm-0.9.7
Note: you may need to restart the kernel to use updated packages.

For example, below is a snippet to construct DeepLabV3+ as proposed in Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. You can now go back to your code and experiment with one the models proposed by segmentation_models_pytorch.

In [ ]:
import segmentation_models_pytorch as smp

model = smp.DeepLabV3Plus(
    encoder_name="resnet18", 
    encoder_weights="imagenet", 
    in_channels=1, 
    classes=1, 
)

dummy_input = torch.zeros(3, 1, 512, 512)
output = model(dummy_input)
print(f"The output has a shape {output.shape}")
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /usr/users/dce-admin/fix/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 11.7MB/s]
The output has a shape torch.Size([3, 1, 512, 512])

Interfacing with an online dashboard : wandb.ai¶

As you may expect, deep learning is about tuning your whole pipeline, experimenting with different models and so on and you need to efficiently track all the experiments you will be running.

One part of the answer lies in using dashboards, especially online dashboard are super useful. One such dashboard is wandb.ai. To monitor your experiments with wandb, you just need to add few lines of code.

First, you need to get an account on wandb. This is free and academics have free plans. Once you have an account, you can create a new project, for which you will obtain both a project name and an entity name.

Then you can connect to your wandb calling

import wandb
wandb.init(project=..., entity=....)

and finally, to record your metrics during training, you simply need to call the following function after every epoch :

wandb.log({my_metric_name: my_metric_value})

In our case, we have two dictionnaries train_metrics and test_metrics which contain our metrics, we just need to call wandb.log on them to get your metrics logged :

wandb dashboard example

Running the code without jupyter¶

Between you and me, a jupyter notebook is great for exploring some concepts but you definitely want to move away from these if you really want to make large scale experiments because running jupyter notebooks is not the most efficient way to perform multiple experiments.

Actually, I adapted a python library I wrote for plankton segmentation so that you are able to evaluate everything from within a notebook.

The library is available on github at https://github.com/jeremyfix/2024_ml4oceans. Feel free to make use of it. You are also provided with a python script submit_slurm.py which allows you to easily run experiments on clusters managed by SLURM (e.g. Jean Zay). Note you will certainly have to adapt the slurm sbatch directives to make it work (e.g. partition name, ...)

In [ ]: