Skip to content

Commit

Permalink
加入biLTSM测试模型
Browse files Browse the repository at this point in the history
  • Loading branch information
shuaihuaiyi committed Jul 22, 2017
1 parent d0c4c3f commit d1aea79
Show file tree
Hide file tree
Showing 7 changed files with 506 additions and 11 deletions.
6 changes: 6 additions & 0 deletions .idea/encodings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions bilstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# coding:utf-8

import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell


# define lstm model and return related features


# return n outputs of the n lstm cells
def biLSTM(x, hidden_size):
# biLSTM:
# 功能:添加bidirectional_lstm操作
# 参数:
# x: [batch, height, width] / [batch, step, embedding_size]
# hidden_size: lstm隐藏层节点个数
# 输出:
# output: [batch, height, 2*hidden_size] / [batch, step, 2*hidden_size]

# input transformation
input_x = tf.transpose(x, [1, 0, 2])
# input_x = tf.reshape(input_x, [-1, w])
# input_x = tf.split(0, h, input_x)
input_x = tf.unstack(input_x)

# define the forward and backward lstm cells
lstm_fw_cell = rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0, state_is_tuple=True)
lstm_bw_cell = rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0, state_is_tuple=True)
output, _, _ = rnn.static_bidirectional_rnn (lstm_fw_cell, lstm_bw_cell, input_x, dtype=tf.float32)

# output transformation to the original tensor type
output = tf.stack(output)
output = tf.transpose(output, [1, 0, 2])
return output
154 changes: 154 additions & 0 deletions data_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# coding=utf-8

import codecs
import logging
import numpy as np
import os

from collections import defaultdict

# define a logger
logging.basicConfig(format="%(message)s", level=logging.INFO)


def load_embedding(filename, embedding_size):
"""
load embedding
"""
embeddings = []
word2idx = defaultdict(list)
idx2word = defaultdict(list)
idx = 0
with codecs.open(filename, mode="r", encoding="utf-8") as rf:
try:
for line in rf.readlines():
idx += 1
arr = line.split(" ")
if len(arr) != (embedding_size + 2):
logging.error("embedding error, index is:%s" % (idx))
continue

embedding = [float(val) for val in arr[1: -1]]
word2idx[arr[0]] = len(word2idx)
idx2word[len(word2idx)] = arr[0]
embeddings.append(embedding)

except Exception as e:
logging.error("load embedding Exception,", e)
finally:
rf.close()

logging.info("load embedding finish!")
return embeddings, word2idx, idx2word


def sent_to_idx(sent, word2idx, sequence_len):
"""
convert sentence to index array
"""
unknown_id = word2idx.get("UNKNOWN", 0)
sent2idx = [word2idx.get(word, unknown_id) for word in sent.split("_")[:sequence_len]]
return sent2idx


def load_train_data(filename, word2idx, sequence_len):
"""
load train data
"""
ori_quests, cand_quests = [], []
with codecs.open(filename, mode="r", encoding="utf-8") as rf:
try:
for line in rf.readlines():
arr = line.strip().split(" ")
if len(arr) != 4 or arr[0] != "1":
logging.error("invalid data:%s" % (line))
continue

ori_quest = sent_to_idx(arr[2], word2idx, sequence_len)
cand_quest = sent_to_idx(arr[3], word2idx, sequence_len)

ori_quests.append(ori_quest)
cand_quests.append(cand_quest)

except Exception as e:
logging.error("load train data Exception,", e)
finally:
rf.close()
logging.info("load train data finish!")

return ori_quests, cand_quests


def create_valid(data, proportion=0.1):
if data is None:
logging.error("data is none")
os._exit(1)

data_len = len(data)
shuffle_idx = np.random.permutation(np.arange(data_len))
data = np.array(data)[shuffle_idx]
seperate_idx = int(data_len * (1 - proportion))
return data[:seperate_idx], data[seperate_idx:]


def load_test_data(filename, word2idx, sequence_len):
"""
load test data
"""
ori_quests, cand_quests, labels, results = [], [], [], []
with codecs.open(filename, mode="r", encoding="utf-8") as rf:
try:
for line in rf.readlines():
arr = line.strip().split(" ")
if len(arr) != 4:
logging.error("invalid data:%s" % line)
continue

ori_quest = sent_to_idx(arr[2], word2idx, sequence_len)
cand_quest = sent_to_idx(arr[3], word2idx, sequence_len)
label = int(arr[0])
result = int(arr[1].split(":")[1])

ori_quests.append(ori_quest)
cand_quests.append(cand_quest)
labels.append(label)
results.append(result)

except Exception as e:
logging.error("load test error,", e)
finally:
rf.close()
logging.info("load test data finish!")
return ori_quests, cand_quests, labels, results


def batch_iter(ori_quests, cand_quests, batch_size, epoches, is_valid=False):
"""
iterate the data
"""
data_len = len(ori_quests)
batch_num = int(data_len / batch_size)
ori_quests = np.array(ori_quests)
cand_quests = np.array(cand_quests)

for epoch in range(epoches):
if is_valid is not True:
shuffle_idx = np.random.permutation(np.arange(batch_num * batch_size))
ori_quests = np.array(ori_quests)[shuffle_idx]
cand_quests = np.array(cand_quests)[shuffle_idx]
for batch in range(batch_num):
start_idx = batch * batch_size
end_idx = min((batch + 1) * batch_size, data_len)
act_batch_size = end_idx - start_idx

# get negative questions
if is_valid:
neg_quests = cand_quests[start_idx: end_idx]
else:
randi_list = []
while len(randi_list) != act_batch_size:
[randi_list.append(idx) for idx in np.random.randint(0, data_len, 5 * act_batch_size) if
start_idx < idx < end_idx and len(randi_list) < act_batch_size]
neg_quests = [cand_quests[idx] for idx in randi_list]

yield (ori_quests[start_idx: end_idx], cand_quests[start_idx: end_idx], neg_quests)
173 changes: 173 additions & 0 deletions execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# coding=utf-8

import logging
import datetime
import time
import tensorflow as tf
import operator

from data_helper import load_train_data, load_test_data, load_embedding, batch_iter
from polymerization import LSTM_QA

# ------------------------- define parameter -----------------------------
tf.flags.DEFINE_string("train_file", "../insuranceQA/train", "train corpus file")
tf.flags.DEFINE_string("test_file", "../insuranceQA/test1", "test corpus file")
tf.flags.DEFINE_string("valid_file", "../insuranceQA/test1.sample", "test corpus file")
tf.flags.DEFINE_string("embedding_file", "../insuranceQA/vectors.nobin", "embedding file")
tf.flags.DEFINE_integer("embedding_size", 100, "embedding size")
tf.flags.DEFINE_float("dropout", 1, "the proportion of dropout")
tf.flags.DEFINE_float("lr", 0.1, "the proportion of dropout")
tf.flags.DEFINE_integer("batch_size", 100, "batch size of each batch")
tf.flags.DEFINE_integer("epoches", 300, "epoches")
tf.flags.DEFINE_integer("rnn_size", 300, "embedding size")
tf.flags.DEFINE_integer("num_rnn_layers", 1, "embedding size")
tf.flags.DEFINE_integer("evaluate_every", 1000, "run evaluation")
tf.flags.DEFINE_integer("num_unroll_steps", 100, "embedding size")
tf.flags.DEFINE_integer("max_grad_norm", 5, "embedding size")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", True, "Log placement of ops on devices")
tf.flags.DEFINE_float("gpu_options", 0.75, "use memory rate")

FLAGS = tf.flags.FLAGS
# ----------------------------- define parameter end ----------------------------------

# ----------------------------- define a logger -------------------------------
logger = logging.getLogger("execute")
logger.setLevel(logging.INFO)

fh = logging.FileHandler("./run.log", mode="w")
fh.setLevel(logging.INFO)

fmt = "%(asctime)-15s %(levelname)s %(filename)s %(lineno)d %(process)d %(message)s"
datefmt = "%a %d %b %Y %H:%M:%S"
formatter = logging.Formatter(fmt, datefmt)

fh.setFormatter(formatter)
logger.addHandler(fh)
# ----------------------------- define a logger end ----------------------------------

# ------------------------------------load data -------------------------------
embedding, word2idx, idx2word = load_embedding(FLAGS.embedding_file, FLAGS.embedding_size)
ori_quests, cand_quests = load_train_data(FLAGS.train_file, word2idx, FLAGS.num_unroll_steps)

test_ori_quests, test_cand_quests, labels, results = load_test_data(FLAGS.test_file, word2idx, FLAGS.num_unroll_steps)
valid_ori_quests, valid_cand_quests, valid_labels, valid_results = load_test_data(FLAGS.valid_file, word2idx,
FLAGS.num_unroll_steps)


# ----------------------------------- load data end ----------------------

