Skip to content

Commit 9ece74e

Browse files
committed
Update main.py
1 parent bf3241c commit 9ece74e

File tree

1 file changed

+88
-54
lines changed

1 file changed

+88
-54
lines changed

main.py

+88-54
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from __future__ import print_function
22
import argparse
33
from math import log10, sqrt
4+
import time
45
import os
56
from os import errno
7+
from os.path import join
68

79
import torch
810
import torch.nn as nn
911
import torch.optim as optim
1012
import torch.backends.cudnn as cudnn
1113
from torch.autograd import Variable
1214
from torch.utils.data import DataLoader
13-
from data import get_training_set, get_test_set
15+
from data import get_training_set, get_validation_set, get_test_set
1416
from model import VDSR
1517

1618

@@ -41,43 +43,62 @@
4143
help='add gaussian noise?')
4244
parser.add_argument('--noise_std', type=float, default=3.0,
4345
help='standard deviation of gaussian noise')
44-
opt = parser.parse_args()
45-
46-
opt.gpuids = list(map(int, opt.gpuids))
47-
print(opt)
48-
49-
50-
use_cuda = opt.cuda
51-
if use_cuda and not torch.cuda.is_available():
52-
raise Exception("No GPU found, please run without --cuda")
53-
54-
55-
cudnn.benchmark = True
56-
57-
58-
train_set = get_training_set(opt.upscale_factor, opt.add_noise, opt.noise_std)
59-
test_set = get_test_set(opt.upscale_factor)
60-
training_data_loader = DataLoader(
61-
dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
62-
testing_data_loader = DataLoader(
63-
dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)
64-
65-
66-
vdsr = VDSR()
67-
criterion = nn.MSELoss()
68-
69-
70-
if(use_cuda):
71-
torch.cuda.set_device(opt.gpuids[0])
72-
with torch.cuda.device(opt.gpuids[0]):
73-
vdsr = vdsr.cuda()
74-
criterion = criterion.cuda()
75-
vdsr = nn.DataParallel(vdsr, device_ids=opt.gpuids,
76-
output_device=opt.gpuids[0])
77-
78-
79-
optimizer = optim.Adam(vdsr.parameters(), lr=opt.lr,
80-
weight_decay=opt.weight_decay)
46+
parser.add_argument('--test', action='store_true', help='test mode')
47+
parser.add_argument('--model', default='', type=str, metavar='PATH',
48+
help='path to test or resume model')
49+
50+
51+
def main():
52+
global opt
53+
opt = parser.parse_args()
54+
opt.gpuids = list(map(int, opt.gpuids))
55+
56+
print(opt)
57+
58+
if opt.cuda and not torch.cuda.is_available():
59+
raise Exception("No GPU found, please run without --cuda")
60+
cudnn.benchmark = True
61+
62+
train_set = get_training_set(
63+
opt.upscale_factor, opt.add_noise, opt.noise_std)
64+
validation_set = get_validation_set(opt.upscale_factor)
65+
test_set = get_test_set(opt.upscale_factor)
66+
training_data_loader = DataLoader(
67+
dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
68+
validating_data_loader = DataLoader(
69+
dataset=validation_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)
70+
testing_data_loader = DataLoader(
71+
dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)
72+
73+
model = VDSR()
74+
criterion = nn.MSELoss()
75+
76+
if opt.cuda:
77+
torch.cuda.set_device(opt.gpuids[0])
78+
with torch.cuda.device(opt.gpuids[0]):
79+
model = model.cuda()
80+
criterion = criterion.cuda()
81+
model = nn.DataParallel(model, device_ids=opt.gpuids,
82+
output_device=opt.gpuids[0])
83+
84+
optimizer = optim.Adam(model.parameters(), lr=opt.lr,
85+
weight_decay=opt.weight_decay)
86+
87+
if opt.test:
88+
model_name = join("model", opt.model)
89+
model = torch.load(model_name)
90+
start_time = time.time()
91+
test(model, criterion, testing_data_loader)
92+
elapsed_time = time.time() - start_time
93+
print("===> average {:.2f} image/sec for processing".format(
94+
100.0/elapsed_time))
95+
return
96+
97+
for epoch in range(1, opt.epochs + 1):
98+
train(model, criterion, epoch, optimizer, training_data_loader)
99+
validate(model, criterion, validating_data_loader)
100+
if epoch % 10 == 0:
101+
checkpoint(model, epoch)
81102

82103

83104
def adjust_learning_rate(epoch):
@@ -86,7 +107,7 @@ def adjust_learning_rate(epoch):
86107
return lr
87108

88109

89-
def train(epoch):
110+
def train(model, criterion, epoch, optimizer, training_data_loader):
90111
lr = adjust_learning_rate(epoch-1)
91112

92113
for param_group in optimizer.param_groups:
@@ -98,42 +119,58 @@ def train(epoch):
98119
for iteration, batch in enumerate(training_data_loader, 1):
99120
input, target = Variable(batch[0]), Variable(
100121
batch[1], requires_grad=False)
101-
if use_cuda:
122+
if opt.cuda:
102123
input = input.cuda()
103124
target = target.cuda()
104125

105126
optimizer.zero_grad()
106-
model_out = vdsr(input)
127+
model_out = model(input)
107128
loss = criterion(model_out, target)
108129
epoch_loss += loss.item()
109130
loss.backward()
110-
nn.utils.clip_grad_norm_(vdsr.parameters(), opt.clip/lr)
131+
nn.utils.clip_grad_norm_(model.parameters(), opt.clip/lr)
111132
optimizer.step()
112133

113-
print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch,
114-
iteration, len(training_data_loader), loss.item()))
134+
print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
135+
epoch, iteration, len(training_data_loader), loss.item()))
115136

