Skip to content

Commit

Permalink
The code for no4 exp
Browse files Browse the repository at this point in the history
fix PSNR
  • Loading branch information
yippp committed Feb 28, 2018
1 parent bf02c01 commit 022ef91
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 28 deletions.
12 changes: 8 additions & 4 deletions butterfly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from dataset.dataset import load_img

img = load_img('./dataset/test/Set5/butterfly_GT.bmp')
y90 = img.resize((90, 90), resample=Image.BICUBIC)
y90.save('butterfly90.bmp')
bicubic = y90.resize((256, 256), resample=Image.BICUBIC)
bicubic.save('bicubic.bmp')
# y90 = img.resize((90, 90), resample=Image.BICUBIC)
# y90.save('butterfly90.bmp')
img244 = img.crop((6, 6, 250, 250))
img244.save('butterfly_crop244.bmp')
y86 = img.resize((86, 86), resample=Image.BICUBIC)
y86.save('butterfly86.bmp')
bicubic = y86.resize((256, 256), resample=Image.BICUBIC)
bicubic.save('bicubic86.bmp')
a=0
2 changes: 1 addition & 1 deletion dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __getitem__(self, index):
target = input_image.copy()
target = self.to_tensor(target)

self.resize = Resize((x_re, y_re))
self.resize = Resize((x_re - 4, y_re - 4))
input_image = self.resize(input_image)
input_image = self.to_tensor(input_image)

Expand Down
14 changes: 13 additions & 1 deletion loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
from torch import mean, sqrt, pow

class HuberLoss(nn.Module):
def __init__(self, delta=1):
Expand All @@ -9,4 +10,15 @@ def __init__(self, delta=1):

def forward(self, input, target):
loss = self.SmoothL1Loss(input / self.delta, target / self.delta)
return loss * self.delta * self.delta
return loss * self.delta * self.delta

class CharbonnierLoss(nn.Module):
def __init__(self, delta=1e-3):
super(CharbonnierLoss, self).__init__()
# self.MSELoss = nn.MSELoss ()
self.delta = delta
return

def forward(self, input, target):
# return torch.sqrt(self.MSELoss(input, target) + self.delta * self.delta)
return mean(sqrt(pow((input - target), 2) + self.delta * self.delta))
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# hyper-parameters
parser.add_argument('--batch_size', type=int, default=128, help='trainingbatch size. Default=128')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs to train for. Default=100')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning Rate. Default=0.001')
parser.add_argument('--lr', type=float, default=1e-2, help='Learning Rate. Default=0.001')
parser.add_argument('--mom', type=float, default=0.9, help='Momentum. Default=0.9')
parser.add_argument('--seed', type=int, default=1, help='random seed to use. Default=1')
parser.add_argument('--train_set', type=str, default='train/91-aug.h5', help='name of train set h5 file.')
parser.add_argument('--logs', type=str, default='./logs/no4/huber0.9',
parser.add_argument('--logs', type=str, default='./logs/no4/CLoss0.001',
help='folder to save the log file. Default=./logs/')

args = parser.parse_args()
Expand Down
22 changes: 9 additions & 13 deletions misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,19 @@ def progress_bar(current, total, msg=None):
if current == 0:
BEGIN_T = time.time() # Reset for new bar.

sys.stdout.write('%d batches' % total)

current_time = time.time()
LAST_T = current_time
total_time = current_time - BEGIN_T

time_used = ' '
time_used += 'Time: {:.3f}s'.format(total_time)
if msg:
time_used += ' | ' + msg

msg = time_used
sys.stdout.write(msg)
if current == total - 1:
time_used = ' '
time_used += 'Time: {:.3f}s'.format(total_time)
if msg:
time_used += ' | ' + msg

if current < total - 1:
sys.stdout.write('\r')
else:
msg = time_used
sys.stdout.write(' %d batches' % total)
sys.stdout.write(msg)
sys.stdout.write('\n')
sys.stdout.flush()
sys.stdout.flush()

15 changes: 8 additions & 7 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchvision.transforms import ToTensor
from scipy.misc import imsave
from torchviz import make_dot
from loss import HuberLoss
from loss import HuberLoss, CharbonnierLoss

class solver(object):
def __init__(self, config, train_loader, set5_h5_loader, set14_h5_loader, set5_img_loader, set14_img_loader):
Expand Down Expand Up @@ -51,7 +51,8 @@ def build_model(self):


# self.criterion = nn.MSELoss()
self.criterion = HuberLoss(delta=.9) # Huber loss
# self.criterion = HuberLoss(delta=0.9) # Huber loss
self.criterion = CharbonnierLoss(delta=0.001) # Charbonnier Loss
torch.manual_seed(self.seed)

if self.GPU:
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_set5_img(self):
avg_psnr = 0
for batch_num, (data, target) in enumerate(self.set5_img_loader):
target = target.numpy()
target = target[:, :, 5:target.shape[2] - 6, 5:target.shape[3] - 6]
target = target[:, :, 6:target.shape[2] - 6, 6:target.shape[3] - 6]
# target = Variable(torch.from_numpy(target))
if self.GPU:
data, target = Variable(data).cuda(), Variable(torch.from_numpy(target)).cuda()
Expand All @@ -158,7 +159,7 @@ def test_set5_img(self):

prediction = self.model(data)
prediction = prediction.data.cpu().numpy()
prediction = prediction[:, :, 5:prediction.shape[2] - 6, 5:prediction.shape[3] - 6]
# prediction = prediction[:, :, 6:prediction.shape[2] - 6, 6:prediction.shape[3] - 6]
if self.GPU:
prediction = Variable(torch.from_numpy(prediction)).cuda()
else:
Expand All @@ -175,7 +176,7 @@ def test_set14_img(self):
avg_psnr = 0
for batch_num, (data, target) in enumerate(self.set14_img_loader):
target = target.numpy()
target = target[:, :, 5:target.shape[2] - 6, 5:target.shape[3] - 6]
target = target[:, :, 6:target.shape[2] - 6, 6:target.shape[3] - 6]
# target = Variable(torch.from_numpy(target))
if self.GPU:
data, target = Variable(data).cuda(), Variable(torch.from_numpy(target)).cuda()
Expand All @@ -184,7 +185,7 @@ def test_set14_img(self):

prediction = self.model(data)
prediction = prediction.data.cpu().numpy()
prediction = prediction[:, :, 5:prediction.shape[2] - 6, 5:prediction.shape[3] - 6]
# prediction = prediction[:, :, 6:prediction.shape[2] - 6, 6:prediction.shape[3] - 6]
if self.GPU:
prediction = Variable(torch.from_numpy(prediction)).cuda()
else:
Expand All @@ -198,7 +199,7 @@ def test_set14_img(self):

def predict(self, epoch):
self.model.eval()
butterfly = load_img('./butterfly90.bmp')
butterfly = load_img('./butterfly86.bmp')
butterfly = torch.unsqueeze(self.to_tensor(butterfly), 0)
if self.GPU:
data = Variable(butterfly).cuda()
Expand Down

0 comments on commit 022ef91

Please sign in to comment.