Multimodal#

[1]:
import sys
sys.path.append('../')
[2]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from utils import circle_size, colors, toImg, preprocess
from model import loadMultiModalFan#, predict_landmarks
from torchgeometry.contrib import spatial_soft_argmax2d
from glob import glob
import tifffile as tiff
[3]:
def load_imgs(path):
    #get files with glob
    files = glob(f"{path}*")
    #sort files by number


    mod0, mod1 = [], []
    for f in files:
        if "mod0" in f:
            mod0.append(f)
        elif "mod1" in f:
            mod1.append(f)

    #load images
    if mod0[0].endswith('.tif'):
        mod0 = [tiff.imread(f) for f in mod0]
    else:
        mod0 = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in mod0]

    #load images
    if mod1[0].endswith('.tif'):
        mod1 = [tiff.imread(f) for f in mod1]
    else:
        mod1 = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in mod1]


    return mod0, mod1

[4]:
#You can download the dataset from here: https://figshare.com/projects/ELD/167318
[5]:
#set your path here
PATH = "../../marcoAnalysis/"

Developing Heart#

[6]:
inpath = f'{PATH}dev_heart_mixed/'
[7]:
#load imgs and move to tensor
mod0,mod1 = load_imgs(inpath)
mod0 = torch.stack([preprocess(img) for img in mod0])
mod1 = torch.stack([preprocess(img) for img in mod1])

You can can either train a model with:

python train.py --elastic_sigma 3.5 --cuda 0 --port 9100 --data_path ../marcoAnalysis/CODA_prostate/ --npts 20 --o scratch --ws 10_000 --angle 10 --model 3d

Or download a test model at: https://figshare.com/projects/ELD/167318

[8]:
#load model
fan = loadMultiModalFan(npoints=20,n_channels=3,path_to_model="../models/multimodal/dev_heart/model_40.fan.pth")
[9]:
def make_mask(img, is_true=True):
    print(is_true)
    if is_true:
        return torch.ones(img.shape[0], dtype=torch.bool)
    else:
        return torch.zeros(img.shape[0], dtype=torch.bool)



def predict_landmarks(fan, image, ismod0=True):
    with torch.no_grad():

        fan.eval()

        with torch.no_grad():
            pts = 4 * spatial_soft_argmax2d(fan(image.cuda(), make_mask(image, ismod0).cuda())[0], False)
    return pts
[10]:
#predict landmarks
pts_mod0 = predict_landmarks(fan, mod0, ismod0=True)
pts_mod1 = predict_landmarks(fan, mod1, ismod0=False)
#combine landmarks and image
#np_img = toImg(image.cuda()[:,:3], pts, 128)
np_mod0 = toImg(mod0.cuda()[:,:3], pts_mod0, 128)
np_mod1 = toImg(mod1.cuda()[:,:3], pts_mod1, 128)

True
False
[11]:
images = np.concatenate([np_mod0, np_mod1], axis=0)
[12]:
#plot 4 images images
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax[0,0].imshow(images[0])
ax[0,1].imshow(images[1])
ax[1,0].imshow(images[2])
ax[1,1].imshow(images[3])

[12]:
<matplotlib.image.AxesImage at 0x7f31c5163d68>
../_images/notebooks_multimodal_14_1.png

SMA#

[13]:
inpath = f'{PATH}SMA_multi_modal_expression/'
[14]:
#load imgs and move to tensor
mod0,mod1 = load_imgs(inpath)
mod0 = torch.stack([preprocess(img) for img in mod0])
mod1 = torch.stack([preprocess(img) for img in mod1])

[15]:
!ls ../Exp_24
args__0.pkl   args__2.pkl  args__9.pkl       model_13.fan.pth  model_6.fan.pth
args__10.pkl  args__3.pkl  args_.txt         model_14.fan.pth  model_7.fan.pth
args__11.pkl  args__4.pkl  code              model_1.fan.pth   model_8.fan.pth
args__12.pkl  args__5.pkl  model_0.fan.pth   model_2.fan.pth   model_9.fan.pth
args__13.pkl  args__6.pkl  model_10.fan.pth  model_3.fan.pth
args__14.pkl  args__7.pkl  model_11.fan.pth  model_4.fan.pth
args__1.pkl   args__8.pkl  model_12.fan.pth  model_5.fan.pth
[16]:
fan = loadMultiModalFan(npoints=10,n_channels=3,path_to_model="../models/multimodal/sma/model_10.fan.pth")
[17]:
#predict landmarks
pts_mod0 = predict_landmarks(fan, mod0, ismod0=True)
pts_mod1 = predict_landmarks(fan, mod1, ismod0=False)
np_mod0 = toImg(mod0.cuda()[:,:3], pts_mod0, 128)
np_mod1 = toImg(mod1.cuda()[:,:3], pts_mod1, 128)

True
False
[18]:
images = np.concatenate([np_mod0, np_mod1], axis=0)
[19]:
#plot 6 imgs images

fig, ax = plt.subplots(2,3, figsize=(10,10))
ax[0,0].imshow(images[0])
ax[0,1].imshow(images[1])
ax[0,2].imshow(images[2])
ax[1,0].imshow(images[3])
ax[1,1].imshow(images[5])
ax[1,2].imshow(images[4])

[19]:
<matplotlib.image.AxesImage at 0x7f374bedbfd0>
../_images/notebooks_multimodal_22_1.png