In this notebook, we'll implement a 3-dimensional UNet image segmentation model in order to predict brain tumor regions from MRI scan data.
We'll use the Training dataset from the 2020 BraTS (Brain Tumor Segmentation) Challenge, which ran in conjunction with the 23rd annual International Conference on Medical Image Computing & Computer Assisted Intervention (MICCAI).
Each sample in the dataset consists of 3-dimensional MRI images for native (T1), post-contrast T1-weighted (T1-CE), T2-weighted (T2), T2 Fluid Attenuated Inversion Recovery (T2-FLAIR) volumes as well as a ground truth segmentation mask indicating a class label for each voxel. For our study, we omit T1 and only use T1-CE, T2, and FLAIR channels.
For more information about the challenge and the dataset, please visit the BraTS 2020 website.
import numpy as np
import os
import glob
from tqdm.auto import tqdm
import random
from IPython.display import HTML
import torch
import matplotlib.pyplot as plt
import pickle
import pandas as pd
random_seed = 42
# A useful function for re-seeding all pseudorandom generators
def set_seed(seed = random_seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
Before proceeding, make sure you have already:
BraTS2020_TrainingData in the same directory as this notebook. We will not use the samples in BraTS2020_ValidationData as segmentation masks (i.e. targets) are not available for these samples - they were held internally in the BraTS system for online validation.# DATA_PATH = 'BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
# mask_filenames = sorted(glob.glob(DATA_PATH+'*/*seg.nii'))
# path_prefixes = [f.replace('seg.nii','') for f in mask_filenames]
The following block of code preprocesses the samples from the BraTS2020_TrainingData directory and saves preprocessed files as Numpy arrays. Later, our Pytorch DataLoaders will load data from these arrays.
After initial cropping, each segmentation mask will have shape (160,192,128) and its entries take values in {0,1,2,3} which stand for the following:
We'll train our network for a common segmentation task in this domain which consists of delineating the following three semantically meaningful tumor parts:
We use a function encode_mask to encode an original mask of shape (160,192,128) to an encoded binary mask of shape (3,160,192,128) where the three binary entries of the vector encoded_mask[:,i,j,k] indicate whether the voxel (i,j,k) from the original mask is classified as ET, TC, WT respectively. Note that with these labels we are considering a multi-class classification problem - e.g. if a voxel originally had label 1, then its vector in this encoded mask is (0,1,1).
The original mask can be recovered using the decode_mask function.
So we'll do the following:
# from lib.image_utils import get_mask, encode_mask, get_stacked_image,pad_crop_to_size, crop_border
# counter = 0
# shapes = {}
# for prefix in tqdm(path_prefixes):
# # Load mask and stacked image
# mask = get_mask(prefix)
# image = get_stacked_image(prefix)
# # Remove border of zeros, and use 0-padding
# # to round all physical dimensions upto next
# # multiple of 16
# image,mask = crop_border(image,mask)
# image,mask = pad_crop_to_size(image,mask)
# shape = image.shape
# if (shape[1] > 160)|(shape[2] > 192)|(shape[3] > 160):
# print(f'Image {counter} shape: {shape}')
# shapes[counter] = shape
# # Each mask has entries in {0,1,2,3}. For a mask with
# # shape (H,W,D), one-hot-encode to shape (3,H,W,D)
# # with corresponding to labels ET, TC, WT
# mask = encode_mask(mask)
# np.save(f'masks/mask_{counter}.npy',mask)
# np.save(f'images/image_{counter}.npy',image)
# counter += 1
Image 339 has a reasonable shape, but visual inspection reveals that image 340 has erroneous nonzero voxels near the edges. Manually crop that one down to shape (3,144,160,144).
# image = np.load('images/image_340.npy')
# mask = np.load('masks/mask_340.npy')
# image = image[:,46:190,47:207,:]
# mask = mask[:,46:190,47:207,:]
# np.save('images/image_340.npy',image)
# np.save('masks/mask_340.npy',mask)
Now that we've saved images and masks from samples as Numpy arrays, we visually inspect some samples.
import matplotlib.animation as animation
from lib.plot import plot_sample, animate_sample
from lib.image_utils import decode_mask
The function plot_sample can be used to produce a figure consisting of subplots of T1-CE, T2, FLAIR, and segmentation mask at a specified axial slice [:,:,nslice]
Note that we decode our masks before plotting, i.e. the color-coding in the segmentation mask plots represents the original segmentation labels 0 (unlabelled), 1 (NCR/NET), 2 (ED), 3 (ET).
set_seed()
img_list = os.listdir('images')
idx = np.random.choice(len(img_list))
image = np.load(f'images/image_{idx}.npy')
mask = np.load(f'masks/mask_{idx}.npy')
mask = decode_mask(mask)
nslice = np.random.randint(0,mask.shape[2])
plot_sample(image,mask,nslice=nslice)
Slice 92 out of 160:
The function animate_sample can be used to produce a figure consisting of subplots of T1-CE, T2, FLAIR, and segmentation mask which animates over the entire range of axial slices.
img_list = os.listdir('images')
idx = np.random.choice(len(img_list))
image = np.load(f'images/image_{idx}.npy')
mask = np.load(f'masks/mask_{idx}.npy')
mask = decode_mask(mask)
ani = animate_sample(image,mask)
HTML(ani.to_jshtml())