forked from huawei-noah/Pretrained-IPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3d1c62c
commit 32ce841
Showing
12 changed files
with
1,334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
|
||
from importlib import import_module | ||
#from dataloader import MSDataLoader | ||
from torch.utils.data import dataloader | ||
from torch.utils.data import ConcatDataset | ||
|
||
# This is a simple wrapper function for ConcatDataset | ||
class MyConcatDataset(ConcatDataset): | ||
def __init__(self, datasets): | ||
super(MyConcatDataset, self).__init__(datasets) | ||
self.train = datasets[0].train | ||
|
||
def set_scale(self, idx_scale): | ||
for d in self.datasets: | ||
if hasattr(d, 'set_scale'): d.set_scale(idx_scale) | ||
|
||
class Data: | ||
def __init__(self, args): | ||
self.loader_train = None | ||
if not args.test_only: | ||
datasets = [] | ||
for d in args.data_train: | ||
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | ||
m = import_module('data.' + module_name.lower()) | ||
datasets.append(getattr(m, module_name)(args, name=d)) | ||
|
||
self.loader_train = dataloader.DataLoader( | ||
MyConcatDataset(datasets), | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
pin_memory=not args.cpu, | ||
num_workers=args.n_threads, | ||
) | ||
|
||
self.loader_test = [] | ||
for d in args.data_test: | ||
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109','CBSD68','Rain100L','GOPRO_Large']: | ||
m = import_module('data.benchmark') | ||
testset = getattr(m, 'Benchmark')(args, train=False, name=d) | ||
else: | ||
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | ||
m = import_module('data.' + module_name.lower()) | ||
testset = getattr(m, module_name)(args, train=False, name=d) | ||
|
||
self.loader_test.append( | ||
dataloader.DataLoader( | ||
testset, | ||
batch_size=args.test_batch_size, | ||
shuffle=False, | ||
pin_memory=not args.cpu, | ||
num_workers=args.n_threads, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
import os | ||
|
||
from data import common | ||
from data import srdata | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torch.utils.data as data | ||
|
||
class Benchmark(srdata.SRData): | ||
def __init__(self, args, name='', train=True, benchmark=True): | ||
super(Benchmark, self).__init__( | ||
args, name=name, train=train, benchmark=True | ||
) | ||
|
||
def _set_filesystem(self, dir_data): | ||
self.apath = os.path.join(dir_data, 'benchmark', self.name) | ||
self.dir_hr = os.path.join(self.apath, 'HR') | ||
if self.input_large: | ||
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') | ||
else: | ||
self.dir_lr = os.path.join(self.apath, 'LR_bicubic') | ||
self.ext = ('', '.png') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
import random | ||
|
||
import numpy as np | ||
import skimage.color as sc | ||
|
||
import torch | ||
|
||
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | ||
ih, iw = args[0].shape[:2] | ||
|
||
tp = patch_size | ||
ip = tp // scale | ||
|
||
ix = random.randrange(0, iw - ip + 1) | ||
iy = random.randrange(0, ih - ip + 1) | ||
|
||
if not input_large: | ||
tx, ty = scale * ix, scale * iy | ||
else: | ||
tx, ty = ix, iy | ||
|
||
ret = [ | ||
args[0][iy:iy + ip, ix:ix + ip, :], | ||
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] | ||
] | ||
|
||
return ret | ||
|
||
def set_channel(*args, n_channels=3): | ||
def _set_channel(img): | ||
if img.ndim == 2: | ||
img = np.expand_dims(img, axis=2) | ||
|
||
c = img.shape[2] | ||
if n_channels == 1 and c == 3: | ||
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) | ||
elif n_channels == 3 and c == 1: | ||
img = np.concatenate([img] * n_channels, 2) | ||
|
||
return img[:,:,:n_channels] | ||
|
||
return [_set_channel(a) for a in args] | ||
|
||
def np2Tensor(*args, rgb_range=255): | ||
def _np2Tensor(img): | ||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | ||
tensor = torch.from_numpy(np_transpose).float() | ||
tensor.mul_(rgb_range / 255) | ||
|
||
return tensor | ||
|
||
return [_np2Tensor(a) for a in args] | ||
|
||
def augment(*args, hflip=True, rot=True): | ||
hflip = hflip and random.random() < 0.5 | ||
vflip = rot and random.random() < 0.5 | ||
rot90 = rot and random.random() < 0.5 | ||
|
||
def _augment(img): | ||
if hflip: img = img[:, ::-1, :] | ||
if vflip: img = img[::-1, :, :] | ||
if rot90: img = img.transpose(1, 0, 2) | ||
|
||
return img | ||
|
||
return [_augment(a) for a in args] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
import os | ||
|
||
from data import common | ||
|
||
import numpy as np | ||
import imageio | ||
|
||
import torch | ||
import torch.utils.data as data | ||
|
||
class Demo(data.Dataset): | ||
def __init__(self, args, name='Demo', train=False, benchmark=False): | ||
self.args = args | ||
self.name = name | ||
self.scale = args.scale | ||
self.idx_scale = 0 | ||
self.train = False | ||
self.benchmark = benchmark | ||
|
||
self.filelist = [] | ||
for f in os.listdir(args.dir_demo): | ||
if f.find('.png') >= 0 or f.find('.jp') >= 0: | ||
self.filelist.append(os.path.join(args.dir_demo, f)) | ||
self.filelist.sort() | ||
|
||
def __getitem__(self, idx): | ||
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] | ||
lr = imageio.imread(self.filelist[idx]) | ||
lr, = common.set_channel(lr, n_channels=self.args.n_colors) | ||
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | ||
|
||
return lr_t, -1, filename | ||
|
||
def __len__(self): | ||
return len(self.filelist) | ||
|
||
def set_scale(self, idx_scale): | ||
self.idx_scale = idx_scale | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
import os | ||
from data import srdata | ||
|
||
class DIV2K(srdata.SRData): | ||
def __init__(self, args, name='DIV2K', train=True, benchmark=False): | ||
data_range = [r.split('-') for r in args.data_range.split('/')] | ||
if train: | ||
data_range = data_range[0] | ||
else: | ||
if args.test_only and len(data_range) == 1: | ||
data_range = data_range[0] | ||
else: | ||
data_range = data_range[1] | ||
|
||
self.begin, self.end = list(map(lambda x: int(x), data_range)) | ||
super(DIV2K, self).__init__( | ||
args, name=name, train=train, benchmark=benchmark | ||
) | ||
|
||
def _scan(self): | ||
names_hr, names_lr = super(DIV2K, self)._scan() | ||
names_hr = names_hr[self.begin - 1:self.end] | ||
names_lr = [n[self.begin - 1:self.end] for n in names_lr] | ||
|
||
return names_hr, names_lr | ||
|
||
def _set_filesystem(self, dir_data): | ||
super(DIV2K, self)._set_filesystem(dir_data) | ||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | ||
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') | ||
#if self.input_large: self.dir_lr += 'L' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
import os | ||
from data import srdata | ||
from data import div2k | ||
|
||
class DIV2KJPEG(div2k.DIV2K): | ||
def __init__(self, args, name='', train=True, benchmark=False): | ||
self.q_factor = int(name.replace('DIV2K-Q', '')) | ||
super(DIV2KJPEG, self).__init__( | ||
args, name=name, train=train, benchmark=benchmark | ||
) | ||
|
||
def _set_filesystem(self, dir_data): | ||
self.apath = os.path.join(dir_data, 'DIV2K') | ||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | ||
self.dir_lr = os.path.join( | ||
self.apath, 'DIV2K_Q{}'.format(self.q_factor) | ||
) | ||
if self.input_large: self.dir_lr += 'L' | ||
self.ext = ('.png', '.jpg') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# 2021.05.07-Changed for IPT | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
|
||
from data import srdata | ||
|
||
class SR291(srdata.SRData): | ||
def __init__(self, args, name='SR291', train=True, benchmark=False): | ||
super(SR291, self).__init__(args, name=name) | ||
|
Oops, something went wrong.