Tutorial#

Setting up environment#

[37]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
[38]:
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

from skimage.io import imread, imshow
import seaborn as sns
import cv2
from glob import glob
import pandas as pd
import json
import os
import numpy as np
import os
import torch
import numpy as np
from ELD.utils import (toImg, preprocess, predict_landmarks, create_target_landmarks,
                       create_target_images, download_images_urls, downscale_images, plot_images,
                       mask_background, padImg, crop_non_tissue, downsize_and_save,
                       rescale_landmarks, pad_image_and_adjust_landmarks, corr,plot_warped_images)
from ELD.model import loadFan, crop, toGrey
from ELD.warp import Homo, Rigid, TPS

Downloading data#

[39]:
urlList = [
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep1.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep2.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep3.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep4.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep5_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep7_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep8_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep9_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep10_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep11_MOB.jpg",
    "https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep12_MOB.jpg"
]
imageList = download_images_urls(urlList)

Downscale the data#

[40]:
imageList = downscale_images(imageList)

Make sure the data have the same flip#

[41]:
for i, image in enumerate(imageList):
    if i in [0,8,9,10]:
        image = np.flip(image)

    if i in [1,3,4]:
        image = np.rot90(image,1)
        image = np.flip(image)

    imageList[i] = image
[42]:
plot_images(imageList)
../_images/notebooks_tutorial_10_0.png

Mask background#

[43]:
imageList = mask_background(imageList)
[44]:
plot_images(imageList)
../_images/notebooks_tutorial_13_0.png

Crop non-tissue regions#

[45]:
imageList = crop_non_tissue(imageList)
[46]:
plot_images(imageList)
../_images/notebooks_tutorial_16_0.png

Resize images to 128x128 and save it for training#

[47]:
small_imgs = downsize_and_save(imageList, "/data/ekvall/tutorial/")

[48]:
plot_images(small_imgs)
../_images/notebooks_tutorial_19_0.png

Train model#

python -m visdom.server -port 9006

eld-train --elastic_sigma 5 --cuda 1 --port 9006 --data_path /data/ekvall/tutorial/ --npts 14 --o scratch --step_size 5 --ws 0 --gamma 0.9 --angle 8 --model unimodal

Preprocess images into torch tensors#

[49]:
image = torch.stack([preprocess(img) for img in small_imgs])

Load model and predict landmarks#

[50]:
fan = loadFan(npoints=14,n_channels=3,path_to_model="../Exp_2/model_158.fan.pth")
#predict landmarks
pts = predict_landmarks(fan, image)

Show landmarks#

[51]:
#combine landmarks and image
np_img = toImg(image.cuda()[:,:3], pts, 128)


fig, axs = plt.subplots(3, 4, figsize=(15, 10))  # adjust the size as needed
axs = axs.ravel()

for i in range(len(np_img)):
    img = np_img[i]
    axs[i].imshow(img)
    axs[i].set_title(f"Image {i+1}")
    axs[i].axis('off')  # to hide the axis

plt.tight_layout()
plt.show()
../_images/notebooks_tutorial_27_0.png

Scale back landmarks to original images#

[52]:


scaled_pts = rescale_landmarks(pts, imageList)

Zero pad all images so they have the same size#

[53]:
padded_images_torch, adjusted_landmarks = pad_image_and_adjust_landmarks(imageList, scaled_pts)

Plot original shaped images with their landmarks#

[54]:
np_img = toImg(padded_images_torch.cuda()[:,:3], adjusted_landmarks, 5 * 128)

fig, axs = plt.subplots(3, 4, figsize=(15, 10))  # adjust the size as needed
axs = axs.ravel()

for i in range(len(np_img)):
    img = np_img[i]
    axs[i].imshow(img)
    axs[i].set_title(f"Image {i+1}")
    axs[i].axis('off')  # to hide the axis

plt.tight_layout()
plt.show()
../_images/notebooks_tutorial_33_0.png

Create destination image and landmarks for registration#

