forked from nagadomi/nunif
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
1,102 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import argparse | ||
from nunif.addon import load_addons | ||
|
||
|
||
def add_default_options(parser): | ||
subparser.add_argument("--dataset-dir", "-i", type=str, required=True, help="input dataset dir") | ||
subparser.add_argument("--data-dir", "-o", type=str, required=True, help="output data dir") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
subparsers = parser.add_subparsers(help="task", required=True) | ||
for addon in load_addons(): | ||
subparser = addon.register_build_training_data(subparsers) | ||
if subparser is not None: | ||
add_default_options(subparser) | ||
args = parser.parse_args() | ||
assert(args.handler is not None) | ||
|
||
args.handler(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,46 @@ | ||
import os | ||
from os import path | ||
|
||
# TODO: train | ||
|
||
|
||
class Addon(): | ||
def __init__(self, name, **kwargs): | ||
self.config = {"name": name} | ||
self.config.update(kwargs) | ||
def __init__(self, name): | ||
self.name = name | ||
|
||
def name(self): | ||
return self.config["name"] | ||
return self.name | ||
|
||
def register_build_training_data(self, subparser): | ||
pass | ||
|
||
def load_addon(addon_dir): | ||
addon_py = path.join(addon_dir, "nunif_addon.py") | ||
addon_module_path = path.splitext(path.relpath(addon_py))[0].replace(os.sep, ".") | ||
addon_module = __import__(addon_module_path, globals(), fromlist=["addon_config"]) | ||
def register_train(self, subparser): | ||
pass | ||
|
||
return addon_module.addon_config() | ||
|
||
def load_addon(addon_dir): | ||
addon_py = path.join(addon_dir, "nunif_addon.py") | ||
if path.exists(addon_py): | ||
addon_module_path = path.splitext(path.relpath(addon_py))[0].replace(os.sep, ".") | ||
addon_module = __import__(addon_module_path, globals(), fromlist=["addon_config"]) | ||
return addon_module.addon_config() | ||
else: | ||
return None | ||
|
||
|
||
def load_addons(addon_dirs=None): | ||
if addon_dirs is None: | ||
root_dir = path.join(path.dirname(__file__), "..") | ||
addon_dirs = [] | ||
for subdir in os.listdir(root_dir): | ||
subdir = path.join(root_dir, subdir) | ||
if path.isdir(subdir): | ||
addon_file = path.join(root_dir, subdir, "nunif_addon.py") | ||
if path.exists(addon_file): | ||
addon_dirs.append(subdir) | ||
|
||
def load_addons(addon_dirs): | ||
addons = [] | ||
for addon_dir in addon_dirs: | ||
addons.append(load_addon(addon_dir)) | ||
addon = load_addon(addon_dir) | ||
if addon is not None: | ||
addons.append(addon) | ||
return addons |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from . se_block import SEBlock | ||
from . inplace_clip import InplaceClip | ||
from . weighted_huber_loss import WeightedHuberLoss | ||
from . lbp_loss import LBPLoss, RandomBinaryConvolution | ||
from . clip_loss import ClipLoss | ||
from . auxiliary_loss import AuxiliaryLoss | ||
from . lbp_2x2_loss import LBP2x2Loss | ||
|
||
__all__ = ["SEBlock", "InplaceClip", "AuxiliaryLoss", "WeightedHuberLoss", "LBP2x2Loss"] | ||
from . channel_weighted_loss import ChannelWeightedLoss, LuminanceWeightedLoss | ||
from . jaccard import JaccardIndex | ||
from . psnr import PSNR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,26 @@ | ||
import torch | ||
import torch.nn as nn | ||
from . functional import auxiliary_loss | ||
|
||
|
||
def auxiliary_loss(inputs, targets, modules, weights): | ||
assert (len(inputs) == len(targets) and len(modules) == len(weights)) | ||
return sum([modules[i].forward(inputs[i], targets[i]) * weights[i] for i in range(len(inputs))]) | ||
|
||
|
||
class AuxiliaryLoss(nn.Module): | ||
def __init__(self, loss_modules, loss_weights=None): | ||
def __init__(self, losses, weight=None): | ||
super(AuxiliaryLoss, self).__init__() | ||
if loss_weights is not None: | ||
loss_weights = torch.ones(len(loss_modules)).float() | ||
assert (len(loss_modules) == len(loss_modules)) | ||
self.loss_modules = nn.ModuleList(loss_modules) | ||
self.loss_weights = loss_weights | ||
if weight is None: | ||
weight = torch.tensor([1.0 / len(losses)] * len(losses), dtype=torch.float) | ||
if isinstance(weight, (tuple, list)): | ||
weight = torch.tensor(weight, dtype=torch.float) | ||
|
||
assert (len(losses) == len(weight)) | ||
self.losses = nn.ModuleList(losses) | ||
self.weight = weight | ||
|
||
def forward(self, inputs, targets): | ||
return auxiliary_loss(inputs, targets, self.loss_modules, self.loss_weights) | ||
assert (isinstance(inputs, (list, tuple))) | ||
if not isinstance(targets, (list, tuple)): | ||
targets = [targets] * len(inputs) | ||
return auxiliary_loss(inputs, targets, self.losses, self.weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from torch import nn | ||
|
||
|
||
class ChannelWeightedLoss(nn.Module): | ||
""" Wrapper Module for channel weight | ||
""" | ||
def __init__(self, module, weight): | ||
super().__init__() | ||
self.module = module | ||
self.weight = weight | ||
|
||
def forward(self, input, target): | ||
b, ch, *_ = input.shape | ||
assert (ch == len(self.weight)) | ||
return sum([self.module(input[:, i:i+1, :, :], target[:, i:i+1, :, :]) * self.weight[i] | ||
for i in range(ch)]) | ||
|
||
|
||
LUMINANCE_WEIGHT = [0.29891, 0.58661, 0.11448] | ||
|
||
|
||
class LuminanceWeightedLoss(ChannelWeightedLoss): | ||
def __init__(self, module): | ||
super().__init__(module, weight=LUMINANCE_WEIGHT) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class ClipLoss(nn.Module): | ||
""" Wrapper Module for `(clamp(input, 0, 1) - clamp(target, 0, 1))` | ||
""" | ||
def __init__(self, module, min_value=0, max_value=1, eta=0.001): | ||
super().__init__() | ||
self.module = module | ||
self.min_value = min_value | ||
self.max_value = max_value | ||
self.eta = eta | ||
|
||
def forward(self, input, target): | ||
noclip_loss = self.module(input, target) | ||
clip_loss = self.module(torch.clamp(input, self.min_value, self.max_value), | ||
torch.clamp(target, self.min_value, self.max_value)) | ||
|
||
return clip_loss + noclip_loss * self.eta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from torch import nn | ||
|
||
|
||
class JaccardIndex(nn.Module): | ||
""" aka IoU | ||
""" | ||
def __init__(self, threshold=0.5): | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
def forward(self, input, target, batch=True): | ||
assert(input.shape == target.shape) | ||
score = 0.0 | ||
count = 0.0 | ||
if not batch: | ||
input = input.unsqueeze(0) | ||
target = target.unsqueeze(0) | ||
for y, t in zip(input, target): | ||
a = (y >= self.threshold).long() | ||
b = (t >= self.threshold).long() | ||
a_count = a.sum().item() | ||
b_count = b.sum().item() | ||
a_and_b = a.mul(b).sum().item() | ||
ab = (a_count + b_count - a_and_b) | ||
if ab > 0.0: | ||
score += (a_and_b / ab) | ||
else: | ||
score += 1.0 | ||
count += 1.0 | ||
return score / count |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.