forked from brightmart/text_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathp8_TextRNN_predict.py
115 lines (111 loc) · 7.03 KB
/
p8_TextRNN_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
#prediction using model.
#process--->1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.predict
import sys
reload(sys)
sys.setdefaultencoding('utf8')
import tensorflow as tf
import numpy as np
from p8_TextRNN_model import TextRNN
from data_util_zhihu import load_data_predict,load_final_test_data,create_voabulary,create_voabulary_label
from tflearn.data_utils import pad_sequences #to_categorical
import os
import codecs
from p7_TextCNN_model import TextCNN
#configuration
FLAGS=tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer("num_classes",1999,"number of label")
tf.app.flags.DEFINE_float("learning_rate",0.01,"learning rate")
tf.app.flags.DEFINE_integer("batch_size", 80, "Batch size for training/evaluating.") #批处理的大小 32-->128
tf.app.flags.DEFINE_integer("decay_steps", 12000, "how many steps before decay learning rate.") #批处理的大小 32-->128
tf.app.flags.DEFINE_float("decay_rate", 0.9, "Rate of decay for learning rate.") #0.5一次衰减多少
tf.app.flags.DEFINE_string("ckpt_dir","text_rnn_checkpoint/","checkpoint location for the model")
tf.app.flags.DEFINE_integer("sequence_length",100,"max sentence length")
tf.app.flags.DEFINE_integer("embed_size",100,"embedding size")
tf.app.flags.DEFINE_boolean("is_training",False,"is traning.true:tranining,false:testing/inference")
tf.app.flags.DEFINE_string("traning_data_path","train-zhihu4-only-title-all.txt","path of traning data.") #train-zhihu4-only-title-all.txt.training-data/test-zhihu4-only-title.txt--->'training-data/train-zhihu5-only-title-multilabel.txt'
tf.app.flags.DEFINE_string("word2vec_model_path","zhihu-word2vec.bin-100","word2vec's vocabulary and vectors")
tf.app.flags.DEFINE_string("predict_target_file","text_rnn_checkpoint/zhihu_result_rnn5.csv","target file path for final prediction")
tf.app.flags.DEFINE_string("predict_source_file",'test-zhihu-forpredict-v4only-title.txt',"target file path for final prediction")
#1.load data(X:list of lint,y:int). 2.create session. 3.feed data. 4.training (5.validation) ,(6.prediction)
def main(_):
# 1.load data with vocabulary of words and labels
vocabulary_word2index, vocabulary_index2word = create_voabulary(simple='simple',word2vec_model_path=FLAGS.word2vec_model_path,name_scope="rnn")
vocab_size = len(vocabulary_word2index)
vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label(name_scope="rnn")
questionid_question_lists=load_final_test_data(FLAGS.predict_source_file)
test= load_data_predict(vocabulary_word2index,vocabulary_word2index_label,questionid_question_lists)
testX=[]
question_id_list=[]
for tuple in test:
question_id,question_string_list=tuple
question_id_list.append(question_id)
testX.append(question_string_list)
# 2.Data preprocessing: Sequence padding
print("start padding....")
testX2 = pad_sequences(testX, maxlen=FLAGS.sequence_length, value=0.) # padding to max length
print("end padding...")
# 3.create session.
config=tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
# 4.Instantiate Model
textRNN=TextRNN(FLAGS.num_classes, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.sequence_length,
vocab_size, FLAGS.embed_size, FLAGS.is_training)
saver=tf.train.Saver()
if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
print("Restoring Variables from Checkpoint for TextRNN")
saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
else:
print("Can't find the checkpoint.going to stop")
return
# 5.feed data, to get logits
number_of_training_data=len(testX2);print("number_of_training_data:",number_of_training_data)
index=0
predict_target_file_f = codecs.open(FLAGS.predict_target_file, 'a', 'utf8')
#for start, end in zip(range(0, number_of_training_data, FLAGS.batch_size),range(FLAGS.batch_size, number_of_training_data+1, FLAGS.batch_size)):
for start, end in zip(range(0, number_of_training_data, FLAGS.batch_size),range(FLAGS.batch_size, number_of_training_data+1, FLAGS.batch_size)):
logits=sess.run(textRNN.logits,feed_dict={textRNN.input_x:testX2[start:end],textRNN.dropout_keep_prob:1}) #'shape of logits:', ( 1, 1999)
# 6. get lable using logtis
#predicted_labels=get_label_using_logits(logits[0],vocabulary_index2word_label) #logits[0]
# 7. write question id and labels to file system.
#write_question_id_with_labels(question_id_list[index],predicted_labels,predict_target_file_f)
#############################################################################################################
print("start:",start,";end:",end)
question_id_sublist=question_id_list[start:end]
get_label_using_logits_batch(question_id_sublist, logits, vocabulary_index2word_label, predict_target_file_f)
########################################################################################################
index=index+1
predict_target_file_f.close()
# get label using logits
def get_label_using_logits(logits,vocabulary_index2word_label,top_number=5):
#print("get_label_using_logits:",logits)
print("get_label_using_logits.shape:", logits.shape) # (10, 1999))=[batch_size,num_labels]===>需要(10,5)
index_list=np.argsort(logits)[-top_number:] #print("sum_p", np.sum(1.0 / (1 + np.exp(-logits))))
index_list=index_list[::-1]
label_list=[]
for index in index_list:
label=vocabulary_index2word_label[index]
label_list.append(label) #('get_label_using_logits.label_list:', [u'-3423450385060590478', u'2838091149470021485', u'-3174907002942471215', u'-1812694399780494968', u'6815248286057533876'])
print("get_label_using_logits.label_list",label_list)
return label_list
# get label using logits
def get_label_using_logits_batch(question_id_sublist,logits_batch,vocabulary_index2word_label,f,top_number=5):
#print("get_label_using_logits.shape:", logits_batch.shape) # (10, 1999))=[batch_size,num_labels]===>需要(10,5)
for i,logits in enumerate(logits_batch):
index_list=np.argsort(logits)[-top_number:] #print("sum_p", np.sum(1.0 / (1 + np.exp(-logits))))
index_list=index_list[::-1]
label_list=[]
for index in index_list:
label=vocabulary_index2word_label[index]
label_list.append(label) #('get_label_using_logits.label_list:', [u'-3423450385060590478', u'2838091149470021485', u'-3174907002942471215', u'-1812694399780494968', u'6815248286057533876'])
#print("get_label_using_logits.label_list",label_list)
write_question_id_with_labels(question_id_sublist[i], label_list, f)
f.flush()
#return label_list
# write question id and labels to file system.
def write_question_id_with_labels(question_id,labels_list,f):
labels_string=",".join(labels_list)
f.write(question_id+","+labels_string+"\n")
if __name__ == "__main__":
tf.app.run()