Skip to content

Commit

Permalink
修复bug
Browse files Browse the repository at this point in the history
应该没有bug了,可以拿来调参
  • Loading branch information
shuaihuaiyi committed Jul 27, 2017
1 parent 9d7dd77 commit 026ed5d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,4 @@ ENV/
# User defined files
*.score
word2vec/zhwiki_2017_03.sg_50d.word2vec
*.model
/model/
9 changes: 5 additions & 4 deletions execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
tf.flags.DEFINE_string("test_file", "data/testing.data", "test corpus file")
tf.flags.DEFINE_string("valid_file", "data/develop.data", "test corpus file")
tf.flags.DEFINE_string("result_file", "predictRst.score", "result file")
saveFile = "savedModel"
saveFile = "model/savedModel"
tf.flags.DEFINE_string("embedding_file", "word2vec\zhwiki_2017_03.sg_50d.word2vec", "embedding file")
tf.flags.DEFINE_integer("embedding_size", 50, "embedding size")
tf.flags.DEFINE_float("dropout", 1, "the proportion of dropout")
Expand Down Expand Up @@ -103,14 +103,14 @@ def valid_model(sess, lstm, valid_questions, valid_answers, valid_file, result_f
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(lstm.loss, tvars),
FLAGS.max_grad_norm)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

tqs, tta, tfa = [], [], []
for ori_train, cand_train, neg_train in batch_iter(train_questions, train_answers,
train_labels, train_questionId, FLAGS.batch_size):
tqs.append(ori_train), tta.append(cand_train), tfa.append(neg_train)
saver = tf.train.Saver()
for i in range(1,4):
for i in range(3):
train_op = tf.train.GradientDescentOptimizer(learningRate).apply_gradients(zip(grads, tvars),
global_step=global_step)
for epoch in range(FLAGS.epochs):
Expand All @@ -119,4 +119,5 @@ def valid_model(sess, lstm, valid_questions, valid_answers, valid_file, result_f
valid_model(sess, lstm, valid_questions, valid_answers, FLAGS.valid_file, FLAGS.result_file)
saver.save(sess, saveFile + str(i*FLAGS.epochs+epoch) + '.model')
learningRate /= 2
valid_model(sess, lstm, test_questions, test_answers, FLAGS.test_file, FLAGS.result_file, False)
#saver.restore(sess,saveFile+"20.model")
valid_model(sess, lstm, test_questions, test_answers, FLAGS.test_file, FLAGS.result_file,False)

0 comments on commit 026ed5d

Please sign in to comment.