-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the PyTorch PyTorch implementation of the paper: “Deep Single Image Manipulation”.
- Loading branch information
Eliahu Horwitz
committed
Jul 2, 2020
1 parent
e54e36d
commit 5adbb52
Showing
48 changed files
with
2,607 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,64 @@ | ||
# DeepSIM | ||
# DeepSIM | ||
### [Project](http://www.vision.huji.ac.il/deepsim) | [Paper]() <br> | ||
Official PyTorch implementation of the paper: "Deep Single Image Manipulation". | ||
|
||
## Results | ||
|
||
<p align='center'> | ||
<img src='imgs/main_table.png' /> | ||
</p> | ||
|
||
### SuperPrimitive2Image | ||
|
||
<p align='center'> | ||
<img src='./imgs/sp1.png' /> | ||
</p> | ||
<p align='center'> | ||
<img src='./imgs/sp2.png' /> | ||
</p> | ||
|
||
|
||
### Image2VideoFrames | ||
|
||
<p align='center'> | ||
<img src='./imgs/im2vid1.png' /> | ||
</p> | ||
|
||
|
||
|
||
## Getting Started | ||
### Training | ||
- Train a model at 640 x 640 resolution (`bash ./scripts/train.sh`): | ||
```bash | ||
#!./scripts/train.sh | ||
python3.7 train.py --dataroot ./datasets/face --name DeepSIM --niter 8000 --niter_decay 8000 --label_nc 0 --no_instance --resize_or_crop none --tps_aug 1 --apply_binary_threshold 1 --resize_or_crop none --loadSize 640 --fineSize 640 | ||
``` | ||
- To view training results, please checkout intermediate results in `./checkpoints/DeepSIM/web/index.html`. | ||
|
||
### Training with your own dataset | ||
- For binary training images (i.e. edge maps) use `--apply_binary_threshold 1`, to ensure that the edges input is indeed binary (during both training and inference), for segmentation maps use `--apply_binary_threshold 0`, you may also use `--edge_threshold` to control the threshold. | ||
- For TPS augmentations use `--tps_aug 1`, this will train the image with a new random TPS warp every epoch. | ||
- See `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags. See pix2pixHD for further details. | ||
### Testing | ||
|
||
|
||
- Test the model (`bash ./scripts/test.sh`): | ||
```bash | ||
#!./scripts/test.sh | ||
python3.7 test.py --dataroot ./datasets/face --name DeepSIM --label_nc 0 --no_instance --resize_or_crop none --apply_binary_threshold 1 --online_tps 0 --no_instance --loadSize 640 --fineSize 640 | ||
``` | ||
The test results will be saved to a html file here: `./results/DeepSIM/test_latest/index.html`. | ||
|
||
More example scripts can be found in the `scripts` directory. | ||
|
||
|
||
## Citation | ||
|
||
If you find this useful for your research, please use the following. | ||
|
||
``` | ||
``` | ||
|
||
|
||
## Acknowledgments | ||
This code borrows heavily from [pix2pixHD](https://github.com/NVIDIA/pix2pixHD). |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os.path | ||
from data.base_dataset import BaseDataset, get_params, get_transform, normalize | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
import torch | ||
|
||
class AlignedDataset(BaseDataset): | ||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
|
||
### input A (label maps) | ||
dir_A = '_A' if self.opt.label_nc == 0 else '_label' | ||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) | ||
self.A_paths = sorted(make_dataset(self.dir_A)) | ||
|
||
### input B (real images) | ||
if opt.isTrain or opt.use_encoded_image: | ||
dir_B = '_B' if self.opt.label_nc == 0 else '_img' | ||
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) | ||
self.B_paths = sorted(make_dataset(self.dir_B)) | ||
|
||
### instance maps | ||
if not opt.no_instance: | ||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') | ||
self.inst_paths = sorted(make_dataset(self.dir_inst)) | ||
|
||
### load precomputed instance-wise encoded features | ||
if opt.load_features: | ||
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') | ||
print('----------- loading features from %s ----------' % self.dir_feat) | ||
self.feat_paths = sorted(make_dataset(self.dir_feat)) | ||
|
||
self.dataset_size = len(self.A_paths) | ||
|
||
def __getitem__(self, index): | ||
### input A (label maps) | ||
edges_path, edges_im = None, None | ||
A_path = self.A_paths[index] | ||
A = Image.open(A_path) | ||
params = get_params(self.opt, A.size, A) | ||
if self.opt.label_nc == 0: | ||
transform_A = get_transform(self.opt, params) | ||
A_img = A.convert('RGB') | ||
A_tensor = transform_A(A_img) | ||
if self.opt.apply_binary_threshold == 1: | ||
ones = torch.ones_like(A_tensor) | ||
minus_ones = torch.ones_like(A_tensor)*-1 | ||
A_tensor = torch.where(A_tensor>=self.opt.edge_threshold, ones,minus_ones) | ||
|
||
else: | ||
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) | ||
A_tensor = transform_A(A) * 255.0 | ||
|
||
B_tensor = inst_tensor = feat_tensor = 0 | ||
### input B (real images) | ||
if self.opt.isTrain or self.opt.use_encoded_image: | ||
B_path = self.B_paths[index] | ||
B = Image.open(B_path).convert('RGB') | ||
transform_B = get_transform(self.opt, params) | ||
B_tensor = transform_B(B) | ||
|
||
### if using instance maps | ||
if not self.opt.no_instance: | ||
inst_path = self.inst_paths[index] | ||
inst = Image.open(inst_path) | ||
inst_tensor = transform_A(inst) | ||
|
||
if self.opt.load_features: | ||
feat_path = self.feat_paths[index] | ||
feat = Image.open(feat_path).convert('RGB') | ||
norm = normalize() | ||
feat_tensor = norm(transform_A(feat)) | ||
|
||
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, | ||
'feat': feat_tensor, 'path': A_path} | ||
|
||
return input_dict | ||
|
||
def __len__(self): | ||
return len(self.A_paths) | ||
|
||
def name(self): | ||
return 'AlignedDataset' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
|
||
class BaseDataLoader(): | ||
def __init__(self): | ||
pass | ||
|
||
def initialize(self, opt): | ||
self.opt = opt | ||
pass | ||
|
||
def load_data(): | ||
return None | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
import numpy as np | ||
import random | ||
from util import tps_warp | ||
import math | ||
from PIL import ImageDraw | ||
|
||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
||
|
||
def get_params(opt, size, input_im): | ||
w, h = size | ||
new_h = h | ||
new_w = w | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
new_h = new_w = opt.loadSize | ||
elif opt.resize_or_crop == 'scale_width_and_crop': | ||
new_w = opt.loadSize | ||
new_h = opt.loadSize * h // w | ||
|
||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) | ||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) | ||
|
||
flip = random.random() > 0.5 | ||
|
||
if opt.tps_aug: | ||
np_im = np.array(input_im) | ||
src = tps_warp._get_regular_grid(np_im, | ||
points_per_dim=opt.tps_points_per_dim) | ||
dst = tps_warp._generate_random_vectors(np_im, src, scale=0.1 * w) | ||
return {'crop_pos': (x, y), 'flip': flip, | ||
'tps': {'src': src, 'dst': dst}} | ||
return {'crop_pos': (x, y), 'flip': flip} | ||
|
||
|
||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True): | ||
transform_list = [] | ||
if opt.tps_aug: | ||
transform_list.append( | ||
transforms.Lambda(lambda img: __apply_tps(img, params['tps']))) | ||
|
||
if 'resize' in opt.resize_or_crop: | ||
osize = [opt.loadSize, opt.loadSize] | ||
transform_list.append(transforms.Scale(osize, method)) | ||
elif 'scale_width' in opt.resize_or_crop: | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __scale_width(img, opt.loadSize, method))) | ||
|
||
if 'crop' in opt.resize_or_crop: | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __crop(img, params['crop_pos'], opt.fineSize))) | ||
|
||
if opt.resize_or_crop == 'none': | ||
base = float(2 ** opt.n_downsample_global) | ||
if opt.netG == 'local': | ||
base *= (2 ** opt.n_local_enhancers) | ||
transform_list.append( | ||
transforms.Lambda(lambda img: __make_power_2(img, base, method))) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append( | ||
transforms.Lambda(lambda img: __flip(img, params['flip']))) | ||
|
||
transform_list += [transforms.ToTensor()] | ||
|
||
if normalize: | ||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
|
||
def normalize(): | ||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
|
||
|
||
def __make_power_2(img, base, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
h = int(round(oh / base) * base) | ||
w = int(round(ow / base) * base) | ||
if (h == oh) and (w == ow): | ||
return img | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __scale_width(img, target_width, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
if (ow == target_width): | ||
return img | ||
w = target_width | ||
h = int(target_width * oh / ow) | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __crop(img, pos, size): | ||
ow, oh = img.size | ||
x1, y1 = pos | ||
tw = th = size | ||
if (ow > tw or oh > th): | ||
return img.crop((x1, y1, x1 + tw, y1 + th)) | ||
return img | ||
|
||
|
||
def __flip(img, flip): | ||
if flip: | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
return img | ||
|
||
|
||
def __apply_tps(img, tps_params): | ||
np_im = np.array(img) | ||
np_im = tps_warp.tps_warp_2(np_im, tps_params['dst'], tps_params['src']) | ||
new_im = Image.fromarray(np_im) | ||
return new_im |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch.utils.data | ||
from data.base_data_loader import BaseDataLoader | ||
import numpy as np | ||
|
||
def CreateDataset(opt): | ||
dataset = None | ||
from data.aligned_dataset import AlignedDataset | ||
dataset = AlignedDataset() | ||
|
||
print("dataset [%s] was created" % (dataset.name())) | ||
dataset.initialize(opt) | ||
return dataset | ||
|
||
class CustomDatasetDataLoader(BaseDataLoader): | ||
|
||
def name(self): | ||
return 'CustomDatasetDataLoader' | ||
|
||
def initialize(self, opt): | ||
BaseDataLoader.initialize(self, opt) | ||
self.dataset = CreateDataset(opt) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=1, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.nThreads), | ||
worker_init_fn=lambda _: np.random.seed()) | ||
|
||
def load_data(self): | ||
return self.dataloader | ||
|
||
def __len__(self): | ||
return min(len(self.dataset), self.opt.max_dataset_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
def CreateDataLoader(opt): | ||
from data.custom_dataset_data_loader import CustomDatasetDataLoader | ||
data_loader = CustomDatasetDataLoader() | ||
print(data_loader.name()) | ||
data_loader.initialize(opt) | ||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
############################################################################### | ||
# Code from | ||
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py | ||
# Modified the original code so that it also loads images from the current | ||
# directory as well as the subdirectories | ||
############################################################################### | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import os | ||
|
||
IMG_EXTENSIONS = [ | ||
'.jpg', '.JPG', '.jpeg', '.JPEG', | ||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' | ||
] | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | ||
|
||
|
||
def make_dataset(dir): | ||
images = [] | ||
assert os.path.isdir(dir), '%s is not a valid directory' % dir | ||
|
||
for root, _, fnames in sorted(os.walk(dir)): | ||
for fname in fnames: | ||
if is_image_file(fname): | ||
path = os.path.join(root, fname) | ||
images.append(path) | ||
|
||
return images | ||
|
||
|
||
def default_loader(path): | ||
return Image.open(path).convert('RGB') | ||
|
||
|
||
class ImageFolder(data.Dataset): | ||
|
||
def __init__(self, root, transform=None, return_paths=False, | ||
loader=default_loader): | ||
imgs = make_dataset(root) | ||
if len(imgs) == 0: | ||
raise(RuntimeError("Found 0 images in: " + root + "\n" | ||
"Supported image extensions are: " + | ||
",".join(IMG_EXTENSIONS))) | ||
|
||
self.root = root | ||
self.imgs = imgs | ||
self.transform = transform | ||
self.return_paths = return_paths | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
path = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.return_paths: | ||
return img, path | ||
else: | ||
return img | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Oops, something went wrong.