1
1
from __future__ import print_function
2
2
import argparse
3
3
from math import log10 , sqrt
4
+ import time
4
5
import os
5
6
from os import errno
7
+ from os .path import join
6
8
7
9
import torch
8
10
import torch .nn as nn
9
11
import torch .optim as optim
10
12
import torch .backends .cudnn as cudnn
11
13
from torch .autograd import Variable
12
14
from torch .utils .data import DataLoader
13
- from data import get_training_set , get_test_set
15
+ from data import get_training_set , get_validation_set , get_test_set
14
16
from model import VDSR
15
17
16
18
41
43
help = 'add gaussian noise?' )
42
44
parser .add_argument ('--noise_std' , type = float , default = 3.0 ,
43
45
help = 'standard deviation of gaussian noise' )
44
- opt = parser .parse_args ()
45
-
46
- opt .gpuids = list (map (int , opt .gpuids ))
47
- print (opt )
48
-
49
-
50
- use_cuda = opt .cuda
51
- if use_cuda and not torch .cuda .is_available ():
52
- raise Exception ("No GPU found, please run without --cuda" )
53
-
54
-
55
- cudnn .benchmark = True
56
-
57
-
58
- train_set = get_training_set (opt .upscale_factor , opt .add_noise , opt .noise_std )
59
- test_set = get_test_set (opt .upscale_factor )
60
- training_data_loader = DataLoader (
61
- dataset = train_set , num_workers = opt .threads , batch_size = opt .batch_size , shuffle = True )
62
- testing_data_loader = DataLoader (
63
- dataset = test_set , num_workers = opt .threads , batch_size = opt .test_batch_size , shuffle = False )
64
-
65
-
66
- vdsr = VDSR ()
67
- criterion = nn .MSELoss ()
68
-
69
-
70
- if (use_cuda ):
71
- torch .cuda .set_device (opt .gpuids [0 ])
72
- with torch .cuda .device (opt .gpuids [0 ]):
73
- vdsr = vdsr .cuda ()
74
- criterion = criterion .cuda ()
75
- vdsr = nn .DataParallel (vdsr , device_ids = opt .gpuids ,
76
- output_device = opt .gpuids [0 ])
77
-
78
-
79
- optimizer = optim .Adam (vdsr .parameters (), lr = opt .lr ,
80
- weight_decay = opt .weight_decay )
46
+ parser .add_argument ('--test' , action = 'store_true' , help = 'test mode' )
47
+ parser .add_argument ('--model' , default = '' , type = str , metavar = 'PATH' ,
48
+ help = 'path to test or resume model' )
49
+
50
+
51
+ def main ():
52
+ global opt
53
+ opt = parser .parse_args ()
54
+ opt .gpuids = list (map (int , opt .gpuids ))
55
+
56
+ print (opt )
57
+
58
+ if opt .cuda and not torch .cuda .is_available ():
59
+ raise Exception ("No GPU found, please run without --cuda" )
60
+ cudnn .benchmark = True
61
+
62
+ train_set = get_training_set (
63
+ opt .upscale_factor , opt .add_noise , opt .noise_std )
64
+ validation_set = get_validation_set (opt .upscale_factor )
65
+ test_set = get_test_set (opt .upscale_factor )
66
+ training_data_loader = DataLoader (
67
+ dataset = train_set , num_workers = opt .threads , batch_size = opt .batch_size , shuffle = True )
68
+ validating_data_loader = DataLoader (
69
+ dataset = validation_set , num_workers = opt .threads , batch_size = opt .test_batch_size , shuffle = False )
70
+ testing_data_loader = DataLoader (
71
+ dataset = test_set , num_workers = opt .threads , batch_size = opt .test_batch_size , shuffle = False )
72
+
73
+ model = VDSR ()
74
+ criterion = nn .MSELoss ()
75
+
76
+ if opt .cuda :
77
+ torch .cuda .set_device (opt .gpuids [0 ])
78
+ with torch .cuda .device (opt .gpuids [0 ]):
79
+ model = model .cuda ()
80
+ criterion = criterion .cuda ()
81
+ model = nn .DataParallel (model , device_ids = opt .gpuids ,
82
+ output_device = opt .gpuids [0 ])
83
+
84
+ optimizer = optim .Adam (model .parameters (), lr = opt .lr ,
85
+ weight_decay = opt .weight_decay )
86
+
87
+ if opt .test :
88
+ model_name = join ("model" , opt .model )
89
+ model = torch .load (model_name )
90
+ start_time = time .time ()
91
+ test (model , criterion , testing_data_loader )
92
+ elapsed_time = time .time () - start_time
93
+ print ("===> average {:.2f} image/sec for processing" .format (
94
+ 100.0 / elapsed_time ))
95
+ return
96
+
97
+ for epoch in range (1 , opt .epochs + 1 ):
98
+ train (model , criterion , epoch , optimizer , training_data_loader )
99
+ validate (model , criterion , validating_data_loader )
100
+ if epoch % 10 == 0 :
101
+ checkpoint (model , epoch )
81
102
82
103
83
104
def adjust_learning_rate (epoch ):
@@ -86,7 +107,7 @@ def adjust_learning_rate(epoch):
86
107
return lr
87
108
88
109
89
- def train (epoch ):
110
+ def train (model , criterion , epoch , optimizer , training_data_loader ):
90
111
lr = adjust_learning_rate (epoch - 1 )
91
112
92
113
for param_group in optimizer .param_groups :
@@ -98,42 +119,58 @@ def train(epoch):
98
119
for iteration , batch in enumerate (training_data_loader , 1 ):
99
120
input , target = Variable (batch [0 ]), Variable (
100
121
batch [1 ], requires_grad = False )
101
- if use_cuda :
122
+ if opt . cuda :
102
123
input = input .cuda ()
103
124
target = target .cuda ()
104
125
105
126
optimizer .zero_grad ()
106
- model_out = vdsr (input )
127
+ model_out = model (input )
107
128
loss = criterion (model_out , target )
108
129
epoch_loss += loss .item ()
109
130
loss .backward ()
110
- nn .utils .clip_grad_norm_ (vdsr .parameters (), opt .clip / lr )
131
+ nn .utils .clip_grad_norm_ (model .parameters (), opt .clip / lr )
111
132
optimizer .step ()
112
133
113
- print ("===> Epoch[{}]({}/{}): Loss: {:.4f}" .format (epoch ,
114
- iteration , len (training_data_loader ), loss .item ()))
134
+ print ("===> Epoch[{}]({}/{}): Loss: {:.4f}" .format (
135
+ epoch , iteration , len (training_data_loader ), loss .item ()))
115
136
116
137
print ("===> Epoch {} Complete: Avg. Loss: {:.4f}" .format (
117
138
epoch , epoch_loss / len (training_data_loader )))
118
139
119
140
120
- def test ():
141
+ def validate (model , criterion , validating_data_loader ):
142
+ avg_psnr = 0
143
+ for batch in validating_data_loader :
144
+ input , target = Variable (batch [0 ]), Variable (batch [1 ])
145
+ if opt .cuda :
146
+ input = input .cuda ()
147
+ target = target .cuda ()
148
+
149
+ prediction = model (input )
150
+ mse = criterion (prediction , target )
151
+ psnr = 10 * log10 (1.0 / mse .item ())
152
+ avg_psnr += psnr
153
+ print ("===> Avg. PSNR: {:.4f} dB" .format (
154
+ avg_psnr / len (validating_data_loader )))
155
+
156
+
157
+ def test (model , criterion , testing_data_loader ):
121
158
avg_psnr = 0
122
159
for batch in testing_data_loader :
123
160
input , target = Variable (batch [0 ]), Variable (batch [1 ])
124
- if use_cuda :
161
+ if opt . cuda :
125
162
input = input .cuda ()
126
163
target = target .cuda ()
127
164
128
- prediction = vdsr (input )
165
+ prediction = model (input )
129
166
mse = criterion (prediction , target )
130
167
psnr = 10 * log10 (1.0 / mse .item ())
131
168
avg_psnr += psnr
132
169
print ("===> Avg. PSNR: {:.4f} dB" .format (
133
170
avg_psnr / len (testing_data_loader )))
134
171
135
172
136
- def checkpoint (epoch ):
173
+ def checkpoint (model , epoch ):
137
174
try :
138
175
if not (os .path .isdir ('model' )):
139
176
os .makedirs (os .path .join ('model' ))
@@ -143,12 +180,9 @@ def checkpoint(epoch):
143
180
raise
144
181
145
182
model_out_path = "model/model_epoch_{}.pth" .format (epoch )
146
- torch .save (vdsr , model_out_path )
183
+ torch .save (model , model_out_path )
147
184
print ("Checkpoint saved to {}" .format (model_out_path ))
148
185
149
186
150
- for epoch in range (1 , opt .epochs + 1 ):
151
- train (epoch )
152
- test ()
153
- if epoch % 10 == 0 :
154
- checkpoint (epoch )
187
+ if __name__ == '__main__' :
188
+ main ()
0 commit comments