Skip to content
This repository has been archived by the owner on Feb 14, 2024. It is now read-only.

Commit

Permalink
Merge pull request #11 from naoto0804/develop
Browse files Browse the repository at this point in the history
Pytorch V0.4
  • Loading branch information
naoto0804 authored May 17, 2018
2 parents d5be09d + 4efad93 commit 2763a8a
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 63 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ I'm really grateful to the [original implementation](https://github.com/xunhuang

## Requirements
- Python 3.5+
- PyTorch 0.3
- PyTorch 0.4+
- TorchVision
- Pillow

(optional, for training)
- tqdm
- TensorboardX

## Usage
Expand All @@ -27,17 +31,17 @@ python torch_to_pytorch.py --model models/decoder.t7
### Test
Use `--content` and `--style` to provide the respective path to the content and style image.
```
python test.py --gpu <gpu_id> --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
```

You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
```
python test.py --gpu <gpu_id> --content_dir input/content --style_dir input/style
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
```

This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
```
python test.py --gpu <gpu_id> --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
```

Some other options:
Expand All @@ -48,8 +52,6 @@ Some other options:


### Train
Please install tensorboardX, tqdm, and scipy for progress bar and logging

Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
```
python train.py --gpu <gpu_id> --content_dir <content_dir> --style_dir <style_dir>
Expand Down
6 changes: 3 additions & 3 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.data.size()
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
Expand All @@ -13,8 +13,8 @@ def calc_mean_std(feat, eps=1e-5):


def adaptive_instance_normalization(content_feat, style_feat):
assert (content_feat.data.size()[:2] == style_feat.data.size()[:2])
size = content_feat.data.size()
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)

Expand Down
25 changes: 16 additions & 9 deletions net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch.nn as nn
from torch.autograd import Variable

from function import adaptive_instance_normalization as adain
from function import calc_mean_std
Expand All @@ -8,7 +7,7 @@
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
Expand All @@ -21,14 +20,14 @@
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
Expand Down Expand Up @@ -104,6 +103,11 @@ def __init__(self, encoder, decoder):
self.decoder = decoder
self.mse_loss = nn.MSELoss()

# fix the encoder
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
for param in getattr(self, name).parameters():
param.requires_grad = False

# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
def encode_with_intermediate(self, input):
results = [input]
Expand All @@ -119,23 +123,26 @@ def encode(self, input):
return input

def calc_content_loss(self, input, target):
assert (input.data.size() == target.data.size())
assert (input.size() == target.size())
assert (target.requires_grad is False)
return self.mse_loss(input, target)

def calc_style_loss(self, input, target):
assert (input.data.size() == target.data.size())
assert (input.size() == target.size())
assert (target.requires_grad is False)
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
return self.mse_loss(input_mean, target_mean) + \
self.mse_loss(input_std, target_std)

def forward(self, content, style):
def forward(self, content, style, alpha=1.0):
assert 0 <= alpha <= 1
style_feats = self.encode_with_intermediate(style)
t = adain(self.encode(content), style_feats[-1])
content_feat = self.encode(content)
t = adain(content_feat, style_feats[-1])
t = alpha * t + (1 - alpha) * content_feat

g_t = self.decoder(Variable(t.data, requires_grad=True))
g_t = self.decoder(t)
g_t_feats = self.encode_with_intermediate(g_t)

loss_c = self.calc_content_loss(g_t_feats[-1], t)
Expand Down
43 changes: 18 additions & 25 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import os
from os.path import basename
from os.path import splitext

import torch
import torch.nn as nn
from PIL import Image
from torch.autograd import Variable
from os.path import basename
from os.path import splitext
from torchvision import transforms
from torchvision.utils import save_image

Expand All @@ -18,7 +16,7 @@
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Scale(size))
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
Expand All @@ -33,8 +31,7 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style_f = vgg(style)
if interpolation_weights:
_, C, H, W = content_f.size()
feat = Variable(torch.FloatTensor(1, C, H, W).zero_().cuda(),
volatile=True)
feat = torch.FloatTensor(1, C, H, W).zero_().to(torch.device('cuda'))
base_feat = adaptive_instance_normalization(content_f, style_f)
for i, w in enumerate(interpolation_weights):
feat = feat + w * base_feat[i:i + 1]
Expand All @@ -47,7 +44,6 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content', type=str,
help='File path to the content image')
parser.add_argument('--content_dir', type=str,
Expand Down Expand Up @@ -86,11 +82,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
help='The weight for blending the style of multiple style images')

args = parser.parse_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)

do_interpolation = False

device = torch.device('cuda')

# Either --content or --contentDir should be given.
assert (args.content or args.content_dir)
# Either --style or --styleDir should be given.
Expand Down Expand Up @@ -129,8 +125,8 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.cuda()
decoder.cuda()
vgg.to(device)
decoder.to(device)

content_tf = test_transform(args.content_size, args.crop)
style_tf = test_transform(args.style_size, args.crop)
Expand All @@ -140,12 +136,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style = torch.stack([style_tf(Image.open(p)) for p in style_paths])
content = content_tf(Image.open(content_path)) \
.unsqueeze(0).expand_as(style)
style = style.cuda()
content = content.cuda()
output = style_transfer(vgg, decoder,
Variable(content, volatile=True),
Variable(style, volatile=True),
args.alpha, interpolation_weights).data
style = style.to(device)
content = content.to(device)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
args.alpha, interpolation_weights)
output = output.cpu()
output_name = '{:s}/{:s}_interpolation{:s}'.format(
args.output, splitext(basename(content_path))[0], args.save_ext)
Expand All @@ -157,13 +152,11 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
style = style_tf(Image.open(style_path))
if args.preserve_color:
style = coral(style, content)
style = style.cuda()
content = content.cuda()
content = Variable(content.unsqueeze(0), volatile=True)
style = Variable(style.unsqueeze(0), volatile=True)

output = style_transfer(vgg, decoder, content, style,
args.alpha).data
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
args.alpha)
output = output.cpu()

output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
Expand Down
4 changes: 2 additions & 2 deletions torch_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def lua_recursive_model(module, seq):
elif name == 'SpatialCrossMapLRN':
lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta,
m.k)
n = Lambda(lambda x, lrn=lrn: Variable(lrn.forward(x.data)))
n = Lambda(lambda x, lrn=lrn: lrn.forward(x))
add_submodule(seq, n)
elif name == 'Sequential':
n = nn.Sequential()
Expand Down Expand Up @@ -213,7 +213,7 @@ def lua_recursive_source(module):
lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format(
(m.size, m.alpha, m.beta, m.k))
s += [
'Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(
'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format(
lrn)]

elif name == 'Sequential':
Expand Down
33 changes: 15 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import argparse
import os

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchvision import transforms
from tqdm import tqdm

import net
from sampler import InfiniteSamplerWrapper

cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated


def train_transform():
Expand Down Expand Up @@ -55,7 +56,6 @@ def adjust_learning_rate(optimizer, iteration_count):

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content_dir', type=str, required=True,
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, required=True,
Expand All @@ -77,8 +77,7 @@ def adjust_learning_rate(optimizer, iteration_count):
parser.add_argument('--save_model_interval', type=int, default=10000)
args = parser.parse_args()

if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
device = torch.device('cuda')

if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
Expand All @@ -93,12 +92,8 @@ def adjust_learning_rate(optimizer, iteration_count):
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])
network = net.Net(vgg, decoder)

for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
for param in getattr(network, name).parameters():
param.requires_grad = False
network.train()
network.cuda()
network.to(device)

content_tf = train_transform()
style_tf = train_transform()
Expand All @@ -119,8 +114,8 @@ def adjust_learning_rate(optimizer, iteration_count):

for i in tqdm(range(args.max_iter)):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = Variable(next(content_iter).cuda())
style_images = Variable(next(style_iter).cuda())
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
loss_c, loss_s = network(content_images, style_images)
loss_c = args.content_weight * loss_c
loss_s = args.style_weight * loss_s
Expand All @@ -130,12 +125,14 @@ def adjust_learning_rate(optimizer, iteration_count):
loss.backward()
optimizer.step()

writer.add_scalar('loss_content', loss_c.data.cpu()[0], i + 1)
writer.add_scalar('loss_style', loss_s.data.cpu()[0], i + 1)
writer.add_scalar('loss_content', loss_c.item(), i + 1)
writer.add_scalar('loss_style', loss_s.item(), i + 1)

if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
torch.save(
net.decoder.state_dict(),
'{:s}/decoder_iter_{:d}.pth.tar'.format(args.save_dir, i + 1)
)
state_dict = net.decoder.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/decoder_iter_{:d}.pth.tar'.format(args.save_dir,
i + 1))
writer.close()

0 comments on commit 2763a8a

Please sign in to comment.