Skip to content

Commit e27e933

Browse files
committed
Update neatly organized codes
1 parent 5a4850f commit e27e933

File tree

3 files changed

+76
-88
lines changed

3 files changed

+76
-88
lines changed

README.md

+20-31
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,30 @@
1-
# vdsr_pytorch
2-
3-
PyTorch Implementation
4-
no muliti-scale
5-
but you can use multi-GPUs
1+
# vdsr_pytorch_lms
2+
VDSR PyTorch Implementation
3+
You can use multi-gpus.
4+
but no multi-scale.
5+
And you can input gaussian noise to input images.
66

77
## Requirement
8-
98
`torch`
109
`torchvision`
11-
`python-tk` (or `pyhton3-tk`)
12-
13-
## Training
10+
`python-tk` (or `python3-tk`)
1411

15-
> python main.py --batch_size 40 --test_batch_size 10 --epochs 100 --cuda --gpuids 0 --upscale_factor 2
12+
## Download dataset
13+
1. Download [DF2K dataset](https://drive.google.com/file/d/1P9pcaGjvq3xiF22GXIq7ciZta3rjZxaY/view?usp=sharing).
14+
2. move under dataset directory i.e. vdsr_pytorch_lms/dataset/DF2K
1615

17-
or
18-
19-
> python3 main.py --batch_size 40 --test_batch_size 10 --epochs 100 --cuda --gpuids 0 --upscale_factor 2
16+
## Training
17+
```
18+
$ python main.py --dataset DF2K --cuda --gpuids 0 1 --upscale_factor 2 --crop_size 256 --batch_size 128 --test_batch_size 32
19+
```
2020

2121
## Test
22+
```
23+
$ python main.py --dataset Urban100 --cuda --gpuids 0 1 --upscale_factor 2 --crop_size 256 --test_batch_size 32 --test --model model_epoch_100.pth
24+
```
2225

23-
> python main.py --batch_size 40 --test_batch_size 10 --epochs 100 --cuda --gpuids 0 --upscale_factor 2 --test --model model_epoch_100.pth
24-
25-
or
26-
27-
> python3 main.py --batch_size 40 --test_batch_size 10 --epochs 100 --cuda --gpuids 0 --upscale_factor 2 --test --model model_epoch_100.pth
28-
29-
## Sample Usage
30-
31-
> python run.py --input_image test_scale2x.jpg --scale_factor 2 --model model_epoch_100.pth --cuda --gpuids 0 --output_filename test_scale2x_out.jpg
32-
33-
or
34-
35-
> python3 run.py --input_image test_scale2x.jpg --scale_factor 2 --model model_epoch_100.pth --cuda --gpuids 0 --output_filename test_scale2x_out.jpg
36-
37-
## 주의
26+
## Sample usage
27+
```
28+
$ python run.py --cuda --gpuids 0 1 --scale_factor 2 --model model_epoch_100.pth --input_image test_scale2x.jpg --output_filename test_scale2x_out.jpg
29+
```
3830

39-
test시에는 Urban100의 데이터가 사용된다.
40-
sample에서 input image는 학습에 사용된 BSDS300 data가 아닌 인터넷에서 가져온 이미지 등 다른 이미지를 사용해야 정확한 성능을 확인할 수 있다.
41-
Urban100데이터는 학습시 사용하지 않으므로 Urban100의 사진들 중 하나를 사용해도 됨.

data.py

+18-46
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,17 @@
1-
from os.path import exists, join, basename
2-
from os import makedirs, remove
3-
from six.moves import urllib
4-
import tarfile
1+
from os.path import join
52
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
63

74
from data_utils import DatasetFromFolder
85

96

10-
def download_bsd300(dest="dataset"):
11-
output_image_dir = join(dest, "BSDS300/images")
12-
13-
if not exists(output_image_dir):
14-
makedirs(dest)
15-
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
16-
print("downloading url ", url)
17-
18-
data = urllib.request.urlopen(url)
19-
20-
file_path = join(dest, basename(url))
21-
with open(file_path, 'wb') as f:
22-
f.write(data.read())
23-
24-
print("Extracting data")
25-
with tarfile.open(file_path) as tar:
26-
for item in tar:
27-
tar.extract(item, dest)
28-
29-
remove(file_path)
30-
31-
return output_image_dir
32-
33-
347
def calculate_valid_crop_size(crop_size, upscale_factor):
358
return crop_size - (crop_size % upscale_factor)
369

3710

3811
def input_transform(crop_size, upscale_factor):
3912
return Compose([
4013
CenterCrop(crop_size),
41-
Resize((crop_size//(upscale_factor*2), crop_size//upscale_factor)),
14+
Resize((crop_size//upscale_factor, crop_size//upscale_factor)),
4215
Resize((crop_size, crop_size)),
4316
ToTensor(),
4417
])
@@ -51,36 +24,35 @@ def target_transform(crop_size):
5124
])
5225

5326

54-
def get_training_set(upscale_factor, add_noise=None, noise_std=3.0):
55-
root_dir = download_bsd300()
56-
27+
def get_training_set(dataset, crop_size, upscale_factor, add_noise=None, noise_std=3.0):
28+
root_dir = join("dataset", dataset)
5729
train_dir = join(root_dir, "train")
58-
crop_size = calculate_valid_crop_size(256, upscale_factor)
30+
cropsize = calculate_valid_crop_size(crop_size, upscale_factor)
5931

6032
return DatasetFromFolder(train_dir,
6133
input_transform=input_transform(
62-
crop_size, upscale_factor),
63-
target_transform=target_transform(crop_size),
34+
cropsize, upscale_factor),
35+
target_transform=target_transform(cropsize),
6436
add_noise=add_noise,
6537
noise_std=noise_std)
6638

6739

68-
def get_validation_set(upscale_factor):
69-
root_dir = download_bsd300()
70-
validation_dir = join(root_dir, "test")
71-
crop_size = calculate_valid_crop_size(256, upscale_factor)
40+
def get_validation_set(dataset, crop_size, upscale_factor):
41+
root_dir = join("dataset", dataset)
42+
validation_dir = join(root_dir, "valid")
43+
cropsize = calculate_valid_crop_size(crop_size, upscale_factor)
7244

7345
return DatasetFromFolder(validation_dir,
7446
input_transform=input_transform(
75-
crop_size, upscale_factor),
76-
target_transform=target_transform(crop_size))
47+
cropsize, upscale_factor),
48+
target_transform=target_transform(cropsize))
7749

