Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
Add the PyTorch PyTorch implementation of the paper: “Deep Single Image Manipulation”.
  • Loading branch information
Eliahu Horwitz committed Jul 2, 2020
1 parent e54e36d commit 5adbb52
Show file tree
Hide file tree
Showing 48 changed files with 2,607 additions and 1 deletion.
65 changes: 64 additions & 1 deletion README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1,64 @@
# DeepSIM
# DeepSIM
### [Project](http://www.vision.huji.ac.il/deepsim) | [Paper]() <br>
Official PyTorch implementation of the paper: "Deep Single Image Manipulation".

## Results

<p align='center'>
<img src='imgs/main_table.png' />
</p>

### SuperPrimitive2Image

<p align='center'>
<img src='./imgs/sp1.png' />
</p>
<p align='center'>
<img src='./imgs/sp2.png' />
</p>


### Image2VideoFrames

<p align='center'>
<img src='./imgs/im2vid1.png' />
</p>



## Getting Started
### Training
- Train a model at 640 x 640 resolution (`bash ./scripts/train.sh`):
```bash
#!./scripts/train.sh
python3.7 train.py --dataroot ./datasets/face --name DeepSIM --niter 8000 --niter_decay 8000 --label_nc 0 --no_instance --resize_or_crop none --tps_aug 1 --apply_binary_threshold 1 --resize_or_crop none --loadSize 640 --fineSize 640
```
- To view training results, please checkout intermediate results in `./checkpoints/DeepSIM/web/index.html`.

### Training with your own dataset
- For binary training images (i.e. edge maps) use `--apply_binary_threshold 1`, to ensure that the edges input is indeed binary (during both training and inference), for segmentation maps use `--apply_binary_threshold 0`, you may also use `--edge_threshold` to control the threshold.
- For TPS augmentations use `--tps_aug 1`, this will train the image with a new random TPS warp every epoch.
- See `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags. See pix2pixHD for further details.
### Testing


- Test the model (`bash ./scripts/test.sh`):
```bash
#!./scripts/test.sh
python3.7 test.py --dataroot ./datasets/face --name DeepSIM --label_nc 0 --no_instance --resize_or_crop none --apply_binary_threshold 1 --online_tps 0 --no_instance --loadSize 640 --fineSize 640
```
The test results will be saved to a html file here: `./results/DeepSIM/test_latest/index.html`.

More example scripts can be found in the `scripts` directory.


## Citation

If you find this useful for your research, please use the following.

```
```


## Acknowledgments
This code borrows heavily from [pix2pixHD](https://github.com/NVIDIA/pix2pixHD).
Empty file added data/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torch

class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot

### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))

### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))

### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))

### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))

self.dataset_size = len(self.A_paths)

def __getitem__(self, index):
### input A (label maps)
edges_path, edges_im = None, None
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size, A)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_img = A.convert('RGB')
A_tensor = transform_A(A_img)
if self.opt.apply_binary_threshold == 1:
ones = torch.ones_like(A_tensor)
minus_ones = torch.ones_like(A_tensor)*-1
A_tensor = torch.where(A_tensor>=self.opt.edge_threshold, ones,minus_ones)

else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0

B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)

### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)

if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))

input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}

return input_dict

def __len__(self):
return len(self.A_paths)

def name(self):
return 'AlignedDataset'
14 changes: 14 additions & 0 deletions data/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

class BaseDataLoader():
def __init__(self):
pass

def initialize(self, opt):
self.opt = opt
pass

def load_data():
return None



124 changes: 124 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
from util import tps_warp
import math
from PIL import ImageDraw


class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass


def get_params(opt, size, input_im):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w

x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))

flip = random.random() > 0.5

if opt.tps_aug:
np_im = np.array(input_im)
src = tps_warp._get_regular_grid(np_im,
points_per_dim=opt.tps_points_per_dim)
dst = tps_warp._generate_random_vectors(np_im, src, scale=0.1 * w)
return {'crop_pos': (x, y), 'flip': flip,
'tps': {'src': src, 'dst': dst}}
return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if opt.tps_aug:
transform_list.append(
transforms.Lambda(lambda img: __apply_tps(img, params['tps'])))

if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize, method)))

if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(
lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(
transforms.Lambda(lambda img: __make_power_2(img, base, method)))

if opt.isTrain and not opt.no_flip:
transform_list.append(
transforms.Lambda(lambda img: __flip(img, params['flip'])))

transform_list += [transforms.ToTensor()]

if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)


def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))


def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)


def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)


def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img


def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img


def __apply_tps(img, tps_params):
np_im = np.array(img)
np_im = tps_warp.tps_warp_2(np_im, tps_params['dst'], tps_params['src'])
new_im = Image.fromarray(np_im)
return new_im
33 changes: 33 additions & 0 deletions data/custom_dataset_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader
import numpy as np

def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()

print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset

class CustomDatasetDataLoader(BaseDataLoader):

def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=1,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
worker_init_fn=lambda _: np.random.seed())

def load_data(self):
return self.dataloader

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
7 changes: 7 additions & 0 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
65 changes: 65 additions & 0 deletions data/image_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)

return images


def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)
Binary file added datasets/car/test_A/car_test1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/car/test_A/car_test2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/car/test_A/car_test_original.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/car/train_A/car.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/car/train_B/car.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/test_A/face_test1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/test_A/face_test2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/test_A/face_test3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/test_A/face_test5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/test_A/face_test6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/train_A/face.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/face/train_B/face.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/im2vid1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/main_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/sp1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/sp2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added models/__init__.py
Empty file.
Loading

0 comments on commit 5adbb52

Please sign in to comment.