Skip to content

Commit

Permalink
fixing numerical bugs and add two time-scale rule.
Browse files Browse the repository at this point in the history
  • Loading branch information
desire2020 committed Nov 9, 2018
1 parent ecce009 commit decfd1f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def main():
target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

mediator = Generator(vocab_size, BATCH_SIZE*2, EMB_DIM*2, HIDDEN_DIM*2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE)
mediator = Generator(vocab_size, BATCH_SIZE*2, EMB_DIM*2, HIDDEN_DIM*2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE, learning_rate=3e-3)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
Expand Down
14 changes: 6 additions & 8 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class Generator(object):
def __init__(self, num_emb, batch_size, emb_dim, hidden_dim,
sequence_length, start_token,
learning_rate=1e-2, reward_gamma=0.95, name="generator", dropout_rate=0.5):
learning_rate=1e-3, reward_gamma=0.95, name="generator", dropout_rate=0.5):
self.num_emb = num_emb
self.batch_size = batch_size
self.emb_dim = emb_dim
Expand All @@ -18,7 +18,6 @@ def __init__(self, num_emb, batch_size, emb_dim, hidden_dim,
self.d_params = []
self.temperature = 1.0
self.create_recurrent_unit = self.create_recurrent_unit_LSTM
self.grad_clip = 5.0
self.name = name
self.dropout_keep_rate = tf.Variable(float(1.0), trainable=False)
self.dropout_on = self.dropout_keep_rate.assign(dropout_rate)
Expand Down Expand Up @@ -115,21 +114,20 @@ def _pretrain_recurrence(i, x_t, h_tm1, g_predictions, log_predictions):
) / (self.sequence_length * self.batch_size)

# training updates
pretrain_opt = self.g_optimizer(self.learning_rate)
pretrain_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.99)

self.pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(self.likelihood_loss, self.g_params), self.grad_clip)
self.likelihood_updates = pretrain_opt.apply_gradients(zip(self.pretrain_grad, self.g_params))

self.likelihood_updates = pretrain_opt.minimize(self.likelihood_loss, var_list=self.g_params)

#######################################################################################################
# Unsupervised Training
#######################################################################################################
self.g_loss = -tf.reduce_sum(
self.g_predictions * (self.rewards - self.log_predictions)
) / batch_size #/ sequence_length
g_opt = self.g_optimizer(self.learning_rate)
g_opt = self.g_optimizer(self.learning_rate, beta1=0.9, beta2=0.99)

self.g_grad, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params), self.grad_clip)
self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))
self.g_updates = g_opt.minimize(self.g_loss, var_list=self.g_params)

def generate(self, sess):
outputs = sess.run(self.gen_x)
Expand Down

0 comments on commit decfd1f

Please sign in to comment.