Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
star-kwon committed Nov 25, 2021
0 parents commit 7360533
Show file tree
Hide file tree
Showing 15 changed files with 770 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/TCI_CyclefreeCycleGAN.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# TCI_Cycle-free CycleGAN

## Paper
[Cycle-free CycleGAN using Invertible Generator for Unsupervised Low-Dose CT Denoising][paper link] (IEEE TCI, T. Kwon et al.)

[paper link]: https://ieeexplore.ieee.org/document/9622180


## Sample Data
Public dataset that we used was from [Low Dose CT Grand Challenge][aapm link].

[aapm link]: https://www.aapm.org/grandchallenge/lowdosect/


## Train
You can use train.py for cycle-free CycleGAN.

For train.py,
training input & target data, test input & target data folder directory is required.


## Test
You can use inference.py for pre-trained cycle-free CycleGAN.

For inference.py,
test input & target data folder directory and pre-trained weight file directory is required.
81 changes: 81 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import argparse
import torch
from utils.dataloader import dataloader
from model import Model
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np

parser = argparse.ArgumentParser(description='Cycle-free CycleGAN')

# Invertible Neural Network Architecture
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
parser.add_argument('--in_channel', default=1, type=int, help='channel number of input image')
parser.add_argument('--n_block', default=4, type=int, help='number of stable coupling layers')
parser.add_argument('--squeeze_num', default=2, type=int, help='image divide for coupling layer')
parser.add_argument('--conv_lu', default=True, help='use LU decomposed convolution instead of plain version')
parser.add_argument('--block_type', default='dense', help='simple: Simple block, dense: Dense block, residual: Residual block')

# Training details
parser.add_argument('--gpu_id', default='0', type=str, help='device_ids for training')
parser.add_argument('--total_iter', default=150001, type=int, help='total training iterations')
parser.add_argument('--visualize_iter', default=5000, type=int, help='visualize test result iterations')
parser.add_argument('--save_model_iter', default=5000, type=int, help='save model weights iterations')
parser.add_argument('--test_num', default=421, type=int, help='test set number')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lambda_gan', default=1.0, type=float, help='lambda value for gan loss')
parser.add_argument('--lambda_l1', default=10.0, type=float, help='lambda value for L1 loss')

# Image directories
parser.add_argument('--path_SDCT_test', required=True, type=str, help='Path to target test image folder')
parser.add_argument('--path_LDCT_test', required=True, type=str, help='Path to input test image folder')

# Checkpoint directories
parser.add_argument('--path_weight', required=True, type=str, help='Path to validation weight file')


opt = parser.parse_args()

model = Model(opt)
model.netG.load_state_dict(torch.load(opt.path_weight, map_location=torch.device('cuda:{}'.format(opt.gpu_id))))

model.eval()

LDCT_test_dataset = iter(dataloader(opt.path_LDCT_test, 1, do_shuffle=False))
SDCT_test_dataset = iter(dataloader(opt.path_SDCT_test, 1, do_shuffle=False))

psnr_out = 0
ssim_out = 0

psnr_LDCT = 0
ssim_LDCT = 0
with torch.no_grad():
for num in range(opt.test_num):
LDCT_test, _ = next(LDCT_test_dataset)
SDCT_test, _ = next(SDCT_test_dataset)

model.set_input_val(LDCT_test, SDCT_test)
model.val_forward()

LDCT_numpy = np.squeeze(model.real_LDCT.cpu().numpy())
SDCT_numpy = np.squeeze(model.real_SDCT.cpu().numpy())
output_numpy = np.squeeze(model.fake_SDCT.cpu().numpy())

psnr_LDCT += psnr(SDCT_numpy, LDCT_numpy, data_range=np.amax(SDCT_numpy))
ssim_LDCT += ssim(SDCT_numpy, LDCT_numpy, data_range=np.amax(SDCT_numpy))

psnr_out += psnr(SDCT_numpy, output_numpy, data_range=np.amax(SDCT_numpy))
ssim_out += ssim(SDCT_numpy, output_numpy, data_range=np.amax(SDCT_numpy))

np.save(f'result/inference_numpy/{"output" + str(num+1).zfill(3)}', (output_numpy * 4000))

print("All test data inferenced and saved!")

mean_psnr_out = psnr_out / opt.test_num
mean_ssim_out = ssim_out / opt.test_num

mean_psnr_LDCT = psnr_LDCT / opt.test_num
mean_ssim_LDCT = ssim_LDCT / opt.test_num

