Skip to content

Commit

Permalink
Pytorch v0.3 -> v0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Naoto Inoue committed May 17, 2018
1 parent 756dafd commit 4a5eef6
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 59 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ I'm really grateful to the [original implementation](https://github.com/xunhuang

## Requirements
- Python 3.5+
- PyTorch 0.3
- PyTorch 0.4+
- TorchVision
- TensorboardX

Expand All @@ -27,17 +27,17 @@ python torch_to_pytorch.py --model models/decoder.t7
### Test
Use `--content` and `--style` to provide the respective path to the content and style image.
```
python test.py --gpu <gpu_id> --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
```

You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
```
python test.py --gpu <gpu_id> --content_dir input/content --style_dir input/style
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
```

This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
```
python test.py --gpu <gpu_id> --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
```

Some other options:
Expand Down
6 changes: 3 additions & 3 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.data.size()
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
Expand All @@ -13,8 +13,8 @@ def calc_mean_std(feat, eps=1e-5):


def adaptive_instance_normalization(content_feat, style_feat):
assert (content_feat.data.size()[:2] == style_feat.data.size()[:2])
size = content_feat.data.size()
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)

Expand Down
18 changes: 11 additions & 7 deletions net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch.nn as nn
from torch.autograd import Variable

from function import adaptive_instance_normalization as adain
from function import calc_mean_std
Expand All @@ -8,7 +7,7 @@
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
Expand All @@ -21,14 +20,14 @@
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
Expand Down Expand Up @@ -104,6 +103,11 @@ def __init__(self, encoder, decoder):
self.decoder = decoder
self.mse_loss = nn.MSELoss()

# fix the encoder
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
for param in getattr(self, name).parameters():
param.requires_grad = False

# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
def encode_with_intermediate(self, input):
results = [input]
Expand All @@ -119,12 +123,12 @@ def encode(self, input):
return input

def calc_content_loss(self, input, target):
assert (input.data.size() == target.data.size())
assert (input.size() == target.size())
assert (target.requires_grad is False)
return self.mse_loss(input, target)

def calc_style_loss(self, input, target):
assert (input.data.size() == target.data.size())
assert (input.size() == target.size())
assert (target.requires_grad is False)
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
Expand All @@ -138,7 +142,7 @@ def forward(self, content, style, alpha=1.0):
t = adain(content_feat, style_feats[-1])
t = alpha * t + (1 - alpha) * content_feat

g_t = self.decoder(Variable(t.data, requires_grad=True))
g_t = self.decoder(t)
g_t_feats = self.encode_with_intermediate(g_t)

loss_c = self.calc_content_loss(g_t_feats[-1], t)
Expand Down
43 changes: 18 additions & 25 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import os
from os.path import basename
from os.path import splitext

import torch
import torch.nn as nn
from PIL import Image
from torch.autograd import Variable
from os.path import basename
from os.path import splitext
from torchvision import transforms
from torchvision.utils import save_image

Expand All @@ -18,7 +16,7 @@
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Scale(size))
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
Expand All @@ -33,8 +31,7 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style_f = vgg(style)
if interpolation_weights:
_, C, H, W = content_f.size()
feat = Variable(torch.FloatTensor(1, C, H, W).zero_().cuda(),
volatile=True)
feat = torch.FloatTensor(1, C, H, W).zero_().to(torch.device('cuda'))
base_feat = adaptive_instance_normalization(content_f, style_f)
for i, w in enumerate(interpolation_weights):
feat = feat + w * base_feat[i:i + 1]
Expand All @@ -47,7 +44,6 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content', type=str,
help='File path to the content image')
parser.add_argument('--content_dir', type=str,
Expand Down Expand Up @@ -86,11 +82,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
help='The weight for blending the style of multiple style images')

args = parser.parse_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)

do_interpolation = False

device = torch.device('cuda')

# Either --content or --contentDir should be given.
assert (args.content or args.content_dir)
# Either --style or --styleDir should be given.
Expand Down Expand Up @@ -129,8 +125,8 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.cuda()
decoder.cuda()
vgg.to(device)
decoder.to(device)

