forked from yunjey/show-attend-and-tell
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepro.py
212 lines (169 loc) · 8.47 KB
/
prepro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from scipy import ndimage
from collections import Counter
from core.vggnet import Vgg19
from core.utils import *
import tensorflow as tf
import numpy as np
import pandas as pd
import hickle
import os
import json
def _process_caption_data(caption_file, image_dir, max_length):
with open(caption_file) as f:
caption_data = json.load(f)
# id_to_filename is a dictionary such as {image_id: filename]}
id_to_filename = {image['id']: image['file_name'] for image in caption_data['images']}
# data is a list of dictionary which contains 'captions', 'file_name' and 'image_id' as key.
data = []
for annotation in caption_data['annotations']:
image_id = annotation['image_id']
annotation['file_name'] = os.path.join(image_dir, id_to_filename[image_id])
data += [annotation]
# convert to pandas dataframe (for later visualization or debugging)
caption_data = pd.DataFrame.from_dict(data)
del caption_data['id']
caption_data.sort_values(by='image_id', inplace=True)
caption_data = caption_data.reset_index(drop=True)
del_idx = []
for i, caption in enumerate(caption_data['caption']):
caption = caption.replace('.','').replace(',','').replace("'","").replace('"','')
caption = caption.replace('&','and').replace('(','').replace(")","").replace('-',' ')
caption = " ".join(caption.split()) # replace multiple spaces
caption_data.set_value(i, 'caption', caption.lower())
if len(caption.split(" ")) > max_length:
del_idx.append(i)
# delete captions if size is larger than max_length
print "The number of captions before deletion: %d" %len(caption_data)
caption_data = caption_data.drop(caption_data.index[del_idx])
caption_data = caption_data.reset_index(drop=True)
print "The number of captions after deletion: %d" %len(caption_data)
return caption_data
def _build_vocab(annotations, threshold=1):
counter = Counter()
max_len = 0
for i, caption in enumerate(annotations['caption']):
words = caption.split(' ') # caption contrains only lower-case words
for w in words:
counter[w] +=1
if len(caption.split(" ")) > max_len:
max_len = len(caption.split(" "))
vocab = [word for word in counter if counter[word] >= threshold]
print ('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold))
word_to_idx = {u'<NULL>': 0, u'<START>': 1, u'<END>': 2}
idx = 3
for word in vocab:
word_to_idx[word] = idx
idx += 1
print "Max length of caption: ", max_len
return word_to_idx
def _build_caption_vector(annotations, word_to_idx, max_length=15):
n_examples = len(annotations)
captions = np.ndarray((n_examples,max_length+2)).astype(np.int32)
for i, caption in enumerate(annotations['caption']):
words = caption.split(" ") # caption contrains only lower-case words
cap_vec = []
cap_vec.append(word_to_idx['<START>'])
for word in words:
if word in word_to_idx:
cap_vec.append(word_to_idx[word])
cap_vec.append(word_to_idx['<END>'])
# pad short caption with the special null token '<NULL>' to make it fixed-size vector
if len(cap_vec) < (max_length + 2):
for j in range(max_length + 2 - len(cap_vec)):
cap_vec.append(word_to_idx['<NULL>'])
captions[i, :] = np.asarray(cap_vec)
print "Finished building caption vectors"
return captions
def _build_file_names(annotations):
image_file_names = []
id_to_idx = {}
idx = 0
image_ids = annotations['image_id']
file_names = annotations['file_name']
for image_id, file_name in zip(image_ids, file_names):
if not image_id in id_to_idx:
id_to_idx[image_id] = idx
image_file_names.append(file_name)
idx += 1
file_names = np.asarray(image_file_names)
return file_names, id_to_idx
def _build_image_idxs(annotations, id_to_idx):
image_idxs = np.ndarray(len(annotations), dtype=np.int32)
image_ids = annotations['image_id']
for i, image_id in enumerate(image_ids):
image_idxs[i] = id_to_idx[image_id]
return image_idxs
def main():
# batch size for extracting feature vectors from vggnet.
batch_size = 100
# maximum length of caption(number of word). if caption is longer than max_length, deleted.
max_length = 15
# if word occurs less than word_count_threshold in training dataset, the word index is special unknown token.
word_count_threshold = 1
# vgg model path
vgg_model_path = './data/imagenet-vgg-verydeep-19.mat'
caption_file = 'data/annotations/captions_train2014.json'
image_dir = 'image/%2014_resized/'
# about 80000 images and 400000 captions for train dataset
train_dataset = _process_caption_data(caption_file='data/annotations/captions_train2014.json',
image_dir='image/train2014_resized/',
max_length=max_length)
# about 40000 images and 200000 captions
val_dataset = _process_caption_data(caption_file='data/annotations/captions_val2014.json',
image_dir='image/val2014_resized/',
max_length=max_length)
# about 4000 images and 20000 captions for val / test dataset
val_cutoff = int(0.1 * len(val_dataset))
test_cutoff = int(0.2 * len(val_dataset))
print 'Finished processing caption data'
save_pickle(train_dataset, 'data/train/train.annotations.pkl')
save_pickle(val_dataset[:val_cutoff], 'data/val/val.annotations.pkl')
save_pickle(val_dataset[val_cutoff:test_cutoff].reset_index(drop=True), 'data/test/test.annotations.pkl')
for split in ['train', 'val', 'test']:
annotations = load_pickle('./data/%s/%s.annotations.pkl' % (split, split))
if split == 'train':
word_to_idx = _build_vocab(annotations=annotations, threshold=word_count_threshold)
save_pickle(word_to_idx, './data/%s/word_to_idx.pkl' % split)
captions = _build_caption_vector(annotations=annotations, word_to_idx=word_to_idx, max_length=max_length)
save_pickle(captions, './data/%s/%s.captions.pkl' % (split, split))
file_names, id_to_idx = _build_file_names(annotations)
save_pickle(file_names, './data/%s/%s.file.names.pkl' % (split, split))
image_idxs = _build_image_idxs(annotations, id_to_idx)
save_pickle(image_idxs, './data/%s/%s.image.idxs.pkl' % (split, split))
# prepare reference captions to compute bleu scores later
image_ids = {}
feature_to_captions = {}
i = -1
for caption, image_id in zip(annotations['caption'], annotations['image_id']):
if not image_id in image_ids:
image_ids[image_id] = 0
i += 1
feature_to_captions[i] = []
feature_to_captions[i].append(caption.lower() + ' .')
save_pickle(feature_to_captions, './data/%s/%s.references.pkl' % (split, split))
print "Finished building %s caption dataset" %split
# extract conv5_3 feature vectors
vggnet = Vgg19(vgg_model_path)
vggnet.build()
with tf.Session() as sess:
tf.initialize_all_variables().run()
for split in ['train', 'val', 'test']:
anno_path = './data/%s/%s.annotations.pkl' % (split, split)
save_path = './data/%s/%s.features.hkl' % (split, split)
annotations = load_pickle(anno_path)
image_path = list(annotations['file_name'].unique())
n_examples = len(image_path)
all_feats = np.ndarray([n_examples, 196, 512], dtype=np.float32)
for start, end in zip(range(0, n_examples, batch_size),
range(batch_size, n_examples + batch_size, batch_size)):
image_batch_file = image_path[start:end]
image_batch = np.array(map(lambda x: ndimage.imread(x, mode='RGB'), image_batch_file)).astype(
np.float32)
feats = sess.run(vggnet.features, feed_dict={vggnet.images: image_batch})
all_feats[start:end, :] = feats
print ("Processed %d %s features.." % (end, split))
# use hickle to save huge feature vectors
hickle.dump(all_feats, save_path)
print ("Saved %s.." % (save_path))
if __name__ == "__main__":
main()