116137
print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(
117138
epoch, epoch_loss / len(training_data_loader)))
118139

119140

120-
def test():
141+
def validate(model, criterion, validating_data_loader):
142+
avg_psnr = 0
143+
for batch in validating_data_loader:
144+
input, target = Variable(batch[0]), Variable(batch[1])
145+
if opt.cuda:
146+
input = input.cuda()
147+
target = target.cuda()
148+
149+
prediction = model(input)
150+
mse = criterion(prediction, target)
151+
psnr = 10 * log10(1.0 / mse.item())
152+
avg_psnr += psnr
153+
print("===> Avg. PSNR: {:.4f} dB".format(
154+
avg_psnr / len(validating_data_loader)))
155+
156+
157+
def test(model, criterion, testing_data_loader):
121158
avg_psnr = 0
122159
for batch in testing_data_loader:
123160
input, target = Variable(batch[0]), Variable(batch[1])
124-
if use_cuda:
161+
if opt.cuda:
125162
input = input.cuda()
126163
target = target.cuda()
127164

128-
prediction = vdsr(input)
165+
prediction = model(input)
129166
mse = criterion(prediction, target)
130167
psnr = 10 * log10(1.0 / mse.item())
131168
avg_psnr += psnr
132169
print("===> Avg. PSNR: {:.4f} dB".format(
133170
avg_psnr / len(testing_data_loader)))
134171

135172

136-
def checkpoint(epoch):
173+
def checkpoint(model, epoch):
137174
try:
138175
if not(os.path.isdir('model')):
139176
os.makedirs(os.path.join('model'))
@@ -143,12 +180,9 @@ def checkpoint(epoch):
143180
raise
144181

145182
model_out_path = "model/model_epoch_{}.pth".format(epoch)
146-
torch.save(vdsr, model_out_path)
183+
torch.save(model, model_out_path)
147184
print("Checkpoint saved to {}".format(model_out_path))
148185

149186

150-
for epoch in range(1, opt.epochs + 1):
151-
train(epoch)
152-
test()
153-
if epoch % 10 == 0:
154-
checkpoint(epoch)
187+
if __name__ == '__main__':
188+
main()

0 commit comments

Comments
 (0)