Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Moebits committed Feb 16, 2024
1 parent 81c75e1 commit 874f4d0
Show file tree
Hide file tree
Showing 7 changed files with 1,354 additions and 6 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "waifu2x",
"version": "1.3.5",
"version": "1.3.6",
"description": "2x upscaling of images with waifu2x",
"main": "dist/waifu2x.js",
"types": "dist/waifu2x.d.ts",
Expand Down
256 changes: 256 additions & 0 deletions scripts/RRDB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import functools
import math
import re
from collections import OrderedDict
import torch
import torch.nn as nn
import block as B


# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
# Which enhanced stuff that was already here
class RRDBNet(nn.Module):
def __init__(
self,
state_dict,
norm=None,
act: str = "leakyrelu",
upsampler: str = "upconv",
mode: str = "CNA",
) -> None:
"""
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
and Chen Change Loy.
This is old-arch Residual in Residual Dense Block Network and is not
the newest revision that's available at github.com/xinntao/ESRGAN.
This is on purpose, the newest Network has severely limited the
potential use of the Network with no benefits.
This network supports model files from both new and old-arch.
Args:
norm: Normalization layer
act: Activation layer
upsampler: Upsample layer. upconv, pixel_shuffle
mode: Convolution mode
"""
super(RRDBNet, self).__init__()

self.state = state_dict
self.norm = norm
self.act = act
self.upsampler = upsampler
self.mode = mode

self.state_map = {
# currently supports old, new, and newer RRDBNet arch models
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
"model.0.weight": ("conv_first.weight",),
"model.0.bias": ("conv_first.bias",),
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
"model.3.weight": ("upconv1.weight", "conv_up1.weight"),
"model.3.bias": ("upconv1.bias", "conv_up1.bias"),
"model.6.weight": ("upconv2.weight", "conv_up2.weight"),
"model.6.bias": ("upconv2.bias", "conv_up2.bias"),
"model.8.weight": ("HRconv.weight", "conv_hr.weight"),
"model.8.bias": ("HRconv.bias", "conv_hr.bias"),
"model.10.weight": ("conv_last.weight",),
"model.10.bias": ("conv_last.bias",),
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
),
}
if "params_ema" in self.state:
self.state = self.state["params_ema"]
self.num_blocks = self.get_num_blocks()

self.plus = any("conv1x1" in k for k in self.state.keys())

self.state = self.new_to_old_arch(self.state)

self.key_arr = list(self.state.keys())
# print(self.key_arr)

self.in_nc = self.state[self.key_arr[0]].shape[1]
self.out_nc = self.state[self.key_arr[-1]].shape[0]

self.scale = self.get_scale()

self.num_filters = self.state[self.key_arr[0]].shape[0]

c2x2 = False
if self.state["model.0.weight"].shape[-2] == 2:
c2x2 = True
self.scale = math.ceil(self.scale ** (1.0 / 3))

# Detect if pixelunshuffle was used (Real-ESRGAN)
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
self.in_nc / 4,
self.in_nc / 16,
):
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
else:
self.shuffle_factor = None

upsample_block = {
"upconv": B.upconv_block,
"pixel_shuffle": B.pixelshuffle_block,
}.get(self.upsampler)
if upsample_block is None:
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")

if self.scale == 3:
upsample_blocks = upsample_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
upscale_factor=3,
act_type=self.act,
c2x2=c2x2,
)
else:
upsample_blocks = [
upsample_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
act_type=self.act,
c2x2=c2x2,
)
for _ in range(int(math.log(self.scale, 2)))
]

self.model = B.sequential(
# fea conv
B.conv_block(
in_nc=self.in_nc,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
B.ShortcutBlock(
B.sequential(
# rrdb blocks
*[
B.RRDB(
nf=self.num_filters,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=self.norm,
act_type=self.act,
mode="CNA",
plus=self.plus,
c2x2=c2x2,
)
for _ in range(self.num_blocks)
],
# lr conv
B.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=self.norm,
act_type=None,
mode=self.mode,
c2x2=c2x2,
),
)
),
*upsample_blocks,
# hr_conv0
B.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=self.act,
c2x2=c2x2,
),
# hr_conv1
B.conv_block(
in_nc=self.num_filters,
out_nc=self.out_nc,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
)

self.load_state_dict(self.state, strict=False)

def new_to_old_arch(self, state):
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
if "params_ema" in state:
state = state["params_ema"]

if "conv_first.weight" not in state:
# model is already old arch, this is a loose check, but should be sufficient
return state

# add nb to state keys
for kind in ("weight", "bias"):
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
f"model.1.sub./NB/.{kind}"
]
del self.state_map[f"model.1.sub./NB/.{kind}"]

old_state = OrderedDict()
for old_key, new_keys in self.state_map.items():
for new_key in new_keys:
if r"\1" in old_key:
for k, v in state.items():
sub = re.sub(new_key, old_key, k)
if sub != k:
old_state[sub] = v
else:
if new_key in state:
old_state[old_key] = state[new_key]

# Sort by first numeric value of each layer
def compare(item1, item2):
parts1 = item1.split(".")
parts2 = item2.split(".")
int1 = int(parts1[1])
int2 = int(parts2[1])
return int1 - int2

sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))

# Rebuild the output dict in the right order
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)

return out_dict

def get_scale(self, min_part: int = 6) -> int:
n = 0
for part in list(self.state):
parts = part.split(".")[1:]
if len(parts) == 2:
part_num = int(parts[0])
if part_num > min_part and parts[1] == "weight":
n += 1
return 2**n

def get_num_blocks(self) -> int:
nbs = []
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
)
for state_key in state_keys:
for k in self.state:
m = re.search(state_key, k)
if m:
nbs.append(int(m.group(1)))
if nbs:
break
return max(*nbs) + 1

def forward(self, x):
if self.shuffle_factor:
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
return self.model(x)
Loading

0 comments on commit 874f4d0

Please sign in to comment.