Skip to content

Commit

Permalink
preprocess_caption
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Jan 31, 2017
1 parent 65fec0e commit 5268a04
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
5 changes: 1 addition & 4 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
dataset = '102flowers' #
need_256 = True # set to True for stackGAN

def preprocess_caption(line):
prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
prep_line = prep_line.replace('-', ' ')
return prep_line


if dataset == '102flowers':
"""
Expand Down
63 changes: 30 additions & 33 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F
gf_dim = 128

w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("generator", reuse=reuse):
Expand All @@ -203,27 +202,27 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F

if net_rnn_embed is not None:
net_rnn_embed = DenseLayer(net_rnn_embed, n_units=t_dim,
act=lambda x: tl.act.lrelu(x, 0.2), W_init = w_init, name='g_reduce_text/dense')
act=lambda x: tl.act.lrelu(x, 0.2), W_init=w_init, name='g_reduce_text/dense')
net_in = ConcatLayer([net_in, net_rnn_embed], concat_dim=1, name='g_concat_z_seq')
else:
print("No text info is used, i.e. DCGAN")

net_h0 = DenseLayer(net_in, gf_dim*8*s16*s16, act=tf.identity,
W_init=w_init, b_init=b_init, name='g_h0/dense')
net_h0 = ReshapeLayer(net_h0, [-1, s16, s16, gf_dim*8], name='g_h0/reshape')
W_init=w_init, b_init=None, name='g_h0/dense')
net_h0 = BatchNormLayer(net_h0, #act=tf.nn.relu,
is_train=is_train, gamma_init=gamma_init, name='g_h0/batch_norm')
net_h0 = ReshapeLayer(net_h0, [-1, s16, s16, gf_dim*8], name='g_h0/reshape')

net_h1 = Conv2d(net_h0, gf_dim*2, (1, 1), (1, 1),
padding='VALID', act=None, W_init=w_init, b_init=b_init, name='g_h1/conv2d')
padding='VALID', act=None, W_init=w_init, b_init=None, name='g_h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h1/batch_norm')
net_h2 = Conv2d(net_h1, gf_dim*2, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h2/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h2/batch_norm')
net_h3 = Conv2d(net_h2, gf_dim*8, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h3/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h3/conv2d')
net_h3 = BatchNormLayer(net_h3, # act=tf.nn.relu,
is_train=is_train, gamma_init=gamma_init, name='g_h3/batch_norm')
net_h3 = ElementwiseLayer(layer=[net_h3, net_h0], combine_fn=tf.add, name='g_h3/add')
Expand All @@ -233,20 +232,20 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F
net_h4 = UpSampling2dLayer(net_h3, size=[s8, s8], is_scale=False, method=1,
align_corners=False, name='g_h4/upsample2d')
net_h4 = Conv2d(net_h4, gf_dim*4, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h4/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h4/conv2d')
net_h4 = BatchNormLayer(net_h4,# act=tf.nn.relu,
is_train=is_train, gamma_init=gamma_init, name='g_h4/batch_norm')

net_h5 = Conv2d(net_h4, gf_dim, (1, 1), (1, 1),
padding='VALID', act=None, W_init=w_init, b_init=b_init, name='g_h5/conv2d')
padding='VALID', act=None, W_init=w_init, b_init=None, name='g_h5/conv2d')
net_h5 = BatchNormLayer(net_h5, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h5/batch_norm')
net_h6 = Conv2d(net_h5, gf_dim, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h6/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h6/conv2d')
net_h6 = BatchNormLayer(net_h6, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h6/batch_norm')
net_h7 = Conv2d(net_h6, gf_dim*4, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h7/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h7/conv2d')
net_h7 = BatchNormLayer(net_h7, #act=tf.nn.relu,
is_train=is_train, gamma_init=gamma_init, name='g_h7/batch_norm')
net_h7 = ElementwiseLayer(layer=[net_h7, net_h4], combine_fn=tf.add, name='g_h7/add')
Expand All @@ -257,7 +256,7 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F
net_h8 = UpSampling2dLayer(net_h7, size=[s4, s4], is_scale=False, method=1,
align_corners=False, name='g_h8/upsample2d')
net_h8 = Conv2d(net_h8, gf_dim*2, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h8/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h8/conv2d')
net_h8 = BatchNormLayer(net_h8, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h8/batch_norm')

Expand All @@ -266,7 +265,7 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F
net_h9 = UpSampling2dLayer(net_h8, size=[s2, s2], is_scale=False, method=1,
align_corners=False, name='g_h9/upsample2d')
net_h9 = Conv2d(net_h9, gf_dim, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_h9/conv2d')
padding='SAME', act=None, W_init=w_init, b_init=None, name='g_h9/conv2d')
net_h9 = BatchNormLayer(net_h9, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g_h9/batch_norm')

Expand All @@ -275,18 +274,17 @@ def generator_txt2img_resnet(input_z, net_rnn_embed=None, is_train=True, reuse=F
net_ho = UpSampling2dLayer(net_h9, size=[s, s], is_scale=False, method=1,
align_corners=False, name='g_ho/upsample2d')
net_ho = Conv2d(net_ho, c_dim, (3, 3), (1, 1),
padding='SAME', act=None, W_init=w_init, b_init=b_init, name='g_ho/conv2d')
padding='SAME', act=None, W_init=w_init, name='g_ho/conv2d')
logits = net_ho.outputs
net_ho.outputs = tf.nn.tanh(net_ho.outputs)
return net_ho, logits

def discriminator_txt2img_resnet(input_images, net_rnn_embed=None, is_train=True, reuse=False):
# https://github.com/hanzhanggit/StackGAN/blob/master/stageI/model.py
# Discriminator with ResNet : line 197 https://github.com/reedscot/icml2016/blob/master/main_cls.lua
w_init = tf.random_normal_initializer(stddev=0.02) # 73
b_init = None
gamma_init=tf.random_normal_initializer(1., 0.02) # 74
df_dim = 64 # 64 for flower, 196 for MSCOCO # number of conv in the first layer discriminator [196] https://github.com/reedscot/icml2016/blob/master/scripts/train_coco.sh

w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init=tf.random_normal_initializer(1., 0.02)
df_dim = 64 # 64 for flower, 196 for MSCOCO
s = 64 # output image size [64]
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

Expand All @@ -297,28 +295,28 @@ def discriminator_txt2img_resnet(input_images, net_rnn_embed=None, is_train=True
padding='SAME', W_init=w_init, name='d_h0/conv2d')

net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='d_h1/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='d_h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h1/batchnorm')
net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='d_h2/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='d_h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h2/batchnorm')
net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='d_h3/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='d_h3/conv2d')
net_h3 = BatchNormLayer(net_h3, #act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h3/batchnorm')

net_h = Conv2d(net_h3, df_dim*2, (1, 1), (1, 1), act=None,
padding='VALID', W_init=w_init, b_init=b_init, name='d_h3/conv2d2')
padding='VALID', W_init=w_init, b_init=None, name='d_h3/conv2d2')
net_h = BatchNormLayer(net_h, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h3/batchnorm2')
net_h = Conv2d(net_h, df_dim*2, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='d_h3/conv2d3')
padding='SAME', W_init=w_init, b_init=None, name='d_h3/conv2d3')
net_h = BatchNormLayer(net_h, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h3/batchnorm3')
net_h = Conv2d(net_h, df_dim*8, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='d_h3/conv2d4')
padding='SAME', W_init=w_init, b_init=None, name='d_h3/conv2d4')
net_h = BatchNormLayer(net_h, #act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h3/batchnorm4')
net_h3 = ElementwiseLayer(layer=[net_h3, net_h], combine_fn=tf.add, name='d_h3/add')
Expand All @@ -334,7 +332,7 @@ def discriminator_txt2img_resnet(input_images, net_rnn_embed=None, is_train=True
net_h3_concat = ConcatLayer([net_h3, net_reduced_text], concat_dim=3, name='d_h3_concat')
# 243 (ndf*8 + 128 or 256) x 4 x 4
net_h3 = Conv2d(net_h3_concat, df_dim*8, (1, 1), (1, 1),
padding='VALID', W_init=w_init, b_init=b_init, name='d_h3/conv2d_2')
padding='VALID', W_init=w_init, b_init=None, name='d_h3/conv2d_2')
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d_h3/batch_norm_2')
else:
Expand All @@ -349,7 +347,6 @@ def discriminator_txt2img_resnet(input_images, net_rnn_embed=None, is_train=True
def cnn_encoder_resnet(input_images, is_train=True, reuse=False, name='cnn'):
# https://github.com/hanzhanggit/StackGAN/blob/master/stageI/model.py d_encode_image
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None
gamma_init=tf.random_normal_initializer(1., 0.02)
df_dim = 64
with tf.variable_scope(name, reuse=reuse):
Expand All @@ -360,28 +357,28 @@ def cnn_encoder_resnet(input_images, is_train=True, reuse=False, name='cnn'):
padding='SAME', W_init=w_init, name='p_h0/conv2d')

net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='p_h1/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='p_h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h1/batchnorm')
net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='p_h2/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='p_h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h2/batchnorm')
net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='p_h3/conv2d')
padding='SAME', W_init=w_init, b_init=None, name='p_h3/conv2d')
net_h3 = BatchNormLayer(net_h3, #act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h3/batchnorm')

net_h = Conv2d(net_h3, df_dim*2, (1, 1), (1, 1), act=None,
padding='VALID', W_init=w_init, b_init=b_init, name='p_h3/conv2d2')
padding='VALID', W_init=w_init, b_init=None, name='p_h3/conv2d2')
net_h = BatchNormLayer(net_h, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h3/batchnorm2')
net_h = Conv2d(net_h, df_dim*2, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='p_h3/conv2d3')
padding='SAME', W_init=w_init, b_init=None, name='p_h3/conv2d3')
net_h = BatchNormLayer(net_h, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h3/batchnorm3')
net_h = Conv2d(net_h, df_dim*8, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='p_h3/conv2d4')
padding='SAME', W_init=w_init, b_init=None, name='p_h3/conv2d4')
net_h = BatchNormLayer(net_h, #act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='p_h3/batchnorm4')
net_h3 = ElementwiseLayer(layer=[net_h3, net_h], combine_fn=tf.add, name='p_h3/add')
Expand Down
14 changes: 6 additions & 8 deletions train_txt2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
from model import *




generator_txt2img = generator_txt2img_resnet
discriminator_txt2img = discriminator_txt2img_resnet
# generator_txt2img = generator_txt2img_resnet
# discriminator_txt2img = discriminator_txt2img_resnet
# # cnn_encoder = cnn_encoder_resnet # for text-image mapping

os.system("mkdir samples")
Expand Down Expand Up @@ -65,7 +63,6 @@
images_test_256 = np.array(images_test_256)
images_train = np.array(images_train)
images_test = np.array(images_test)
# exit()

###======================== DEFIINE MODEL ===================================###

Expand Down Expand Up @@ -210,15 +207,16 @@
# sample_sentence = captions_ids_test[0:sample_size]
for i, sentence in enumerate(sample_sentence):
print("seed: %s" % sentence)
sentence = preprocess_caption(sentence)
sample_sentence[i] = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(sentence)] + [vocab.end_id] # add END_ID
# sample_sentence[i] = [vocab.word_to_id(word) for word in sentence]
# print(sample_sentence[i])
sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post')


n_epoch = 1000 # 600 when pre-trained rnn
n_epoch = 600
print_freq = 1
n_batch_epoch = int(n_captions_train / batch_size)
n_batch_epoch = int(n_images_train / batch_size)

for epoch in range(n_epoch+1):
start_time = time.time()

Expand Down
19 changes: 18 additions & 1 deletion train_uim2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,24 @@
import argparse

## Load Oxford 102 flowers dataset
from data_loader import *
# from data_loader import *
import pickle

with open("_vocab.pickle", 'rb') as f:
vocab = pickle.load(f)
with open("_image_train.pickle", 'rb') as f:
images_train_256, images_train = pickle.load(f)
with open("_image_test.pickle", 'rb') as f:
images_test_256, images_test = pickle.load(f)
with open("_n.pickle", 'rb') as f:
n_captions_train, n_captions_test, n_captions_per_image, n_images_train, n_images_test = pickle.load(f)
with open("_caption.pickle", 'rb') as f:
captions_ids_train, captions_ids_test = pickle.load(f)
images_train_256 = np.array(images_train_256)
images_test_256 = np.array(images_test_256)
images_train = np.array(images_train)
images_test = np.array(images_test)


generator_txt2img = generator_txt2img_resnet

Expand Down
6 changes: 6 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import scipy
import scipy.misc
import numpy as np
import re
import string

""" The functions here will be merged into TensorLayer after finishing this project.
"""
Expand Down Expand Up @@ -37,6 +39,10 @@ def get_random_int(min=0, max=10, number=5):
"""
return [random.randint(min,max) for p in range(0,number)]

def preprocess_caption(line):
prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
prep_line = prep_line.replace('-', ' ')
return prep_line


## Save images
Expand Down

0 comments on commit 5268a04

Please sign in to comment.