From 9a4d6e264794d1a7157d048c23edf6f226110281 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Wed, 8 Feb 2023 23:27:55 +0900 Subject: [PATCH] waifu2x: Add a option to use only bicubic in train --- waifu2x/training/dataset.py | 10 +++++++++- waifu2x/training/trainer.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/waifu2x/training/dataset.py b/waifu2x/training/dataset.py index fd332b52..6621bff6 100644 --- a/waifu2x/training/dataset.py +++ b/waifu2x/training/dataset.py @@ -29,6 +29,7 @@ # "vision.bicubic_no_antialias", ) INTERPOLATION_NEAREST = "box" +INTERPOLATION_BICUBIC = "catrom" #INTERPOLATION_MODE_WEIGHTS = (1/3, 1/3, 1/6, 1/16, 1/3, 1/12) # noqa: E226 INTERPOLATION_MODE_WEIGHTS = (1/3, 1/3, 1/6, 1/16, 1/3) # noqa: E226 @@ -72,6 +73,7 @@ def __call__(self, x, y): interpolation = random.choices(INTERPOLATION_MODES, weights=INTERPOLATION_MODE_WEIGHTS, k=1)[0] else: interpolation = self.interpolation + if self.scale_factor == 2: if not self.training: blur = 1 + self.blur_shift / 4 @@ -156,6 +158,7 @@ def __init__(self, input_dir, scale_factor, tile_size, num_samples=None, da_jpeg_p=0, da_scale_p=0, da_chshuf_p=0, da_unsharpmask_p=0, da_grayscale_p=0, + bicubic_only=False, deblur=0, resize_blur_p=0.1, noise_level=-1, style=None, training=True): @@ -173,7 +176,12 @@ def __init__(self, input_dir, else: jpeg_transform = TP.Identity() if scale_factor > 1: + if bicubic_only: + interpolation = INTERPOLATION_BICUBIC + else: + interpolation = None # random random_downscale_x = RandomDownscaleX(scale_factor=scale_factor, + interpolation=interpolation, blur_shift=deblur, resize_blur_p=resize_blur_p) random_downscale_x_nearest = RandomDownscaleX(scale_factor=scale_factor, interpolation=INTERPOLATION_NEAREST) @@ -208,7 +216,7 @@ def __init__(self, input_dir, ]) else: self.gt_transforms = TS.Identity() - interpolation = "catrom" + interpolation = INTERPOLATION_BICUBIC if scale_factor > 1: downscale_x = RandomDownscaleX(scale_factor=scale_factor, blur_shift=deblur, diff --git a/waifu2x/training/trainer.py b/waifu2x/training/trainer.py index 8d1393aa..b4b29b26 100644 --- a/waifu2x/training/trainer.py +++ b/waifu2x/training/trainer.py @@ -58,6 +58,7 @@ def create_dataloader(self, type): input_dir=path.join(self.args.data_dir, "train"), model_offset=model_offset, scale_factor=scale_factor, + bicubic_only=self.args.b4b, style=self.args.style, noise_level=self.args.noise_level, tile_size=self.args.size, @@ -240,6 +241,8 @@ def register(subparsers, default_parser): parser.add_argument("--hard-example", type=str, default="linear", choices=["none", "linear", "top10", "top20"], help="hard example mining for training data sampleing") + parser.add_argument("--b4b", action="store_true", + help="use only bicubic downsampling for bicubic downsampling restoration") parser.set_defaults( batch_size=16,