Skip to content

Commit

Permalink
Incorporating the changes for pre-calculating the image features
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilmaram committed Jun 1, 2018
1 parent b3cedb8 commit 3b5f6d0
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 32 deletions.
8 changes: 4 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
class Config(object):
def __init__(self):
## Questions and Annotataions JSON files
self.DATA_DIR ='./datasets/'
self.DATA_DIR ='../datasets/'
self.TRAIN_QUESTIONS_FILE='v2_OpenEnded_mscoco_train2014_questions.json'
self.TRAIN_ANNOTATIONS_FILE='v2_mscoco_train2014_annotations.json'
#self.TRAIN_IMAGE_DIR = self.DATA_DIR + '/train2014/'
self.TRAIN_IMAGE_DIR = '/Users/sainikhilmaram/Desktop/train2014'
self.TRAIN_IMAGE_DIR = self.DATA_DIR + '/train2014/'
#self.TRAIN_IMAGE_DIR = '/Users/sainikhilmaram/Desktop/train2014'


self.VAL_QUESTIONS_FILE='v2_OpenEnded_mscoco_val2014_questions.json'
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self):

## Testing Parameters
self.TEST_QUESTION_FILE = 'test_question_file.txt'
self.TEST_IMAGE_DIR = 'test_image_dir/'
self.TEST_IMAGE_DIR = self.DATA_DIR+'test_image_dir/'


## LSTM parameters
Expand Down
27 changes: 17 additions & 10 deletions vqa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ class DataSet(object):
def __init__(self,
image_id_list,
image_file_list,
question_id_list,
question_idxs_list,
question_masks_list,
question_id_list=None,
question_idxs_list=None,
question_masks_list=None,
question_type_list=None,
answer_id_list=None,
answer_idxs_list=None,
answer_masks_list=None,
answer_type_list=None,
batch_size=1,
is_train=False,
phase="train",
shuffle=False):


