Skip to content

Commit

Permalink
Merge pull request naoto0804#4 from naoto0804/develop
Browse files Browse the repository at this point in the history
Always add epsilon when computing var
  • Loading branch information
naoto0804 authored Jan 9, 2018
2 parents cadcaf5 + f3b9a78 commit 2165a33
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
33 changes: 19 additions & 14 deletions function.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import torch


def adaptive_instance_normalization(content_feat, style_feat, eps=1e-5):
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = content_feat.data.size()
size = feat.data.size()
assert (len(size) == 4)
N, C = size[:2]
assert (size[:2] == style_feat.data.size()[:2])
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std


style_std = style_feat.view(N, C, -1).std(dim=2).view(N, C, 1, 1)
style_mean = style_feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
def adaptive_instance_normalization(content_feat, style_feat):
assert (content_feat.data.size()[:2] == style_feat.data.size()[:2])
size = content_feat.data.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)

content_var = content_feat.view(N, C, -1).var(dim=2) + eps
content_std = content_var.sqrt().view(N, C, 1, 1)
content_mean = content_feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def calc_feat_flatten_mean_std(feat):
def _calc_feat_flatten_mean_std(feat):
# takes 3D feat (C, H, W), return mean and std of array within channels
assert (feat.size()[0] == 3)
assert (isinstance(feat, torch.FloatTensor))
Expand All @@ -28,7 +33,7 @@ def calc_feat_flatten_mean_std(feat):
return feat_flatten, mean, std


def mat_sqrt(x):
def _mat_sqrt(x):
U, D, V = torch.svd(x)
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())

Expand All @@ -37,21 +42,21 @@ def coral(source, target):
# assume both source and target are 3D array (C, H, W)
# Note: flatten -> f

source_f, source_f_mean, source_f_std = calc_feat_flatten_mean_std(source)
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
source_f_norm = (source_f - source_f_mean.expand_as(
source_f)) / source_f_std.expand_as(source_f)
source_f_cov_eye = \
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

target_f, target_f_mean, target_f_std = calc_feat_flatten_mean_std(target)
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
target_f_norm = (target_f - target_f_mean.expand_as(
target_f)) / target_f_std.expand_as(target_f)
target_f_cov_eye = \
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

source_f_norm_transfer = torch.mm(
mat_sqrt(target_f_cov_eye),
torch.mm(torch.inverse(mat_sqrt(source_f_cov_eye)),
_mat_sqrt(target_f_cov_eye),
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
source_f_norm)
)

Expand Down
12 changes: 7 additions & 5 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.autograd import Variable

from function import adaptive_instance_normalization as adain
from function import calc_mean_std

decoder = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
Expand Down Expand Up @@ -121,11 +122,12 @@ def mse(input, target):
return nn.MSELoss()(input,
Variable(target.data, requires_grad=False))

def calc_style_loss(a, b):
assert (a.data.size() == b.data.size())
size = a.data.size()[:2] + (-1,)
return mse(a.view(*size).mean(2), b.view(*size).mean(2)) + \
mse(a.view(*size).std(2), b.view(*size).std(2))
def calc_style_loss(input, target):
assert (input.data.size() == target.data.size())
assert (target.requires_grad is False)
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
return mse(input_mean, target_mean) + mse(input_std, target_std)

style_feats = self.encode_with_intermediate(style)
t = adain(self.encode(content), style_feats[-1])
Expand Down

0 comments on commit 2165a33

Please sign in to comment.