Skip to content

Commit

Permalink
use torch.cuda.set_device
Browse files Browse the repository at this point in the history
  • Loading branch information
Naoto Inoue committed Jan 10, 2018
1 parent bb8c68f commit 937134b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ python convert_torch.py --model models/decoder.t7
### Test
Use `--content` and `--style` to provide the respective path to the content and style image.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
python test.py --gpu <gpu_id> --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
```

You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
python test.py --gpu <gpu_id> --content_dir input/content --style_dir input/style
```

This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
python test.py --gpu <gpu_id> --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
```

Some other options:
Expand All @@ -51,7 +51,7 @@ Please install tensorflow, tqdm, and scipy for progress bar and logging

Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
```
CUDA_VISIBLE_DEVICES=<gpu_id> python train.py --content_dir <content_dir> --style_dir <style_dir>
python train.py --gpu <gpu_id> --content_dir <content_dir> --style_dir <style_dir>
```

For more details and parameters, please refer to --help option.
Expand Down
9 changes: 6 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse

import os
from os.path import basename
from os.path import splitext

import torch
import torch.nn as nn
from PIL import Image
from os.path import basename
from os.path import splitext
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image
Expand Down Expand Up @@ -47,6 +47,7 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content', type=str,
help='File path to the content image')
parser.add_argument('--content_dir', type=str,
Expand Down Expand Up @@ -85,6 +86,8 @@ def style_transfer(vgg, decoder, content, style, alpha=1.0,
help='The weight for blending the style of multiple style images')

args = parser.parse_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)

do_interpolation = False

Expand Down
17 changes: 10 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,16 @@ def name(self):
return 'FlatFolderDataset'


def adjust_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = args.lr / (1.0 + args.lr_decay * iteration_count)
for param_group in optimizer.param_groups:
param_group['lr'] = lr


parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--content_dir', type=str, required=True,
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, required=True,
Expand All @@ -70,13 +78,8 @@ def name(self):
parser.add_argument('--n_threads', type=int, default=16)
args = parser.parse_args()


def adjust_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = args.lr / (1.0 + args.lr_decay * iteration_count)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

if args.gpu >= 0:
torch.cuda.set_device(args.gpu)

if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
Expand Down

0 comments on commit 937134b

Please sign in to comment.