Expand All @@ -33,7 +33,7 @@ def __init__(self,
self.answer_type_list = np.array(answer_type_list)

self.batch_size = batch_size
self.is_train = is_train
self.phase = phase
self.shuffle = shuffle
self.setup()

Expand Down Expand Up @@ -66,18 +66,25 @@ def next_batch(self):

image_files = self.image_file_list[current_idxs]
image_idxs = self.image_id_list[current_idxs]
question_idxs = self.question_idxs_list[current_idxs]
question_masks = self.question_masks_list[current_idxs]


if self.is_train:
if self.phase == "train":
question_idxs = self.question_idxs_list[current_idxs]
question_masks = self.question_masks_list[current_idxs]
answer_idxs = self.answer_idxs_list[current_idxs]
answer_masks = self.answer_masks_list[current_idxs]
self.current_idx += self.batch_size
return image_files,image_idxs, question_idxs, question_masks, answer_idxs, answer_masks
else:
elif self.phase == "test":
question_idxs = self.question_idxs_list[current_idxs]
question_masks = self.question_masks_list[current_idxs]
self.current_idx += self.batch_size
return image_files,image_idxs,question_idxs,question_masks
elif self.phase == "cnn_features":
self.current_idx += self.batch_size
return image_files,question_idxs,question_masks
return image_files, image_idxs



def has_next_batch(self):
""" Determine whether there is a batch left. """
Expand Down
17 changes: 9 additions & 8 deletions vqa_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,31 @@ def assign_args(args):
# Build the model
model.build()
sess.run(tf.global_variables_initializer())

## Load the Pre-trained CNN file
model.encoder.cnn.load_cnn(sess,config.CNN_PRETRAINED_FILE)
if (config.LOAD_MODEL):
model.load(sess,config.MODEL_FILE_NAME)
# Train the data with the data set and embedding matrix
model.train(sess,data_set)

elif config.PHASE=="cnn_features":

## Create Vocabulary object
vocabulary = Vocabulary(config)
## Build the vocabulary to get the indexes
vocabulary.build(config.DATA_DIR + config.TRAIN_QUESTIONS_FILE)
# ## Create Vocabulary object
# vocabulary = Vocabulary(config)
# ## Build the vocabulary to get the indexes
# vocabulary.build(config.DATA_DIR + config.TRAIN_QUESTIONS_FILE)

## Create the data set
data_set = prepare_train_data(config, vocabulary)
data_set = prepare_cnn_data(config)

model = vqa_model_static_cnn(config)
model.build()
sess.run(tf.global_variables_initializer())
## Load Pre-trained CNN file
model.cnn.load_cnn(sess, config.CNN_PRETRAINED_FILE)
model.train(sess,data_set)




elif config.PHASE == 'test':
config.set_batch_size(1)
print("Config.LSTM Size : {}".format(config.LSTM_BATCH_SIZE))
Expand Down
13 changes: 6 additions & 7 deletions vqa_model_static_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def train(self, sess, train_data):
self.fc_dict = {}


#for _ in tqdm(list(range(train_data.num_batches)), desc='batch'):
for _ in tqdm(list(range(self.config.NUM_BATCHES)), desc='batch'):
for _ in tqdm(list(range(train_data.num_batches)), desc='batch'):
#for _ in tqdm(list(range(self.config.NUM_BATCHES)), desc='batch'):
batch = train_data.next_batch()
image_files, image_idxs, _, _, _, _ = batch
image_files, image_idxs = batch
images = self.image_loader.load_images(image_files)

feed_dict = {self.images:images}
Expand All @@ -61,11 +61,10 @@ def train(self, sess, train_data):
## Save conv5_3 and fc2 into two dictionaries
i = 0
for idx in image_idxs:
if idx not in self.conv_dict:
self.conv_dict[str(idx)] = self.conv5_3[i]

if idx not in self.fc_dict:
self.fc_dict[str(idx)] = self.fc2[i]
self.conv_dict[str(idx)] = self.conv5_3[i]

self.fc_dict[str(idx)] = self.fc2[i]

i = i + 1

Expand Down
21 changes: 18 additions & 3 deletions vqa_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def prepare_train_data(config,vocabulary):
answer_masks_list,
answer_type_list,
config.BATCH_SIZE,
True,
config.PHASE,
True)
print("Training Data prepared")
return dataset
Expand All @@ -283,7 +283,7 @@ def prepare_test_data(config,vocabulary):
question_masks[config.MAX_QUESTION_LENGTH - question_num_words:] = 1

## Get the Image Files, Currently we will have only one
files = os.listdir(config.DATA_DIR + config.TEST_IMAGE_DIR)
files = os.listdir(config.TEST_IMAGE_DIR)
image_file_list = [os.path.join(config.TEST_IMAGE_DIR, f) for f in files
if f.lower().endswith('.jpg') or f.lower().endswith('.jpeg')]

Expand All @@ -299,7 +299,7 @@ def prepare_test_data(config,vocabulary):
question_id_list,
question_idxs_list,
question_masks_list,
batch_size=1,is_train=False)
batch_size=1,phase=config.PHASE)
print("Testing Data prepared")

## Get the Top answers
Expand All @@ -309,3 +309,18 @@ def prepare_test_data(config,vocabulary):
return dataset,top_answers


def prepare_cnn_data(config):
files = os.listdir(config.TRAIN_IMAGE_DIR)
image_file_list = [os.path.join(config.TRAIN_IMAGE_DIR, f) for f in files
if f.lower().endswith('.jpg') or f.lower().endswith('.jpeg')]

image_id_list = [f[:-4] for f in files]
dataset = DataSet(image_id_list,
image_file_list,
batch_size=config.BATCH_SIZE, phase=config.PHASE)
print("CNN Dataset prepared")

return dataset



0 comments on commit 3b5f6d0

Please sign in to comment.