Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
brightmart committed Nov 21, 2018
1 parent 3d8a490 commit 75982d3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
40 changes: 26 additions & 14 deletions a00_Bert/train_bert_multi-label.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import bert_modeling as modeling
import optimization
import tensorflow as tf
import os
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

from utils import load_data,init_label_dict,get_label_using_logits,get_target_label_short,compute_confuse_matrix,\
compute_micro_macro,compute_confuse_matrix_batch,get_label_using_logits_batch,get_target_label_short_batch
Expand Down Expand Up @@ -41,19 +43,20 @@ def main(_):
# 2. create model, define train operation
bert_config = modeling.BertConfig(vocab_size=len(word2index), hidden_size=FLAGS.hidden_size, num_hidden_layers=FLAGS.num_hidden_layers,
num_attention_heads=FLAGS.num_attention_heads,intermediate_size=FLAGS.intermediate_size)
input_ids = tf.placeholder(tf.int32, [FLAGS.batch_size, FLAGS.max_seq_length], name="input_ids")
input_mask = tf.placeholder(tf.int32, [FLAGS.batch_size, FLAGS.max_seq_length], name="input_mask")
segment_ids = tf.placeholder(tf.int32, [FLAGS.batch_size,FLAGS.max_seq_length],name="segment_ids")
label_ids = tf.placeholder(tf.float32, [FLAGS.batch_size,num_labels], name="label_ids")
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_ids") # FLAGS.batch_size
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_mask")
segment_ids = tf.placeholder(tf.int32, [None,FLAGS.max_seq_length],name="segment_ids")
label_ids = tf.placeholder(tf.float32, [None,num_labels], name="label_ids")
is_training = FLAGS.is_training #tf.placeholder(tf.bool, name="is_training")

use_one_hot_embeddings = False
loss, per_example_loss, logits, probabilities, model = create_model(bert_config, is_training, input_ids, input_mask,
segment_ids, label_ids, num_labels,use_one_hot_embeddings)
# define train operation
num_train_steps = int(float(num_examples) / float(FLAGS.batch_size * FLAGS.num_epochs));use_tpu=False
num_warmup_steps = int(num_train_steps * 0.1)
train_op = optimization.create_optimizer(loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, use_tpu)
#num_train_steps = int(float(num_examples) / float(FLAGS.batch_size * FLAGS.num_epochs)); use_tpu=False; num_warmup_steps = int(num_train_steps * 0.1)
#train_op = optimization.create_optimizer(loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, use_tpu)
global_step = tf.Variable(0, trainable=False, name="Global_Step")
train_op = tf.contrib.layers.optimize_loss(loss, global_step=global_step, learning_rate=FLAGS.learning_rate,optimizer="Adam", clip_gradients=3.0)

is_training_eval=False
loss_eval, per_example_loss_eval, logits_eval, probabilities_eval, model_eval = create_model(bert_config, is_training_eval, input_ids, input_mask,
Expand All @@ -63,11 +66,14 @@ def main(_):
gpu_config.gpu_options.allow_growth = True
sess = tf.Session(config=gpu_config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
print("Checkpoint Exists. Restoring Variables from Checkpoint.")
saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
number_of_training_data = len(trainX)
iteration = 0
curr_epoch = 0 #sess.run(textCNN.epoch_step)
batch_size = FLAGS.batch_size
saver = tf.train.Saver()
for epoch in range(curr_epoch, FLAGS.num_epochs):
loss_total, counter = 0.0, 0
for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
Expand All @@ -79,15 +85,18 @@ def main(_):
loss_total, counter = loss_total + curr_loss, counter + 1
if counter % 20 == 0:
print(epoch,"\t",iteration,"\tloss:",loss_total/float(counter),"\tcurrent_loss:",curr_loss)
if counter % 500==0:
print("trainX[",start,"]:",trainX[start])
print("trainY[",start,"]:",trainY[start])

# evaulation
if start!=0 and start % (300 * FLAGS.batch_size) == 0:
eval_loss, f1_score, f1_micro, f1_macro = do_eval(sess,input_ids,input_mask,segment_ids,label_ids,is_training_eval,loss_eval,
probabilities_eval,vaildX, vaildY, num_labels,batch_size)
if start!=0 and start % (1000 * FLAGS.batch_size) == 0:
eval_loss, f1_score, f1_micro, f1_macro = do_eval(sess,input_ids,input_mask,segment_ids,label_ids,is_training_eval,loss,
probabilities,vaildX, vaildY, num_labels,batch_size)
print("Epoch %d Validation Loss:%.3f\tF1 Score:%.3f\tF1_micro:%.3f\tF1_macro:%.3f" % (
epoch, eval_loss, f1_score, f1_micro, f1_macro))
# save model to checkpoint
if start % (2000 * FLAGS.batch_size)==0:
if start % (3000 * FLAGS.batch_size)==0:
save_path = FLAGS.ckpt_dir + "model.ckpt"
print("Going to save model..")
saver.save(sess, save_path, global_step=epoch)
Expand Down Expand Up @@ -146,19 +155,22 @@ def do_eval(sess,input_ids,input_mask,segment_ids,label_ids,is_training,loss,pro
number_examples = len(vaildX)
eval_loss, eval_counter, eval_f1_score, eval_p, eval_r = 0.0, 0, 0.0, 0.0, 0.0
label_dict = init_label_dict(num_labels)
f1_score_micro_sklearn_total=0.0
# batch_size=1 # TODO
for start, end in zip(range(0, number_examples, batch_size), range(batch_size, number_examples, batch_size)):
input_mask_, segment_ids_ = get_input_mask_segment_ids(vaildX[start:end])
feed_dict = {input_ids: vaildX[start:end],input_mask:input_mask_,segment_ids:segment_ids_,
label_ids:vaildY[start:end]}
curr_eval_loss, prob = sess.run([loss, probabilities],feed_dict)
target_labels=get_target_label_short_batch(vaildY[start:end])
predict_labels=get_label_using_logits_batch(prob)
#print("predict_labels:",predict_labels)
label_dict=compute_confuse_matrix_batch(target_labels,predict_labels,label_dict,name='bert')
eval_loss, eval_counter = eval_loss + curr_eval_loss, eval_counter + 1

f1_micro, f1_macro = compute_micro_macro(label_dict) # label_dictis a dict, key is: accusation,value is: (TP,FP,FN). where TP is number of True Positive
f1_score = (f1_micro + f1_macro) / 2.0
return eval_loss / float(eval_counter), f1_score, f1_micro, f1_macro
f1_score_result = (f1_micro + f1_macro) / 2.0
return eval_loss / float(eval_counter), f1_score_result, f1_micro, f1_macro

def get_input_mask_segment_ids(train_x_batch):
"""
Expand Down
3 changes: 2 additions & 1 deletion a00_Bert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import random

random_number=500

def load_data(cache_file_h5py,cache_file_pickle):
"""
load data from h5py and pickle cache files, which is generate by take step by step of pre-processing.ipynb
Expand Down Expand Up @@ -71,7 +73,6 @@ def compute_f1_score_removed(label_list_top5,eval_y):
f1_score=2.0*p_5*r_5/(p_5+r_5+0.000001)
return f1_score,p_5,r_5

random_number=1000
def compute_confuse_matrix(target_y,predict_y,label_dict,name='default'):
"""
compute true postive(TP), false postive(FP), false negative(FN) given target lable and predict label
Expand Down

0 comments on commit 75982d3

Please sign in to comment.