Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Oct 8, 2024
2 parents 6171490 + 319fcea commit 23ad9a0
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 160 deletions.
4 changes: 2 additions & 2 deletions iw3/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,8 @@ def parse_args(self):
if not validate_number(self.cbo_fps.GetValue(), 0.25, 1000.0, allow_empty=False):
self.show_validation_error_message(T("Max FPS"), 0.25, 1000.0)
return None
if not validate_number(self.cbo_crf.GetValue(), 0, 30, is_int=True):
self.show_validation_error_message(T("CRF"), 0, 30)
if not validate_number(self.cbo_crf.GetValue(), 0, 51, is_int=True):
self.show_validation_error_message(T("CRF"), 0, 51)
return None
if not validate_number(self.cbo_ema_decay.GetValue(), 0.1, 0.999):
self.show_validation_error_message(T("Flicker Reduction"), 0.1, 0.999)
Expand Down
24 changes: 24 additions & 0 deletions nunif/modules/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ def window_partition2d(x, window_size):
return x


def window_reverse2d(x, out_shape, window_size):
# reverse window_reverse2d
# x: B, N, C, H, W
OB, OC, OH, OW = out_shape
assert OC == x.shape[2]
SH, SW = window_size if isinstance(window_size, (list, tuple)) else [window_size, window_size]
assert OH % SH == 0 and OW % SW == 0
H = OH // SH
W = OW // SW
x = x.reshape(OB, H, W, OC, SH, SW)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(OB, OC, OH, OW)
return x


def _test_bhwc():
src = x = torch.rand((4, 3, 2, 2))
x = bchw_to_bhwc(x)
Expand Down Expand Up @@ -197,7 +212,16 @@ def _test_bnc():
print("pass _test_bnc")


def _test_window():
x = torch.rand((4, 3, 6, 6))
y = window_partition2d(x, window_size=2)
z = window_reverse2d(y, x.shape, window_size=2)
assert x.shape == z.shape
assert (x - z).abs().sum() == 0


if __name__ == "__main__":
_test_bhwc()
_test_pixel_shuffle()
_test_bnc()
_test_window()
2 changes: 1 addition & 1 deletion nunif/modules/reflection_pad2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def reflection_pad2d_naive(x, padding, detach=False):
elif bottom < 0:
x = x[:, :, :bottom, :]

return x
return x.contiguous()


class ReflectionPad2dNaive(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion nunif/modules/replication_pad2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def replication_pad2d_naive(x, padding, detach=False):
elif bottom < 0:
x = x[:, :, :bottom, :]

return x
return x.contiguous()


class ReplicationPad2dNaive(nn.Module):
Expand Down
162 changes: 107 additions & 55 deletions nunif/modules/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@
import torch.nn.functional as F
import math
from functools import lru_cache

from . replication_pad2d import replication_pad2d_naive
from . reflection_pad2d import reflection_pad2d_naive

# differentiable transforms for loss function


def rotate_grid(batch, height, width, angle, device):
def rotate_grid(batch, height, width, angle, device, dtype):
with torch.no_grad():
angle = math.radians(angle)
py, px = torch.meshgrid(torch.linspace(-1, 1, height, device=device),
torch.linspace(-1, 1, width, device=device), indexing="ij")
py, px = torch.meshgrid(torch.linspace(-1, 1, height, device=device, dtype=dtype),
torch.linspace(-1, 1, width, device=device, dtype=dtype), indexing="ij")
mesh_x = px * math.cos(angle) - py * math.sin(angle)
mesh_y = px * math.sin(angle) + py * math.cos(angle)
grid = torch.stack((mesh_x, mesh_y), 2).unsqueeze(0).repeat(batch, 1, 1, 1).detach()
grid = torch.stack((mesh_x, mesh_y), 2).unsqueeze(0).repeat(batch, 1, 1, 1).contiguous().detach()
return grid


@lru_cache
def rotate_grid_cache(batch, height, width, angle, device):
return rotate_grid(batch, height, width, angle, device)
def rotate_grid_cache(batch, height, width, angle, device, dtype):
return rotate_grid(batch, height, width, angle, device, dtype)


PAD_MODE_NN = {
Expand All @@ -31,24 +32,36 @@ def rotate_grid_cache(batch, height, width, angle, device):
}


def _pad(input, pad, mode="constant", value=0):
if mode == "reflect":
return reflection_pad2d_naive(input, pad, detach=True)
elif mode == "replicate":
return replication_pad2d_naive(input, pad, detach=True)
else:
return F.pad(input, pad, mode=mode, value=value).contiguous()


def diff_rotate(x, angle, mode="bilinear", padding_mode="zeros", align_corners=False, expand=False, cache=True):
if expand and padding_mode not in {"zeros", "constant"}:
raise ValueError(f"expand=True does not support padding_mode={padding_mode}")

# x: BCHW
B, _, H, W = x.shape
if expand:
pad_h = (int(2 ** 0.5 * H) - H) // 2 + 1
pad_w = (int(2 ** 0.5 * W) - W) // 2 + 1
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
B, _, H, W = x.shape
pad_h = (int(2 ** 0.5 * H) - H) // 2 + 1
pad_w = (int(2 ** 0.5 * W) - W) // 2 + 1
x = _pad(x, (pad_w, pad_w, pad_h, pad_h),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
B, _, H, W = x.shape

if cache:
grid = rotate_grid_cache(B, H, W, angle, x.device)
grid = rotate_grid_cache(B, H, W, angle, device=x.device, dtype=x.dtype)
else:
grid = rotate_grid(B, H, W, angle, x.device)

x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
grid = rotate_grid(B, H, W, angle, device=x.device, dtype=x.dtype)
x = F.grid_sample(x, grid, mode=mode, padding_mode="zeros", align_corners=align_corners)
if not expand:
x = F.pad(x, (-pad_w, -pad_w, -pad_h, -pad_h))

return x
return x.contiguous()


def diff_random_rotate(x, angle=45, mode="bilinear", padding_mode="zeros", align_corners=False, expand=False):
Expand All @@ -58,31 +71,41 @@ def diff_random_rotate(x, angle=45, mode="bilinear", padding_mode="zeros", align
return torch.cat([diff_rotate(x[i:i + 1, :, :, :], angle=angle[i].item(),
mode=mode, padding_mode=padding_mode,
align_corners=align_corners, expand=expand, cache=False)
for i in range(B)], dim=0)
for i in range(B)], dim=0).contiguous()


def diff_random_rotate_pair(x, y, angle=45, mode="bilinear", padding_mode="zeros", align_corners=False, expand=False):
assert x.shape[0] == y.shape[0]
B, _, H, W = x.shape
angle = (torch.rand((B,), device=x.device) * 2 - 1) * angle
xy = torch.stack((x, y), dim=1)
xys = []
for i in range(B):
xys.append(diff_rotate(
xy[i], angle=angle[i].item(),

if x.dtype == y.dtype:
xy = torch.stack((x, y), dim=1)
xys = []
for i in range(B):
xys.append(diff_rotate(
xy[i], angle=angle[i].item(),
mode=mode, padding_mode=padding_mode,
align_corners=align_corners, expand=expand, cache=False))
x = torch.stack([xyi[0] for xyi in xys], dim=0)
y = torch.stack([xyi[1] for xyi in xys], dim=0)
else:
x = torch.cat([diff_rotate(
x[i:i + 1, :, :, :], angle=angle[i].item(),
mode=mode, padding_mode=padding_mode,
align_corners=align_corners, expand=expand, cache=False) for i in range(B)], dim=0)
y = torch.cat([diff_rotate(
y[i:i + 1, :, :, :], angle=angle[i].item(),
mode=mode, padding_mode=padding_mode,
align_corners=align_corners, expand=expand, cache=False))
x = torch.stack([xyi[0] for xyi in xys], dim=0)
y = torch.stack([xyi[1] for xyi in xys], dim=0)
align_corners=align_corners, expand=expand, cache=False) for i in range(B)], dim=0)

return x, y
return x.contiguous(), y.contiguous()


def diff_translate(x, x_shift, y_shift, padding_mode="zeros", expand_x=0, expand_y=0):
# NOTE: padded values with reflect or replicate have copied gradients.
# there may be cases where that is undesirable.
return F.pad(x, (x_shift + expand_x, -x_shift + expand_x,
y_shift + expand_y, -y_shift + expand_y),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
return _pad(x, (x_shift + expand_x, -x_shift + expand_x,
y_shift + expand_y, -y_shift + expand_y),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0).contiguous()


def diff_random_translate(x, ratio=0.15, size=None, padding_mode="zeros", expand=False):
Expand All @@ -102,10 +125,11 @@ def diff_random_translate(x, ratio=0.15, size=None, padding_mode="zeros", expand
# FIXME: remove loop
return torch.cat([diff_translate(x[i:i + 1, :, :, :], x_shift=x_shift[i], y_shift=y_shift[i],
padding_mode=padding_mode,
expand_x=expand_x, expand_y=expand_y) for i in range(B)], dim=0)
expand_x=expand_x, expand_y=expand_y) for i in range(B)], dim=0).contiguous()


