forked from shuaihuaiyi/QA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d0c4c3f
commit d1aea79
Showing
7 changed files
with
506 additions
and
11 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# coding:utf-8 | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.ops import rnn, rnn_cell | ||
|
||
|
||
# define lstm model and return related features | ||
|
||
|
||
# return n outputs of the n lstm cells | ||
def biLSTM(x, hidden_size): | ||
# biLSTM: | ||
# 功能:添加bidirectional_lstm操作 | ||
# 参数: | ||
# x: [batch, height, width] / [batch, step, embedding_size] | ||
# hidden_size: lstm隐藏层节点个数 | ||
# 输出: | ||
# output: [batch, height, 2*hidden_size] / [batch, step, 2*hidden_size] | ||
|
||
# input transformation | ||
input_x = tf.transpose(x, [1, 0, 2]) | ||
# input_x = tf.reshape(input_x, [-1, w]) | ||
# input_x = tf.split(0, h, input_x) | ||
input_x = tf.unstack(input_x) | ||
|
||
# define the forward and backward lstm cells | ||
lstm_fw_cell = rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0, state_is_tuple=True) | ||
lstm_bw_cell = rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0, state_is_tuple=True) | ||
output, _, _ = rnn.static_bidirectional_rnn (lstm_fw_cell, lstm_bw_cell, input_x, dtype=tf.float32) | ||
|
||
# output transformation to the original tensor type | ||
output = tf.stack(output) | ||
output = tf.transpose(output, [1, 0, 2]) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# coding=utf-8 | ||
|
||
import codecs | ||
import logging | ||
import numpy as np | ||
import os | ||
|
||
from collections import defaultdict | ||
|
||
# define a logger | ||
logging.basicConfig(format="%(message)s", level=logging.INFO) | ||
|
||
|
||
def load_embedding(filename, embedding_size): | ||
""" | ||
load embedding | ||
""" | ||
embeddings = [] | ||
word2idx = defaultdict(list) | ||
idx2word = defaultdict(list) | ||
idx = 0 | ||
with codecs.open(filename, mode="r", encoding="utf-8") as rf: | ||
try: | ||
for line in rf.readlines(): | ||
idx += 1 | ||
arr = line.split(" ") | ||
if len(arr) != (embedding_size + 2): | ||
logging.error("embedding error, index is:%s" % (idx)) | ||
continue | ||
|
||
embedding = [float(val) for val in arr[1: -1]] | ||
word2idx[arr[0]] = len(word2idx) | ||
idx2word[len(word2idx)] = arr[0] | ||
embeddings.append(embedding) | ||
|
||
except Exception as e: | ||
logging.error("load embedding Exception,", e) | ||
finally: | ||
rf.close() | ||
|
||
logging.info("load embedding finish!") | ||
return embeddings, word2idx, idx2word | ||
|
||
|
||
def sent_to_idx(sent, word2idx, sequence_len): | ||
""" | ||
convert sentence to index array | ||
""" | ||
unknown_id = word2idx.get("UNKNOWN", 0) | ||
sent2idx = [word2idx.get(word, unknown_id) for word in sent.split("_")[:sequence_len]] | ||
return sent2idx | ||
|
||
|
||
def load_train_data(filename, word2idx, sequence_len): | ||
""" | ||
load train data | ||
""" | ||
ori_quests, cand_quests = [], [] | ||
with codecs.open(filename, mode="r", encoding="utf-8") as rf: | ||
try: | ||
for line in rf.readlines(): | ||
arr = line.strip().split(" ") | ||
if len(arr) != 4 or arr[0] != "1": | ||
logging.error("invalid data:%s" % (line)) | ||
continue | ||
|
||
ori_quest = sent_to_idx(arr[2], word2idx, sequence_len) | ||
cand_quest = sent_to_idx(arr[3], word2idx, sequence_len) | ||
|
||
ori_quests.append(ori_quest) | ||
cand_quests.append(cand_quest) | ||
|
||
except Exception as e: | ||
logging.error("load train data Exception,", e) | ||
finally: | ||
rf.close() | ||
logging.info("load train data finish!") | ||
|
||
return ori_quests, cand_quests | ||
|
||
|
||
def create_valid(data, proportion=0.1): | ||
if data is None: | ||
logging.error("data is none") | ||
os._exit(1) | ||
|
||
data_len = len(data) | ||
shuffle_idx = np.random.permutation(np.arange(data_len)) | ||
data = np.array(data)[shuffle_idx] | ||
seperate_idx = int(data_len * (1 - proportion)) | ||
return data[:seperate_idx], data[seperate_idx:] | ||
|
||
|
||
def load_test_data(filename, word2idx, sequence_len): | ||
""" | ||
load test data | ||
""" | ||
ori_quests, cand_quests, labels, results = [], [], [], [] | ||
with codecs.open(filename, mode="r", encoding="utf-8") as rf: | ||
try: | ||
for line in rf.readlines(): | ||
arr = line.strip().split(" ") | ||
if len(arr) != 4: | ||
logging.error("invalid data:%s" % line) | ||
continue | ||
|
||
ori_quest = sent_to_idx(arr[2], word2idx, sequence_len) | ||
cand_quest = sent_to_idx(arr[3], word2idx, sequence_len) | ||
label = int(arr[0]) | ||
result = int(arr[1].split(":")[1]) | ||
|
||
ori_quests.append(ori_quest) | ||
cand_quests.append(cand_quest) | ||
labels.append(label) | ||
results.append(result) | ||
|
||
except Exception as e: | ||
logging.error("load test error,", e) | ||
finally: | ||
rf.close() | ||
logging.info("load test data finish!") | ||
return ori_quests, cand_quests, labels, results | ||
|
||
|
||
def batch_iter(ori_quests, cand_quests, batch_size, epoches, is_valid=False): | ||
""" | ||
iterate the data | ||
""" | ||
data_len = len(ori_quests) | ||
batch_num = int(data_len / batch_size) | ||
ori_quests = np.array(ori_quests) | ||
cand_quests = np.array(cand_quests) | ||
|
||
for epoch in range(epoches): | ||
if is_valid is not True: | ||
shuffle_idx = np.random.permutation(np.arange(batch_num * batch_size)) | ||
ori_quests = np.array(ori_quests)[shuffle_idx] | ||
cand_quests = np.array(cand_quests)[shuffle_idx] | ||
for batch in range(batch_num): | ||
start_idx = batch * batch_size | ||
end_idx = min((batch + 1) * batch_size, data_len) | ||
act_batch_size = end_idx - start_idx | ||
|
||
# get negative questions | ||
if is_valid: | ||
neg_quests = cand_quests[start_idx: end_idx] | ||
else: | ||
randi_list = [] | ||
while len(randi_list) != act_batch_size: | ||
[randi_list.append(idx) for idx in np.random.randint(0, data_len, 5 * act_batch_size) if | ||
start_idx < idx < end_idx and len(randi_list) < act_batch_size] | ||
neg_quests = [cand_quests[idx] for idx in randi_list] | ||
|
||
yield (ori_quests[start_idx: end_idx], cand_quests[start_idx: end_idx], neg_quests) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# coding=utf-8 | ||
|
||
import logging | ||
import datetime | ||
import time | ||
import tensorflow as tf | ||
import operator | ||
|
||
from data_helper import load_train_data, load_test_data, load_embedding, batch_iter | ||
from polymerization import LSTM_QA | ||
|
||
# ------------------------- define parameter ----------------------------- | ||
tf.flags.DEFINE_string("train_file", "../insuranceQA/train", "train corpus file") | ||
tf.flags.DEFINE_string("test_file", "../insuranceQA/test1", "test corpus file") | ||
tf.flags.DEFINE_string("valid_file", "../insuranceQA/test1.sample", "test corpus file") | ||
tf.flags.DEFINE_string("embedding_file", "../insuranceQA/vectors.nobin", "embedding file") | ||
tf.flags.DEFINE_integer("embedding_size", 100, "embedding size") | ||
tf.flags.DEFINE_float("dropout", 1, "the proportion of dropout") | ||
tf.flags.DEFINE_float("lr", 0.1, "the proportion of dropout") | ||
tf.flags.DEFINE_integer("batch_size", 100, "batch size of each batch") | ||
tf.flags.DEFINE_integer("epoches", 300, "epoches") | ||
tf.flags.DEFINE_integer("rnn_size", 300, "embedding size") | ||
tf.flags.DEFINE_integer("num_rnn_layers", 1, "embedding size") | ||
tf.flags.DEFINE_integer("evaluate_every", 1000, "run evaluation") | ||
tf.flags.DEFINE_integer("num_unroll_steps", 100, "embedding size") | ||
tf.flags.DEFINE_integer("max_grad_norm", 5, "embedding size") | ||
# Misc Parameters | ||
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") | ||
tf.flags.DEFINE_boolean("log_device_placement", True, "Log placement of ops on devices") | ||
tf.flags.DEFINE_float("gpu_options", 0.75, "use memory rate") | ||
|
||
FLAGS = tf.flags.FLAGS | ||
# ----------------------------- define parameter end ---------------------------------- | ||
|
||
# ----------------------------- define a logger ------------------------------- | ||
logger = logging.getLogger("execute") | ||
logger.setLevel(logging.INFO) | ||
|
||
fh = logging.FileHandler("./run.log", mode="w") | ||
fh.setLevel(logging.INFO) | ||
|
||
fmt = "%(asctime)-15s %(levelname)s %(filename)s %(lineno)d %(process)d %(message)s" | ||
datefmt = "%a %d %b %Y %H:%M:%S" | ||
formatter = logging.Formatter(fmt, datefmt) | ||
|
||
fh.setFormatter(formatter) | ||
logger.addHandler(fh) | ||
# ----------------------------- define a logger end ---------------------------------- | ||
|
||
# ------------------------------------load data ------------------------------- | ||
embedding, word2idx, idx2word = load_embedding(FLAGS.embedding_file, FLAGS.embedding_size) | ||
ori_quests, cand_quests = load_train_data(FLAGS.train_file, word2idx, FLAGS.num_unroll_steps) | ||
|
||
test_ori_quests, test_cand_quests, labels, results = load_test_data(FLAGS.test_file, word2idx, FLAGS.num_unroll_steps) | ||
valid_ori_quests, valid_cand_quests, valid_labels, valid_results = load_test_data(FLAGS.valid_file, word2idx, | ||
FLAGS.num_unroll_steps) | ||
|
||
|
||
# ----------------------------------- load data end ---------------------- | ||
|
||
# ----------------------------------- execute train model --------------------------------- | ||
def run_step(sess, ori_batch, cand_batch, neg_batch, lstm, dropout=1.): | ||
start_time = time.time() | ||
feed_dict = { | ||
lstm.ori_input_quests: ori_batch, | ||
lstm.cand_input_quests: cand_batch, | ||
lstm.neg_input_quests: neg_batch, | ||
lstm.keep_prob: dropout | ||
} | ||
|
||
_, step, ori_cand_score, ori_neg_score, cur_loss, cur_acc = sess.run( | ||
[train_op, global_step, lstm.ori_cand, lstm.ori_neg, lstm.loss, lstm.acc], feed_dict) | ||
time_str = datetime.datetime.now().isoformat() | ||
right, wrong, score = [0.0] * 3 | ||
for i in range(0, len(ori_batch)): | ||
if ori_cand_score[i] > 0.55 and ori_neg_score[i] < 0.4: | ||
right += 1.0 | ||
else: | ||
wrong += 1.0 | ||
score += ori_cand_score[i] - ori_neg_score[i] | ||
time_elapsed = time.time() - start_time | ||
logger.info("%s: step %s, loss %s, acc %s, score %s, wrong %s, %6.7f secs/batch" % ( | ||
time_str, step, cur_loss, cur_acc, score, wrong, time_elapsed)) | ||
|
||
return cur_loss, ori_cand_score | ||
|
||
|
||
def valid_run_step(sess, ori_batch, cand_batch, lstm, dropout=1.): | ||
feed_dict = { | ||
lstm.test_input_q: ori_batch, | ||
lstm.test_input_a: cand_batch, | ||
lstm.keep_prob: dropout | ||
} | ||
|
||
step, ori_cand_score = sess.run([global_step, lstm.test_q_a], feed_dict) | ||
|
||
return ori_cand_score | ||
|
||
|
||
# ---------------------------------- execute train model end -------------------------------------- | ||
|
||
def cal_acc(labels, results, total_ori_cand): | ||
if len(labels) == len(results) == len(total_ori_cand): | ||
retdict = {} | ||
for label, result, ori_cand in zip(labels, results, total_ori_cand): | ||
if result not in retdict: | ||
retdict[result] = [] | ||
retdict[result].append((ori_cand, label)) | ||
|
||
correct = 0 | ||
for key, value in retdict.items(): | ||
value.sort(key=operator.itemgetter(0), reverse=True) | ||
score, flag = value[0] | ||
if flag == 1: | ||
correct += 1 | ||
return 1. * correct / len(retdict) | ||
else: | ||
logger.info("data error") | ||
return 0 | ||
|
||
|
||
# ---------------------------------- execute valid model ------------------------------------------ | ||
def valid_model(sess, lstm, valid_ori_quests, valid_cand_quests, labels, results): | ||
logger.info("start to validate model") | ||
total_ori_cand = [] | ||
for ori_valid, cand_valid, neg_valid in batch_iter(valid_ori_quests, valid_cand_quests, FLAGS.batch_size, 1, | ||
is_valid=True): | ||
ori_cand = valid_run_step(sess, ori_valid, cand_valid, lstm) | ||
total_ori_cand.extend(ori_cand) | ||
|
||
data_len = len(total_ori_cand) | ||
acc = cal_acc(labels[:data_len], results[:data_len], total_ori_cand) | ||
timestr = datetime.datetime.now().isoformat() | ||
logger.info("%s, evaluation acc:%s" % (timestr, acc)) | ||
|
||
|
||
# ---------------------------------- execute valid model end -------------------------------------- | ||
|
||
# ----------------------------------- begin to train ----------------------------------- | ||
with tf.Graph().as_default(): | ||
with tf.device("/gpu:0"): | ||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_options) | ||
session_conf = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, | ||
log_device_placement=FLAGS.log_device_placement, | ||
gpu_options=gpu_options) | ||
with tf.Session(config=session_conf).as_default() as sess: | ||
lstm = LSTM_QA(FLAGS.batch_size, FLAGS.num_unroll_steps, embedding, FLAGS.embedding_size, FLAGS.rnn_size, | ||
FLAGS.num_rnn_layers, FLAGS.max_grad_norm) | ||
global_step = tf.Variable(0, name="globle_step", trainable=False) | ||
tvars = tf.trainable_variables() | ||
grads, _ = tf.clip_by_global_norm(tf.gradients(lstm.loss, tvars), | ||
FLAGS.max_grad_norm) | ||
|
||
# optimizer = tf.train.GradientDescentOptimizer(lstm.lr) | ||
optimizer = tf.train.GradientDescentOptimizer(1e-1) | ||
optimizer.apply_gradients(zip(grads, tvars)) | ||
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) | ||
|
||
sess.run(tf.global_variables_initializer()) | ||
|
||
for epoch in range(FLAGS.epoches): | ||
# cur_lr = FLAGS.lr / (epoch + 1) | ||
# lstm.assign_new_lr(sess, cur_lr) | ||
# logger.info("current learning ratio:" + str(cur_lr)) | ||
for ori_train, cand_train, neg_train in batch_iter(ori_quests, cand_quests, FLAGS.batch_size, | ||
epoches=1): | ||
run_step(sess, ori_train, cand_train, neg_train, lstm) | ||
cur_step = tf.train.global_step(sess, global_step) | ||
|
||
if cur_step % FLAGS.evaluate_every == 0 and cur_step != 0: | ||
valid_model(sess, lstm, valid_ori_quests, valid_cand_quests, valid_labels, valid_results) | ||
valid_model(sess, lstm, test_ori_quests, test_cand_quests, labels, results) | ||
# ---------------------------------- end train ----------------------------------- |
Oops, something went wrong.