ZooScan image segmentation using a UNet with a pretrained backbone¶

In [ ]:
%load_ext autoreload
%autoreload 2

In this lab, we would like to train a deep neural network for performing semantic segmentation of Zooscan images. This labwork will guide through :

  • how to load labeled zooscan images
  • how to implement a UNet for performing the semantic segmentation
  • how to train the network
  • how to perform inference on a new image with a trained network

The cell below will collect a subset of a dataset for training and testing. The complete dataset we can use is made of $23$ labeled zooplankton images provided by :

Sorbonne Université/CNRS - Institut de la Mer de Villefranche (IMEV), Sorbonne Université/CNRS - Laboratoire d'Océanographie de Villefranche (LOV); 2020; Plankton community in Régent (680µm) net, Point B, Villefranche-sur-Mer, France https://dx.doi.org/10.14284/477

In [ ]:
!rm -rf data
In [ ]:
import os
import glob
import random
import shutil

max_numfiles = 10
train_fraction = 0.8

datarootdir = "/mounts/Datasets3/2024-2025-ChallengePlankton/train"

if "TMPDIR" in os.environ:
    # On the DCE, $TMPDIR refer to a temporary local storage on the compute node
    tmpdir = os.environ["TMPDIR"]
    os.symlink(tmpdir, "./data")
else:
    tmpdir = "./data"

if os.path.exists(tmpdir):
    shutil.rmtree(tmpdir)

os.makedirs(tmpdir)

# Locate all the xxxxx_scan.png.ppm from the datarootdir with the labels xxxxx_mask.png.ppm
# split it randomly in 80% train and 20% test
# Only consider up to max_numfiles files
# Copy these data in their respective folders into the tmpdir

scan_files = glob.glob(os.path.join(datarootdir, "*_scan.png.ppm"))
# Compute the mask filenames from the scan files
mask_files = [f.replace("_scan.png.ppm", "_mask.png.ppm") for f in scan_files]

# Random shuffle the files for the split. Apply the same random to the masks to keep
# both lists aligned

idx = list(range(len(scan_files)))[:max_numfiles]
random.shuffle(idx)

train_idx = idx[: int(len(idx) * train_fraction)]
test_idx = idx[int(len(idx) * train_fraction) :]

train_scan_files = [scan_files[i] for i in train_idx]
train_mask_files = [mask_files[i] for i in train_idx]
test_scan_files = [scan_files[i] for i in test_idx]
test_mask_files = [mask_files[i] for i in test_idx]


def copy_files(files, target_dir):
    """
    Copy the files to the target directory.
    """
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    for f in files:
        os.system(f"cp {f} {target_dir}")


# Copy the files to the tmpdir
copy_files(train_scan_files, os.path.join(tmpdir, "train"))
copy_files(train_mask_files, os.path.join(tmpdir, "train"))

copy_files(test_scan_files, os.path.join(tmpdir, "test"))
copy_files(test_mask_files, os.path.join(tmpdir, "test"))

# Copy the taxa file as well
os.system(f"cp {datarootdir}/taxa.csv {os.path.join(tmpdir, 'train')}")
os.system(f"cp {datarootdir}/taxa.csv {os.path.join(tmpdir, 'test')}")
Out[ ]:
0

Introduction¶

We are provided with labeled zooscan which means we have both the zooscan image and a mask. The mask is an image of labels. The data are located within the data subdirectory; Every scan is around $20000\times 15000$ pixels.

