Skip to content

Commit

Permalink
implement style interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
Naoto Inoue committed Nov 30, 2017
1 parent 09f81d3 commit cfd17c6
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 59 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ python convert_torch.py --model models/decoder.t7
### Test
Use `--content` and `--style` to provide the respective path to the content and style image.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg --gpu
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
```

You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style --gpu
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
```

Some other options:
Expand All @@ -42,8 +42,8 @@ Some other options:

## TODO
- [x] Implement the preserve color option
- [x] Implement the style interpolation option
- [ ] Implement the spatial control option
- [ ] Implement the style interpolation option

## References
- [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017.
Expand Down
9 changes: 0 additions & 9 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ def adaptive_instance_normalization(content_feat, style_feat):
return normalized_feat * style_std + style_mean


def style_transfer(vgg, decoder, content, style, alpha=1.0):
assert (0.0 <= alpha <= 1.0)
content_feat = vgg(content)
style_feat = vgg(style)
feat = adaptive_instance_normalization(content_feat, style_feat)
feat = feat * alpha + content_feat * (1 - alpha)
return decoder(feat)


def calc_feat_flatten_mean_std(feat):
# takes 3D feat (C, H, W), return mean and std of array within channels
assert (feat.size()[0] == 3)
Expand Down
137 changes: 90 additions & 47 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,58 @@
import os
import torch
import torch.nn as nn
from PIL import Image
from os.path import basename
from os.path import splitext
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image

import net
from function import style_transfer
from function import adaptive_instance_normalization
from function import coral
from PIL import Image


def custom_transform(size):
def custom_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Scale(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform


def style_transfer(vgg, decoder, content, style, alpha=1.0,
interpolation_weights=None):
assert (0.0 <= alpha <= 1.0)
content_f = vgg(content)
style_f = vgg(style)
if interpolation_weights:
_, C, H, W = content_f.size()
feat = Variable(torch.FloatTensor(1, C, H, W).zero_().cuda(),
volatile=True)
base_feat = adaptive_instance_normalization(content_f, style_f)
for i, w in enumerate(interpolation_weights):
feat = feat + w * base_feat[i:i + 1]
content_f = content_f[0:1]
else:
feat = adaptive_instance_normalization(content_f, style_f)
feat = feat * alpha + content_f * (1 - alpha)
return decoder(feat)


parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content', type=str,
help='File path to the content image')
parser.add_argument('--content_dir', type=str,
help='Directory path to a batch of content images')
parser.add_argument('--style', type=str,
help='File path to the style image')
# parser.add_argument('--style', type=str,
# help='File path to the style image, or multiple style \
# images separated by commas if you want to do style \
# interpolation or spatial control')
help='File path to the style image, or multiple style \
images separated by commas if you want to do style \
interpolation or spatial control')
parser.add_argument('--style_dir', type=str,
help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')
Expand All @@ -46,10 +67,10 @@ def custom_transform(size):
parser.add_argument('--style_size', type=int, default=512,
help='New (minimum) size for the style image, \
keeping the original size if set to 0')
parser.add_argument('--crop', action='store_true',
help='do center crop to create squared image')
parser.add_argument('--save_ext', default='.jpg',
help='The extension name of the output image')
parser.add_argument('--gpu', action='store_true',
help='Zero-indexed ID of the GPU to use')
parser.add_argument('--output', type=str, default='output',
help='Directory to save the output image(s)')

Expand All @@ -59,17 +80,39 @@ def custom_transform(size):
parser.add_argument('--alpha', type=float, default=1.0,
help='The weight that controls the degree of \
stylization. Should be between 0 and 1')
# parser.add_argument(
# '--style_interpolation_weights', type=str, default='',
# help='The weight for blending the style of multiple style images')
parser.add_argument(
'--style_interpolation_weights', type=str, default='',
help='The weight for blending the style of multiple style images')

args = parser.parse_args()

do_interpolation = False

# Either --content or --contentDir should be given.
assert (args.content or args.content_dir)
# Either --style or --styleDir should be given.
assert (args.style or args.style_dir)

if args.content:
content_paths = [args.content]
else:
content_paths = [os.path.join(args.content_dir, f) for f in
os.listdir(args.content_dir)]

if args.style:
style_paths = args.style.split(',')
if len(style_paths) == 1:
style_paths = [args.style]
else:
do_interpolation = True
assert (args.style_interpolation_weights != ''), \
'Please specify interpolation weights'
weights = [int(i) for i in args.style_interpolation_weights.split(',')]
interpolation_weights = [w / sum(weights) for w in weights]
else:
style_paths = [os.path.join(args.style_dir, f) for f in
os.listdir(args.style_dir)]

if not os.path.exists(args.output):
os.mkdir(args.output)

Expand All @@ -81,47 +124,47 @@ def custom_transform(size):

decoder.load_state_dict(torch.load(args.decoder))
vgg.load_state_dict(torch.load(args.vgg))

vgg = nn.Sequential(*list(vgg.children())[:31])
if args.gpu:
decoder.cuda()
vgg.cuda()

content_transform = custom_transform(args.content_size)
style_transform = custom_transform(args.style_size)
vgg.cuda()
decoder.cuda()

if args.content:
content_paths = [args.content]
else:
content_paths = [os.path.join(args.content_dir, f) for f in
os.listdir(args.content_dir)]

if args.style:
# style_paths = args.style.split(',')
# if len(style_paths) == 1:
style_paths = [args.style]
else:
style_paths = [os.path.join(args.style_dir, f) for f in
os.listdir(args.style_dir)]
content_tf = custom_transform(args.content_size, args.crop)
style_tf = custom_transform(args.style_size, args.crop)

for content_path in content_paths:
for style_path in style_paths:
content = content_transform(Image.open(content_path))
style = style_transform(Image.open(style_path))
if args.preserve_color:
style = coral(style, content)
if args.gpu:
if do_interpolation: # one content image, N style image
style = torch.stack([style_tf(Image.open(p)) for p in style_paths])
content = content_tf(Image.open(content_path)) \
.unsqueeze(0).expand_as(style)
style = style.cuda()
content = content.cuda()
output = style_transfer(vgg, decoder,
Variable(content, volatile=True),
Variable(style, volatile=True),
args.alpha, interpolation_weights).data
output = output.cpu()
output_name = '{:s}/{:s}_interpolation{:s}'.format(
args.output, splitext(basename(content_path))[0], args.save_ext)
save_image(output, output_name)

else: # process one content and one style
for style_path in style_paths:
content = content_tf(Image.open(content_path))
style = style_tf(Image.open(style_path))
if args.preserve_color:
style = coral(style, content)
style = style.cuda()
content = content.cuda()
content = Variable(content.unsqueeze(0), volatile=True)
style = Variable(style.unsqueeze(0), volatile=True)
content = Variable(content.unsqueeze(0), volatile=True)
style = Variable(style.unsqueeze(0), volatile=True)

output = style_transfer(vgg, decoder, content, style, args.alpha).data
if args.gpu:
output = style_transfer(vgg, decoder, content, style,
args.alpha).data
output = output.cpu()

output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
args.output, os.path.splitext(os.path.basename(content_path))[0],
os.path.splitext(os.path.basename(style_path))[0], args.save_ext
)
save_image(output, output_name)
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
args.output, splitext(basename(content_path))[0],
splitext(basename(style_path))[0], args.save_ext
)
save_image(output, output_name)

0 comments on commit cfd17c6

Please sign in to comment.