Skip to content

Commit

Permalink
🙊 fix macro f1
Browse files Browse the repository at this point in the history
  • Loading branch information
iofu728 committed Apr 11, 2019
1 parent 0134a06 commit a6715ea
Showing 1 changed file with 60 additions and 199 deletions.
259 changes: 60 additions & 199 deletions a02_TextCNN/p7_TextCNN_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import h5py
import os
import random
from numba import jit
#configuration
FLAGS=tf.app.flags.FLAGS

Expand Down Expand Up @@ -54,8 +55,7 @@ def main(_):
#print some message for debug purpose
print("trainX[0:10]:", trainX[0:10])
print("trainY[0]:", trainY[0:10])
train_y_short = get_target_label_short(trainY[0])
print("train_y_short:", train_y_short)
print("train_y_short:", trainY[0])

#2.create session.
config=tf.ConfigProto()
Expand Down Expand Up @@ -128,203 +128,64 @@ def main(_):


# 在验证集上做验证,报告损失、精确度
def do_eval(sess,textCNN,evalX,evalY,num_classes):
evalX=evalX[0:3000]
evalY=evalY[0:3000]
number_examples=len(evalX)
eval_loss,eval_counter,eval_f1_score,eval_p,eval_r=0.0,0,0.0,0.0,0.0
batch_size=1
label_dict_confuse_matrix=init_label_dict(num_classes)
for start,end in zip(range(0,number_examples,batch_size),range(batch_size,number_examples,batch_size)):
feed_dict = {textCNN.input_x: evalX[start:end], textCNN.input_y_multilabel:evalY[start:end],textCNN.dropout_keep_prob: 1.0,
def do_eval(sess, textCNN, evalX, evalY, num_classes):
evalX = evalX[0:3000]
evalY = evalY[0:3000]
number_examples = len(evalX)
eval_loss, eval_counter, eval_f1_score, eval_p, eval_r = 0.0, 0, 0.0, 0.0, 0.0
batch_size = FLAGS.batch_size
predict = []

for start, end in zip(range(0, number_examples, batch_size), range(batch_size, number_examples, batch_size)):
''' evaluation in one batch '''
feed_dict = {textCNN.input_x: evalX[start:end], textCNN.input_y_multilabel: evalY[start:end], textCNN.dropout_keep_prob: 1.0,
textCNN.is_training_flag: False}
curr_eval_loss, logits= sess.run([textCNN.loss_val,textCNN.logits],feed_dict)#curr_eval_acc--->textCNN.accuracy
predict_y = get_label_using_logits(logits[0])
target_y= get_target_label_short(evalY[start:end][0])
#f1_score,p,r=compute_f1_score(list(label_list_top5), evalY[start:end][0])
label_dict_confuse_matrix=compute_confuse_matrix(target_y, predict_y, label_dict_confuse_matrix)
eval_loss,eval_counter=eval_loss+curr_eval_loss,eval_counter+1

f1_micro,f1_macro=compute_micro_macro(label_dict_confuse_matrix) #label_dict_accusation is a dict, key is: accusation,value is: (TP,FP,FN). where TP is number of True Positive
f1_score=(f1_micro+f1_macro)/2.0
return eval_loss/float(eval_counter),f1_score,f1_micro,f1_macro

#######################################
def compute_f1_score(predict_y,eval_y):
"""
compoute f1_score.
:param logits: [batch_size,label_size]
:param evalY: [batch_size,label_size]
:return:
"""
f1_score=0.0
p_5=0.0
r_5=0.0
return f1_score,p_5,r_5

def compute_f1_score_removed(label_list_top5,eval_y):
"""
compoute f1_score.
:param logits: [batch_size,label_size]
:param evalY: [batch_size,label_size]
:return:
"""
num_correct_label=0
eval_y_short=get_target_label_short(eval_y)
for label_predict in label_list_top5:
if label_predict in eval_y_short:
num_correct_label=num_correct_label+1
#P@5=Precision@5
num_labels_predicted=len(label_list_top5)
all_real_labels=len(eval_y_short)
p_5=num_correct_label/num_labels_predicted
#R@5=Recall@5
r_5=num_correct_label/all_real_labels
f1_score=2.0*p_5*r_5/(p_5+r_5+0.000001)
return f1_score,p_5,r_5

random_number=1000
def compute_confuse_matrix(target_y,predict_y,label_dict,name='default'):
"""
compute true postive(TP), false postive(FP), false negative(FN) given target lable and predict label
:param target_y:
:param predict_y:
:param label_dict {label:(TP,FP,FN)}
:return: macro_f1(a scalar),micro_f1(a scalar)
"""
#1.get target label and predict label
if random.choice([x for x in range(random_number)]) ==1:
print(name+".target_y:",target_y,";predict_y:",predict_y) #debug purpose

#2.count number of TP,FP,FN for each class
y_labels_unique=[]
y_labels_unique.extend(target_y)
y_labels_unique.extend(predict_y)
y_labels_unique=list(set(y_labels_unique))
for i,label in enumerate(y_labels_unique): #e.g. label=2
TP, FP, FN = label_dict[label]
if label in predict_y and label in target_y:#predict=1,truth=1 (TP)
TP=TP+1
elif label in predict_y and label not in target_y:#predict=1,truth=0(FP)
FP=FP+1
elif label not in predict_y and label in target_y:#predict=0,truth=1(FN)
FN=FN+1
label_dict[label] = (TP, FP, FN)
return label_dict

def compute_micro_macro(label_dict):
"""
compute f1 of micro and macro
:param label_dict:
:return: f1_micro,f1_macro: scalar, scalar
"""
f1_micro = compute_f1_micro_use_TFFPFN(label_dict)
f1_macro= compute_f1_macro_use_TFFPFN(label_dict)
return f1_micro,f1_macro

def compute_TF_FP_FN_micro(label_dict):
"""
compute micro FP,FP,FN
:param label_dict_accusation: a dict. {label:(TP, FP, FN)}
:return:TP_micro,FP_micro,FN_micro
"""
TP_micro,FP_micro,FN_micro=0.0,0.0,0.0
for label,tuplee in label_dict.items():
TP,FP,FN=tuplee
TP_micro=TP_micro+TP
FP_micro=FP_micro+FP
FN_micro=FN_micro+FN
return TP_micro,FP_micro,FN_micro
def compute_f1_micro_use_TFFPFN(label_dict):
"""
compute f1_micro
:param label_dict: {label:(TP,FP,FN)}
:return: f1_micro: a scalar
"""
TF_micro_accusation, FP_micro_accusation, FN_micro_accusation =compute_TF_FP_FN_micro(label_dict)
f1_micro_accusation = compute_f1(TF_micro_accusation, FP_micro_accusation, FN_micro_accusation,'micro')
return f1_micro_accusation

def compute_f1_macro_use_TFFPFN(label_dict):
"""
compute f1_macro
:param label_dict: {label:(TP,FP,FN)}
:return: f1_macro
"""
f1_dict= {}
num_classes=len(label_dict)
for label, tuplee in label_dict.items():
TP,FP,FN=tuplee
f1_score_onelabel=compute_f1(TP,FP,FN,'macro')
f1_dict[label]=f1_score_onelabel
f1_score_sum=0.0
for label,f1_score in f1_dict.items():
f1_score_sum=f1_score_sum+f1_score
f1_score=f1_score_sum/float(num_classes)
return f1_score

small_value=0.00001
def compute_f1(TP,FP,FN,compute_type):
"""
compute f1
:param TP_micro: number.e.g. 200
:param FP_micro: number.e.g. 200
:param FN_micro: number.e.g. 200
:return: f1_score: a scalar
"""
precison=TP/(TP+FP+small_value)
recall=TP/(TP+FN+small_value)
f1_score=(2*precison*recall)/(precison+recall+small_value)

if random.choice([x for x in range(500)]) == 1:print(compute_type,"precison:",str(precison),";recall:",str(recall),";f1_score:",f1_score)

return f1_score
def init_label_dict(num_classes):
"""
init label dict. this dict will be used to save TP,FP,FN
:param num_classes:
:return: label_dict: a dict. {label_index:(0,0,0)}
"""
label_dict={}
for i in range(num_classes):
label_dict[i]=(0,0,0)
return label_dict

def get_target_label_short(eval_y):
eval_y_short=[] #will be like:[22,642,1391]
for index,label in enumerate(eval_y):
if label>0:
eval_y_short.append(index)
return eval_y_short

#get top5 predicted labels
def get_label_using_logits(logits,top_number=5):
# index_list=np.argsort(logits)[-top_number:]
#vindex_list=index_list[::-1]
y_predict_labels = [i for i in range(len(logits)) if logits[i] >= 0.50] # TODO 0.5PW e.g.[2,12,13,10]
if len(y_predict_labels) < 1: y_predict_labels = [np.argmax(logits)]

return y_predict_labels

#统计预测的准确率
def calculate_accuracy(labels_predicted, labels,eval_counter):
label_nozero=[]
#print("labels:",labels)
labels=list(labels)
for index,label in enumerate(labels):
if label>0:
label_nozero.append(index)
if eval_counter<2:
print("labels_predicted:",labels_predicted," ;labels_nozero:",label_nozero)
count = 0
label_dict = {x: x for x in label_nozero}
for label_predict in labels_predicted:
flag = label_dict.get(label_predict, None)
if flag is not None:
count = count + 1
return count / len(labels)

##################################################
current_eval_loss, logits = sess.run(
[textCNN.loss_val, textCNN.logits], feed_dict)
predict += logits[0]
eval_loss += current_eval_loss
eval_counter += 1

if not FLAGS.multi_label_flag:
predict = [int(ii > 0.5) for ii in predict]
_, _, f1_macro, f1_micro, _ = fastF1(predict, evalY)
f1_score = (f1_micro+f1_macro)/2.0
return eval_loss/float(eval_counter), f1_score, f1_micro, f1_macro

@jit
def fastF1(result, predict):
''' f1 score '''
true_total, r_total, p_total, p, r = 0, 0, 0, 0, 0
total_list = []
for trueValue in range(6):
trueNum, recallNum, precisionNum = 0, 0, 0
for index, values in enumerate(result):
if values == trueValue:
recallNum += 1
if values == predict[index]:
trueNum += 1
if predict[index] == trueValue:
precisionNum += 1
R = trueNum / recallNum if recallNum else 0
P = trueNum / precisionNum if precisionNum else 0
true_total += trueNum
r_total += recallNum
p_total += precisionNum
p += P
r += R
f1 = (2 * P * R) / (P + R) if (P + R) else 0
print(id2rela[trueValue], P, R, f1)
total_list.append([P, R, f1])
p /= 6
r /= 6
micro_r = true_total / r_total
micro_p = true_total / p_total
macro_f1 = (2 * p * r) / (p + r) if (p + r) else 0
micro_f1 = (2 * micro_p * micro_r) / (micro_p +
micro_r) if (micro_p + micro_r) else 0
print('P: {:.2f}%, R: {:.2f}%, Micro_f1: {:.2f}%, Macro_f1: {:.2f}%'.format(
p*100, r*100, micro_f1 * 100, macro_f1*100))
return p, r, macro_f1, micro_f1, total_lists

def assign_pretrained_word_embedding(sess,vocabulary_index2word,vocab_size,textCNN,word2vec_model_path):
import word2vec # we put import here so that many people who do not use word2vec do not need to install this package. you can move import to the beginning of this file.
Expand Down Expand Up @@ -390,4 +251,4 @@ def load_data(cache_file_h5py,cache_file_pickle):
print("INFO. cache file load successful...")
return word2index, label2index,train_X,train_Y,vaild_X,valid_Y,test_X,test_Y
if __name__ == "__main__":
tf.app.run()
tf.app.run()

0 comments on commit a6715ea

Please sign in to comment.