-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# TCI_Cycle-free CycleGAN | ||
|
||
## Paper | ||
[Cycle-free CycleGAN using Invertible Generator for Unsupervised Low-Dose CT Denoising][paper link] (IEEE TCI, T. Kwon et al.) | ||
|
||
[paper link]: https://ieeexplore.ieee.org/document/9622180 | ||
|
||
|
||
## Sample Data | ||
Public dataset that we used was from [Low Dose CT Grand Challenge][aapm link]. | ||
|
||
[aapm link]: https://www.aapm.org/grandchallenge/lowdosect/ | ||
|
||
|
||
## Train | ||
You can use train.py for cycle-free CycleGAN. | ||
|
||
For train.py, | ||
training input & target data, test input & target data folder directory is required. | ||
|
||
|
||
## Test | ||
You can use inference.py for pre-trained cycle-free CycleGAN. | ||
|
||
For inference.py, | ||
test input & target data folder directory and pre-trained weight file directory is required. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import argparse | ||
import torch | ||
from utils.dataloader import dataloader | ||
from model import Model | ||
from skimage.metrics import structural_similarity as ssim | ||
from skimage.metrics import peak_signal_noise_ratio as psnr | ||
import numpy as np | ||
|
||
parser = argparse.ArgumentParser(description='Cycle-free CycleGAN') | ||
|
||
# Invertible Neural Network Architecture | ||
parser.add_argument('--batch_size', default=1, type=int, help='batch size') | ||
parser.add_argument('--in_channel', default=1, type=int, help='channel number of input image') | ||
parser.add_argument('--n_block', default=4, type=int, help='number of stable coupling layers') | ||
parser.add_argument('--squeeze_num', default=2, type=int, help='image divide for coupling layer') | ||
parser.add_argument('--conv_lu', default=True, help='use LU decomposed convolution instead of plain version') | ||
parser.add_argument('--block_type', default='dense', help='simple: Simple block, dense: Dense block, residual: Residual block') | ||
|
||
# Training details | ||
parser.add_argument('--gpu_id', default='0', type=str, help='device_ids for training') | ||
parser.add_argument('--total_iter', default=150001, type=int, help='total training iterations') | ||
parser.add_argument('--visualize_iter', default=5000, type=int, help='visualize test result iterations') | ||
parser.add_argument('--save_model_iter', default=5000, type=int, help='save model weights iterations') | ||
parser.add_argument('--test_num', default=421, type=int, help='test set number') | ||
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') | ||
parser.add_argument('--lambda_gan', default=1.0, type=float, help='lambda value for gan loss') | ||
parser.add_argument('--lambda_l1', default=10.0, type=float, help='lambda value for L1 loss') | ||
|
||
# Image directories | ||
parser.add_argument('--path_SDCT_test', required=True, type=str, help='Path to target test image folder') | ||
parser.add_argument('--path_LDCT_test', required=True, type=str, help='Path to input test image folder') | ||
|
||
# Checkpoint directories | ||
parser.add_argument('--path_weight', required=True, type=str, help='Path to validation weight file') | ||
|
||
|
||
opt = parser.parse_args() | ||
|
||
model = Model(opt) | ||
model.netG.load_state_dict(torch.load(opt.path_weight, map_location=torch.device('cuda:{}'.format(opt.gpu_id)))) | ||
|
||
model.eval() | ||
|
||
LDCT_test_dataset = iter(dataloader(opt.path_LDCT_test, 1, do_shuffle=False)) | ||
SDCT_test_dataset = iter(dataloader(opt.path_SDCT_test, 1, do_shuffle=False)) | ||
|
||
psnr_out = 0 | ||
ssim_out = 0 | ||
|
||
psnr_LDCT = 0 | ||
ssim_LDCT = 0 | ||
with torch.no_grad(): | ||
for num in range(opt.test_num): | ||
LDCT_test, _ = next(LDCT_test_dataset) | ||
SDCT_test, _ = next(SDCT_test_dataset) | ||
|
||
model.set_input_val(LDCT_test, SDCT_test) | ||
model.val_forward() | ||
|
||
LDCT_numpy = np.squeeze(model.real_LDCT.cpu().numpy()) | ||
SDCT_numpy = np.squeeze(model.real_SDCT.cpu().numpy()) | ||
output_numpy = np.squeeze(model.fake_SDCT.cpu().numpy()) | ||
|
||
psnr_LDCT += psnr(SDCT_numpy, LDCT_numpy, data_range=np.amax(SDCT_numpy)) | ||
ssim_LDCT += ssim(SDCT_numpy, LDCT_numpy, data_range=np.amax(SDCT_numpy)) | ||
|
||
psnr_out += psnr(SDCT_numpy, output_numpy, data_range=np.amax(SDCT_numpy)) | ||
ssim_out += ssim(SDCT_numpy, output_numpy, data_range=np.amax(SDCT_numpy)) | ||
|
||
np.save(f'result/inference_numpy/{"output" + str(num+1).zfill(3)}', (output_numpy * 4000)) | ||
|
||
print("All test data inferenced and saved!") | ||
|
||
mean_psnr_out = psnr_out / opt.test_num | ||
mean_ssim_out = ssim_out / opt.test_num | ||
|
||
mean_psnr_LDCT = psnr_LDCT / opt.test_num | ||
mean_ssim_LDCT = ssim_LDCT / opt.test_num | ||
|
||
print("\nMetrics on test set \t LDCT_PSNR/SSIM: %2.4f / %1.4f \t Output_PSNR/SSIM: %2.4f / %1.4f" | ||
% (mean_psnr_LDCT, mean_ssim_LDCT, mean_psnr_out, mean_ssim_out)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from networks.invertible_generator import InvertibleGenerator | ||
from networks.discriminator import Discriminator_patch | ||
import torch | ||
from torch import nn, optim | ||
|
||
class Model(): | ||
def __init__(self, opt): | ||
self.opt = opt | ||
self.device = torch.device('cuda:{}'.format(opt.gpu_id)) | ||
|
||
self.model_names = ['netG', 'netD'] | ||
|
||
self.netG = InvertibleGenerator(in_channel=opt.in_channel, n_block=opt.n_block, squeeze_num=opt.squeeze_num, conv_lu=opt.conv_lu, block_type=opt.block_type) | ||
self.netD = Discriminator_patch() | ||
|
||
self.netG.to(self.device) | ||
self.netD.to(self.device) | ||
|
||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr) | ||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr) | ||
|
||
self.scheduler_G = optim.lr_scheduler.StepLR(self.optimizer_G, step_size=50000, gamma=0.5, last_epoch=-1) | ||
self.scheduler_D = optim.lr_scheduler.StepLR(self.optimizer_D, step_size=50000, gamma=0.5, last_epoch=-1) | ||
|
||
self.criterion_GAN = nn.MSELoss() | ||
self.criterion_L1 = nn.L1Loss() | ||
|
||
def set_input(self, LDCT, SDCT): | ||
self.real_LDCT = LDCT.to(self.device) | ||
self.real_SDCT = SDCT.to(self.device) | ||
|
||
def set_input_val(self, LDCT, SDCT): | ||
self.real_LDCT = LDCT.to(self.device) | ||
self.real_SDCT = SDCT.to(self.device) | ||
|
||
def forward(self): | ||
self.fake_SDCT = self.netG(self.real_LDCT) | ||
|
||
def val_forward(self): | ||
self.fake_SDCT = self.netG(self.real_LDCT) | ||
self.cycle_LDCT = self.netG.inverse(self.fake_SDCT) | ||
self.fake_LDCT = self.netG.inverse(self.real_SDCT) | ||
|
||
def backward_G(self): | ||
fake_SDCT_logits = self.netD(self.fake_SDCT) | ||
self.Gan_loss = self.criterion_GAN(fake_SDCT_logits, torch.ones_like(fake_SDCT_logits)) | ||
self.L1_loss = self.criterion_L1(self.real_LDCT, self.fake_SDCT) | ||
|
||
self.total_g_loss = self.opt.lambda_gan * self.Gan_loss + self.opt.lambda_l1 * self.L1_loss | ||
self.total_g_loss.backward() | ||
|
||
def backward_D(self): | ||
pred_real = self.netD(self.real_SDCT) | ||
D_real_loss = self.criterion_GAN(pred_real, torch.ones_like(pred_real)) | ||
|
||
pred_fake = self.netD(self.fake_SDCT.detach()) | ||
D_fake_loss = self.criterion_GAN(pred_fake, torch.zeros_like(pred_fake)) | ||
|
||
self.total_D_loss = (D_real_loss + D_fake_loss)/2 | ||
self.total_D_loss.backward() | ||
|
||
def set_requires_grad(self, net, requires_grad=False): | ||
for param in net.parameters(): | ||
param.requires_grad = requires_grad | ||
|
||
# main loop | ||
def optimize_parameters(self): | ||
|
||
self.forward() | ||
self.set_requires_grad(self.netD, False) | ||
self.optimizer_G.zero_grad() | ||
self.backward_G() | ||
self.optimizer_G.step() | ||
self.scheduler_G.step() | ||
|
||
self.set_requires_grad(self.netD, True) | ||
self.optimizer_D.zero_grad() | ||
self.backward_D() | ||
self.optimizer_D.step() | ||
self.scheduler_D.step() | ||
|
||
def test(self): | ||
with torch.no_grad(): | ||
self.forward() | ||
|
||
def train(self): | ||
for name in self.model_names: | ||
if isinstance(name, str): | ||
net = getattr(self, name) | ||
net.train() | ||
|
||
def eval(self): | ||
for name in self.model_names: | ||
if isinstance(name, str): | ||
net = getattr(self, name) | ||
net.eval() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from torch import nn | ||
|
||
class Discriminator_patch(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.conv1_64 = nn.Conv2d(1,64,4,2,padding=1, bias=False) | ||
self.conv1_64.weight.data.normal_(0, 0.02) | ||
|
||
self.conv64_128 = nn.Conv2d(64,128,4,2,padding=1) | ||
self.conv64_128.weight.data.normal_(0, 0.02) | ||
self.conv64_128.bias.data.zero_() | ||
|
||
self.batchnorm128 = nn.BatchNorm2d(128) | ||
self.batchnorm128.weight.data.normal_(1.0, 0.02) | ||
self.batchnorm128.bias.data.zero_() | ||
|
||
self.conv128_256 = nn.Conv2d(128, 256, 4, 1, padding=1) | ||
self.conv128_256.weight.data.normal_(0, 0.02) | ||
self.conv128_256.bias.data.zero_() | ||
|
||
self.batchnorm256 = nn.BatchNorm2d(256) | ||
self.batchnorm256.weight.data.normal_(1.0, 0.02) | ||
self.batchnorm256.bias.data.zero_() | ||
|
||
self.conv256_1 = nn.Conv2d(256,1,4,1,padding=1) | ||
self.conv256_1.weight.data.normal_(0, 0.02) | ||
self.conv256_1.bias.data.zero_() | ||
|
||
self.leakyrelu = nn.LeakyReLU(0.2) | ||
|
||
def forward(self, input): | ||
|
||
x = self.conv1_64(input) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.conv64_128(x) | ||
x = self.batchnorm128(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.conv128_256(x) | ||
x = self.batchnorm256(x) | ||
x = self.leakyrelu(x) | ||
|
||
x = self.conv256_1(x) | ||
|
||
return x |