Tutorial#
Setting up environment#
[37]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
[38]:
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from skimage.io import imread, imshow
import seaborn as sns
import cv2
from glob import glob
import pandas as pd
import json
import os
import numpy as np
import os
import torch
import numpy as np
from ELD.utils import (toImg, preprocess, predict_landmarks, create_target_landmarks,
create_target_images, download_images_urls, downscale_images, plot_images,
mask_background, padImg, crop_non_tissue, downsize_and_save,
rescale_landmarks, pad_image_and_adjust_landmarks, corr,plot_warped_images)
from ELD.model import loadFan, crop, toGrey
from ELD.warp import Homo, Rigid, TPS
Downloading data#
[39]:
urlList = [
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep1.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep2.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep3.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep4.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep5_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep7_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep8_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep9_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep10_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep11_MOB.jpg",
"https://9b0ce2.p3cdn1.secureserver.net/wp-content/uploads/2016/07/HE_Rep12_MOB.jpg"
]
imageList = download_images_urls(urlList)
Downscale the data#
[40]:
imageList = downscale_images(imageList)
Make sure the data have the same flip#
[41]:
for i, image in enumerate(imageList):
if i in [0,8,9,10]:
image = np.flip(image)
if i in [1,3,4]:
image = np.rot90(image,1)
image = np.flip(image)
imageList[i] = image
[42]:
plot_images(imageList)
Mask background#
[43]:
imageList = mask_background(imageList)
[44]:
plot_images(imageList)
Crop non-tissue regions#
[45]:
imageList = crop_non_tissue(imageList)
[46]:
plot_images(imageList)
Resize images to 128x128 and save it for training#
[47]:
small_imgs = downsize_and_save(imageList, "/data/ekvall/tutorial/")
[48]:
plot_images(small_imgs)
Train model#
python -m visdom.server -port 9006
eld-train --elastic_sigma 5 --cuda 1 --port 9006 --data_path /data/ekvall/tutorial/ --npts 14 --o scratch --step_size 5 --ws 0 --gamma 0.9 --angle 8 --model unimodal
Preprocess images into torch tensors#
[49]:
image = torch.stack([preprocess(img) for img in small_imgs])
Load model and predict landmarks#
[50]:
fan = loadFan(npoints=14,n_channels=3,path_to_model="../Exp_2/model_158.fan.pth")
#predict landmarks
pts = predict_landmarks(fan, image)
Show landmarks#
[51]:
#combine landmarks and image
np_img = toImg(image.cuda()[:,:3], pts, 128)
fig, axs = plt.subplots(3, 4, figsize=(15, 10)) # adjust the size as needed
axs = axs.ravel()
for i in range(len(np_img)):
img = np_img[i]
axs[i].imshow(img)
axs[i].set_title(f"Image {i+1}")
axs[i].axis('off') # to hide the axis
plt.tight_layout()
plt.show()
Scale back landmarks to original images#
[52]:
scaled_pts = rescale_landmarks(pts, imageList)
Zero pad all images so they have the same size#
[53]:
padded_images_torch, adjusted_landmarks = pad_image_and_adjust_landmarks(imageList, scaled_pts)
Plot original shaped images with their landmarks#
[54]:
np_img = toImg(padded_images_torch.cuda()[:,:3], adjusted_landmarks, 5 * 128)
fig, axs = plt.subplots(3, 4, figsize=(15, 10)) # adjust the size as needed
axs = axs.ravel()
for i in range(len(np_img)):
img = np_img[i]
axs[i].imshow(img)
axs[i].set_title(f"Image {i+1}")
axs[i].axis('off') # to hide the axis
plt.tight_layout()
plt.show()
Create destination image and landmarks for registration#
[55]:
image = padded_images_torch
dst_image = create_target_images(image, 0)
pts = adjusted_landmarks
dst_pts = create_target_landmarks(pts, 0)
Homography#
[56]:
homo_transform = Homo()
[57]:
#Warp image
mapped_imgs = homo_transform.warp_img(image.cuda(), pts, dst_pts, size=863)
#Warp landmarks
mapped_pts = homo_transform.warp_pts(pts, dst_pts, pts)
[58]:
homo_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[59]:
plot_warped_images(mapped_imgs, mapped_pts, homo_loss, 5 * 128, 'Homography')
Rigid transformation#
[60]:
rigid_transform = Rigid()
[61]:
#warp images
mapped_imgs = rigid_transform.warp_img(image.cuda(), pts, dst_pts, (863, 863))
#warp landmarks
mapped_pts = rigid_transform.warp_pts(pts, dst_pts, pts)
[62]:
#rigid_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
rigid_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[63]:
plot_warped_images(mapped_imgs, mapped_pts, rigid_loss, 5 * 128, 'Rigid')
Affine transform#
[64]:
tps_transform = TPS()
[65]:
mapped_imgs = tps_transform.warp_img(image.cuda(), pts, dst_pts, reg=1e20, norm=True, size=863)
mapped_pts = tps_transform.warp_pts(pts, dst_pts, pts, reg=1e20)
[66]:
#affine_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
affine_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[67]:
plot_warped_images(mapped_imgs, mapped_pts, affine_loss, 5 * 128, 'Affine')
Thin-plate splines#
[68]:
mapped_imgs = tps_transform.warp_img(image.cuda(), pts, dst_pts, reg=0, norm=True, size=863)
mapped_pts = tps_transform.warp_pts(pts, dst_pts, pts, reg=0)
[69]:
#tps_loss = corr(*crop(mapped_imgs, dst_image.cuda())).cpu().numpy()[1:]
tps_loss = corr(mapped_imgs, dst_image.cuda()).cpu().numpy()[1:]
[70]:
plot_warped_images(mapped_imgs, mapped_pts, tps_loss, 5 * 128, 'TPS')
Comparision of registration methods#
[71]:
rigid_loss, affine_loss, tps_loss, homo_loss = rigid_loss.tolist(), affine_loss.tolist(), tps_loss.tolist(), homo_loss.tolist()
[72]:
# Combine all loss lists and create a corresponding list of method names
losses = rigid_loss + affine_loss + tps_loss + homo_loss
methods = ['Rigid']*len(rigid_loss) + ['Affine']*len(affine_loss) + ['TPS']*len(tps_loss) + ['Homography']*len(homo_loss)
# Combine all loss lists and create a corresponding list of method names
losses = rigid_loss + affine_loss + tps_loss + homo_loss
methods = ['Rigid']*len(rigid_loss) + ['Affine']*len(affine_loss) + ['TPS']*len(tps_loss) + ['Homography']*len(homo_loss)
# Create a DataFrame
data = pd.DataFrame({'Method': methods, 'Loss': losses})
# Define a custom color palette with unique colors for each method
custom_palette = sns.color_palette("husl", n_colors=len(set(methods)))
# Create a dictionary to map each method to a unique color
method_color_dict = {method: color for method, color in zip(set(methods), custom_palette)}
# Map each method to its corresponding color in the DataFrame
data['Color'] = data['Method'].map(method_color_dict)
# Create a scatter plot using Seaborn with unique colors for each method
plt.figure(figsize=(10, 6))
sns.scatterplot(data=data, x='Method', y='Loss', hue='Method', palette=method_color_dict, alpha=0.6)
plt.title('Comparison of Correlation by Method')
plt.xlabel('Method')
plt.ylabel('Loss')
plt.grid(True)
plt.legend(title='Method', loc='lower left')
# Add some space on the right for the legend
plt.subplots_adjust(right=1)
plt.show()