Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why cycleconsistency is not robust #9

Open
JunMa11 opened this issue Jan 24, 2023 · 5 comments
Open

why cycleconsistency is not robust #9

JunMa11 opened this issue Jan 24, 2023 · 5 comments

Comments

@JunMa11
Copy link

JunMa11 commented Jan 24, 2023

Dear @suxuann ,

Thanks for sharing the awesome work.

I tried DDIB for modality transfer: CT image to MR image.

I trained a CT model and an MR model on my own dataset based on guided-diffusion, respectively. I have verified that they can generate good samples.

Then, I tried cycle consistency and modality transfer CT-> MR. However, the cycle consistency is not robust and the transferred MR images have very different structures.

Here are some examples:

ct_0

ct_1

ct_2

ct_3

What could be the possible reason? Any comments are highly appreciated.

@JunMa11
Copy link
Author

JunMa11 commented Jan 24, 2023

I'm also attaching the code

import argparse
import numpy as np
import os
join = os.path.join
import pathlib
import torch.distributed as dist
from skimage import io, color
import torch
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
    create_model_and_diffusion,
    args_to_dict
)
import matplotlib.pyplot as plt

def create_argparser():
    defaults = dict(
        image_size=256,
        batch_size=1,
        num_channels=64,
        num_res_blocks=3,
        num_heads=1,
        diffusion_steps=1000,
        noise_schedule='linear',
        lr=1e-4,
        clip_denoised=False,
        num_samples=1, # 10000
        use_ddim=True,
        # timestep_respacing='ddim250',
        model_path="",
    )
    ori = model_and_diffusion_defaults()
    # defaults.update(model_and_diffusion_defaults())
    ori.update(defaults)
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, ori)
    return parser

# def main():
args = create_argparser().parse_args()

logger.log(f"args: {args}")

dist_util.setup_dist()
logger.configure(dir='./log')

code_folder = './'
# data_folder = './datasets' # get_code_and_dataset_folders()


#%% load model
def read_model_and_diffusion(args, model_path):
    """Reads the latest model from the given directory."""

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cuda"))
    model.to(dist_util.dev())
    # if args.use_fp16:
    #     model.convert_to_fp16()
    model.eval()
    return model, diffusion

ct_model_path =  './work_dir/abdomenCT256/ema_0.9999_480000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)
mr_model_path = './work_dir/abdomenMR256/ema_0.9999_480000.pt'
t_model, t_diffusion = read_model_and_diffusion(args, mr_model_path)
save_path = './log'
#%% translate image
s_img_path = './demo-img'
names = sorted(os.listdir(s_img_path))
# names = ['ct_ori.png']
def sample2img(sample):
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()[0]
    
    return sample

for name in names:
    ct_data = io.imread(join(s_img_path, name))

    s_np = ct_data / np.max(ct_data)
    s_np = (s_np - 0.5) * 2.0
    # s_np = np.repeat(np.expand_dims(s_np, -1), 3, -1)
    assert s_np.shape == (256, 256, 3), 'shape error! Current shape' + ct_data.shape
    s_np = np.expand_dims(s_np, 0)
    
    source = torch.from_numpy(s_np.astype(np.float32)).permute(0,3,1,2).to('cuda')
    # print(f"{source.shape=}")
    noise = s_diffusion.ddim_reverse_sample_loop(
        s_model, source,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    source_recon = s_diffusion.ddim_sample_loop(
        s_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    target = t_diffusion.ddim_sample_loop(
        t_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    #%% plot
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,8))
    images = [ct_data, color.rgb2gray(sample2img(noise)), sample2img(source_recon), sample2img(target)]
    titles = ['CT image', 'CT noise encode', \
        'CT reconstruction', 'CT2MR']
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(titles[i])
        ax.axis('off')
    plt.suptitle(name)

    plt.savefig(join(save_path, name), dpi=300)

@suxuann
Copy link
Owner

suxuann commented Jan 24, 2023

Hi Jun, thanks for your interests in our work, and attempting to validate our method on CT & MR images.

DDIBs translate images via a (regularized) optimal transport process. This is both an advantage and a limitation of our method. Training diffusion models on the two domains, independently, serves to decouple the training process; but the resulting optimal-transport based translation process may not necessarily produce images that you desire.

You can refer to Appendix B of our paper: https://arxiv.org/pdf/2203.08382.pdf, for detailed explanations about the phenomenon you observe. Let us know if you have additional questions!

@JunMa11
Copy link
Author

JunMa11 commented Jan 27, 2023

Hi @suxuann ,

Thanks for your answer very much.
Now I understand the reason for the 2nd question.

Could you please explain the following question a little bit?

Why is cycle consistency (the noise encoding cannot reconstruct the original image) not robust? Base on the proof, it should be robust for different images.

@leoil
Copy link

leoil commented Mar 10, 2023

Hi @JunMa11 , I'd like to ask some questions about model training.
If I want to train a new model on my own dataset, just like your . /work_dir/abdomenCT256/ema_0.9999_480000.pt
Could you please tell me how I should prepare the training script?

@yang1173350896
Copy link

Hi @JunMa11 ,
I tried to reconstruct the original MR as well, but my reconstruction has a color problem.
I tried to normalize the image to [0,1], but it still can't reconstruct the original image.
Could you please tell me what could be the possible reason?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants