forked from shuaihuaiyi/QA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
14 changed files
with
130 additions
and
5,937 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,118 @@ | ||
import readData | ||
import os | ||
|
||
import time | ||
|
||
import qaData | ||
import taevaluation | ||
import tensorflow as tf | ||
|
||
from qaLSTM import QaLstm | ||
|
||
|
||
def restore(): | ||
try: | ||
saver.restore(sess, trainedModel) | ||
except Exception as e: | ||
print("加载模型失败,重新开始训练") | ||
train() | ||
|
||
|
||
def train(): | ||
# 准备训练数据 | ||
qTrain, aTrain, lTrain, qIdTrain = qaData.loadData(trainingFile, word2idx, unrollSteps, True) | ||
qDevelop, aDevelop, lDevelop, qIdDevelop = qaData.loadData(developFile, word2idx, unrollSteps, True) | ||
trainQuestionCounts = qIdTrain[-1] + 1 | ||
for i in range(len(qIdDevelop)): | ||
qIdDevelop[i] += trainQuestionCounts | ||
tqs, tta, tfa = [], [], [] | ||
for question, trueAnswer, falseAnswer in qaData.batchIter(qTrain + qDevelop, aTrain + aDevelop, | ||
lTrain + lDevelop, qIdTrain + qIdDevelop, batchSize): | ||
tqs.append(question), tta.append(trueAnswer), tfa.append(falseAnswer) | ||
# 开始训练 | ||
sess.run(tf.global_variables_initializer()) | ||
for i in range(lrDownCount): | ||
optimizer = tf.train.GradientDescentOptimizer(learningRate) | ||
optimizer.apply_gradients(zip(grads, tvars)) | ||
trainOp = optimizer.apply_gradients(zip(grads, tvars), global_step=globalStep) | ||
for epoch in range(epochs): | ||
for question, trueAnswer, falseAnswer in zip(tqs, tta, tfa): | ||
startTime = time.time() | ||
feed_dict = { | ||
lstm.ori_input_quests: question, | ||
lstm.cand_input_quests: trueAnswer, | ||
lstm.neg_input_quests: falseAnswer, | ||
lstm.keep_prob: dropout | ||
} | ||
_, step, _, _, loss, acc = \ | ||
sess.run([trainOp, globalStep, lstm.ori_cand, lstm.ori_neg, lstm.loss, lstm.acc], feed_dict) | ||
timesUsed = time.time() - startTime | ||
print("step:", step, "loss:", loss, "acc:", acc, "time:", timeUsed) | ||
saver.save(sess, saveFile) | ||
learningRate *= lrDownRate | ||
|
||
|
||
if __name__ == '__main__': | ||
# 定义参数 | ||
trainingFile = "data/training.data" | ||
validFile = "data/develop.data" | ||
developFile = "data/develop.data" | ||
testFile = "data/testing.data" | ||
saveFile = "savedModel" | ||
saveFile = "newModel/savedModel" | ||
trainedModel = "trainedModel/savedModel" | ||
embeddingFile = "word2vec/zhwiki_2017_03.sg_50d.word2vec" | ||
embeddingSize = 50 #词向量的维度 | ||
embeddingSize = 50 # 词向量的维度 | ||
|
||
dropout = 1.0 | ||
learningRate = 0.4 | ||
batchSize = 20 # 每一批次处理的问题个数 | ||
epochs = 20 | ||
tf.flags.DEFINE_integer("rnn_size", 100, "rnn size") | ||
tf.flags.DEFINE_integer("num_rnn_layers", 1, "embedding size") | ||
tf.flags.DEFINE_integer("num_unroll_steps", 100, "句子中的最大词汇数目") | ||
tf.flags.DEFINE_integer("max_grad_norm", 5, "max grad norm") | ||
# Misc Parameters | ||
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") | ||
tf.flags.DEFINE_float("gpu_options", 0.75, "use memory rate") | ||
|
||
gpuMemUsage = 0.8 | ||
gpuDevice = "/gpu:0" | ||
|
||
# 读取数据 | ||
trainingList = readData.readFile(trainingFile) | ||
testList = readData.readFile(testFile) | ||
embeddingDict = readData.readEmbeddingFile(embeddingFile, embeddingSize) | ||
|
||
# 预处理 | ||
trainingVec = readData.textToVec(trainingList, embeddingDict) | ||
testVec = readData.textToVec(testList, embeddingDict) | ||
del embeddingDict # 减少内存占用 | ||
|
||
# 定义模型 todo | ||
learningRate = 0.4 # 学习速度 | ||
lrDownRate = 0.5 # 学习速度下降速度 | ||
lrDownCount = 4 # 学习速度下降次数 | ||
epochs = 20 # 每次学习速度指数下降之前执行的完整epoch次数 | ||
batchSize = 20 # 每一批次处理的<b>问题</b>个数 | ||
|
||
# 开始训练 | ||
with tf.Graph().as_default(): | ||
with tf.device(gpuDevice): | ||
gpuOptions = tf.GPUOptions(per_process_gpu_memory_fraction=gpuMemUsage) | ||
session_conf = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, | ||
log_device_placement=FLAGS.log_device_placement, | ||
gpu_options=gpuOptions) | ||
with tf.Session(config=session_conf).as_default() as sess: | ||
pass # todo | ||
|
||
# 评估 todo | ||
pass | ||
rnnSize = 100 # LSTM cell中隐藏层神经元的个数 | ||
|
||
unrollSteps = 100 # 句子中的最大词汇数目 | ||
max_grad_norm = 5 | ||
|
||
allow_soft_placement = True # Allow device soft device placement | ||
gpuMemUsage = 0.8 # 显存最大使用 | ||
gpuDevice = "/gpu:0" # GPU设备名 | ||
|
||
# 读取测试数据 | ||
embedding, word2idx, idx2word = qaData.loadEmbedding(embeddingFile, embeddingSize) | ||
qTest, aTest, _, qIdTest = qaData.loadData(testFile, word2idx, unrollSteps) | ||
|
||
# 配置TensorFlow | ||
with tf.Graph().as_default(), tf.device(gpuDevice): | ||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpuMemUsage) | ||
session_conf = tf.ConfigProto(allow_soft_placement=allow_soft_placement, gpu_options=gpu_options) | ||
with tf.Session(config=session_conf).as_default() as sess: | ||
# 加载LSTM网络 | ||
globalStep = tf.Variable(0, name="globalStep", trainable=False) | ||
lstm = QaLstm(batchSize, unrollSteps, embedding, embeddingSize, rnnSize) | ||
tvars = tf.trainable_variables() | ||
grads, _ = tf.clip_by_global_norm(tf.gradients(lstm.loss, tvars), max_grad_norm) | ||
saver = tf.train.Saver() | ||
|
||
# 加载模型或训练模型 | ||
if os.path.exists(trainedModel + '.index'): | ||
while (True): | ||
choice = input("找到已经训练好的模型,是否载入(y/n)") | ||
if choice.strip().lower() == 'y': | ||
restore() | ||
break | ||
elif choice.strip().lower() == 'n': | ||
choice = input("您真的确定吗?重新训练会消耗大量时间与硬件资源(yes/no)") | ||
if choice.strip().lower() == 'yes': | ||
train() | ||
break | ||
elif choice.strip().lower() == 'no': | ||
restore() | ||
break | ||
else: | ||
print("无效的输入!\n") | ||
else: | ||
print("无效的输入!\n") | ||
else: | ||
train() | ||
# 进行测试,输出结果 | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.