Skip to content

Commit

Permalink
Update inference_diffir.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Zj-BinXia authored Oct 28, 2023
1 parent 47653c2 commit 7d3c3cc
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions DiffIR-RealSR/inference_diffir.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ def pad_test(lq,scale):
im_list.sort()
im_list = [name for name in im_list if name.endswith('.png')]

with torch.no_grad():
for name in im_list:
path = os.path.join(args.im_path, name)
im = cv2.imread(path)
im = img2tensor(im)
im = im.unsqueeze(0).cuda(0)/255.
lq,mod_pad_h,mod_pad_w= pad_test(im,args.scale)
with torch.no_grad():
sr = model(lq)
_, _, h, w = sr.size()
sr = sr[:, :, 0:h - mod_pad_h * args.scale, 0:w - mod_pad_w * args.scale]
im_sr = tensor2img(sr, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1))
save_path = os.path.join(args.res_path, name.split('.')[0]+'_out.png')
cv2.imwrite(save_path, im_sr)
print(save_path)

for name in im_list:
path = os.path.join(args.im_path, name)
im = cv2.imread(path)
im = img2tensor(im)
im = im.unsqueeze(0).cuda(0)/255.
lq,mod_pad_h,mod_pad_w= pad_test(im,args.scale)
with torch.no_grad():
sr = model(lq)
_, _, h, w = sr.size()
sr = sr[:, :, 0:h - mod_pad_h * args.scale, 0:w - mod_pad_w * args.scale]
im_sr = tensor2img(sr, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1))
save_path = os.path.join(args.res_path, name.split('.')[0]+'_out.png')
cv2.imwrite(save_path, im_sr)
print(save_path)

0 comments on commit 7d3c3cc

Please sign in to comment.