Skip to content

Commit

Permalink
waifu2x: Support art_scan model(--style scan) in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed May 4, 2023
1 parent 0e01f70 commit cd3e4ba
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions waifu2x/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
path.join(path.dirname(path.abspath(__file__)), "pretrained_models"),
"swin_unet", "art"))

DEFAULT_ART_SCAN_MODEL_DIR = path.abspath(path.join(
path.join(path.dirname(path.abspath(__file__)), "pretrained_models"),
"swin_unet", "art_scan"))

DEFAULT_PHOTO_MODEL_DIR = path.abspath(path.join(
path.join(path.dirname(path.abspath(__file__)), "pretrained_models"),
"swin_unet", "photo"))


def antialias(x):
w, h = x.size
x = TF.resize(x, (h * 2, w * 2), interpolation=InterpolationMode.BILINEAR, antialias=True)
x = TF.resize(x, (h, w), interpolation=InterpolationMode.BICUBIC, antialias=True)
return x


def convert_files(ctx, files, args, enable_amp):
loader = ImageLoader(files=files, max_queue_size=128,
load_func=IL.load_image,
Expand All @@ -42,8 +39,6 @@ def convert_files(ctx, files, args, enable_amp):
futures = []
with torch.no_grad(), PoolExecutor(max_workers=cpu_count() // 2 or 1) as pool:
for im, meta in tqdm(loader, ncols=60):
if args.pre_antialias:
im = antialias(im)
rgb, alpha = IL.to_tensor(im, return_alpha=True)
rgb, alpha = ctx.convert(
rgb, alpha, args.method, args.noise_level,
Expand All @@ -53,6 +48,8 @@ def convert_files(ctx, files, args, enable_amp):
if args.depth is not None:
meta["depth"] = args.depth
depth = meta["depth"] if "depth" in meta and meta["depth"] is not None else 8
if args.grayscale:
meta["grayscale"] = True
futures.append(pool.submit(
IL.save_image,
IL.to_image(rgb, alpha, depth=depth),
Expand All @@ -71,15 +68,15 @@ def convert_file(ctx, args, enable_amp):

with torch.no_grad():
im, meta = IL.load_image(args.input, color="rgb", keep_alpha=True)
if args.pre_antialias:
im = antialias(im)
rgb, alpha = IL.to_tensor(im, return_alpha=True)
rgb, alpha = ctx.convert(rgb, alpha, args.method, args.noise_level,
args.tile_size, args.batch_size,
args.tta, enable_amp=enable_amp)
if args.depth is not None:
meta["depth"] = args.depth
depth = meta["depth"] if "depth" in meta and meta["depth"] is not None else 8
if args.grayscale:
meta["grayscale"] = True
IL.save_image(IL.to_image(rgb, alpha, depth=depth),
filename=args.output, meta=meta,
format=fmt)
Expand All @@ -98,6 +95,8 @@ def main(args):
if args.model_dir is None:
if args.style == "photo":
model_dir = DEFAULT_PHOTO_MODEL_DIR
elif args.style in {"scan", "art_scan"}:
model_dir = DEFAULT_ART_SCAN_MODEL_DIR
else:
model_dir = DEFAULT_ART_MODEL_DIR
else:
Expand All @@ -120,21 +119,34 @@ def main(args):
parser.add_argument("--model-dir", type=str, help="model dir")
parser.add_argument("--noise-level", "-n", type=int, default=0, choices=[0, 1, 2, 3], help="noise level")
parser.add_argument("--method", "-m", type=str,
choices=["scale4x", "scale", "noise", "noise_scale", "noise_scale4x", "scale2x", "noise_scale2x"],
choices=["scale4x", "scale2x",
"noise_scale4x", "noise_scale2x",
"scale", "noise", "noise_scale"],
default="noise_scale", help="method")
parser.add_argument("--gpu", "-g", type=int, nargs="+", default=[0], help="GPU device ids. -1 for CPU")
parser.add_argument("--batch-size", type=int, default=4, help="minibatch_size")
parser.add_argument("--tile-size", type=int, default=256, help="tile size for tiled render")
parser.add_argument("--output", "-o", type=str, required=True, help="output file or directory")
parser.add_argument("--input", "-i", type=str, required=True, help="input file or directory. (*.txt, *.csv) for image list")
parser.add_argument("--gpu", "-g", type=int, nargs="+", default=[0],
help="GPU device ids. -1 for CPU")
parser.add_argument("--batch-size", type=int, default=4,
help="minibatch_size")
parser.add_argument("--tile-size", type=int, default=256,
help="tile size for tiled render")
parser.add_argument("--output", "-o", type=str, required=True,
help="output file or directory")
parser.add_argument("--input", "-i", type=str, required=True,
help="input file or directory. (*.txt, *.csv) for image list")
parser.add_argument("--tta", action="store_true", help="use TTA mode")
parser.add_argument("--disable-amp", action="store_true", help="disable AMP for some special reason")
parser.add_argument("--image-lib", type=str, choices=["pil", "wand"], default="pil",
help="image library to encode/decode images")
parser.add_argument("--depth", type=int, help="bit-depth of output image. enabled only with `--image-lib wand`")
parser.add_argument("--format", "-f", type=str, default="png", choices=["png", "webp", "jpeg"], help="output image format")
parser.add_argument("--pre-antialias", action="store_true", help="Removing sharp artifacts before run.")
parser.add_argument("--style", type=str, choices=["art", "photo"], help="style for default model (art/photo). Ignored when --model-dir option is specified.")
parser.add_argument("--depth", type=int,
help="bit-depth of output image. enabled only with `--image-lib wand`")
parser.add_argument("--format", "-f", type=str, default="png", choices=["png", "webp", "jpeg"],
help="output image format")
parser.add_argument("--style", type=str, choices=["art", "photo", "scan", "art_scan"],
help=("style for default model (art/scan/photo). "
"Ignored when --model-dir option is specified."))
parser.add_argument("--grayscale", action="store_true",
help="Convert to grayscale format")

args = parser.parse_args()
logger.debug(f"waifu2x.cli.main: {str(args)}")
if args.image_lib == "wand":
Expand Down

0 comments on commit cd3e4ba

Please sign in to comment.