Skip to content

Commit

Permalink
waifu2x: Add a option to use only bicubic in train
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Feb 8, 2023
1 parent f8618d8 commit 9a4d6e2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
10 changes: 9 additions & 1 deletion waifu2x/training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# "vision.bicubic_no_antialias",
)
INTERPOLATION_NEAREST = "box"
INTERPOLATION_BICUBIC = "catrom"
#INTERPOLATION_MODE_WEIGHTS = (1/3, 1/3, 1/6, 1/16, 1/3, 1/12) # noqa: E226
INTERPOLATION_MODE_WEIGHTS = (1/3, 1/3, 1/6, 1/16, 1/3) # noqa: E226

Expand Down Expand Up @@ -72,6 +73,7 @@ def __call__(self, x, y):
interpolation = random.choices(INTERPOLATION_MODES, weights=INTERPOLATION_MODE_WEIGHTS, k=1)[0]
else:
interpolation = self.interpolation

if self.scale_factor == 2:
if not self.training:
blur = 1 + self.blur_shift / 4
Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(self, input_dir,
scale_factor,
tile_size, num_samples=None,
da_jpeg_p=0, da_scale_p=0, da_chshuf_p=0, da_unsharpmask_p=0, da_grayscale_p=0,
bicubic_only=False,
deblur=0, resize_blur_p=0.1,
noise_level=-1, style=None,
training=True):
Expand All @@ -173,7 +176,12 @@ def __init__(self, input_dir,
else:
jpeg_transform = TP.Identity()
if scale_factor > 1:
if bicubic_only:
interpolation = INTERPOLATION_BICUBIC
else:
interpolation = None # random
random_downscale_x = RandomDownscaleX(scale_factor=scale_factor,
interpolation=interpolation,
blur_shift=deblur, resize_blur_p=resize_blur_p)
random_downscale_x_nearest = RandomDownscaleX(scale_factor=scale_factor,
interpolation=INTERPOLATION_NEAREST)
Expand Down Expand Up @@ -208,7 +216,7 @@ def __init__(self, input_dir,
])
else:
self.gt_transforms = TS.Identity()
interpolation = "catrom"
interpolation = INTERPOLATION_BICUBIC
if scale_factor > 1:
downscale_x = RandomDownscaleX(scale_factor=scale_factor,
blur_shift=deblur,
Expand Down
3 changes: 3 additions & 0 deletions waifu2x/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def create_dataloader(self, type):
input_dir=path.join(self.args.data_dir, "train"),
model_offset=model_offset,
scale_factor=scale_factor,
bicubic_only=self.args.b4b,
style=self.args.style,
noise_level=self.args.noise_level,
tile_size=self.args.size,
Expand Down Expand Up @@ -240,6 +241,8 @@ def register(subparsers, default_parser):
parser.add_argument("--hard-example", type=str, default="linear",
choices=["none", "linear", "top10", "top20"],
help="hard example mining for training data sampleing")
parser.add_argument("--b4b", action="store_true",
help="use only bicubic downsampling for bicubic downsampling restoration")

parser.set_defaults(
batch_size=16,
Expand Down

0 comments on commit 9a4d6e2

Please sign in to comment.