print("\nMetrics on test set \t LDCT_PSNR/SSIM: %2.4f / %1.4f \t Output_PSNR/SSIM: %2.4f / %1.4f"
% (mean_psnr_LDCT, mean_ssim_LDCT, mean_psnr_out, mean_ssim_out))
96 changes: 96 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from networks.invertible_generator import InvertibleGenerator
from networks.discriminator import Discriminator_patch
import torch
from torch import nn, optim

class Model():
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda:{}'.format(opt.gpu_id))

self.model_names = ['netG', 'netD']

self.netG = InvertibleGenerator(in_channel=opt.in_channel, n_block=opt.n_block, squeeze_num=opt.squeeze_num, conv_lu=opt.conv_lu, block_type=opt.block_type)
self.netD = Discriminator_patch()

self.netG.to(self.device)
self.netD.to(self.device)

self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr)
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr)

self.scheduler_G = optim.lr_scheduler.StepLR(self.optimizer_G, step_size=50000, gamma=0.5, last_epoch=-1)
self.scheduler_D = optim.lr_scheduler.StepLR(self.optimizer_D, step_size=50000, gamma=0.5, last_epoch=-1)

self.criterion_GAN = nn.MSELoss()
self.criterion_L1 = nn.L1Loss()

def set_input(self, LDCT, SDCT):
self.real_LDCT = LDCT.to(self.device)
self.real_SDCT = SDCT.to(self.device)

def set_input_val(self, LDCT, SDCT):
self.real_LDCT = LDCT.to(self.device)
self.real_SDCT = SDCT.to(self.device)

def forward(self):
self.fake_SDCT = self.netG(self.real_LDCT)

def val_forward(self):
self.fake_SDCT = self.netG(self.real_LDCT)
self.cycle_LDCT = self.netG.inverse(self.fake_SDCT)
self.fake_LDCT = self.netG.inverse(self.real_SDCT)

def backward_G(self):
fake_SDCT_logits = self.netD(self.fake_SDCT)
self.Gan_loss = self.criterion_GAN(fake_SDCT_logits, torch.ones_like(fake_SDCT_logits))
self.L1_loss = self.criterion_L1(self.real_LDCT, self.fake_SDCT)

self.total_g_loss = self.opt.lambda_gan * self.Gan_loss + self.opt.lambda_l1 * self.L1_loss
self.total_g_loss.backward()

def backward_D(self):
pred_real = self.netD(self.real_SDCT)
D_real_loss = self.criterion_GAN(pred_real, torch.ones_like(pred_real))

pred_fake = self.netD(self.fake_SDCT.detach())
D_fake_loss = self.criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

self.total_D_loss = (D_real_loss + D_fake_loss)/2
self.total_D_loss.backward()

def set_requires_grad(self, net, requires_grad=False):
for param in net.parameters():
param.requires_grad = requires_grad

# main loop
def optimize_parameters(self):

self.forward()
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
self.scheduler_G.step()

self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.backward_D()
self.optimizer_D.step()
self.scheduler_D.step()

def test(self):
with torch.no_grad():
self.forward()

def train(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.train()

def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.eval()
47 changes: 47 additions & 0 deletions networks/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torch import nn

class Discriminator_patch(nn.Module):
def __init__(self):
super().__init__()

self.conv1_64 = nn.Conv2d(1,64,4,2,padding=1, bias=False)
self.conv1_64.weight.data.normal_(0, 0.02)

self.conv64_128 = nn.Conv2d(64,128,4,2,padding=1)
self.conv64_128.weight.data.normal_(0, 0.02)
self.conv64_128.bias.data.zero_()

self.batchnorm128 = nn.BatchNorm2d(128)
self.batchnorm128.weight.data.normal_(1.0, 0.02)
self.batchnorm128.bias.data.zero_()

self.conv128_256 = nn.Conv2d(128, 256, 4, 1, padding=1)
self.conv128_256.weight.data.normal_(0, 0.02)
self.conv128_256.bias.data.zero_()

self.batchnorm256 = nn.BatchNorm2d(256)
self.batchnorm256.weight.data.normal_(1.0, 0.02)
self.batchnorm256.bias.data.zero_()

self.conv256_1 = nn.Conv2d(256,1,4,1,padding=1)
self.conv256_1.weight.data.normal_(0, 0.02)
self.conv256_1.bias.data.zero_()

self.leakyrelu = nn.LeakyReLU(0.2)

def forward(self, input):

x = self.conv1_64(input)
x = self.leakyrelu(x)

x = self.conv64_128(x)
x = self.batchnorm128(x)
x = self.leakyrelu(x)

x = self.conv128_256(x)
x = self.batchnorm256(x)
x = self.leakyrelu(x)

x = self.conv256_1(x)

return x
Loading

0 comments on commit 7360533

Please sign in to comment.