16
16
from model import VDSR
17
17
18
18
19
- parser = argparse .ArgumentParser (
20
- description = 'PyTorch VDSR' )
19
+ parser = argparse .ArgumentParser (description = 'PyTorch VDSR' )
20
+ parser .add_argument ('--dataset' , type = str , default = 'BSDS300' ,
21
+ required = True , help = "dataset directory name" )
22
+ parser .add_argument ('--crop_size' , type = int , default = 256 ,
23
+ required = True , help = "network input size" )
21
24
parser .add_argument ('--upscale_factor' , type = int , default = 2 ,
22
25
required = True , help = "super resolution upscale factor" )
23
- parser .add_argument ('--batch_size' , type = int , default = 64 ,
24
- help = ' training batch size' )
26
+ parser .add_argument ('--batch_size' , type = int , default = 128 ,
27
+ help = " training batch size" )
25
28
parser .add_argument ('--test_batch_size' , type = int ,
26
- default = 10 , help = ' testing batch size' )
27
- parser .add_argument ('--epochs' , type = int , default = 100 ,
29
+ default = 32 , help = " testing batch size" )
30
+ parser .add_argument ('--epochs' , type = int , default = 10 ,
28
31
help = 'number of epochs to train for' )
29
32
parser .add_argument ('--lr' , type = float , default = 0.001 ,
30
33
help = 'Learning Rate. Default=0.001' )
35
38
parser .add_argument ("--weight-decay" , "--wd" , default = 1e-4 ,
36
39
type = float , help = "Weight decay, Default: 1e-4" )
37
40
parser .add_argument ('--cuda' , action = 'store_true' , help = 'use cuda?' )
38
- parser .add_argument ('--threads' , type = int , default = 16 ,
41
+ parser .add_argument ('--threads' , type = int , default = 128 ,
39
42
help = 'number of threads for data loader to use' )
40
43
parser .add_argument ('--gpuids' , default = [0 ], nargs = '+' ,
41
44
help = 'GPU ID for using' )
@@ -57,12 +60,15 @@ def main():
57
60
58
61
if opt .cuda and not torch .cuda .is_available ():
59
62
raise Exception ("No GPU found, please run without --cuda" )
63
+
60
64
cudnn .benchmark = True
61
65
62
- train_set = get_training_set (
66
+ train_set = get_training_set (opt . dataset , opt . crop_size ,
63
67
opt .upscale_factor , opt .add_noise , opt .noise_std )
64
- validation_set = get_validation_set (opt .upscale_factor )
65
- test_set = get_test_set (opt .upscale_factor )
68
+ validation_set = get_validation_set (
69
+ opt .dataset , opt .crop_size , opt .upscale_factor )
70
+ test_set = get_test_set (
71
+ opt .dataset , opt .crop_size , opt .upscale_factor )
66
72
training_data_loader = DataLoader (
67
73
dataset = train_set , num_workers = opt .threads , batch_size = opt .batch_size , shuffle = True )
68
74
validating_data_loader = DataLoader (
@@ -90,16 +96,34 @@ def main():
90
96
start_time = time .time ()
91
97
test (model , criterion , testing_data_loader )
92
98
elapsed_time = time .time () - start_time
93
- print ("===> average {:.2f} image/sec for processing " .format (
99
+ print ("===> average {:.2f} image/sec for test " .format (
94
100
100.0 / elapsed_time ))
95
101
return
96
102
103
+ train_time = 0.0
104
+ validate_time = 0.0
97
105
for epoch in range (1 , opt .epochs + 1 ):
106
+ start_time = time .time ()
98
107
train (model , criterion , epoch , optimizer , training_data_loader )
108
+ elapsed_time = time .time () - start_time
109
+ train_time += elapsed_time
110
+ print ("===> {:.2f} seconds to train this epoch" .format (
111
+ elapsed_time ))
112
+ start_time = time .time ()
99
113
validate (model , criterion , validating_data_loader )
114
+ elapsed_time = time .time () - start_time
115
+ validate_time += elapsed_time
116
+ print ("===> {:.2f} seconds to validate this epoch" .format (
117
+ elapsed_time ))
100
118
if epoch % 10 == 0 :
101
119
checkpoint (model , epoch )
102
120
121
+ print ("===> average training time per epoch: {:.2f} seconds" .format (train_time / opt .epochs ))
122
+ print ("===> average validation time per epoch: {:.2f} seconds" .format (validate_time / opt .epochs ))
123
+ print ("===> training time: {:.2f} seconds" .format (train_time ))
124
+ print ("===> validation time: {:.2f} seconds" .format (validate_time ))
125
+ print ("===> total training time: {:.2f} seconds" .format (train_time + validate_time ))
126
+
103
127
104
128
def adjust_learning_rate (epoch ):
105
129
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
@@ -185,4 +209,7 @@ def checkpoint(model, epoch):
185
209
186
210
187
211
if __name__ == '__main__' :
212
+ start_time = time .time ()
188
213
main ()
214
+ elapsed_time = time .time () - start_time
215
+ print ("===> total time: {:.2f} seconds" .format (elapsed_time ))
0 commit comments