7850

79-
def get_test_set(upscale_factor):
80-
test_dir = "dataset/Urban100"
81-
crop_size = calculate_valid_crop_size(256, upscale_factor)
51+
def get_test_set(dataset, crop_size, upscale_factor):
52+
test_dir = join("dataset", dataset)
53+
cropsize = calculate_valid_crop_size(crop_size, upscale_factor)
8254

8355
return DatasetFromFolder(test_dir,
8456
input_transform=input_transform(
85-
crop_size, upscale_factor),
86-
target_transform=target_transform(crop_size))
57+
cropsize, upscale_factor),
58+
target_transform=target_transform(cropsize))

main.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
from model import VDSR
1717

1818

19-
parser = argparse.ArgumentParser(
20-
description='PyTorch VDSR')
19+
parser = argparse.ArgumentParser(description='PyTorch VDSR')
20+
parser.add_argument('--dataset', type=str, default='BSDS300',
21+
required=True, help="dataset directory name")
22+
parser.add_argument('--crop_size', type=int, default=256,
23+
required=True, help="network input size")
2124
parser.add_argument('--upscale_factor', type=int, default=2,
2225
required=True, help="super resolution upscale factor")
23-
parser.add_argument('--batch_size', type=int, default=64,
24-
help='training batch size')
26+
parser.add_argument('--batch_size', type=int, default=128,
27+
help="training batch size")
2528
parser.add_argument('--test_batch_size', type=int,
26-
default=10, help='testing batch size')
27-
parser.add_argument('--epochs', type=int, default=100,
29+
default=32, help="testing batch size")
30+
parser.add_argument('--epochs', type=int, default=10,
2831
help='number of epochs to train for')
2932
parser.add_argument('--lr', type=float, default=0.001,
3033
help='Learning Rate. Default=0.001')
@@ -35,7 +38,7 @@
3538
parser.add_argument("--weight-decay", "--wd", default=1e-4,
3639
type=float, help="Weight decay, Default: 1e-4")
3740
parser.add_argument('--cuda', action='store_true', help='use cuda?')
38-
parser.add_argument('--threads', type=int, default=16,
41+
parser.add_argument('--threads', type=int, default=128,
3942
help='number of threads for data loader to use')
4043
parser.add_argument('--gpuids', default=[0], nargs='+',
4144
help='GPU ID for using')
@@ -57,12 +60,15 @@ def main():
5760

5861
if opt.cuda and not torch.cuda.is_available():
5962
raise Exception("No GPU found, please run without --cuda")
63+
6064
cudnn.benchmark = True
6165

62-
train_set = get_training_set(
66+
train_set = get_training_set(opt.dataset, opt.crop_size,
6367
opt.upscale_factor, opt.add_noise, opt.noise_std)
64-
validation_set = get_validation_set(opt.upscale_factor)
65-
test_set = get_test_set(opt.upscale_factor)
68+
validation_set = get_validation_set(
69+
opt.dataset, opt.crop_size, opt.upscale_factor)
70+
test_set = get_test_set(
71+
opt.dataset, opt.crop_size, opt.upscale_factor)
6672
training_data_loader = DataLoader(
6773
dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
6874
validating_data_loader = DataLoader(
@@ -90,16 +96,34 @@ def main():
9096
start_time = time.time()
9197
test(model, criterion, testing_data_loader)
9298
elapsed_time = time.time() - start_time
93-
print("===> average {:.2f} image/sec for processing".format(
99+
print("===> average {:.2f} image/sec for test".format(
94100
100.0/elapsed_time))
95101
return
96102

103+
train_time = 0.0
104+
validate_time = 0.0
97105
for epoch in range(1, opt.epochs + 1):
106+
start_time = time.time()
98107
train(model, criterion, epoch, optimizer, training_data_loader)
108+
elapsed_time = time.time() - start_time
109+
train_time += elapsed_time
110+
print("===> {:.2f} seconds to train this epoch".format(
111+
elapsed_time))
112+
start_time = time.time()
99113
validate(model, criterion, validating_data_loader)
114+
elapsed_time = time.time() - start_time
115+
validate_time += elapsed_time
116+
print("===> {:.2f} seconds to validate this epoch".format(
117+
elapsed_time))
100118
if epoch % 10 == 0:
101119
checkpoint(model, epoch)
102120

121+
print("===> average training time per epoch: {:.2f} seconds".format(train_time/opt.epochs))
122+
print("===> average validation time per epoch: {:.2f} seconds".format(validate_time/opt.epochs))
123+
print("===> training time: {:.2f} seconds".format(train_time))
124+
print("===> validation time: {:.2f} seconds".format(validate_time))
125+
print("===> total training time: {:.2f} seconds".format(train_time+validate_time))
126+
103127

104128
def adjust_learning_rate(epoch):
105129
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
@@ -185,4 +209,7 @@ def checkpoint(model, epoch):
185209

186210

187211
if __name__ == '__main__':
212+
start_time = time.time()
188213
main()
214+
elapsed_time = time.time() - start_time
215+
print("===> total time: {:.2f} seconds".format(elapsed_time))

0 commit comments

Comments
 (0)