Skip to content

Commit

Permalink
small change
Browse files Browse the repository at this point in the history
  • Loading branch information
test authored and test committed Apr 23, 2018
1 parent 2ddaaa1 commit a266ade
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
4 changes: 2 additions & 2 deletions a02_TextCNN/p7_TextCNN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def instantiate_weights(self):
self.b_projection = tf.get_variable("b_projection",shape=[self.num_classes]) #[label_size] #ADD 2017.06.09

def inference(self):
"""main computation graph here: 1.embedding-->2.CONV-RELU-MAX_POOLING-->3.linear classifier"""
"""main computation graph here: 1.embedding-->2.CONV-BN-RELU-MAX_POOLING-->3.linear classifier"""
# 1.=====>get emebedding of words in the sentence
self.embedded_words = tf.nn.embedding_lookup(self.Embedding,self.input_x)#[None,sentence_length,embed_size]
self.sentence_embeddings_expanded=tf.expand_dims(self.embedded_words,-1) #[None,sentence_length,embed_size,1). expand dimension so meet input requirement of 2d-conv
Expand Down Expand Up @@ -99,7 +99,7 @@ def inference(self):
#4.=====>add dropout: use tf.nn.dropout
with tf.name_scope("dropout"):
self.h_drop=tf.nn.dropout(self.h_pool_flat,keep_prob=self.dropout_keep_prob) #[None,num_filters_total]

self.h_drop=tf.layers.dense(self.h_drop,self.num_filters_total,activation=tf.nn.tanh,use_bias=True)
#5. logits(use linear layer)and predictions(argmax)
with tf.name_scope("output"):
logits = tf.matmul(self.h_drop,self.W_projection) + self.b_projection #shape:[None, self.num_classes]==tf.matmul([None,self.embed_size],[self.embed_size,self.num_classes])
Expand Down
21 changes: 12 additions & 9 deletions a02_TextCNN/p7_TextCNN_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# -*- coding: utf-8 -*-
#import sys
#reload(sys)
#sys.setdefaultencoding('utf-8') #gb2312
#training the model.
#process--->1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.training (5.validation) ,(6.prediction)
#import sys
Expand All @@ -14,11 +17,11 @@
#configuration
FLAGS=tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("traning_data_path","../data/train_label_single100_merge.txt","path of traning data.") #sample_multiple_label.txt-->train_label_single100_merge
tf.app.flags.DEFINE_string("traning_data_path","../data/sample_multiple_label.txt","path of traning data.") #sample_multiple_label.txt-->train_label_single100_merge
tf.app.flags.DEFINE_integer("vocab_size",100000,"maximum vocab size.")

tf.app.flags.DEFINE_float("learning_rate",0.0001,"learning rate")
tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size for training/evaluating.") #批处理的大小 32-->128
tf.app.flags.DEFINE_float("learning_rate",0.0003,"learning rate")
tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size for training/evaluating.") #批处理的大小 32-->128
tf.app.flags.DEFINE_integer("decay_steps", 1000, "how many steps before decay learning rate.") #6000批处理的大小 32-->128
tf.app.flags.DEFINE_float("decay_rate", 1.0, "Rate of decay for learning rate.") #0.65一次衰减多少
tf.app.flags.DEFINE_string("ckpt_dir","text_cnn_title_desc_checkpoint/","checkpoint location for the model")
Expand All @@ -32,7 +35,7 @@
tf.app.flags.DEFINE_string("word2vec_model_path","word2vec-title-desc.bin","word2vec's vocabulary and vectors")
tf.app.flags.DEFINE_string("name_scope","cnn","name scope value.")
tf.app.flags.DEFINE_boolean("multi_label_flag",True,"use multi label or single label.")
filter_sizes=[7]
filter_sizes=[6,7,8]

#1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.training (5.validation) ,(6.prediction)
def main(_):
Expand Down Expand Up @@ -61,9 +64,9 @@ def main(_):
if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
print("Restoring Variables from Checkpoint.")
saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
for i in range(3): #decay learning rate if necessary.
print(i,"Going to decay learning rate by half.")
sess.run(textCNN.learning_rate_decay_half_op)
#for i in range(3): #decay learning rate if necessary.
# print(i,"Going to decay learning rate by half.")
# sess.run(textCNN.learning_rate_decay_half_op)
else:
print('Initializing Variables')
sess.run(tf.global_variables_initializer())
Expand Down Expand Up @@ -91,7 +94,7 @@ def main(_):
print("Epoch %d\tBatch %d\tTrain Loss:%.3f\tLearning rate:%.5f" %(epoch,counter,loss/float(counter),lr))

########################################################################################################
if start%(1000*FLAGS.batch_size)==0: # eval every 3000 steps.
if start%(2000*FLAGS.batch_size)==0: # eval every 3000 steps.
eval_loss, f1_score, precision, recall = do_eval(sess, textCNN, testX, testY,iteration)
print("Epoch %d Validation Loss:%.3f\tF1 Score:%.3f\tPrecision:%.3f\tRecall:%.3f" % (epoch, eval_loss, f1_score, precision, recall))
# save model to checkpoint
Expand All @@ -112,7 +115,7 @@ def main(_):
saver.save(sess,save_path,global_step=epoch)

# 5.最后在测试集上做测试,并报告测试准确率 Test
test_loss,_,_,_ = do_eval(sess, textCNN, testX, testY)
test_loss,_,_,_ = do_eval(sess, textCNN, testX, testY,iteration)
print("Test Loss:%.3f" % ( test_loss))
pass

Expand Down

0 comments on commit a266ade

Please sign in to comment.