-
Notifications
You must be signed in to change notification settings - Fork 8
/
test.py
70 lines (53 loc) · 2.62 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import argparse
import torch
import torchvision
from PIL import Image
from tqdm import tqdm
from torchvision.transforms import transforms, Compose, RandomHorizontalFlip, RandomVerticalFlip
from basicsr.archs.mprnet_arch import MPRNet
from basicsr.utils.flare_util import predict_flare_from_6_channel, RandomGammaCorrection
def inference(input_path, output_path, model_path):
rot_transform = Compose([
RandomGammaCorrection(10.0),
RandomHorizontalFlip(1.0),
RandomVerticalFlip(1.0)
])
to_tensor = transforms.ToTensor()
resize = transforms.Resize(512)
crop =transforms.CenterCrop(512)
gamma = torch.Tensor([2.2])
model = MPRNet(img_ch=6, output_ch=6).cuda()
model.load_state_dict(torch.load(model_path)['params'])
model.eval()
input_name_list = os.listdir(input_path)
os.makedirs(os.path.join(output_path, "input"), exist_ok=True)
os.makedirs(os.path.join(output_path, "deflare"), exist_ok=True)
os.makedirs(os.path.join(output_path, "flare"), exist_ok=True)
for cur_input_name in tqdm((input_name_list)):
torch.cuda.empty_cache()
cur_input_path = os.path.join(input_path, cur_input_name)
cur_input_save_path = os.path.join(output_path, "input", cur_input_name)
cur_deflare_path = os.path.join(output_path, "deflare", cur_input_name)
cur_flare_path = os.path.join(output_path, "flare", cur_input_name)
cur_input_img = Image.open(cur_input_path).convert("RGB")
cur_input_img = crop(resize(to_tensor(cur_input_img)))
cur_input_img = cur_input_img.cuda().unsqueeze(0)
with torch.no_grad():
lq_rot = rot_transform(cur_input_img)
lq = torch.concat((cur_input_img, lq_rot),1)
output_img = model(lq)[0]
deflare_img, flare_img_predicted, merge_img_predicted = predict_flare_from_6_channel(output_img, gamma)
torchvision.utils.save_image(cur_input_img, cur_input_save_path)
torchvision.utils.save_image(flare_img_predicted, cur_flare_path)
torchvision.utils.save_image(deflare_img, cur_deflare_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_path', type=str, default='test/lq/',
help='Input image folder.')
parser.add_argument('-o', '--output_path', type=str, default='results/',
help='Output folder.')
parser.add_argument('-m', '--model_path', type=str, default='expirements/net_g_last.pth',
help='Checkpoint folder.')
args = parser.parse_args()
inference(args.input_path, args.output_path, args.model_path)