forked from pytorch/ignite
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* working example of fast-neural-style * updated .travis.yml * working with python2.7 * flake8 warnings fixed * changed num_workers = 0 in data_loader * incorporated comments from PR review by vfdev-5 * removed cifar10, using fakedataset from pytorch library for fast computation * removed cifar10, using fakedataset from pytorch library for fast computation * changed .travis.yml with correct test * added handlers.py with ProgBar handler * updated neural_style.py to incorporate vfdev-5 comments * lint check failed, made corrections * minor fixes and typos * incorproated alykhantejani comments * updated .travis.yml test * Updated README.md * updated .travis.yml with correct image_size argument * fixed README.md with correct arguments * zero_grad() now on optimizer instead of model * Update README.md
- Loading branch information
1 parent
ffe9b6e
commit 283897c
Showing
10 changed files
with
508 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# fast-neural-style | ||
|
||
### Introduction | ||
This example is ported over from [pytorch-examples](https://github.com/pytorch/examples/tree/master/fast_neural_style). | ||
|
||
It uses `ignite` to implement an algorithm for artistic style transfer as described in [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155). | ||
|
||
<p align="center"> | ||
<img src="images/style_images/mosaic.jpg" height="200px"> | ||
<img src="images/content_images/amber.jpg" height="200px"> | ||
<img src="images/output_images/mosaic_amber.jpg" height="200px"> | ||
</p> | ||
|
||
### Requirements | ||
|
||
* `torch` | ||
* `torchvision` | ||
* `ignite` | ||
|
||
Example for `virtualenv` setup: | ||
|
||
`virtualenv --python=/usr/bin/python3.5 env` | ||
|
||
`source env/bin/activate` | ||
|
||
`pip install torch torchvision pytorch-ignite` | ||
|
||
The code runs on CPU, but GPU allows it to run much faster. If using GPU, please ensure proper libraries are installed. | ||
|
||
### Documentation | ||
|
||
#### Training | ||
Code can be used to train a style transfer model for any image. To run code correctly, ensure that [MSCOCO dataset](http://images.cocodataset.org/zips/train2014.zip) and a style image are downloaded. | ||
|
||
Since the code using Pytorch's Dataset functions, ensure that directory with MSCOCO dataset is formatted as shown below. The directory should be setup such that the location of the dataset is MSCOCO, which contains a single folder 0, containing all the images. | ||
|
||
|
||
```bash | ||
├── MSCOCO | ||
│ ├── 0 | ||
│ │ ├── RY48TY43YT.jpg | ||
│ │ ├── 4324J0FNFL.jpg | ||
│ │ ├── Y9REWJKNFE.jpg | ||
``` | ||
|
||
##### Example | ||
`python neural_style.py train --epochs 2 --cuda 1 --dataset mscoco --dataroot /path/to/mscoco --style_image ./images/style_images/mosaic.jpg` | ||
|
||
##### Flags | ||
* `--epochs`: number of training epochs, default is 2. | ||
* `--batch_size`: batch size for training, default is 8. | ||
* `--dataset`: type of dataset. | ||
* `--dataroot`: path to training dataset, the path should point to a folder containing another folder with all the training images. | ||
* `--style_image`: path to style-image. | ||
* `--checkpoint_model_dir`: path to folder where checkpoints of trained models will be saved. | ||
* `--checkpoint_interval`: number of batches after which a checkpoint of trained model will be created. | ||
* `--image_size`: size of training images, default is 256 X 256. | ||
* `--style_size`: size of style-image, default is the original size of style image. | ||
* `--cuda`: set it to 1 for running on GPU, 0 for CPU. | ||
* `--seed`: random seed for training. | ||
* `--content_weight`: weight for content-loss, default is 1e5. | ||
* `--style_weight`: weight for style-loss, default is 1e10. | ||
* `--lr`: learning rate, default is 1e-3. | ||
|
||
|
||
#### Evaluation | ||
|
||
Code can be used to stylize an image using a trained style transfer model. | ||
|
||
##### Example | ||
`python neural_style.py eval --content_image ./images/content_images/amber.jpg --output_image test.png --cuda 1 --model /tmp/checkpoints/checkpoint_net_2.pth` | ||
|
||
#### Flags | ||
* `--content_image`: path to content image you want to stylize. | ||
* `--content_scale`: factor for scaling down the content image. | ||
* `--output_image`: path for saving the output image. | ||
* `--model`: saved model to be used for stylizing the image. | ||
* `--cuda`: set it to 1 for running on GPU, 0 for CPU. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import sys | ||
|
||
|
||
class Progbar(object): | ||
|
||
def __init__(self, loader, metrics): | ||
self.num_iterations = len(loader) | ||
self.output_stream = sys.stdout | ||
self.metrics = metrics | ||
self.alpha = 0.98 | ||
|
||
def _calc_running_avg(self, engine): | ||
for k, v in engine.state.output.items(): | ||
old_v = self.metrics.get(k, v) | ||
new_v = self.alpha * old_v + (1 - self.alpha) * v | ||
self.metrics[k] = new_v | ||
|
||
def __call__(self, engine): | ||
self._calc_running_avg(engine) | ||
num_seen = engine.state.iteration - self.num_iterations * (engine.state.epoch - 1) | ||
|
||
percent_seen = 100 * float(num_seen) / self.num_iterations | ||
equal_to = int(percent_seen / 10) | ||
done = int(percent_seen) == 100 | ||
|
||
bar = '[' + '=' * equal_to + '>' * (not done) + ' ' * (10 - equal_to) + ']' | ||
message = 'Epoch {epoch} | {percent_seen:.2f}% | {bar}'.format(epoch=engine.state.epoch, | ||
percent_seen=percent_seen, | ||
bar=bar) | ||
for key, value in self.metrics.items(): | ||
message += ' | {name}: {value:.2e}'.format(name=key, value=value) | ||
|
||
message += '\r' | ||
|
||
self.output_stream.write(message) | ||
self.output_stream.flush() | ||
|
||
if done: | ||
self.output_stream.write('\n') |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# coding: utf-8 | ||
from __future__ import print_function, division | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
import random | ||
import torch | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets | ||
from torchvision import transforms | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.handlers import ModelCheckpoint | ||
|
||
import utils | ||
from transformer_net import TransformerNet | ||
from vgg import Vgg16 | ||
from handlers import Progbar | ||
|
||
from collections import OrderedDict | ||
|
||
|
||
def check_paths(args): | ||
try: | ||
if args.checkpoint_model_dir is not None and not (os.path.exists(args.checkpoint_model_dir)): | ||
os.makedirs(args.checkpoint_model_dir) | ||
except OSError as e: | ||
raise OSError(e) | ||
|
||
|
||
def check_manual_seed(args): | ||
seed = args.seed or random.randint(1, 10000) | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
|
||
def check_dataset(args): | ||
transform = transforms.Compose([ | ||
transforms.Resize(args.image_size), | ||
transforms.CenterCrop(args.image_size), | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x.mul(255)) | ||
]) | ||
|
||
if args.dataset in {'folder', 'mscoco'}: | ||
train_dataset = datasets.ImageFolder(args.dataroot, transform) | ||
elif args.dataset == 'test': | ||
train_dataset = datasets.FakeData(size=args.batch_size, image_size=(3, 32, 32), | ||
num_classes=1, transform=transform) | ||
else: | ||
raise RuntimeError("Invalid dataset name: {}".format(args.dataset)) | ||
|
||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) | ||
|
||
return train_loader | ||
|
||
|
||
def train(args): | ||
device = torch.device("cuda" if args.cuda else "cpu") | ||
|
||
train_loader = check_dataset(args) | ||
transformer = TransformerNet().to(device) | ||
optimizer = Adam(transformer.parameters(), args.lr) | ||
mse_loss = torch.nn.MSELoss() | ||
|
||
vgg = Vgg16(requires_grad=False).to(device) | ||
style_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x.mul(255)) | ||
]) | ||
|
||
style = utils.load_image(args.style_image, size=args.style_size) | ||
style = style_transform(style) | ||
style = style.repeat(args.batch_size, 1, 1, 1).to(device) | ||
|
||
features_style = vgg(utils.normalize_batch(style)) | ||
gram_style = [utils.gram_matrix(y) for y in features_style] | ||
|
||
running_avgs = OrderedDict() | ||
|
||
def step(engine, batch): | ||
|
||
x, _ = batch | ||
x = x.to(device) | ||
|
||
n_batch = len(x) | ||
|
||
optimizer.zero_grad() | ||
|
||
y = transformer(x) | ||
|
||
x = utils.normalize_batch(x) | ||
y = utils.normalize_batch(y) | ||
|
||
features_x = vgg(x) | ||
features_y = vgg(y) | ||
|
||
content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) | ||
|
||
style_loss = 0. | ||
for ft_y, gm_s in zip(features_y, gram_style): | ||
gm_y = utils.gram_matrix(ft_y) | ||
style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) | ||
style_loss *= args.style_weight | ||
|
||
total_loss = content_loss + style_loss | ||
total_loss.backward() | ||
optimizer.step() | ||
|
||
return { | ||
'content_loss': content_loss.item(), | ||
'style_loss': style_loss.item(), | ||
'total_loss': total_loss.item() | ||
} | ||
|
||
trainer = Engine(step) | ||
checkpoint_handler = ModelCheckpoint(args.checkpoint_model_dir, 'checkpoint', | ||
save_interval=args.checkpoint_interval, | ||
n_saved=10, require_empty=False, create_dir=True) | ||
progress_bar = Progbar(loader=train_loader, metrics=running_avgs) | ||
|
||
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, | ||
to_save={'net': transformer}) | ||
trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=progress_bar) | ||
trainer.run(train_loader, max_epochs=args.epochs) | ||
|
||
|
||
def stylize(args): | ||
device = torch.device("cuda" if args.cuda else "cpu") | ||
|
||
content_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x.mul(255)) | ||
]) | ||
|
||
content_image = utils.load_image(args.content_image, scale=args.content_scale) | ||
content_image = content_transform(content_image) | ||
content_image = content_image.unsqueeze(0).to(device) | ||
|
||
with torch.no_grad(): | ||
style_model = torch.load(args.model) | ||
style_model.to(device) | ||
output = style_model(content_image).cpu() | ||
utils.save_image(args.output_image, output[0]) | ||
|
||
|
||
def main(): | ||
main_arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style") | ||
subparsers = main_arg_parser.add_subparsers(title="subcommands", dest="subcommand") | ||
|
||
train_arg_parser = subparsers.add_parser("train", help="parser for training arguments") | ||
train_arg_parser.add_argument("--epochs", type=int, default=2, help="number of training epochs, default is 2") | ||
train_arg_parser.add_argument("--batch_size", type=int, default=8, | ||
help="batch size for training, default is 8") | ||
train_arg_parser.add_argument("--dataset", type=str, required=True, choices={'test', 'folder', 'mscoco'}, | ||
help="type of dataset to be used.") | ||
train_arg_parser.add_argument("--dataroot", type=str, required=True, | ||
help="path to training dataset, the path should point to a folder " | ||
"containing another folder with all the training images") | ||
train_arg_parser.add_argument("--style_image", type=str, default="test", | ||
help="path to style-image") | ||
train_arg_parser.add_argument("--test_image", type=str, default="test", | ||
help="path to test-image") | ||
train_arg_parser.add_argument("--checkpoint_model_dir", type=str, default='/tmp/checkpoints', | ||
help="path to folder where checkpoints of trained models will be saved") | ||
train_arg_parser.add_argument("--checkpoint_interval", type=int, default=1, | ||
help="number of batches after which a checkpoint of trained model will be created") | ||
train_arg_parser.add_argument("--image_size", type=int, default=256, | ||
help="size of training images, default is 256 X 256") | ||
train_arg_parser.add_argument("--style_size", type=int, default=None, | ||
help="size of style-image, default is the original size of style image") | ||
train_arg_parser.add_argument("--cuda", type=int, default=1, | ||
help="set it to 1 for running on GPU, 0 for CPU") | ||
train_arg_parser.add_argument("--seed", type=int, default=42, | ||
help="random seed for training") | ||
train_arg_parser.add_argument("--content_weight", type=float, default=1e5, | ||
help="weight for content-loss, default is 1e5") | ||
train_arg_parser.add_argument("--style_weight", type=float, default=1e10, | ||
help="weight for style-loss, default is 1e10") | ||
train_arg_parser.add_argument("--lr", type=float, default=1e-3, | ||
help="learning rate, default is 1e-3") | ||
|
||
eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments") | ||
eval_arg_parser.add_argument("--content_image", type=str, required=True, | ||
help="path to content image you want to stylize") | ||
eval_arg_parser.add_argument("--content_scale", type=float, default=None, | ||
help="factor for scaling down the content image") | ||
eval_arg_parser.add_argument("--output_image", type=str, required=True, | ||
help="path for saving the output image") | ||
eval_arg_parser.add_argument("--model", type=str, required=True, | ||
help="saved model to be used for stylizing the image.") | ||
eval_arg_parser.add_argument("--cuda", type=int, required=True, | ||
help="set it to 1 for running on GPU, 0 for CPU") | ||
|
||
args = main_arg_parser.parse_args() | ||
|
||
if args.subcommand is None: | ||
raise ValueError("ERROR: specify either train or eval") | ||
if args.cuda and not torch.cuda.is_available(): | ||
raise ValueError("ERROR: cuda is not available, try running on CPU") | ||
|
||
if args.subcommand == "train": | ||
check_manual_seed(args) | ||
check_paths(args) | ||
train(args) | ||
else: | ||
stylize(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.