Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HantingChen authored Jun 27, 2021
1 parent 3d1c62c commit 32ce841
Show file tree
Hide file tree
Showing 12 changed files with 1,334 additions and 0 deletions.
56 changes: 56 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>


from importlib import import_module
#from dataloader import MSDataLoader
from torch.utils.data import dataloader
from torch.utils.data import ConcatDataset

# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
def __init__(self, datasets):
super(MyConcatDataset, self).__init__(datasets)
self.train = datasets[0].train

def set_scale(self, idx_scale):
for d in self.datasets:
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)

class Data:
def __init__(self, args):
self.loader_train = None
if not args.test_only:
datasets = []
for d in args.data_train:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d))

self.loader_train = dataloader.DataLoader(
MyConcatDataset(datasets),
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)

self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109','CBSD68','Rain100L','GOPRO_Large']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)

self.loader_test.append(
dataloader.DataLoader(
testset,
batch_size=args.test_batch_size,
shuffle=False,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)
)
28 changes: 28 additions & 0 deletions data/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

import os

from data import common
from data import srdata

import numpy as np

import torch
import torch.utils.data as data

class Benchmark(srdata.SRData):
def __init__(self, args, name='', train=True, benchmark=True):
super(Benchmark, self).__init__(
args, name=name, train=train, benchmark=True
)

def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'benchmark', self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('', '.png')

70 changes: 70 additions & 0 deletions data/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

import random

import numpy as np
import skimage.color as sc

import torch

def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
ih, iw = args[0].shape[:2]

tp = patch_size
ip = tp // scale

ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)

if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy

ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]

return ret

def set_channel(*args, n_channels=3):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)

c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)

return img[:,:,:n_channels]

return [_set_channel(a) for a in args]

def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)

return tensor

return [_np2Tensor(a) for a in args]

def augment(*args, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5

def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)

return img

return [_augment(a) for a in args]

42 changes: 42 additions & 0 deletions data/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

import os

from data import common

import numpy as np
import imageio

import torch
import torch.utils.data as data

class Demo(data.Dataset):
def __init__(self, args, name='Demo', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.benchmark = benchmark

self.filelist = []
for f in os.listdir(args.dir_demo):
if f.find('.png') >= 0 or f.find('.jp') >= 0:
self.filelist.append(os.path.join(args.dir_demo, f))
self.filelist.sort()

def __getitem__(self, idx):
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
lr = imageio.imread(self.filelist[idx])
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)

return lr_t, -1, filename

def __len__(self):
return len(self.filelist)

def set_scale(self, idx_scale):
self.idx_scale = idx_scale

35 changes: 35 additions & 0 deletions data/div2k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

import os
from data import srdata

class DIV2K(srdata.SRData):
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]

self.begin, self.end = list(map(lambda x: int(x), data_range))
super(DIV2K, self).__init__(
args, name=name, train=train, benchmark=benchmark
)

def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]

return names_hr, names_lr

def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
#if self.input_large: self.dir_lr += 'L'

23 changes: 23 additions & 0 deletions data/div2kjpeg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

import os
from data import srdata
from data import div2k

class DIV2KJPEG(div2k.DIV2K):
def __init__(self, args, name='', train=True, benchmark=False):
self.q_factor = int(name.replace('DIV2K-Q', ''))
super(DIV2KJPEG, self).__init__(
args, name=name, train=train, benchmark=benchmark
)

def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'DIV2K')
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
)
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.jpg')

9 changes: 9 additions & 0 deletions data/sr291.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>

from data import srdata

class SR291(srdata.SRData):
def __init__(self, args, name='SR291', train=True, benchmark=False):
super(SR291, self).__init__(args, name=name)

Loading

0 comments on commit 32ce841

Please sign in to comment.