Skip to content

Commit

Permalink
added CuDNN support.
Browse files Browse the repository at this point in the history
  • Loading branch information
desire2020 committed Nov 18, 2018
1 parent 3707088 commit ace616c
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 16 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
__pycache__
save/*.txt
save/*.txt
saved_model
.idea
33 changes: 23 additions & 10 deletions cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from dataloader import Gen_Data_loader
from generator import Generator
from mediator import Mediator
from target_lstm import TARGET_LSTM
import pickle

Expand All @@ -17,6 +18,7 @@
SEED = 88
BATCH_SIZE = 64
M_DROPOUT_RATE = 0.5 # Dropout rate of M (optional)
RESTORE = False

#########################################################################################
# Basic Training Parameters
Expand Down Expand Up @@ -97,6 +99,7 @@ def main():
assert START_TOKEN == 0

gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
gan_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing
vocab_size = 5000
Expand All @@ -105,7 +108,9 @@ def main():
target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

mediator = Generator(vocab_size, BATCH_SIZE, EMB_DIM*2, HIDDEN_DIM*2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE, learning_rate=3e-3)
mediator = Mediator(vocab_size, BATCH_SIZE, EMB_DIM * 2, HIDDEN_DIM * 2, SEQ_LENGTH, START_TOKEN,
name="mediator", dropout_rate=M_DROPOUT_RATE, learning_rate=3e-3,
with_professor_forcing=False)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
Expand All @@ -115,6 +120,7 @@ def main():
# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
gen_data_loader.create_batches(positive_file)
gan_data_loader.create_batches(positive_file)
generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file)
val_data_loader.create_batches(eval_file)

Expand All @@ -124,6 +130,9 @@ def main():
# pre-train generator (default 0 epochs)(not recommended)
print('Start pre-training...')
log.write('pre-training...\n')
saver = tf.train.Saver(tf.global_variables())
if RESTORE:
saver.restore(sess, "saved_model/CoT")
for epoch in range(PRE_EPOCH_NUM):
loss = mle_epoch(sess, generator, gen_data_loader)
if epoch % 1 == 0:
Expand Down Expand Up @@ -164,16 +173,20 @@ def main():
# Train the mediator
for _ in range(1):
bnll_ = []
collected_x = []
ratio = 1
for it in range(ratio):
collected_x.extend([gen_data_loader.next_batch(), generator.generate(sess)])
collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH])
np.random.shuffle(collected_x)
collected_x = np.reshape(collected_x, [-1, BATCH_SIZE*2, SEQ_LENGTH])
"""
d_loss_ = []
for it in range(3):
feed = {
mediator.x0: gan_data_loader.next_batch(),
mediator.x1: generator.generate(sess)
}
d_loss, _ = sess.run([mediator.d_loss, mediator.d_update], feed)
d_loss_.append(d_loss)
"""
for it in range(1):
feed = {
mediator.x: collected_x[it],
mediator.x0: gen_data_loader.next_batch(),
mediator.x1: generator.generate(sess)
}
bnll = sess.run(mediator.likelihood_loss, feed)
bnll_.append(bnll)
Expand All @@ -188,7 +201,7 @@ def main():
jsd = jsd_calculate(sess, generator, target_lstm)
print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch, 'jsd ', jsd)
log_jsd.write("%d\t%f\n" % (iter_idx // gen_data_loader.num_batch, jsd))

saver.save(sess, "saved_model/CoT")
log.close()
log_nll.close()
log_jsd.close()
Expand Down
9 changes: 4 additions & 5 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,12 @@ def _pretrain_recurrence(i, x_t, h_tm1, g_predictions, log_predictions):
self.med_log_likelihood = tf.cumsum(tf.reduce_sum(self.rewards * one_hot_x, axis=-1), exclusive=True, axis=-1)
# pretraining loss
self.likelihood_loss = -tf.reduce_sum(
tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_emb]), 1e-20, 1.0)
)
tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) *
tf.reshape(self.log_predictions, [-1, self.num_emb])
) / (self.sequence_length * batch_size)

# training updates
pretrain_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.99)
pretrain_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.95)


self.likelihood_updates = pretrain_opt.minimize(self.likelihood_loss, var_list=self.g_params)
Expand All @@ -129,7 +128,7 @@ def _pretrain_recurrence(i, x_t, h_tm1, g_predictions, log_predictions):
self.g_loss = -tf.reduce_sum(
self.g_predictions * (self.rewards - self.log_predictions)
) / batch_size / sequence_length
g_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.99)
g_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.95)

self.g_updates = g_opt.minimize(self.g_loss, var_list=self.g_params)

Expand Down
119 changes: 119 additions & 0 deletions mediator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import tensorflow as tf
from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import CudnnLSTM, CUDNN_RNN_BIDIRECTION
from tensorflow.contrib import layers
class Critic(object):
def __call__(self, h):
# sequence -> [b, l, v]
_, l, v = h.get_shape().as_list()
h = tf.reshape(h, [-1, l, 1, v])
with tf.variable_scope("textmover", reuse=tf.AUTO_REUSE):
h0 = layers.convolution2d(
h, v, [4, 1], [2, 1],
activation_fn=tf.nn.softplus
)
h1 = layers.convolution2d(
h0, v, [4, 1], [1, 1],
activation_fn=tf.nn.softplus
)
h2 = layers.convolution2d(
h1, v, [4, 1], [2, 1],
activation_fn=tf.nn.softplus
)
h = layers.flatten(h2)
h = layers.fully_connected(
h, 1, activation_fn=tf.identity
)
return h

class Mediator(object):
def __init__(self, num_emb, batch_size, emb_dim, hidden_dim,
sequence_length, start_token,
learning_rate=1e-3, reward_gamma=0.95, name="mediator", dropout_rate=0.5, with_professor_forcing=False):
self.num_emb = num_emb
# self.batch_size = batch_size
self.emb_dim = emb_dim
self.hidden_dim = hidden_dim
self.sequence_length = sequence_length
self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
self.reward_gamma = reward_gamma
self.g_params = []
self.d_params = []
self.temperature = 1.0
self.name = name
self.dropout_keep_rate = tf.Variable(float(1.0), trainable=False)
self.dropout_on = self.dropout_keep_rate.assign(dropout_rate)
self.dropout_off = self.dropout_keep_rate.assign(1.0)
self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

self.x0 = tf.placeholder(tf.int32, shape=[None, self.sequence_length])
self.x = self.x0
self.x1 = tf.placeholder(tf.int32, shape=[None, self.sequence_length])
input_x0 = tf.pad(self.x0, [[0, 0], [1, 0]])[:, 0:self.sequence_length]
input_x1 = tf.pad(self.x1, [[0, 0], [1, 0]])[:, 0:self.sequence_length]
output_x0 = tf.one_hot(
self.x0, self.num_emb, on_value=1.0, off_value=0.0
)
output_x1 = tf.one_hot(
self.x1, self.num_emb, on_value=1.0, off_value=0.0
)
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
embedding = tf.get_variable(
name="word_embeddings",
initializer=tf.random_normal(shape=[self.num_emb, self.emb_dim], stddev=0.1)
)
Wo = tf.get_variable(
name="Weight_output",
initializer=tf.random_normal(shape=[self.hidden_dim, self.num_emb], stddev=0.1)
)
bo = tf.get_variable(
name="bias_output",
initializer=tf.random_normal(shape=[self.num_emb], stddev=0.1)
)
rnn = CudnnLSTM(
num_layers=1,
num_units=self.hidden_dim,
kernel_initializer=tf.orthogonal_initializer()
)
def language_modeling(input_x):
with tf.variable_scope("language_model", reuse=tf.AUTO_REUSE):
emb_x = tf.nn.embedding_lookup(
embedding, input_x
)
emb_x = tf.transpose(emb_x, [1, 0, 2])
h, _ = rnn(emb_x)
h = tf.transpose(h, [1, 0, 2])
h = tf.nn.dropout(h, self.dropout_keep_rate)
pred = tf.nn.log_softmax(
tf.reshape(h, [-1, self.hidden_dim]) @ Wo + bo,
axis=-1)
return h, tf.reshape(pred, [-1, self.sequence_length, self.num_emb])
self.h0, self.log_predictions = language_modeling(input_x0)
self.h1, self.log_predictions_ = language_modeling(input_x1)
self.likelihood_loss = -tf.reduce_mean(
tf.reduce_sum(
self.log_predictions * output_x0 +
self.log_predictions_ * output_x1, axis=-1)
) / 2.0
self.m_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.95)
if with_professor_forcing:
with tf.variable_scope("professor_forcing", reuse=tf.AUTO_REUSE):
critic = Critic()
myu = tf.random_uniform(shape=[tf.shape(self.x0)[0], self.sequence_length, 1],
minval=0.0, maxval=1.0)
hybrid = self.h0 * myu + self.h1 * (1.0 - myu)
gp = tf.reduce_mean(tf.nn.relu(tf.norm(
tf.reshape(tf.gradients(critic(hybrid), [hybrid])[0], [tf.shape(self.x0)[0], -1]),
axis=-1) - 1.0) ** 2)
self.d_loss = tf.reduce_mean(critic(self.h0) - critic(self.h1))
self.d_opt = tf.train.AdamOptimizer(1e-4, beta1=0.5, beta2=0.9)
self.d_params = [v for v in tf.trainable_variables() if "professor_forcing" in v.name]
self.d_update = self.d_opt.minimize(self.d_loss + 5.0 * gp, var_list=self.d_params)
self.m_params = [v for v in tf.trainable_variables() if name in v.name]
if not with_professor_forcing:
self.likelihood_updates = self.m_opt.minimize(self.likelihood_loss, var_list=self.m_params)
else:
self.likelihood_updates = self.m_opt.minimize(self.likelihood_loss - self.d_loss, var_list=self.m_params)

def get_reward(self, sess, x):
output = sess.run(self.log_predictions, feed_dict={self.x0: x})
return output

0 comments on commit ace616c

Please sign in to comment.