# ----------------------------------- execute train model ---------------------------------
def run_step(sess, ori_batch, cand_batch, neg_batch, lstm, dropout=1.):
start_time = time.time()
feed_dict = {
lstm.ori_input_quests: ori_batch,
lstm.cand_input_quests: cand_batch,
lstm.neg_input_quests: neg_batch,
lstm.keep_prob: dropout
}

_, step, ori_cand_score, ori_neg_score, cur_loss, cur_acc = sess.run(
[train_op, global_step, lstm.ori_cand, lstm.ori_neg, lstm.loss, lstm.acc], feed_dict)
time_str = datetime.datetime.now().isoformat()
right, wrong, score = [0.0] * 3
for i in range(0, len(ori_batch)):
if ori_cand_score[i] > 0.55 and ori_neg_score[i] < 0.4:
right += 1.0
else:
wrong += 1.0
score += ori_cand_score[i] - ori_neg_score[i]
time_elapsed = time.time() - start_time
logger.info("%s: step %s, loss %s, acc %s, score %s, wrong %s, %6.7f secs/batch" % (
time_str, step, cur_loss, cur_acc, score, wrong, time_elapsed))

return cur_loss, ori_cand_score


def valid_run_step(sess, ori_batch, cand_batch, lstm, dropout=1.):
feed_dict = {
lstm.test_input_q: ori_batch,
lstm.test_input_a: cand_batch,
lstm.keep_prob: dropout
}

step, ori_cand_score = sess.run([global_step, lstm.test_q_a], feed_dict)

return ori_cand_score


# ---------------------------------- execute train model end --------------------------------------

def cal_acc(labels, results, total_ori_cand):
if len(labels) == len(results) == len(total_ori_cand):
retdict = {}
for label, result, ori_cand in zip(labels, results, total_ori_cand):
if result not in retdict:
retdict[result] = []
retdict[result].append((ori_cand, label))

correct = 0
for key, value in retdict.items():
value.sort(key=operator.itemgetter(0), reverse=True)
score, flag = value[0]
if flag == 1:
correct += 1
return 1. * correct / len(retdict)
else:
logger.info("data error")
return 0


# ---------------------------------- execute valid model ------------------------------------------
def valid_model(sess, lstm, valid_ori_quests, valid_cand_quests, labels, results):
logger.info("start to validate model")
total_ori_cand = []
for ori_valid, cand_valid, neg_valid in batch_iter(valid_ori_quests, valid_cand_quests, FLAGS.batch_size, 1,
is_valid=True):
ori_cand = valid_run_step(sess, ori_valid, cand_valid, lstm)
total_ori_cand.extend(ori_cand)

data_len = len(total_ori_cand)
acc = cal_acc(labels[:data_len], results[:data_len], total_ori_cand)
timestr = datetime.datetime.now().isoformat()
logger.info("%s, evaluation acc:%s" % (timestr, acc))


# ---------------------------------- execute valid model end --------------------------------------

# ----------------------------------- begin to train -----------------------------------
with tf.Graph().as_default():
with tf.device("/gpu:0"):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_options)
session_conf = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement,
gpu_options=gpu_options)
with tf.Session(config=session_conf).as_default() as sess:
lstm = LSTM_QA(FLAGS.batch_size, FLAGS.num_unroll_steps, embedding, FLAGS.embedding_size, FLAGS.rnn_size,
FLAGS.num_rnn_layers, FLAGS.max_grad_norm)
global_step = tf.Variable(0, name="globle_step", trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(lstm.loss, tvars),
FLAGS.max_grad_norm)

# optimizer = tf.train.GradientDescentOptimizer(lstm.lr)
optimizer = tf.train.GradientDescentOptimizer(1e-1)
optimizer.apply_gradients(zip(grads, tvars))
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)

sess.run(tf.global_variables_initializer())

for epoch in range(FLAGS.epoches):
# cur_lr = FLAGS.lr / (epoch + 1)
# lstm.assign_new_lr(sess, cur_lr)
# logger.info("current learning ratio:" + str(cur_lr))
for ori_train, cand_train, neg_train in batch_iter(ori_quests, cand_quests, FLAGS.batch_size,
epoches=1):
run_step(sess, ori_train, cand_train, neg_train, lstm)
cur_step = tf.train.global_step(sess, global_step)

if cur_step % FLAGS.evaluate_every == 0 and cur_step != 0:
valid_model(sess, lstm, valid_ori_quests, valid_cand_quests, valid_labels, valid_results)
valid_model(sess, lstm, test_ori_quests, test_cand_quests, labels, results)
# ---------------------------------- end train -----------------------------------
Loading

0 comments on commit d1aea79

Please sign in to comment.