forked from codeslake/IFAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
executable file
·105 lines (77 loc) · 3.95 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import cog
import torch
from configs.config_IFAN import get_config
from ckpt_manager import CKPT_Manager
from models import create_model
from utils import *
from data_loader.utils import load_file_list, refine_image, read_frame, preproc_frame
from pathlib import Path
import tempfile
import cv2
class Predictor(cog.BasePredictor):
def setup(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.config = get_config('IFAN_CVPR2021', 'IFAN', 'config_IFAN')
self.config.network = 'IFAN'
model = create_model(self.config)
self.network = model.get_network().eval()
self.network = self.network.to(self.device)
ckpt_manager = CKPT_Manager(root_dir='', model_name='IFAN', cuda=True)
load_state, ckpt_name = ckpt_manager.load_ckpt(self.network, abs_name = f"{__file__.replace('predict.py', '')}ckpt/IFAN.pytorch")
# @cog.input("image", type=Path, help="Input image, only supports images with .png and .jpg extensions")
def predict(self, image):
max_side = 1920
assert str(image).split('.')[-1] in ['png', 'jpg'], 'image should end with ".jpg" or ".png"'
C_cpu = read_frame(str(image), self.config.norm_val, None)
b, h, w, c = C_cpu.shape
if max(h, w) > max_side:
scale_ratio = max_side / max(h, w)
C_cpu = np.expand_dims(cv2.resize(C_cpu[0], dsize=(int(w*scale_ratio), int(h*scale_ratio)), interpolation=cv2.INTER_AREA), 0)
C = torch.FloatTensor(refine_image(C_cpu, self.config.refine_val).transpose(0, 3, 1, 2).copy()).to(self.device)
with torch.no_grad():
out = self.network(C)
output = out['result']
output_cpu = output.cpu().numpy()[0].transpose(1, 2, 0)
output_cpu = (np.flip(output_cpu, 2) * 255).astype(np.uint8)
out_path = Path(tempfile.mkdtemp()) / 'out.jpg'
cv2.imwrite(str(out_path), output_cpu)
return out_path
def predict_image(self, image):
max_side = 1920
# assert str(image).split('.')[-1] in ['png', 'jpg'], 'image should end with ".jpg" or ".png"'
#
C_cpu = preproc_frame(image, self.config.norm_val, None)
# assert max(image) <= 1
b, h, w, c = C_cpu.shape
if max(h, w) > max_side:
scale_ratio = max_side / max(h, w)
C_cpu = np.expand_dims(cv2.resize(C_cpu[0], dsize=(int(w*scale_ratio), int(h*scale_ratio)), interpolation=cv2.INTER_AREA), 0)
C = torch.FloatTensor(refine_image(C_cpu, self.config.refine_val).transpose(0, 3, 1, 2).copy()).to(self.device)
with torch.no_grad():
out = self.network(C)
output = out['result']
output_cpu = output.cpu().numpy()[0].transpose(1, 2, 0)
# output_cpu = (np.flip(output_cpu, 2) * 255).astype(np.uint8)
output_cpu = (output_cpu * 255).astype(np.uint8)
# out_path = Path(tempfile.mkdtemp()) / 'out.jpg'
# cv2.imwrite(str(out_path), output_cpu)
return output_cpu
def predict_image_old(self, image):
max_side = 1920
# assert str(image).split('.')[-1] in ['png', 'jpg'], 'image should end with ".jpg" or ".png"'
#
C_cpu = read_frame(str(image), self.config.norm_val, None)
# assert max(image) <= 1
b, h, w, c = C_cpu.shape
if max(h, w) > max_side:
scale_ratio = max_side / max(h, w)
C_cpu = np.expand_dims(cv2.resize(C_cpu[0], dsize=(int(w*scale_ratio), int(h*scale_ratio)), interpolation=cv2.INTER_AREA), 0)
C = torch.FloatTensor(refine_image(C_cpu, self.config.refine_val).transpose(0, 3, 1, 2).copy()).to(self.device)
with torch.no_grad():
out = self.network(C)
output = out['result']
output_cpu = output.cpu().numpy()[0].transpose(1, 2, 0)
output_cpu = (np.flip(output_cpu, 2) * 255).astype(np.uint8)
# out_path = Path(tempfile.mkdtemp()) / 'out.jpg'
# cv2.imwrite(str(out_path), output_cpu)
return output_cpu