Skip to content

Commit

Permalink
Transition to tf 1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruben Vereecken committed Jun 22, 2017
1 parent f10e761 commit 8fbcd33
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 67 deletions.
65 changes: 33 additions & 32 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@


class CaptionGenerator(object):
def __init__(self, word_to_idx, dim_feature=[196, 512], dim_embed=512, dim_hidden=1024, n_time_step=16,
def __init__(self, word_to_idx, dim_feature=[196, 512], dim_embed=512, dim_hidden=1024, n_time_step=16,
prev2out=True, ctx2out=True, alpha_c=0.0, selector=True, dropout=True):
"""
Args:
word_to_idx: word-to-index mapping dictionary.
dim_feature: (optional) Dimension of vggnet19 conv5_3 feature vectors.
dim_embed: (optional) Dimension of word embedding.
dim_hidden: (optional) Dimension of all hidden state.
n_time_step: (optional) Time step size of LSTM.
n_time_step: (optional) Time step size of LSTM.
prev2out: (optional) previously generated word to hidden state. (see Eq (7) for explanation)
ctx2out: (optional) context to hidden state (see Eq (7) for explanation)
alpha_c: (optional) Doubly stochastic regularization coefficient. (see Section (4.2.1) for explanation)
selector: (optional) gating scalar for context vector. (see Section (4.2.1) for explanation)
dropout: (optional) If true then dropout layer is added.
"""

self.word_to_idx = word_to_idx
self.idx_to_word = {i: w for w, i in word_to_idx.iteritems()}
self.prev2out = prev2out
Expand All @@ -55,7 +55,7 @@ def __init__(self, word_to_idx, dim_feature=[196, 512], dim_embed=512, dim_hidde
# Place holder for features and captions
self.features = tf.placeholder(tf.float32, [None, self.L, self.D])
self.captions = tf.placeholder(tf.int32, [None, self.T + 1])

def _get_initial_lstm(self, features):
with tf.variable_scope('initial_lstm'):
features_mean = tf.reduce_mean(features, 1)
Expand All @@ -79,7 +79,7 @@ def _project_features(self, features):
with tf.variable_scope('project_features'):
w = tf.get_variable('w', [self.D, self.D], initializer=self.weight_initializer)
features_flat = tf.reshape(features, [-1, self.D])
features_proj = tf.matmul(features_flat, w)
features_proj = tf.matmul(features_flat, w)
features_proj = tf.reshape(features_proj, [-1, self.L, self.D])
return features_proj

Expand All @@ -91,18 +91,18 @@ def _attention_layer(self, features, features_proj, h, reuse=False):

h_att = tf.nn.relu(features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b) # (N, L, D)
out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L)
alpha = tf.nn.softmax(out_att)
alpha = tf.nn.softmax(out_att)
context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') #(N, D)
return context, alpha

def _selector(self, context, h, reuse=False):
with tf.variable_scope('selector', reuse=reuse):
w = tf.get_variable('w', [self.H, 1], initializer=self.weight_initializer)
b = tf.get_variable('b', [1], initializer=self.const_initializer)
beta = tf.nn.sigmoid(tf.matmul(h, w) + b, 'beta') # (N, 1)
context = tf.mul(beta, context, name='selected_context')
context = tf.multiply(beta, context, name='selected_context')
return context, beta

def _decode_lstm(self, x, h, context, dropout=False, reuse=False):
with tf.variable_scope('logits', reuse=reuse):
w_h = tf.get_variable('w_h', [self.H, self.M], initializer=self.weight_initializer)
Expand All @@ -126,9 +126,9 @@ def _decode_lstm(self, x, h, context, dropout=False, reuse=False):
h_logits = tf.nn.dropout(h_logits, 0.5)
out_logits = tf.matmul(h_logits, w_out) + b_out
return out_logits

def _batch_norm(self, x, mode='train', name=None):
return tf.contrib.layers.batch_norm(inputs=x,
return tf.contrib.layers.batch_norm(inputs=x,
decay=0.95,
center=True,
scale=True,
Expand All @@ -141,14 +141,14 @@ def build_model(self):
captions = self.captions
batch_size = tf.shape(features)[0]

captions_in = captions[:, :self.T]
captions_out = captions[:, 1:]
captions_in = captions[:, :self.T]
captions_out = captions[:, 1:]
mask = tf.to_float(tf.not_equal(captions_out, self._null))


# batch normalize feature vectors
features = self._batch_norm(features, mode='train', name='conv_features')

c, h = self._get_initial_lstm(features=features)
x = self._word_embedding(inputs=captions_in)
features_proj = self._project_features(features=features)
Expand All @@ -162,28 +162,29 @@ def build_model(self):
alpha_list.append(alpha)

if self.selector:
context, beta = self._selector(context, h, reuse=(t!=0))
context, beta = self._selector(context, h, reuse=(t!=0))

with tf.variable_scope('lstm', reuse=(t!=0)):
_, (c, h) = lstm_cell(inputs=tf.concat(1, [x[:,t,:], context]), state=[c, h])
_, (c, h) = lstm_cell(inputs=tf.concat( [x[:,t,:], context],1), state=[c, h])

logits = self._decode_lstm(x[:,t,:], h, context, dropout=self.dropout, reuse=(t!=0))
loss += tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, captions_out[:, t]) * mask[:, t])


loss += tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=captions_out[:, t],logits=logits)*mask[:, t] )