content_tf = test_transform(args.content_size, args.crop)
style_tf = test_transform(args.style_size, args.crop)
Expand All @@ -140,12 +136,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style = torch.stack([style_tf(Image.open(p)) for p in style_paths])
content = content_tf(Image.open(content_path)) \
.unsqueeze(0).expand_as(style)
style = style.cuda()
content = content.cuda()
output = style_transfer(vgg, decoder,
Variable(content, volatile=True),
Variable(style, volatile=True),
args.alpha, interpolation_weights).data
style = style.to(device)
content = content.to(device)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
args.alpha, interpolation_weights)
output = output.cpu()
output_name = '{:s}/{:s}_interpolation{:s}'.format(
args.output, splitext(basename(content_path))[0], args.save_ext)
Expand All @@ -157,13 +152,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style = style_tf(Image.open(style_path))
if args.preserve_color:
style = coral(style, content)
style = style.cuda()
content = content.cuda()
content = Variable(content.unsqueeze(0), volatile=True)
style = Variable(style.unsqueeze(0), volatile=True)

output = style_transfer(vgg, decoder, content, style,
args.alpha).data
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
args.alpha)
output = output.cpu()

output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
Expand Down
4 changes: 2 additions & 2 deletions torch_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def lua_recursive_model(module, seq):
elif name == 'SpatialCrossMapLRN':
lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta,
m.k)
n = Lambda(lambda x, lrn=lrn: Variable(lrn.forward(x.data)))
n = Lambda(lambda x, lrn=lrn: lrn.forward(x))
add_submodule(seq, n)
elif name == 'Sequential':
n = nn.Sequential()
Expand Down Expand Up @@ -213,7 +213,7 @@ def lua_recursive_source(module):
lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format(
(m.size, m.alpha, m.beta, m.k))
s += [
'Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(
'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format(
lrn)]

elif name == 'Sequential':
Expand Down
33 changes: 15 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import argparse
import os

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchvision import transforms
from tqdm import tqdm

import net
from sampler import InfiniteSamplerWrapper

cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated


def train_transform():
Expand Down Expand Up @@ -55,7 +56,6 @@ def adjust_learning_rate(optimizer, iteration_count):

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content_dir', type=str, required=True,
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, required=True,
Expand All @@ -77,8 +77,7 @@ def adjust_learning_rate(optimizer, iteration_count):
parser.add_argument('--save_model_interval', type=int, default=10000)
args = parser.parse_args()

if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
device = torch.device('cuda')

if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
Expand All @@ -93,12 +92,8 @@ def adjust_learning_rate(optimizer, iteration_count):
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])
network = net.Net(vgg, decoder)

for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
for param in getattr(network, name).parameters():
param.requires_grad = False
network.train()
network.cuda()
network.to(device)

content_tf = train_transform()
style_tf = train_transform()
Expand All @@ -119,8 +114,8 @@ def adjust_learning_rate(optimizer, iteration_count):

for i in tqdm(range(args.max_iter)):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = Variable(next(content_iter).cuda())
style_images = Variable(next(style_iter).cuda())
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
loss_c, loss_s = network(content_images, style_images)
loss_c = args.content_weight * loss_c
loss_s = args.style_weight * loss_s
Expand All @@ -130,12 +125,14 @@ def adjust_learning_rate(optimizer, iteration_count):
loss.backward()
optimizer.step()

writer.add_scalar('loss_content', loss_c.data.cpu()[0], i + 1)
writer.add_scalar('loss_style', loss_s.data.cpu()[0], i + 1)
writer.add_scalar('loss_content', loss_c.item(), i + 1)
writer.add_scalar('loss_style', loss_s.item(), i + 1)

if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
torch.save(
net.decoder.state_dict(),
'{:s}/decoder_iter_{:d}.pth.tar'.format(args.save_dir, i + 1)
)
state_dict = net.decoder.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/decoder_iter_{:d}.pth.tar'.format(args.save_dir,
i + 1))
writer.close()

0 comments on commit 4a5eef6

Please sign in to comment.