forked from diyiiyiii/StyTR-2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
summary.py
116 lines (99 loc) · 4.32 KB
/
summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import os
import torch
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from pathlib import Path
import models.transformer as transformer
import models.StyTR as StyTR
from sampler import InfiniteSamplerWrapper
from torchvision.utils import save_image
def train_transform():
transform_list = [
# transforms.Resize(size=(128, 432)),
transforms.Resize(size=(512, 512)),
transforms.RandomCrop(256),
# transforms.RandomCrop(256),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
class FlatFolderDataset(data.Dataset):
def __init__(self, root, transform):
super(FlatFolderDataset, self).__init__()
self.root = root
print(self.root)
self.path = os.listdir(self.root)
if os.path.isdir(os.path.join(self.root,self.path[0])):
self.paths = []
for file_name in os.listdir(self.root):
for file_name1 in os.listdir(os.path.join(self.root,file_name)):
self.paths.append(self.root+"/"+file_name+"/"+file_name1)
else:
self.paths = list(Path(self.root).glob('*'))
self.transform = transform
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(str(path)).convert('RGB')
img = self.transform(img)
return img
def __len__(self):
return len(self.paths)
def name(self):
return 'FlatFolderDataset'
def adjust_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = 2e-4 / (1.0 + args.lr_decay * (iteration_count - 1e4))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def warmup_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = args.lr * 0.1 * (1.0 + 3e-4 * iteration_count)
# print(lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content_dir', default='./images1', type=str,
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', default='./images2', type=str, #wikiart dataset crawled from https://www.wikiart.org/
help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth') #run the train.py, please download the pretrained vgg checkpoint
# training options
parser.add_argument('--save_dir', default='./experiments',
help='Directory to save the model')
parser.add_argument('--log_dir', default='./logs',
help='Directory to save the log')
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--lr_decay', type=float, default=1e-5)
parser.add_argument('--max_iter', type=int, default=160000)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--style_weight', type=float, default=10.0)
parser.add_argument('--content_weight', type=float, default=7.0)
parser.add_argument('--n_threads', type=int, default=16)
parser.add_argument('--save_model_interval', type=int, default=10000)
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--hidden_dim', default=512, type=int,
help="Size of the embeddings (dimension of the transformer)")
args = parser.parse_args()
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:0" if USE_CUDA else "cpu")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
writer = SummaryWriter(log_dir=args.log_dir)
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
embedding = StyTR.PatchEmbed()
Trans = transformer.Transformer()
network = StyTR.StyTrans(vgg,decoder,embedding, Trans,args)
from torchinfo import summary
summary(network, input_data=(torch.randn(1, 3, 256, 256), torch.randn(1, 3, 256, 256)), depth=1)