In [ ]:
!ls data/train
rg20090204_mask.png.ppm  rg20090617_mask.png.ppm  rg20091028_mask.png.ppm
rg20090204_scan.png.ppm  rg20090617_scan.png.ppm  rg20091028_scan.png.ppm
rg20090318_mask.png.ppm  rg20090715_mask.png.ppm  rg20091125_mask.png.ppm
rg20090318_scan.png.ppm  rg20090715_scan.png.ppm  rg20091125_scan.png.ppm
rg20090610_mask.png.ppm  rg20090902_mask.png.ppm  taxa.csv
rg20090610_scan.png.ppm  rg20090902_scan.png.ppm
In [ ]:
!for f in data/train/*_scan.png.ppm; do file $f; done
data/train/rg20090204_scan.png.ppm: Netpbm image data, size = 22817 x 14569, rawbits, greymap
data/train/rg20090318_scan.png.ppm: Netpbm image data, size = 22797 x 14549, rawbits, greymap
data/train/rg20090610_scan.png.ppm: Netpbm image data, size = 22717 x 14459, rawbits, greymap
data/train/rg20090617_scan.png.ppm: Netpbm image data, size = 22817 x 14529, rawbits, greymap
data/train/rg20090715_scan.png.ppm: Netpbm image data, size = 22807 x 14549, rawbits, greymap
data/train/rg20090902_scan.png.ppm: Netpbm image data, size = 22807 x 14319, rawbits, greymap
data/train/rg20091028_scan.png.ppm: Netpbm image data, size = 22787 x 14379, rawbits, greymap
data/train/rg20091125_scan.png.ppm: Netpbm image data, size = 22817 x 14519, rawbits, greymap

An example ZooScan image and its mask are displayed below

Zooscan image and mask example

The colorcode of the mask is dependent on the class to which the pixel has been assigned by the human labeler Note this is certainly a tough work to pixel label these huge images but the quality of your data and labels are of the uppermost importance for our deep learning algorithms to work. The semantic of the class indices are provided in the taxa.csv file.

In [ ]:
!cat data/train/taxa.csv
label,living,label_nb
artefact,False,1
badfocus<artefact,False,2
bubble,False,3
detritus,False,4
fiber<detritus,False,5
t001,False,6
t003,False,7
Abylopsis tetragona,True,8
Acartiidae,True,9
Actiniaria,True,10
Actinopterygii,True,11
Aglaura,True,12
Amphipoda,True,13
Annelida,True,14
Atlanta,True,15
Aulacantha,True,16
Bivalvia<Mollusca,True,17
Calanidae,True,18
Calanoida,True,19
Calocalanus pavo,True,20
Candaciidae,True,21
Cavolinia inflexa,True,22
Centropagidae,True,23
Chaetognatha,True,24
Chelophyes appendiculata,True,25
Cliidae,True,26
Collodaria,True,27
Copilia,True,28
Corycaeidae,True,29
Creseidae,True,30
Creseis acicula,True,31
Ctenophora<Metazoa,True,32
Cymbulia peroni,True,33
Doliolida,True,34
Echinodermata,True,35
Eucalanidae,True,36
Euchaetidae,True,37
Eumalacostraca,True,38
Evadne,True,39
Fritillariidae,True,40
Gammaridea,True,41
Globigerinidae,True,42
Gymnosomata,True,43
Haloptilus,True,44
Harpacticoida,True,45
Heterorhabdidae,True,46
Hydrozoa,True,47
Hyperiidea,True,48
Insecta,True,49
Lensia subtilis,True,50
Limacinidae,True,51
Liriope<Geryoniidae,True,52
Metridinidae,True,53
Mollusca,True,54
Obelia,True,55
Oikopleuridae,True,56
Oithonidae,True,57
Oncaeidae,True,58
Orbulina,True,59
Ostracoda,True,60
Penilia avirostris,True,61
Physonectae,True,62
Podon,True,63
Pontellidae,True,64
Pontellina plumata,True,65
Pyrosomatida,True,66
Rhincalanidae,True,67
Rhizaria,True,68
Rhopalonema velatum,True,69
Salpida,True,70
Sapphirinidae,True,71
Solmundella bitentaculata,True,72
Temoridae,True,73
bract<Abylidae,True,74
bract<Abylopsis tetragona,True,75
bract<Diphyidae,True,76
calyptopsis<Euphausiacea,True,77
chain<Salpida,True,78
cirrus,True,79
colony<Phaeodaria,True,80
cypris,True,81
damaged<Aulacantha,True,82
division,True,83
egg<Actinopterygii,True,84
egg<Mollusca,True,85
egg<other,True,86
ephyra<Scyphozoa,True,87
eudoxie<Abylopsis tetragona,True,88
eudoxie<Bassia bassensis,True,89
eudoxie<Diphyidae,True,90
gonophore<Abylopsis tetragona,True,91
gonophore<Bassia bassensis,True,92
gonophore<Diphyidae,True,93
head<Chaetognatha,True,94
juvenile<Salpida,True,95
larvae<Annelida,True,96
larvae<Porcellanidae,True,97
larvae<Squillidae,True,98
like<Collodaria,True,99
like<Laomediidae,True,100
megalopa,True,101
multiple<Copepoda,True,102
multiple<other,True,103
nauplii<Cirripedia,True,104
nauplii<Crustacea,True,105
nectophore<Abylopsis tetragona,True,106
nectophore<Bassia bassensis,True,107
nectophore<Diphyidae,True,108
nectophore<Hippopodiidae,True,109
nectophore<Physonectae,True,110
nucleus<Salpidae,True,111
othertocheck,True,112
part<Annelida,True,113
part<Cnidaria,True,114
part<Crustacea,True,115
part<Ctenophora,True,116
part<Mollusca,True,117
part<Siphonophorae,True,118
phyllosoma,True,119
pluteus<Echinoidea,True,120
pluteus<Ophiuroidea,True,121
protozoea<Mysida,True,122
protozoea<Penaeidae,True,123
protozoea<Sergestidae,True,124
scale,True,125
seaweed,True,126
siphonula,True,127
t002,True,128
tail<Appendicularia,True,129
tail<Chaetognatha,True,130
trunk<Appendicularia,True,131
wing,True,132
zoea<Brachyura,True,133
zoea<Galatheidae,True,134

We do have $134$ classes of objects plus the background (label $0$). All the living stuff have been assigned a label strictly higher than $7$. The "living" column is exactly the same as testing if the label_nb is smaller or larger than $7$.

The classes are unbalanced. For example, counting the number of occurences of all the $135$ classes (background + $134$ non living stuff and living organisms), we obtain the count below, by decreasing occurence.

The most represented class is the background with almost $6$ billions pixels. The most represented non-background class is Chaetognatha (class label $50$) with $58$ million pixels and the less represented class is Bivalvia/Mollusca with only $1163$ K pixels. This strong unbalance can induce a lot of trouble when training a neural network as the over-represented classes can be more easily learned than the under-represented ones.

Classe name Count
background 6648643689
Chaetognatha 58412864
Salpida 56586684
detritus 33649190
Calanidae 17664122
Aulacantha 12856260
Creseis acicula 8225786
Eumalacostraca 7416376
nectophore<Diphyidae 7036773
badfocus<artefact 6540988
multiple<other 6236479
Rhopalonema velatum 5682772
Calanoida 5602157
Metridinidae 5429665
damaged<Aulacantha 4346996
Chelophyes appendiculata 4061841
nectophore<Physonectae 3743692
Candaciidae 3646939
ephyra<Scyphozoa 3104710
Euchaetidae 2741439
gonophore<Diphyidae 2502463
Abylopsis tetragona 2443476
Doliolida 2442549
Hydrozoa 2418177
Centropagidae 1803662
nectophore<Abylopsis tetragona 1760329
artefact 1672201
othertocheck 1434265
zoea<Galatheidae 1407337
Eucalanidae 1393273
zoea<Brachyura 1360657
Temoridae 1258350
t001 1142346
bract<Abylopsis tetragona 1135173
fiber<detritus 1094546
Aglaura 1071800
Ctenophora<Metazoa 1046977
part<Crustacea 999993
egg<other 995258
bract<Diphyidae 991480
Cavolinia inflexa 954769
tail<Chaetognatha 951886
Creseidae 914248
gonophore<Abylopsis tetragona 908079
Corycaeidae 853293
Oikopleuridae 846252
protozoea<Mysida 780997
calyptopsis<Euphausiacea 780051
bubble 659786
Hyperiidea 570687
eudoxie<Abylopsis tetragona 548613
Sapphirinidae 490605
Heterorhabdidae 487610
Fritillariidae 481907
nucleus<Salpidae 468802
part<Siphonophorae 461031
Liriope<Geryoniidae 428388
Copilia 413538
egg<Actinopterygii 400252
Pyrosomatida 393117
eudoxie<Diphyidae 391068
Actinopterygii 386618
part<Cnidaria 380051
seaweed 361969
Lensia subtilis 331479
larvae<Squillidae 329856
division 326004
head<Chaetognatha 299446
siphonula 286749
part<Mollusca 285145
larvae<Porcellanidae 278677
megalopa 259496
Penilia avirostris 259458
tail<Appendicularia 234629
Solmundella bitentaculata 232568
like<Collodaria 224458
Rhizaria 217177
juvenile<Salpida 201881
pluteus<Ophiuroidea 178319
Annelida 166785
like<Laomediidae 164274
nauplii<Crustacea 133174
Gammaridea 131295
Collodaria 118535
Cymbulia peroni 117545
protozoea<Sergestidae 115471
Gymnosomata 114724
chain<Salpida 109933
Oithonidae 96581
Ostracoda 96258
trunk<Appendicularia 91666
Actiniaria 86592
Haloptilus 84510
egg<Mollusca 76979
Limacinidae 70271
Rhincalanidae 70044
Pontellidae 65340
multiple<Copepoda 64726
Physonectae 61635
colony<Phaeodaria 60429
part<Ctenophora 55696
gonophore<Bassia bassensis 51563
bract<Abylidae 50050
Cliidae 49886
Atlanta 48107
Acartiidae 47749
Amphipoda 45436
pluteus<Echinoidea 45399
t003 40408
phyllosoma 36785
protozoea<Penaeidae 33334
Podon 29154
Obelia 28436
Globigerinidae 24077
Pontellina plumata 22435
Oncaeidae 20112
Orbulina 18316
Insecta 16989
nectophore<Bassia bassensis 13441
nauplii<Cirripedia 9843
Echinodermata 9496
cypris 6894
Evadne 6758
Harpacticoida 5712
t002 4580
Calocalanus pavo 4258
wing 4236
Bivalvia<Mollusca 1663
eudoxie<Bassia bassensis 0
cirrus 0
larvae<Annelida 0
part<Annelida 0
scale 0
nectophore<Hippopodiidae 0
Mollusca 0

The plot below shows the number of pixels per class. Note that the y-axis is in logscale. As we count the number of pixels per class, the unbalance could be explained by some organisms larger than others or by the over-representation of these. In any case, this unbalance will cause you trouble when training a neural network.

Number of pixels per class

If we group the classes by non-living vs living, we obtain the following counts where you still have $10$ times more "non living pixels".

Classe name Count
Non living 6.693.443.154
Living 259.647.119

Imports¶

All the code you are going to write will be into the following files :

  • data.py : script responsible for handling the data pipeline, producing the dataloaders for the training and validation data
  • models/ : submodule responsible for providing the different models you want to experiment
In [ ]:
# Standard imports
import pathlib

# External imports
import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml

# Local imports
import wandb
import planktoseg
from planktoseg import data
from planktoseg import models
from planktoseg import optim
from planktoseg import utils
from planktoseg import metrics
from planktoseg import main

use_cuda = torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device("cpu")

logdir = pathlib.Path("./logs")
if not logdir.exists():
    logdir.mkdir(exist_ok=True)
/usr/users/dce-admin/fix/GIT/2024_ml4oceans/venv/lib/python3.10/site-packages/requests/__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).
  warnings.warn(

Data loading and exploration¶

In this part, we will be exploring the dataset :

  • load and visualize some samples
  • investigate the balance of the classes
  • investigate the application of some data augmentation techniques.

The very first step when you write a pipeline for training neural networks on your data is to prepare your data pipeline.

In Pytorch, the expected output of this first step is to provide the data loaders, i.e. the python iterable objects able to give you minibatches for training and validation.

Dataset¶

In the data.py script, you are provided with the PlanktonDataset object. We wrote this class to ease your work of loading the data. This class can :

  • load the data (image and mask),
  • split every sample into patches and you can configure both the patch size and the patch stride (overlap = patch_size - patch_stride),
  • apply transformations on the input patches and mask patches
  • switch between two semantic segmentation tasks :
    1. classify a pixel as belonging to one of the $135$ classes (task = SegmentationTask.LIVING_CLASSES)
    2. classify a pixel as belonging to non living vs living classes (task = SegmentationTask.LIVING_NONLIVING)

As an example, in the cell below, we basically invoke this dataset object which is going to :

  • split the large scans into patches of size $512 \times 512$ with a stride of $128 \times 128$
  • merge the labels into two classes : Non living (label < 7) and Living (label >= 8)
  • every patch and mask are transformed from an Image type to pytorch Tensor
In [ ]:
transform = A.pytorch.ToTensorV2()

dataset = data.PlanktonDataset(root="./data/train", patch_size=(512, 512), patch_stride=(128, 128), task=data.SegmentationTask.LIVING_NONLIVING)
dataset = data.WrappedDataset(dataset, transform)
100%|██████████| 8/8 [00:00<00:00, 7584.64it/s]
100%|██████████| 8/8 [00:00<00:00, 7584.64it/s]

We can now access both an image and its label by indexing the dataset. Feel free to repeat the execution of the cell below as it randomly samples the dataset for a new image/mask.

Feel free to run the cell below multiple times. Every time, it will randomly sample one patch and its binary label.

In [ ]:
idx = np.random.randint(len(dataset)) 

img, mask = dataset[idx]

print(f"The img is of type {type(img)} and of shape {img.shape}")
print(f"The mask is of type {type(mask)} and of shape {mask.shape}")

plt.subplot(1, 2, 1)
plt.imshow(img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Mask")
plt.axis("off")
The img is of type <class 'torch.Tensor'> and of shape torch.Size([1, 512, 512])
The mask is of type <class 'torch.Tensor'> and of shape torch.Size([512, 512])
Out[ ]:
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
No description has been provided for this image

Exploring the augmentation transforms¶

Data (of quality) is of paramount importance but still limited in availability. We can virtually increase its availability by applying so called "data augmentations" which are transforms of your inputs for which you know how to transform the target.

For this, we will use the albumentations python library which offers several transforms for various tasks (classification, object detection, semantic segmentation)

Your first exercise is to identify, among the available data transforms (see for example this documentation and do not forget to scroll down to the spatial level transforms) which ones can be suitable. For example :

  • A.HorizontalFlip, which randomly flips horizontally your image/mask pair,
  • A.VerticalFlip which randomly flips vertically your image/mask pair,
  • A.MaskDropout which randomly mask some objects

You could also Blur or add GaussianNoise. I believe these transforms still make sense for our plankton segmentation task given the snow we see on the image and the possibly un-focused objects. If you wonder about "Where should I stop adding transforms ?" / "How do I know my transforms are ok ?". Remember, the purpose of these transforms is to augment your data, hopefully to improve your validation metric. We look for transforms that will improve the ability of your neural network to generalize its responses on new images.

Exercice

In the sample code below, you can experiment with different augmentations. Some are already provided, feel free to :

  • run multiple times the cell below. The transforms are randomly sampled and applied on the inputs. Every time you run the cell, the same input image/mask will be differently transformed
  • modify the content of the transform = A.Compose([..]) to add or remove transforms and observe its impact
In [ ]:
idx = np.random.randint(len(dataset)) 
idx = 15000

# Sample the dataset without all the transforms
# to see what the image and mask originally look like
original_transform = A.pytorch.ToTensorV2()
dataset.transform = original_transform
orig_img, orig_mask = dataset[idx]

##########################################################################################
# TODO: Feel free to change the code below 
# Tune the augmented transform by prepending your choosen transform before the 
# conversion to pytorch tensor
# Fill free to evaluate the cell as you add transforms to see their effect
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.MaskDropout((1, 1), image_fill_value=255, p=1), 
    A.Blur(),
    A.pytorch.ToTensorV2()
])
##########################################################################################

# And now sample the exact same image/mask 
dataset.transform = transform
aug_img, aug_mask = dataset[idx]

plt.subplot(2, 2, 1)
plt.imshow(orig_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan original")
plt.axis("off")

plt.subplot(2, 2, 2)
plt.imshow(orig_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Original Mask")
plt.axis("off")

plt.subplot(2, 2, 3)
plt.imshow(aug_img.squeeze(), cmap="gray", clim=(0, 255))
plt.title("Zooscan augmented image")
plt.axis("off")

plt.subplot(2, 2, 4)
plt.imshow(aug_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Augmented mask")
plt.axis("off")
/tmp/fix-108932/ipykernel_2324214/2288155435.py:18: UserWarning: Argument(s) 'image_fill_value' are not valid for transform MaskDropout
  A.MaskDropout((1, 1), image_fill_value=255, p=1),
Out[ ]:
(np.float64(-0.5), np.float64(511.5), np.float64(511.5), np.float64(-0.5))
No description has been provided for this image

Adding the transforms in the data pipeline¶

The suitable augmentation transforms you identified in the previous sections needs to be given in the data pipeline.

This is already defined in the planktoseg/data.py script. The most important function in the data.py script is the get_dataloaders function where :

  • the dataset is split between a training and validation folds
  • both with their own transforms (with augmentations for the training, without augmentation for the validation)
  • creates the dataloaders providing the mini batches

The get_dataloaders is the function that will be called by the main script to build up the dataloaders, and that's the only thing we need from the data pipeline.

Exercice

Locate in the planktoseg/data.py script, in the get_dataloaders function, where and which transforms are defined. Once done, move forward with the next cell which will create your dataloaders.

In [ ]:
# The configuration below requires 11GB of VRAM
data_config = {"trainpath": "./data/train", 
               "valid_ratio": 0.2,
               "batch_size": 32,
               "num_workers": 7,
               "patch_size": (256, 256),
               "patch_stride": (128, 128),
               "task": "living_nonliving",
               "normalize": True,
               "max_num_samples": 50000
               }

train_loader, valid_loader, input_size, num_classes, normalizing_stats = data.get_dataloaders(data_config, use_cuda)

# The normalizing statistics must be saved because they will be needed for inference
with open(logdir / "normalizing_stats.yaml", "w") as file:
    yaml.dump(normalizing_stats, file)
100%|██████████| 8/8 [00:00<00:00, 4497.91it/s]
100%|██████████| 1250/1250 [01:24<00:00, 14.83it/s]
In [ ]:
print(f"The train dataloader contains {len(train_loader)} mini batches")
print(f"The valid dataloader contains {len(valid_loader)} mini batches")
print(f"Our problem contains {num_classes} classes and our inputs have the shape {input_size}")
print(f"For normalizing the input, we will be using the following statistics {normalizing_stats}")
The train dataloader contains 1250 mini batches
The valid dataloader contains 313 mini batches
Our problem contains 2 classes and our inputs have the shape (1, 256, 256)
For normalizing the input, we will be using the following statistics {'mean': 0.8375427017211914, 'std': 0.09082679715796328}

Encoder/Decoder with a pretrained backbone¶

Since we now have the data to train on, the next step is to implement the model. We are going to implement a U-Net, as seen during the lecture, with a pre-trained backbone.

UNet architecture

In our codebase, it is the responsibility of the models module to build the models. This module provides you a builder of models which is the models.build_model function. The UNet model is defined in the models/unet.py script as the UNet class. If you look at the code of that class, you will notice it is built with two components : the encoder which is a TimmEncoder and the decoder of type Decoder.

Experimenting with the encoder¶

The code for the encoder is already given to you. This code will load a pretrained model provided by the timm library.

Exercice

Locate in the planktoseg/models/unet.py the class TimmEncoder. This class is using the external library timm to access pretrained neural networks. In particular, an encoder has exactly the stucture of a classification network and it makes sense to use, for example, a ResNet pretrained on ImageNet as an initial encoder.

Exercice

To better appreciate what the encoder produces as an output, the cell below will :

  • create an encoder with cin=1 (grayscale input images) and model_name=resnet18
  • perform a forward pass through the network and check the dimensionality of the outputs.

As we are only interested in checking the dimensionalities of the output of the encoder, you can use a dummy torch.zeros tensor to perform the forward propagation. This is handful way to trigger the computation of a network.

In [ ]:
from planktoseg.models.unet import TimmEncoder

encoder = TimmEncoder(cin=1, model_name="resnet18")
fake_input = torch.zeros(1, 1, 512, 512) # This must be in the pytorch format (B, C, H, W)
f4, [f1, f2, f3] = encoder(fake_input)

print(f"With an input of shape {fake_input.shape}, \n our model outputs tensors of shape :")
print(f" - f1 : {f1.shape}")
print(f" - f2 : {f2.shape}")
print(f" - f3 : {f3.shape}")
print(f" - f4 : {f4.shape}")
With an input of shape torch.Size([1, 1, 512, 512]), 
 our model outputs tensors of shape :
 - f1 : torch.Size([1, 64, 128, 128])
 - f2 : torch.Size([1, 128, 64, 64])
 - f3 : torch.Size([1, 256, 32, 32])
 - f4 : torch.Size([1, 512, 16, 16])

Implementing the decoder¶

We now move on the decoder part. The decoder receives the outputs of the encoder and progressively upscales spatially the representation to finally obtain an output whose spatial dimensions match the ones of the input, and produce, for every pixel, a score for each of our $2$ classes.

With an input of shape $(B, C, H, W)$, the output of the final layer is expected to be of shape $(B, K, H, W)$ where $K$ is the number of classes you want to predict for every pixel.

The specificity of the UNet is that along the upscaling path, you integrate features from intermediate layers of the encoder through the so-called shortcut connections, the connections that provide the $f_3, f_2, f_1$ features depicted on the figure above.

Exercice

Locate the code of the Decoder class in the models.py script. As the encoder, the decoder is built from the repetition of blocks, so called DecoderBlock in the code. All the codes are already prodived to you and you can take time to understand what is going on.

A Decoder block receives, along the upscaling path, an input tensor with cin channels and is built from a sequence of layers :

  • conv1 which is a Sequential block with 2D convolution, Batch normalization layer and ReLU. This sub-block should output cin channels,
  • up_conv which is a Sequential block with an upscaling layer followed by a 2D convolution, Batch normalization and ReLU. This sub-block should output cin//2 channels
  • conv2 which is a sequential block with a 2D convolution, Batch normalization layer and ReLU. This sub-block takes as input cin channels and outputs cin//2 channels.

For the forward pass, your decoder block receives two inputs : x and f_encoder. The tensor x are the features that are flowing through the upscaling path while f_encoder are the features received from the encoder through the shortcut connections. During the forward pass you need to :

  • pass the x features through conv1 and up_conv,
  • concatenate the output of the previous step with the encoder features f_encoder,
  • pass the output of the previous step through the final sub-block conv2

Note how the dimensions change during these steps. At the output of up_conv, your tensor will have cin//2 channels. The f_encoder contains cin//2 channels as well. When both are concatenated, your tensor has cin channels and the final conv2 transforms these into cin//2 channels.

To test your implementation, you can run the following cell which should complete without errors. When you code your own network, it is definitely a good practice to locally test your implementation. You want to locate potential errors as soon as possible. Imagine if you had to debug your code only when all the pipeline is setup, what a nightmare 😱 ! Just never do that !

In [ ]:
from planktoseg.models.unet import DecoderBlock, Decoder

# First test : we check the forward pass through a decoder block is working
print("First test")
dummy_input = torch.zeros(3, 4, 128, 128)
dummy_encoder_features = torch.zeros(3, 2, 256, 256)
block = DecoderBlock(4)
output = block(dummy_input, dummy_encoder_features)

print(f"With an input of shape {dummy_input.shape} with encoder features of shape {dummy_encoder_features.shape}, the output of the decoder block is of shape {output.shape}")

assert list(output.shape) == [3, 2, 256, 256]

# Second test : we check the forward pass through a complete decoder is working
print("\nSecond test")
batch_size = 3
K = 10
f1 = torch.zeros(batch_size, 64, 128, 128)
f2 = torch.zeros(batch_size, 128, 64, 64)
f3 = torch.zeros(batch_size, 256, 32, 32)
f4 = torch.zeros(batch_size, 512, 16, 16)

decoder = Decoder(num_classes = K)
output = decoder(f4, [f1, f2, f3])

print(f"The output of the decoder is of shape {output.shape}")

assert list(output.shape) == [batch_size, K, 512, 512]
print("If we reach that point, we are good to go !")
First test
With an input of shape torch.Size([3, 4, 128, 128]) with encoder features of shape torch.Size([3, 2, 256, 256]), the output of the decoder block is of shape torch.Size([3, 2, 256, 256])

Second test
The output of the decoder is of shape torch.Size([3, 10, 512, 512])
If we reach that point, we are good to go !

Building the complete model¶

Once the encoder and decoder are implemented, we can create the complete model and send it to the device used for the experiments.

In [ ]:
%%capture

model = models.UNet({"encoder": {"model_name": "resnet18"}}, input_size, 1 if num_classes == 2 else num_classes)
model = model.to(device)

Loss function and metrics for an unbalanced classification¶

Great, we have the data and we have a model. It is now time to define the loss which quantifies the mismatch between the predictions of our network and the groundtruth, and the metrics which will help us monitor the training.

Our problem is a classification problem, although pixel-wise. You need to classify every single pixel of an image. A natural first guess loss function in this case in the cross entropy loss which reads :

$$ CE(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -log(f_w(x_i)_{h,w,y_i}) $$

where we denote by $f_w(x_i)$ the probability distribution your model assigns to each class for all the $H \times W$ pixels of your image, so that $f_w(x_i)_{h, w, y_i}$ is the probability assigned by your model to the pixel $(h, w)$ and to the class $y_i$. That loss induces an over-influence of the majority class in an unbalanced dataset. Other losses may be prefered such as the focal loss, dice loss, weighted cross entropy loss, ...

In this lab, you are provided with an implementation of the focal loss. The focal loss will strongly decrease the influence of the pixels that are correctly predicted and, usually, these correspond to the pixels belonging to the overrepresented classes. It reads :

$$ Focal(\{x_i, y_i\}) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1}\sum_{h=0}^{H-1}\sum_{w=0}^{W-1} -(1-p(y_{i,h,w} | x_i))^\gamma log(p(y_{i,h,w} | x_i)) = \frac{1}{N\times H \times W} \sum_{i=0}^{N-1} -(1 - f_w(x_i)_{h,w,y_i})^\gamma log(f_w(x_i)_{h,w,y_i}) $$

with, for example, $\gamma=2$. In order to illustrate the difference between the cross entropy loss and focal loss, we display below the two losses as a function of $p(y_{i,h,w} | x_i)$. As you can see, as the probability assigned by your model tends to $1$, the influence of the loss value is lowered with respect to the BCE loss.

In [ ]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

gamma = 2
Nsteps = 50
p = torch.linspace(0., 1., steps=Nsteps)
labels = torch.ones_like(p)

bce_loss_values = -torch.log(p)
focal_loss_values = -(1-p)**gamma * torch.log(p)

plt.figure()
plt.plot(p, bce_loss_values, label='BCE loss')
plt.plot(p, focal_loss_values, label='Focal loss')
plt.xlabel("Probability assigned to the class to be predicted")
plt.ylabel("Loss value")
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x7f6c880ac6d0>
No description has been provided for this image
Exercice

Evaluate the following cell to define the loss as the Focal loss

In [ ]:
loss = optim.BinaryFocalLoss()

Last elements : optimizer, early stopping, metrics, loggers, ...¶

The final elements we need are the optimizer, the early stopping callback, some metrics computations and possibly loggers.

To define all these, you just need to evaluate the following cell.

In [ ]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Metrics evaluated over the training and validation folds
train_fmetrics = {
    "focal": main.deepcs.metrics.GenericBatchMetric(loss),
    "accuracy": main.BatchAccuracy(),
    "confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
train_fmetrics["F1"] = metrics.BinaryF1Metric()
test_fmetrics = {
    "focal": main.deepcs.metrics.GenericBatchMetric(loss),
    "accuracy": main.BatchAccuracy(),
    "confusion_matrix": metrics.BinaryConfusionMatrixMetric(),
}
test_fmetrics["F1"] = metrics.BinaryF1Metric()

# Define the early stopping callback
model_checkpoint = utils.ModelCheckpoint(
    model, logdir, (1,) + input_size, device, min_is_best=False
)

Our first experiments¶

All right, we now have all the building blocks for running our first training. Depending on your settings and GPU, it may be more or less fast :)

On a GTX 3090, with a batch size of $16$, processing $3000$ minibatches, with images of size $512 \times 512$ takes almost 20 minutes per epoch and 22GB of VRAM.

Also, I observed, while preparing the labwork that the batch size should not be too small. Using a batch size of $7$ led to convergence to bad quality local minimum while a batch size of $16$ allowed to better minimize the focal loss. Unfortunately, this implies a lower bound on the GPU to use.

Important Training with the code below may fail. Sometimes, the optimizer does not find a good solution. If the F1 score is below $0.1$ after the first epoch, the optimization will generally completely fail. In this case, you are advised to recreate the network which will draw a new initialization and then re-run the training code below. It means reevaluating two cells above, the one where the model is created and the one where the optimizer is created.

In [ ]:
num_epochs = 20

metrics_store = {"train": [], "valid": []}

def postprocess_metrics(fmetrics, metrics_dict):
    """
    This function is used only for extracting
    the precision and recall we track for displaying and ploting
    """
    cm = fmetrics["confusion_matrix"]
    metrics_dict["precision"] = cm.get_precision()
    metrics_dict["recall"] = cm.get_recall()

# Evaluate the metrics before training
train_metrics = utils.test(model, train_loader, device, test_fmetrics)
postprocess_metrics(train_fmetrics, train_metrics)
metrics_store["train"].append(train_metrics)

valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
postprocess_metrics(test_fmetrics, valid_metrics)
metrics_store["valid"].append(valid_metrics)

for e in range(num_epochs):
        # Train 1 epoch
        train_metrics = utils.train(
            model, train_loader, loss, optimizer, device, train_fmetrics, use_autocast=True, gradient_accumulation=2
        )
        postprocess_metrics(train_fmetrics, train_metrics)
        metrics_store["train"].append(train_metrics)

        # Test
        valid_metrics = utils.test(model, valid_loader, device, test_fmetrics)
        postprocess_metrics(test_fmetrics, valid_metrics)
        metrics_store["valid"].append(valid_metrics)

        # Save the model if it is better
        checkpoint_metric_name = "F1"
        checkpoint_metric = valid_metrics[checkpoint_metric_name]
        updated = model_checkpoint.update(checkpoint_metric)

        # Display the metrics
        metrics_msg = f" Epoch {e} / {num_epochs}\n"
        metrics_msg += "- Train : \n  "
        metrics_msg += "\n  ".join(
            f" {m_name}: {m_value}" for (m_name, m_value) in train_metrics.items()
        )
        metrics_msg += "\n"
        metrics_msg += "- Valid : \n  "
        metrics_msg += "\n  ".join(
            f" {m_name}: {m_value}"
            + ("[>> BETTER <<]" if updated and m_name == checkpoint_metric_name else "")
            for (m_name, m_value) in valid_metrics.items()
        )
        print(metrics_msg)
        
    
100%|██████████| 1250/1250 [01:17<00:00, 16.06it/s]
100%|██████████| 313/313 [00:19<00:00, 16.07it/s]
1250it [02:22,  8.80it/s]
100%|██████████| 313/313 [00:19<00:00, 16.00it/s]
 Epoch 0 / 20
- Train : 
   focal: 0.08842981127202511
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.984560818434984, 0.015439181565016006], [0.14412051680573137, 0.8558794831942687]]
   F1: 0.7784466352306805
   precision: 0.7137056591530477
   recall: 0.8558794831942687
- Valid : 
   focal: 0.05701432123780251
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9918627047558664, 0.008137295244133596], [0.14919320684204732, 0.8508067931579527]]
   F1: 0.8353033018339561[>> BETTER <<]
   precision: 0.8203547115599382
   recall: 0.8508067931579527
1250it [02:21,  8.86it/s]
100%|██████████| 313/313 [00:19<00:00, 16.04it/s]
 Epoch 1 / 20
- Train : 
   focal: 0.03922448422163725
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9904625031552303, 0.009537496844769644], [0.13008181098011845, 0.8699181890198815]]
   F1: 0.8356533342599359
   precision: 0.8039862340299617
   recall: 0.8699181890198815
- Valid : 
   focal: 0.029857855677604676
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9932920011982709, 0.006707998801729129], [0.13468788886477076, 0.8653121111352292]]
   F1: 0.8572113073535104[>> BETTER <<]
   precision: 0.8492607716083391
   recall: 0.8653121111352292
1250it [02:20,  8.87it/s]
100%|██████████| 313/313 [00:19<00:00, 16.21it/s]
 Epoch 2 / 20
- Train : 
   focal: 0.024323564728349446
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9917350603795003, 0.008264939620499722], [0.1359301555745717, 0.8640698444254283]]
   F1: 0.8438503026707409
   precision: 0.8246044818847238
   recall: 0.8640698444254283
- Valid : 
   focal: 0.021647291553020476
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9908758379932292, 0.009124162006770839], [0.13249373759876096, 0.8675062624012391]]
   F1: 0.8355805587025873
   precision: 0.8059212428402837
   recall: 0.8675062624012391
1250it [02:20,  8.88it/s]
100%|██████████| 313/313 [00:19<00:00, 16.35it/s]
 Epoch 3 / 20
- Train : 
   focal: 0.017641140395402908
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9920694084442917, 0.007930591555708258], [0.1376783865567481, 0.8623216134432519]]
   F1: 0.8459301773457703
   precision: 0.8302119059302435
   recall: 0.8623216134432519
- Valid : 
   focal: 0.013542211799323559
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9941612251769757, 0.005838774823024345], [0.1360545160263835, 0.8639454839736165]]
   F1: 0.8649697379272963[>> BETTER <<]
   precision: 0.8659963897237278
   recall: 0.8639454839736165
1250it [02:20,  8.89it/s]
100%|██████████| 313/313 [00:19<00:00, 16.16it/s]
 Epoch 4 / 20
- Train : 
   focal: 0.013638624830171467
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9926297385773518, 0.007370261422648223], [0.12705171000269547, 0.8729482899973046]]
   F1: 0.857121663716136
   precision: 0.8419291206236177
   recall: 0.8729482899973046
- Valid : 
   focal: 0.011095934394001961
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9951373790442328, 0.004862620955767134], [0.16323752516072546, 0.8367624748392746]]
   F1: 0.8590556422492047
   precision: 0.8825692001849259
   recall: 0.8367624748392746
1250it [02:20,  8.89it/s]
100%|██████████| 313/313 [00:19<00:00, 16.47it/s]
 Epoch 5 / 20
- Train : 
   focal: 0.01242508983053267
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9924712930203758, 0.007528706979624189], [0.13744219662504664, 0.8625578033749534]]
   F1: 0.849772129978209
   precision: 0.8374543683141202
   recall: 0.8625578033749534
- Valid : 
   focal: 0.010794548659026623
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9966169482205985, 0.0033830517794014915], [0.2367741356548104, 0.7632258643451896]]
   F1: 0.8292845199497126
   precision: 0.9078616816617145
   recall: 0.7632258643451896
1250it [02:20,  8.91it/s]
100%|██████████| 313/313 [00:19<00:00, 16.42it/s]
 Epoch 6 / 20
- Train : 
   focal: 0.011730377678200602
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9924439163056096, 0.0075560836943904175], [0.14460572334697896, 0.855394276653021]]
   F1: 0.8454502023786447
   precision: 0.8358184544117562
   recall: 0.855394276653021
- Valid : 
   focal: 0.01126092714741826
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9951496462410684, 0.004850353758931638], [0.2583452952711382, 0.7416547047288617]]
   F1: 0.8006160646754155
   precision: 0.8697621565842611
   recall: 0.7416547047288617
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:19<00:00, 16.43it/s]
 Epoch 7 / 20
- Train : 
   focal: 0.00965581653751433
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.992733850456108, 0.0072661495438920845], [0.12477549324825661, 0.8752245067517433]]
   F1: 0.8593696661909014
   precision: 0.8441561661561292
   recall: 0.8752245067517433
- Valid : 
   focal: 0.007794975908100605
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9946186580063453, 0.005381341993654733], [0.12643160849646837, 0.8735683915035316]]
   F1: 0.8749766079043589[>> BETTER <<]
   precision: 0.8763893718034284
   recall: 0.8735683915035316
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:19<00:00, 16.43it/s]
 Epoch 8 / 20
- Train : 
   focal: 0.00840166380442679
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9934070462559824, 0.006592953744017635], [0.11425221017753541, 0.8857477898224646]]
   F1: 0.871604891263012
   precision: 0.857985731486906
   recall: 0.8857477898224646
- Valid : 
   focal: 0.00729588970541954
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9948697971305516, 0.00513020286944847], [0.12204731687660936, 0.8779526831233907]]
   F1: 0.879969894359631[>> BETTER <<]
   precision: 0.8819963965569119
   recall: 0.8779526831233907
1250it [02:20,  8.92it/s]
100%|██████████| 313/313 [00:18<00:00, 16.65it/s]
 Epoch 9 / 20
- Train : 
   focal: 0.00903737815283239
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9928777442356218, 0.007122255764378206], [0.1271027596486579, 0.8728972403513421]]
   F1: 0.8594265863134615
   precision: 0.846423692785223
   recall: 0.8728972403513421
- Valid : 
   focal: 81.92049057636187
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9965938311714925, 0.0034061688285075614], [0.21197029022559302, 0.788029709774407]]
   F1: 0.8446109021563942
   precision: 0.9099457696234862
   recall: 0.788029709774407
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:18<00:00, 16.76it/s]
 Epoch 10 / 20
- Train : 
   focal: 0.010216786290705203
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9920643734216613, 0.00793562657833868], [0.14364471531299328, 0.8563552846870067]]
   F1: 0.8424685543851981
   precision: 0.8291410860471871
   recall: 0.8563552846870067
- Valid : 
   focal: 0.007851592037826776
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9939980920025971, 0.006001907997402858], [0.1307855630352585, 0.8692144369647415]]
   F1: 0.866339992173695
   precision: 0.8634844959853186
   recall: 0.8692144369647415
1250it [02:20,  8.92it/s]
100%|██████████| 313/313 [00:18<00:00, 16.79it/s]
 Epoch 11 / 20
- Train : 
   focal: 0.012365453279763461
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.990900827167916, 0.009099172832083967], [0.18188237161674434, 0.8181176283832556]]
   F1: 0.8097510928015189
   precision: 0.801715631269967
   recall: 0.8181176283832556
- Valid : 
   focal: 0.01211442045941949
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9866634692434963, 0.013336530756503664], [0.11881718346236742, 0.8811828165376325]]
   F1: 0.8060066449038713
   precision: 0.7426491665671002
   recall: 0.8811828165376325
1250it [02:20,  8.91it/s]
100%|██████████| 313/313 [00:19<00:00, 16.44it/s]
 Epoch 12 / 20
- Train : 
   focal: 0.00994514806009829
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9916616401745947, 0.008338359825405267], [0.1430034611996807, 0.8569965388003193]]
   F1: 0.8391456287440654
   precision: 0.8221228212669075
   recall: 0.8569965388003193
- Valid : 
   focal: 0.008191654963046312
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9951566549303158, 0.004843345069684199], [0.17867614593492412, 0.8213238540650759]]
   F1: 0.8501356020769013
   precision: 0.8810422501341807
   recall: 0.8213238540650759
1250it [02:20,  8.89it/s]
100%|██████████| 313/313 [00:18<00:00, 16.58it/s]
 Epoch 13 / 20
- Train : 
   focal: 0.008770923649333417
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9927187649208038, 0.007281235079196158], [0.1266496330985091, 0.8733503669014909]]
   F1: 0.858172658250107
   precision: 0.8436005066422713
   recall: 0.8733503669014909
- Valid : 
   focal: 610.7457845468864
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9932479249673458, 0.006752075032654225], [0.10077234489191923, 0.8992276551080808]]
   F1: 0.8756613779627713
   precision: 0.8532987702372062
   recall: 0.8992276551080808
1250it [02:20,  8.91it/s]
100%|██████████| 313/313 [00:19<00:00, 16.38it/s]
 Epoch 14 / 20
- Train : 
   focal: 0.008078433918114752
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9930969118227039, 0.006903088177296168], [0.11827914878939208, 0.881720851210608]]
   F1: 0.8664157908929995
   precision: 0.85171751596145
   recall: 0.881720851210608
- Valid : 
   focal: 0.008209147394448519
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9926701755096341, 0.0073298244903658525], [0.11140162674575407, 0.888598373254246]]
   F1: 0.8642170277256166
   precision: 0.8411379017261984
   recall: 0.888598373254246
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:18<00:00, 16.63it/s]
 Epoch 15 / 20
- Train : 
   focal: 0.007690220666583627
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9934242070290201, 0.0065757929709798916], [0.11429300379981688, 0.8857069962001831]]
   F1: 0.871741296839137
   precision: 0.8582974003434525
   recall: 0.8857069962001831
- Valid : 
   focal: 3685.4800201559765
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9963474393553177, 0.0036525606446823284], [0.16411204986307656, 0.8358879501369234]]
   F1: 0.8709352461982078
   precision: 0.9090501012175851
   recall: 0.8358879501369234
1250it [02:20,  8.91it/s]
100%|██████████| 313/313 [00:18<00:00, 16.76it/s]
 Epoch 16 / 20
- Train : 
   focal: 0.008115153606235982
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9934597096986847, 0.006540290301315302], [0.1216568022097556, 0.8783431977902444]]
   F1: 0.8679831469771179
   precision: 0.8579400410367266
   recall: 0.8783431977902444
- Valid : 
   focal: 0.007096120508015156
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.99353291155378, 0.006467088446220048], [0.09646189051865214, 0.9035381094813478]]
   F1: 0.8808086676332931[>> BETTER <<]
   precision: 0.8591947160059871
   recall: 0.9035381094813478
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:18<00:00, 16.50it/s]
 Epoch 17 / 20
- Train : 
   focal: 0.007484080256335437
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9937075754398506, 0.006292424560149496], [0.11211651956260728, 0.8878834804373927]]
   F1: 0.8756691379980188
   precision: 0.8638595600334612
   recall: 0.8878834804373927
- Valid : 
   focal: 0.0064941814891994
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9940529750876993, 0.00594702491230073], [0.09126185229101297, 0.908738147708987]]
   F1: 0.8887836537879457[>> BETTER <<]
   precision: 0.8696866703279454
   recall: 0.908738147708987
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:18<00:00, 16.62it/s]
 Epoch 18 / 20
- Train : 
   focal: 0.007285832682531327
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9939532045004701, 0.0060467954995299085], [0.11159671556204534, 0.8884032844379547]]
   F1: 0.8783230073492208
   precision: 0.8685417650328868
   recall: 0.8884032844379547
- Valid : 
   focal: 0.006532670154795051
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9944549133035253, 0.005545086696474739], [0.09599148344789532, 0.9040085165521047]]
   F1: 0.8902232299668931[>> BETTER <<]
   precision: 0.8768520541015115
   recall: 0.9040085165521047
1250it [02:20,  8.90it/s]
100%|██████████| 313/313 [00:19<00:00, 16.43it/s]
 Epoch 19 / 20
- Train : 
   focal: 0.007049820588435978
   accuracy: 0.9569657318115234
   confusion_matrix: [[0.9939999722908326, 0.006000027709167373], [0.10616309726312911, 0.8938369027368709]]
   F1: 0.881788731239684
   precision: 0.8701164061801988
   recall: 0.8938369027368709
- Valid : 
   focal: 0.009244078635424376
   accuracy: 0.9581524917602539
   confusion_matrix: [[0.9921400886423535, 0.007859911357646538], [0.07643128580035363, 0.9235687141996464]]
   F1: 0.8781125640789897
   precision: 0.8369210343202492
   recall: 0.9235687141996464

During training, we recorded metrics into the metrics_store dictionnary and we can display these metrics. This is a very basic way to plot these metrics, at the very end of the lab, we propose wandb.ai which is way more convenient.

In [ ]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15,5), dpi=150)
plt.subplot(1, 4, 1)
plt.plot([mi["focal"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["focal"] for mi in metrics_store["valid"]], 'k--')
plt.title("Focal loss")
plt.xlabel("Epoch")

plt.subplot(1, 4, 2)
plt.plot([mi["F1"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["F1"] for mi in metrics_store["valid"]], 'k--')
plt.title("F1 score")
plt.xlabel("Epoch")

plt.subplot(1, 4, 3)
plt.plot([mi["precision"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["precision"] for mi in metrics_store["valid"]], 'k--')
plt.title("Precision")
plt.xlabel("Epoch")

plt.subplot(1, 4, 4)
plt.plot([mi["recall"] for mi in metrics_store["train"]], 'k-')
plt.plot([mi["recall"] for mi in metrics_store["valid"]], 'k--')
plt.title("Recall")
plt.xlabel("Epoch")
Out[ ]:
Text(0.5, 0, 'Epoch')
No description has been provided for this image

Inference on new zooscan images with non overlapping tiles¶

Now that we trained our first UNet, we would like to apply it on new data, to perform so called "inference". In the training loop above, the model that we consider the best is the one minimizing the F1 score on the validation data. This best model has been saved during the optimization as a ONNX file. There are actually several ways to save a model, for example as either a torch tensor of parameters or a ONNX graph and the ONNX export is certainly the most portable of the two.

ONNX graphs can be executed with any runtime in a lot of different languages, even in a javascript running in a static webpage !

The cell below is a standalone cell. It can be executed without the other cells being executed. This is really the complete inference code.

Also, the patch size for inference is arbitrary. Indeed, as a UNet is a fully convolutional model, it can be trained on patches of shape, say $512 \times 512$, and then seemingly be evaluated on patches of arbitrary sizes, for example $4096 \times 4096$. To test inference, my advice would be to start with small patch sizes and then increase it. You may fill your memory if you use a too large patch size. To perform inference on a large image, a naive approach is to split it into smaller non overlapping patches, perform inference on each and then stick them together.

To run the following cell, you may need to restart your kernel so that the memory gets freed on the GPU.

Also, you can either use the network you trained or use the network I have trained for you, available in the pretrained folder. Just adapt the value of the logdir variable below to choose one or the other.

In [ ]:
import pathlib
import matplotlib.pyplot as plt
import onnxruntime as ort
from PIL import Image
Image.MAX_IMAGE_PIXELS = 25000 * 15000
import numpy as np
import yaml
import glob

providers = []
use_cuda = True
patch_size = 2048

if use_cuda:
    providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")

# You may adapt the following to either use 
# the model you trained  or the pretrained model you are provided
logdir = pathlib.Path("pretrained_model")
#logdir = pathlib.Path("logs")

inference_session = ort.InferenceSession(
    str(logdir / "best_model.onnx"), providers=providers
)

# Load our normalizing statistics
stats = yaml.safe_load(open(str(logdir / "normalizing_stats.yaml"), "r"))
mean = stats["mean"]
std = stats["std"]

# Load our image
scan_files = list(glob.glob("./data/test/*_scan.png.ppm"))
mask_files = [f.replace("_scan.png.ppm", "_mask.png.ppm") for f in scan_files]
test_idx = 0

scan_path = scan_files[test_idx]
scan_img = np.array(Image.open(scan_path))

mask_path = mask_files[test_idx]
mask_img = np.array(Image.open(mask_path))
print(mask_img.shape)

crop_offset = (5048, 2048)
# Normalize our input
scan_img = ((scan_img - mean * 255.)/(std * 255.)).astype(np.float32)
scan_img = scan_img[np.newaxis, np.newaxis, ...]
scan_img = scan_img[:, :, crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)]

# Get the ground truth mask
# print(np.unique(mask_img))
mask_img = mask_img[crop_offset[0]:(crop_offset[0] + patch_size), crop_offset[1]:(crop_offset[1] + patch_size)] >= 8

# Perform an inference
logits = inference_session.run(None, {"scan": scan_img})[0]
probs = 1.0 / (1.0 + np.exp(-logits))

pred_mask = probs >= 0.5

# Plot the results
plt.figure(dpi=300)
plt.subplot(1, 4, 1)
plt.imshow(scan_img.squeeze(), cmap="gray")
plt.title("Zooscan image")
plt.axis("off")

plt.subplot(1, 4, 2)
plt.imshow(probs.squeeze(), interpolation="none", clim=(0.0, 1.0))
plt.title("Probabilities")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(pred_mask.squeeze(), interpolation="none", cmap="tab20c")
plt.title("Predicted Mask")
plt.axis("off")

plt.subplot(1, 4, 4)
plt.imshow(mask_img, interpolation="none", cmap="tab20c")
plt.title("Ground truth")
plt.axis("off")

plt.tight_layout()
(14519, 22777)
No description has been provided for this image

Going further¶

Using DeepLab v3+ instead of the UNet with pretrained backbone¶

All right you implemented and trained a custom U-Net with a pretrained backbone. What about experimenting with a more state of the art model such as a deeplab v3+ ? For this, you have the possibility to use the segmentation models provided by torchvision or rely on an external library such as segmentation models_pytorch.

In [ ]:
%pip install segmentation_models_pytorch
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl (109 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 109.5/109.5 KB 1.4 MB/s eta 0:00:00a 0:00:01
Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: huggingface-hub>=0.24.6 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.25.1)
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 KB 1.6 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: six in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (1.16.0)
Collecting timm==0.9.7
  Downloading timm-0.9.7-py3-none-any.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 5.1 MB/s eta 0:00:00:00:0100:01
Requirement already satisfied: torchvision>=0.5.0 in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (0.19.1)
Requirement already satisfied: pillow in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (10.4.0)
Requirement already satisfied: tqdm in ./venv/lib/python3.10/site-packages (from segmentation_models_pytorch) (4.66.5)
Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (from efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.4.1)
Collecting munch
  Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Requirement already satisfied: safetensors in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (0.4.5)
Requirement already satisfied: pyyaml in ./venv/lib/python3.10/site-packages (from timm==0.9.7->segmentation_models_pytorch) (6.0.2)
Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.16.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (4.12.2)
Requirement already satisfied: packaging>=20.9 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (24.1)
Requirement already satisfied: fsspec>=2023.5.0 in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.9.0)
Requirement already satisfied: requests in ./venv/lib/python3.10/site-packages (from huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.32.3)
Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (from torchvision>=0.5.0->segmentation_models_pytorch) (2.1.1)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.0.106)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (10.3.2.106)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.1.4)
Requirement already satisfied: triton==3.0.0 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.0.0)
Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.13.3)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.3.1)
Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.20.5)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.4.5.107)
Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (3.3)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (9.1.0.70)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.1.105)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./venv/lib/python3.10/site-packages (from torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (11.0.2.54)
Requirement already satisfied: nvidia-nvjitlink-cu12 in ./venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (12.6.77)
Requirement already satisfied: urllib3<3,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (2024.8.30)
Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests->huggingface-hub>=0.24.6->segmentation_models_pytorch) (3.10)
Requirement already satisfied: MarkupSafe>=2.0 in ./venv/lib/python3.10/site-packages (from jinja2->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./venv/lib/python3.10/site-packages (from sympy->torch->efficientnet-pytorch==0.7.1->segmentation_models_pytorch) (1.3.0)
Using legacy 'setup.py install' for efficientnet-pytorch, since package 'wheel' is not installed.
Using legacy 'setup.py install' for pretrainedmodels, since package 'wheel' is not installed.
Installing collected packages: munch, efficientnet-pytorch, timm, pretrainedmodels, segmentation_models_pytorch
  Running setup.py install for efficientnet-pytorch ... done
  Attempting uninstall: timm
    Found existing installation: timm 1.0.9
    Uninstalling timm-1.0.9:
      Successfully uninstalled timm-1.0.9
  Running setup.py install for pretrainedmodels ... done
Successfully installed efficientnet-pytorch-0.7.1 munch-4.0.0 pretrainedmodels-0.7.4 segmentation_models_pytorch-0.3.4 timm-0.9.7
Note: you may need to restart the kernel to use updated packages.

For example, below is a snippet to construct DeepLabV3+ as proposed in Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. You can now go back to your code and experiment with one the models proposed by segmentation_models_pytorch.

In [ ]:
import segmentation_models_pytorch as smp
import torch

model = smp.DeepLabV3Plus(
    encoder_name="resnet18", 
    encoder_weights="imagenet", 
    in_channels=1, 
    classes=1, 
)

dummy_input = torch.zeros(3, 1, 512, 512)
output = model(dummy_input)
print(f"The output has a shape {output.shape}")
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /usr/users/dce-admin/fix/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 11.7MB/s]
The output has a shape torch.Size([3, 1, 512, 512])

Interfacing with an online dashboard : wandb.ai¶

As you may expect, deep learning is about tuning your whole pipeline, experimenting with different models and so on and you need to efficiently track all the experiments you will be running.

One part of the answer lies in using dashboards, especially online dashboard are super useful. One such dashboard is wandb.ai. To monitor your experiments with wandb, you just need to add few lines of code.

First, you need to get an account on wandb. This is free and academics have free plans. Once you have an account, you can create a new project, for which you will obtain both a project name and an entity name.

Then you can connect to your wandb calling

import wandb
wandb.init(project=..., entity=....)

and finally, to record your metrics during training, you simply need to call the following function after every epoch :

wandb.log({my_metric_name: my_metric_value})

In our case, we have two dictionnaries train_metrics and test_metrics which contain our metrics, we just need to call wandb.log on them to get your metrics logged :

wandb dashboard example

Running the code without jupyter¶

Between you and me, a jupyter notebook is great for exploring some concepts but you definitely want to move away from these if you really want to make large scale experiments because running jupyter notebooks is not the most efficient way to perform multiple experiments.

Actually, I adapted a python library I wrote for plankton segmentation so that you are able to evaluate everything from within a notebook.

The library is available on github at https://github.com/jeremyfix/2024_ml4oceans. Feel free to make use of it. You are also provided with a python script submit_slurm.py which allows you to easily run experiments on clusters managed by SLURM (e.g. Jean Zay). Note you will certainly have to adapt the slurm sbatch directives to make it work (e.g. partition name, ...)

In [ ]: