Skip to content

Commit

Permalink
First commit for training
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Dec 8, 2022
1 parent 4766e1a commit 86698c8
Show file tree
Hide file tree
Showing 28 changed files with 1,102 additions and 212 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ tmp
*.tar
data
!data/.keep
models/
!models/.keep
waifu2x/pretrained_models
waifu2x/web-config.ini
waifu2x/public_html
!waifu2x/webgen/assets/*.png
!waifu2x/webgen/assets/*.jpg

# ----- Python.gitignore
# Byte-compiled / optimized / DLL files
Expand Down
24 changes: 24 additions & 0 deletions build_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import argparse
from nunif.addon import load_addons


def add_default_options(parser):
subparser.add_argument("--dataset-dir", "-i", type=str, required=True, help="input dataset dir")
subparser.add_argument("--data-dir", "-o", type=str, required=True, help="output data dir")


def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(help="task", required=True)
for addon in load_addons():
subparser = addon.register_build_training_data(subparsers)
if subparser is not None:
add_default_options(subparser)
args = parser.parse_args()
assert(args.handler is not None)

args.handler(args)


if __name__ == "__main__":
main()
42 changes: 30 additions & 12 deletions nunif/addon.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
import os
from os import path

# TODO: train


class Addon():
def __init__(self, name, **kwargs):
self.config = {"name": name}
self.config.update(kwargs)
def __init__(self, name):
self.name = name

def name(self):
return self.config["name"]
return self.name

def register_build_training_data(self, subparser):
pass

def load_addon(addon_dir):
addon_py = path.join(addon_dir, "nunif_addon.py")
addon_module_path = path.splitext(path.relpath(addon_py))[0].replace(os.sep, ".")
addon_module = __import__(addon_module_path, globals(), fromlist=["addon_config"])
def register_train(self, subparser):
pass

return addon_module.addon_config()

def load_addon(addon_dir):
addon_py = path.join(addon_dir, "nunif_addon.py")
if path.exists(addon_py):
addon_module_path = path.splitext(path.relpath(addon_py))[0].replace(os.sep, ".")
addon_module = __import__(addon_module_path, globals(), fromlist=["addon_config"])
return addon_module.addon_config()
else:
return None


def load_addons(addon_dirs=None):
if addon_dirs is None:
root_dir = path.join(path.dirname(__file__), "..")
addon_dirs = []
for subdir in os.listdir(root_dir):
subdir = path.join(root_dir, subdir)
if path.isdir(subdir):
addon_file = path.join(root_dir, subdir, "nunif_addon.py")
if path.exists(addon_file):
addon_dirs.append(subdir)

def load_addons(addon_dirs):
addons = []
for addon_dir in addon_dirs:
addons.append(load_addon(addon_dir))
addon = load_addon(addon_dir)
if addon is not None:
addons.append(addon)
return addons
15 changes: 15 additions & 0 deletions nunif/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,18 @@ def get_config(self):
"i2i_in_size": self.i2i_in_size
})
return config


class SoftMaxBaseModel(Model):
name = "nunif.softmax_base_model"

def __init__(self, kwargs, class_names):
super(SoftMaxBaseModel, self).__init__(kwargs)
self.softmax_class_names = class_names

def get_config(self):
config = dict(super().get_config())
config.update({
"softmax_class_names": self.softmax_class_names
})
return config
17 changes: 10 additions & 7 deletions nunif/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@
from .. logger import logger


def save_model(model, model_path, updated_at=None, train_kwargs=None):
def save_model(model, model_path, updated_at=None, train_kwargs=None, **kwargs):
if isinstance(model, nn.DataParallel):
model = model.model
assert (isinstance(model, Model))
updated_at = str(updated_at or datetime.now(timezone.utc))
if train_kwargs is not None and not isinstance(train_kwargs, dict):
train_kwargs = vars(train_kwargs) # Namespace
torch.save({"nunif_model": 1,
"name": model.name,
"updated_at": updated_at,
"kwargs": model.get_kwargs(),
"train_kwargs": train_kwargs,
"state_dict": model.state_dict()}, model_path)
data = {
"nunif_model": 1,
"name": model.name,
"updated_at": updated_at,
"kwargs": model.get_kwargs(),
"train_kwargs": train_kwargs,
"state_dict": model.state_dict()}
data.update(kwargs)
torch.save(data, model_path)


def load_model(model_path, device_ids=None, strict=True, map_location="cpu"):
Expand Down
9 changes: 5 additions & 4 deletions nunif/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from . se_block import SEBlock
from . inplace_clip import InplaceClip
from . weighted_huber_loss import WeightedHuberLoss
from . lbp_loss import LBPLoss, RandomBinaryConvolution
from . clip_loss import ClipLoss
from . auxiliary_loss import AuxiliaryLoss
from . lbp_2x2_loss import LBP2x2Loss

__all__ = ["SEBlock", "InplaceClip", "AuxiliaryLoss", "WeightedHuberLoss", "LBP2x2Loss"]
from . channel_weighted_loss import ChannelWeightedLoss, LuminanceWeightedLoss
from . jaccard import JaccardIndex
from . psnr import PSNR
26 changes: 18 additions & 8 deletions nunif/modules/auxiliary_loss.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import torch
import torch.nn as nn
from . functional import auxiliary_loss


def auxiliary_loss(inputs, targets, modules, weights):
assert (len(inputs) == len(targets) and len(modules) == len(weights))
return sum([modules[i].forward(inputs[i], targets[i]) * weights[i] for i in range(len(inputs))])


class AuxiliaryLoss(nn.Module):
def __init__(self, loss_modules, loss_weights=None):
def __init__(self, losses, weight=None):
super(AuxiliaryLoss, self).__init__()
if loss_weights is not None:
loss_weights = torch.ones(len(loss_modules)).float()
assert (len(loss_modules) == len(loss_modules))
self.loss_modules = nn.ModuleList(loss_modules)
self.loss_weights = loss_weights
if weight is None:
weight = torch.tensor([1.0 / len(losses)] * len(losses), dtype=torch.float)
if isinstance(weight, (tuple, list)):
weight = torch.tensor(weight, dtype=torch.float)

assert (len(losses) == len(weight))
self.losses = nn.ModuleList(losses)
self.weight = weight

def forward(self, inputs, targets):
return auxiliary_loss(inputs, targets, self.loss_modules, self.loss_weights)
assert (isinstance(inputs, (list, tuple)))
if not isinstance(targets, (list, tuple)):
targets = [targets] * len(inputs)
return auxiliary_loss(inputs, targets, self.losses, self.weight)
24 changes: 24 additions & 0 deletions nunif/modules/channel_weighted_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch import nn


class ChannelWeightedLoss(nn.Module):
""" Wrapper Module for channel weight
"""
def __init__(self, module, weight):
super().__init__()
self.module = module
self.weight = weight

def forward(self, input, target):
b, ch, *_ = input.shape
assert (ch == len(self.weight))
return sum([self.module(input[:, i:i+1, :, :], target[:, i:i+1, :, :]) * self.weight[i]
for i in range(ch)])


LUMINANCE_WEIGHT = [0.29891, 0.58661, 0.11448]


class LuminanceWeightedLoss(ChannelWeightedLoss):
def __init__(self, module):
super().__init__(module, weight=LUMINANCE_WEIGHT)
20 changes: 20 additions & 0 deletions nunif/modules/clip_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from torch import nn


class ClipLoss(nn.Module):
""" Wrapper Module for `(clamp(input, 0, 1) - clamp(target, 0, 1))`
"""
def __init__(self, module, min_value=0, max_value=1, eta=0.001):
super().__init__()
self.module = module
self.min_value = min_value
self.max_value = max_value
self.eta = eta

def forward(self, input, target):
noclip_loss = self.module(input, target)
clip_loss = self.module(torch.clamp(input, self.min_value, self.max_value),
torch.clamp(target, self.min_value, self.max_value))

return clip_loss + noclip_loss * self.eta
26 changes: 0 additions & 26 deletions nunif/modules/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,5 @@ def inplace_clip(x, min_value, max_value):
return torch.clamp_(x, min_value, max_value)


def weighted_huber_loss(input, target, weight, gamma=1.0, reduction='mean'):
t = torch.abs(input - target).mul_(weight)
loss = torch.where(t < gamma, 0.5 * t ** 2, (t - 0.5 * gamma) * gamma)
if reduction == 'mean':
loss = torch.mean(loss)
elif reduction == 'sum':
loss = torch.sum(loss)
elif reduction == 'spatial_mean':
bs, ch, h, w = input.shape
loss = loss.view(bs, ch, -1).mean(dim=2).sum(dim=1).mean()
elif reduction == 'none':
pass
else:
raise ValueError(f"undefined reduction: {reduction}")
return loss


def auxiliary_loss(inputs, targets, modules, weights):
assert (len(inputs) == len(targets) == len(modules) == len(weights))
n = len(inputs)
loss = None
for i in range(n):
z = modules[i].forward(inputs[i], targets[i]) * weights[i]
if loss is None:
loss = z
else:
loss = loss + z
return loss
30 changes: 30 additions & 0 deletions nunif/modules/jaccard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch import nn


class JaccardIndex(nn.Module):
""" aka IoU
"""
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold

def forward(self, input, target, batch=True):
assert(input.shape == target.shape)
score = 0.0
count = 0.0
if not batch:
input = input.unsqueeze(0)
target = target.unsqueeze(0)
for y, t in zip(input, target):
a = (y >= self.threshold).long()
b = (t >= self.threshold).long()
a_count = a.sum().item()
b_count = b.sum().item()
a_and_b = a.mul(b).sum().item()
ab = (a_count + b_count - a_and_b)
if ab > 0.0:
score += (a_and_b / ab)
else:
score += 1.0
count += 1.0
return score / count
81 changes: 0 additions & 81 deletions nunif/modules/lbp_2x2_loss.py

This file was deleted.

Loading

0 comments on commit 86698c8

Please sign in to comment.