if self.alpha_c > 0:
alphas = tf.transpose(tf.pack(alpha_list), (1, 0, 2)) # (N, T, L)
alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L)
alphas_all = tf.reduce_sum(alphas, 1) # (N, L)
alpha_reg = self.alpha_c * tf.reduce_sum((16./196 - alphas_all) ** 2)
alpha_reg = self.alpha_c * tf.reduce_sum((16./196 - alphas_all) ** 2)
loss += alpha_reg

return loss / tf.to_float(batch_size)

def build_sampler(self, max_len=20):
features = self.features

# batch normalize feature vectors
features = self._batch_norm(features, mode='test', name='conv_features')

c, h = self._get_initial_lstm(features=features)
features_proj = self._project_features(features=features)

Expand All @@ -196,23 +197,23 @@ def build_sampler(self, max_len=20):
if t == 0:
x = self._word_embedding(inputs=tf.fill([tf.shape(features)[0]], self._start))
else:
x = self._word_embedding(inputs=sampled_word, reuse=True)
x = self._word_embedding(inputs=sampled_word, reuse=True)

context, alpha = self._attention_layer(features, features_proj, h, reuse=(t!=0))
alpha_list.append(alpha)

if self.selector:
context, beta = self._selector(context, h, reuse=(t!=0))
context, beta = self._selector(context, h, reuse=(t!=0))
beta_list.append(beta)

with tf.variable_scope('lstm', reuse=(t!=0)):
_, (c, h) = lstm_cell(inputs=tf.concat(1, [x, context]), state=[c, h])
_, (c, h) = lstm_cell(inputs=tf.concat( [x, context],1), state=[c, h])

logits = self._decode_lstm(x, h, context, reuse=(t!=0))
sampled_word = tf.argmax(logits, 1)
sampled_word_list.append(sampled_word)
sampled_word = tf.argmax(logits, 1)
sampled_word_list.append(sampled_word)