def diff_random_translate_pair(x, y, ratio=0.15, size=None, padding_mode="zeros", expand=False):
assert x.shape[0] == y.shape[0]
B, _, H, W = x.shape
if size is not None:
x_shift = torch.randint(low=-size, high=size + 1, size=(B,), device=x.device)
Expand All @@ -119,17 +143,27 @@ def diff_random_translate_pair(x, y, ratio=0.15, size=None, padding_mode="zeros"
else:
expand_x = expand_y = 0

xy = torch.stack((x, y), dim=1)
xys = []
for i in range(B):
xys.append(diff_translate(
xy[i], x_shift=x_shift[i], y_shift=y_shift[i],
if x.dtype == y.dtype:
xy = torch.stack((x, y), dim=1)
xys = []
for i in range(B):
xys.append(diff_translate(
xy[i], x_shift=x_shift[i], y_shift=y_shift[i],
padding_mode=padding_mode,
expand_x=expand_x, expand_y=expand_y))
x = torch.stack([xyi[0] for xyi in xys], dim=0)
y = torch.stack([xyi[1] for xyi in xys], dim=0)
else:
x = torch.cat([diff_translate(
x[i:i + 1, :, :, :], x_shift=x_shift[i], y_shift=y_shift[i],
padding_mode=padding_mode,
expand_x=expand_x, expand_y=expand_y) for i in range(B)], dim=0)
y = torch.cat([diff_translate(
y[i:i + 1, :, :, :], x_shift=x_shift[i], y_shift=y_shift[i],
padding_mode=padding_mode,
expand_x=expand_x, expand_y=expand_y))
x = torch.stack([xyi[0] for xyi in xys], dim=0)
y = torch.stack([xyi[1] for xyi in xys], dim=0)
expand_x=expand_x, expand_y=expand_y) for i in range(B)], dim=0)

return x, y
return x.contiguous(), y.contiguous()


class DiffPairRandomTranslate(nn.Module):
Expand All @@ -145,13 +179,14 @@ def __init__(self, ratio=0.15, size=None, padding_mode="zeros", expand=False, in
def expand_pad(input, target, ratio=0.15, size=None, padding_mode="zeros"):
size = size if size else int(input.shape[2:] * ratio)
expand_x = expand_y = size
input = F.pad(input, (expand_x, expand_x, expand_y, expand_y),
mode=PAD_MODE_NN.get(padding_mode, padding_mode))
target = F.pad(target, (expand_x, expand_x, expand_y, expand_y),
mode=PAD_MODE_NN.get(padding_mode, padding_mode))
return input, target
padding_mode = PAD_MODE_NN.get(padding_mode, padding_mode)
pad = (expand_x, expand_x, expand_y, expand_y)
input = _pad(input, pad, mode=padding_mode, value=0)
target = _pad(target, pad, mode=padding_mode, value=0)
return input.contiguous(), target.contiguous()

def forward(self, input, target):
assert input.shape[0] == target.shape[0]
if self.training:
if self.instance_random:
return diff_random_translate_pair(input, target, ratio=self.ratio, size=self.size,
Expand Down Expand Up @@ -194,11 +229,11 @@ def expand_pad(input, target, padding_mode):
H, W = input.shape[:2]
pad_h = (int(2 ** 0.5 * H) - H) // 2 + 1
pad_w = (int(2 ** 0.5 * W) - W) // 2 + 1
input = F.pad(input, (pad_w, pad_w, pad_h, pad_h),
input = _pad(input, (pad_w, pad_w, pad_h, pad_h),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
target = _pad(input, (pad_w, pad_w, pad_h, pad_h),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
target = F.pad(input, (pad_w, pad_w, pad_h, pad_h),
mode=PAD_MODE_NN.get(padding_mode, padding_mode), value=0)
return input, target
return input.contiguous(), target.contiguous()

def forward(self, input, target):
if self.training:
Expand All @@ -211,9 +246,9 @@ def forward(self, input, target):
# batch random
angle = (torch.rand(1).item() * 2 - 1) * self.angle
input = diff_rotate(input, angle, mode=self.mode, padding_mode=self.padding_mode,
align_corners=self.align_corners, expand=self.expand)
align_corners=self.align_corners, expand=self.expand, cache=False)
target = diff_rotate(target, angle, mode=self.mode, padding_mode=self.padding_mode,
align_corners=self.align_corners, expand=self.expand)
align_corners=self.align_corners, expand=self.expand, cache=False)
return input, target
else:
if self.expand:
Expand All @@ -222,6 +257,23 @@ def forward(self, input, target):
return input, target


class DiffPairRandomDownsample(nn.Module):
def __init__(self, scale_factor_min=0.5, scale_factor_max=1.0):
super().__init__()
self.scale_factor_min = scale_factor_min
self.scale_factor_max = scale_factor_max

def forward(self, input, target):
if self.training:
scale_factor = (self.scale_factor_max - self.scale_factor_min) * torch.rand(1).item() + self.scale_factor_min
else:
scale_factor = (self.scale_factor_max - self.scale_factor_min) * 0.5 + self.scale_factor_min

input = F.interpolate(input, scale_factor=scale_factor, mode="bilinear", align_corners=False, antialias=True)
target = F.interpolate(target, scale_factor=scale_factor, mode="bilinear", align_corners=False, antialias=True)
return input, target


def _test_rotate():
import torchvision.io as IO
import torchvision.transforms.functional as TF
Expand All @@ -236,7 +288,7 @@ def _test_rotate():
TF.to_pil_image(z[0]).show()
time.sleep(0.5)

z = diff_rotate(x, 45, expand=True, padding_mode="reflection")
z = diff_rotate(x, 45, expand=False, padding_mode="reflection")
TF.to_pil_image(z[0]).show()

x = x.repeat(4, 1, 1, 1)
Expand Down
5 changes: 5 additions & 0 deletions nunif/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,17 @@ def test_output_size(test_callback, video_stream, vf):
src_color_range=ColorRange.JPEG, dst_color_range=video_stream.codec_context.color_range)
pts_step = int((1. / video_stream.time_base) / 30) or 1
test_frame.pts = pts_step

try_count = 0
while True:
while True:
frame = video_filter.update(test_frame)
test_frame.pts = (test_frame.pts + pts_step)
if frame is not None:
break
try_count += 1
if try_count * video_stream.codec_context.width * video_stream.codec_context.height * 3 > 300 * 1024 * 1024:
raise RuntimeError("Unable to estimate output size of video filter")
output_frame = get_new_frames(test_callback(frame))
if output_frame:
output_frame = output_frame[0]
Expand Down
Loading

0 comments on commit 23ad9a0

Please sign in to comment.