-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpermutedAdaIN.py
56 lines (45 loc) · 1.75 KB
/
permutedAdaIN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Implementation of the pAdaIN layer. This layer can be added after every convolutional layer and acts
as a regularization which increases overall performance.
It is only applied during training.
"""
import random
import torch
import torch.nn as nn
class PermuteAdaptiveInstanceNorm2d(nn.Module):
def __init__(self, p=0.01, eps=1e-5):
super(PermuteAdaptiveInstanceNorm2d, self).__init__()
self.p = p
self.eps = eps
def forward(self, x):
permute = random.random() < self.p
if permute and self.training:
perm_indices = torch.randperm(x.size()[0])
else:
return x
size = x.size()
N, C, H, W = size
if (H, W) == (1, 1):
print('encountered bad dims')
return x
return adaptive_instance_normalization(x, x[perm_indices], self.eps)
def extra_repr(self) -> str:
return 'p={}'.format(
self.p
)
def calc_mean_std(feat, eps=1e-5):
size = feat.size()
assert (len(size) == 4)
N, C, H, W = size
feat_std = torch.sqrt(feat.view(N, C, -1).var(dim=2).view(N, C, 1, 1) + eps)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat, eps=1e-5):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat.detach(), eps)
content_mean, content_std = calc_mean_std(content_feat, eps)
content_std = content_std + eps # to avoid division by 0
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)