alphas = tf.transpose(tf.pack(alpha_list), (1, 0, 2)) # (N, T, L)
alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L)
betas = tf.transpose(tf.squeeze(beta_list), (1, 0)) # (N, T)
sampled_captions = tf.transpose(tf.pack(sampled_word_list), (1, 0)) # (N, max_len)
sampled_captions = tf.transpose(tf.stack(sampled_word_list), (1, 0)) # (N, max_len)
return alphas, betas, sampled_captions
79 changes: 44 additions & 35 deletions core/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import skimage.transform
import numpy as np
import time
import os
import os
import cPickle as pickle
from scipy import ndimage
from utils import *
Expand All @@ -18,9 +18,9 @@ def __init__(self, model, data, val_data, **kwargs):
- data: Training data; dictionary with the following keys:
- features: Feature vectors of shape (82783, 196, 512)
- file_names: Image file names of shape (82783, )
- captions: Captions of shape (400000, 17)
- image_idxs: Indices for mapping caption to image of shape (400000, )
- word_to_idx: Mapping dictionary from word to index
- captions: Captions of shape (400000, 17)
- image_idxs: Indices for mapping caption to image of shape (400000, )
- word_to_idx: Mapping dictionary from word to index
- val_data: validation data; for print out BLEU scores for each epoch.
Optional Arguments:
- n_epochs: The number of epochs to run for training.
Expand All @@ -29,9 +29,9 @@ def __init__(self, model, data, val_data, **kwargs):
- learning_rate: Learning rate; default value is 0.01.
- print_every: Integer; training losses will be printed every print_every iterations.
- save_every: Integer; model variables will be saved every save_every epoch.
- pretrained_model: String; pretrained model path
- model_path: String; model path for saving
- test_model: String; model path for test
- pretrained_model: String; pretrained model path
- model_path: String; model path for saving
- test_model: String; model path for test
"""

self.model = model
Expand All @@ -55,7 +55,7 @@ def __init__(self, model, data, val_data, **kwargs):
elif self.update_rule == 'momentum':
self.optimizer = tf.train.MomentumOptimizer
elif self.update_rule == 'rmsprop':
self.optimizer = tf.train.RMSPropOptimizer
self.optimizer = tf.train.RMSPropOptimizer

if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
Expand All @@ -65,7 +65,9 @@ def __init__(self, model, data, val_data, **kwargs):

def train(self):
# train/val dataset
n_examples = self.data['captions'].shape[0]
# Changed this because I keep less features than captions, see prepro
# n_examples = self.data['captions'].shape[0]
n_examples = self.data['features'].shape[0]
n_iters_per_epoch = int(np.ceil(float(n_examples)/self.batch_size))
features = self.data['features']
captions = self.data['captions']
Expand All @@ -74,37 +76,44 @@ def train(self):
n_iters_val = int(np.ceil(float(val_features.shape[0])/self.batch_size))

# build graphs for training model and sampling captions
loss = self.model.build_model()
tf.get_variable_scope().reuse_variables()
_, _, generated_captions = self.model.build_sampler(max_len=20)
# This scope fixed things!!
with tf.variable_scope(tf.get_variable_scope()):
loss = self.model.build_model()
tf.get_variable_scope().reuse_variables()
_, _, generated_captions = self.model.build_sampler(max_len=20)

# train op
with tf.name_scope('optimizer'):
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
optimizer = self.optimizer(learning_rate=self.learning_rate)
grads = tf.gradients(loss, tf.trainable_variables())
grads_and_vars = list(zip(grads, tf.trainable_variables()))
train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars)

# summary op
tf.scalar_summary('batch_loss', loss)

# summary op
# tf.scalar_summary('batch_loss', loss)
tf.summary.scalar('batch_loss', loss)
for var in tf.trainable_variables():
tf.histogram_summary(var.op.name, var)
#tf.histogram_summary(var.op.name, var)
tf.summary.histogram(var.op.name, var)
for grad, var in grads_and_vars:
tf.histogram_summary(var.op.name+'/gradient', grad)

summary_op = tf.merge_all_summaries()
#tf.histogram_summary(var.op.name+'/gradient', grad)
tf.summary.histogram(var.op.name+'/gradient', grad)

#summary_op = tf.merge_all_summaries()
summary_op = tf.summary.merge_all()

print "The number of epoch: %d" %self.n_epochs
print "Data size: %d" %n_examples
print "Batch size: %d" %self.batch_size
print "Iterations per epoch: %d" %n_iters_per_epoch

config = tf.ConfigProto(allow_soft_placement = True)
#config.gpu_options.per_process_gpu_memory_fraction=0.9
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
tf.initialize_all_variables().run()
summary_writer = tf.train.SummaryWriter(self.log_path, graph=tf.get_default_graph())
tf.global_variables_initializer().run()
#summary_writer = tf.train.SummaryWriter(self.log_path, graph=tf.get_default_graph())
summary_writer = tf.summary.FileWriter(self.log_path, graph=tf.get_default_graph())
saver = tf.train.Saver(max_to_keep=40)

if self.pretrained_model is not None:
Expand Down Expand Up @@ -138,7 +147,7 @@ def train(self):
ground_truths = captions[image_idxs == image_idxs_batch[0]]
decoded = decode_captions(ground_truths, self.model.idx_to_word)
for j, gt in enumerate(decoded):
print "Ground truth %d: %s" %(j+1, gt)
print "Ground truth %d: %s" %(j+1, gt)
gen_caps = sess.run(generated_captions, feed_dict)
decoded = decode_captions(gen_caps, self.model.idx_to_word)
print "Generated caption: %s\n" %decoded[0]
Expand All @@ -148,16 +157,16 @@ def train(self):
print "Elapsed time: ", time.time() - start_t
prev_loss = curr_loss
curr_loss = 0

# print out BLEU scores and file write
if self.print_bleu:
all_gen_cap = np.ndarray((val_features.shape[0], 20))
for i in range(n_iters_val):
features_batch = val_features[i*self.batch_size:(i+1)*self.batch_size]
feed_dict = {self.model.features: features_batch}
gen_cap = sess.run(generated_captions, feed_dict=feed_dict)
gen_cap = sess.run(generated_captions, feed_dict=feed_dict)
all_gen_cap[i*self.batch_size:(i+1)*self.batch_size] = gen_cap

all_decoded = decode_captions(all_gen_cap, self.model.idx_to_word)
save_pickle(all_decoded, "./data/val/val.candidate.captions.pkl")
scores = evaluate(data_path='./data', split='val', get_scores=True)
Expand All @@ -167,16 +176,16 @@ def train(self):
if (e+1) % self.save_every == 0:
saver.save(sess, os.path.join(self.model_path, 'model'), global_step=e+1)
print "model-%s saved." %(e+1)


def test(self, data, split='train', attention_visualization=True, save_sampled_captions=True):
'''
Args:
- data: dictionary with the following keys:
- features: Feature vectors of shape (5000, 196, 512)
- file_names: Image file names of shape (5000, )
- captions: Captions of shape (24210, 17)
- image_idxs: Indices for mapping caption to image of shape (24210, )
- captions: Captions of shape (24210, 17)
- image_idxs: Indices for mapping caption to image of shape (24210, )
- features_to_captions: Mapping feature to captions (5000, 4~5)
- split: 'train', 'val' or 'test'
- attention_visualization: If True, visualize attention weights with images for each sampled word. (ipthon notebook)
Expand All @@ -187,7 +196,7 @@ def test(self, data, split='train', attention_visualization=True, save_sampled_c

# build a graph to sample captions
alphas, betas, sampled_captions = self.model.build_sampler(max_len=20) # (N, max_len, L), (N, max_len)

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
Expand All @@ -208,7 +217,7 @@ def test(self, data, split='train', attention_visualization=True, save_sampled_c
plt.imshow(img)
plt.axis('off')

# Plot images with attention weights
# Plot images with attention weights
words = decoded[n].split(" ")
for t in range(len(words)):
if t > 18:
Expand All @@ -228,6 +237,6 @@ def test(self, data, split='train', attention_visualization=True, save_sampled_c
for i in range(num_iter):
features_batch = features[i*self.batch_size:(i+1)*self.batch_size]
feed_dict = { self.model.features: features_batch }
all_sam_cap[i*self.batch_size:(i+1)*self.batch_size] = sess.run(sampled_captions, feed_dict)
all_sam_cap[i*self.batch_size:(i+1)*self.batch_size] = sess.run(sampled_captions, feed_dict)
all_decoded = decode_captions(all_sam_cap, self.model.idx_to_word)
save_pickle(all_decoded, "./data/%s/%s.candidate.captions.pkl" %(split,split))
save_pickle(all_decoded, "./data/%s/%s.candidate.captions.pkl" %(split,split))

0 comments on commit 8fbcd33

Please sign in to comment.