[55]:
image = padded_images_torch
dst_image = create_target_images(image, 0)

pts = adjusted_landmarks
dst_pts = create_target_landmarks(pts, 0)

Homography#

[56]:
homo_transform = Homo()
[57]:
#Warp image
mapped_imgs = homo_transform.warp_img(image.cuda(), pts, dst_pts, size=863)
#Warp landmarks
mapped_pts = homo_transform.warp_pts(pts, dst_pts, pts)
[58]:
homo_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[59]:
plot_warped_images(mapped_imgs, mapped_pts, homo_loss, 5 * 128, 'Homography')
../_images/notebooks_tutorial_40_0.png

Rigid transformation#

[60]:
rigid_transform = Rigid()
[61]:
#warp images
mapped_imgs = rigid_transform.warp_img(image.cuda(), pts, dst_pts, (863, 863))
#warp landmarks
mapped_pts = rigid_transform.warp_pts(pts, dst_pts, pts)
[62]:
#rigid_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
rigid_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[63]:
plot_warped_images(mapped_imgs, mapped_pts, rigid_loss, 5 * 128, 'Rigid')
../_images/notebooks_tutorial_45_0.png

Affine transform#

[64]:
tps_transform = TPS()
[65]:
mapped_imgs = tps_transform.warp_img(image.cuda(), pts, dst_pts, reg=1e20, norm=True, size=863)
mapped_pts = tps_transform.warp_pts(pts, dst_pts, pts, reg=1e20)
[66]:
#affine_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
affine_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[67]:
plot_warped_images(mapped_imgs, mapped_pts, affine_loss, 5 * 128, 'Affine')
../_images/notebooks_tutorial_50_0.png

Thin-plate splines#

[68]:
mapped_imgs = tps_transform.warp_img(image.cuda(), pts, dst_pts, reg=0, norm=True, size=863)
mapped_pts = tps_transform.warp_pts(pts, dst_pts, pts, reg=0)
[69]:
#tps_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
tps_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[70]:
plot_warped_images(mapped_imgs, mapped_pts, tps_loss, 5 * 128, 'TPS')
../_images/notebooks_tutorial_54_0.png

Comparision of registration methods#

[71]:
rigid_loss, affine_loss, tps_loss, homo_loss = rigid_loss.tolist(), affine_loss.tolist(), tps_loss.tolist(), homo_loss.tolist()
[72]:
# Combine all loss lists and create a corresponding list of method names
losses = rigid_loss + affine_loss + tps_loss + homo_loss
methods = ['Rigid']*len(rigid_loss) + ['Affine']*len(affine_loss) + ['TPS']*len(tps_loss) + ['Homography']*len(homo_loss)

# Combine all loss lists and create a corresponding list of method names
losses = rigid_loss + affine_loss + tps_loss + homo_loss
methods = ['Rigid']*len(rigid_loss) + ['Affine']*len(affine_loss) + ['TPS']*len(tps_loss) + ['Homography']*len(homo_loss)

# Create a DataFrame
data = pd.DataFrame({'Method': methods, 'Loss': losses})

# Define a custom color palette with unique colors for each method
custom_palette = sns.color_palette("husl", n_colors=len(set(methods)))

# Create a dictionary to map each method to a unique color
method_color_dict = {method: color for method, color in zip(set(methods), custom_palette)}

# Map each method to its corresponding color in the DataFrame
data['Color'] = data['Method'].map(method_color_dict)

# Create a scatter plot using Seaborn with unique colors for each method
plt.figure(figsize=(10, 6))
sns.scatterplot(data=data, x='Method', y='Loss', hue='Method', palette=method_color_dict, alpha=0.6)

plt.title('Comparison of Correlation by Method')
plt.xlabel('Method')
plt.ylabel('Loss')
plt.grid(True)
plt.legend(title='Method', loc='lower left')

# Add some space on the right for the legend
plt.subplots_adjust(right=1)


plt.show()
../_images/notebooks_tutorial_57_0.png