Skip to content

Commit

Permalink
[new] add show-result-list for confusion_matrix (fastnlp#309)
Browse files Browse the repository at this point in the history
add ConfusionMatrix and allow choose part of the column
  • Loading branch information
ROGERDJQ authored Jul 10, 2020
1 parent c957ed6 commit 228cca4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
2 changes: 2 additions & 0 deletions fastNLP/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(self,
pred=None,
target=None,
seq_len=None,
show_result=None,
print_ratio=False
):
r"""
Expand All @@ -326,6 +327,7 @@ def __init__(self,
super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(
show_result=show_result,
vocab=vocab,
print_ratio=print_ratio,
)
Expand Down
26 changes: 21 additions & 5 deletions fastNLP/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@

class ConfusionMatrix:
r"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None, print_ratio=False):
def __init__(self, show_result=None,vocab=None, print_ratio=False):
r"""
:param show_result: list type, 数据类型需要和target保持一致
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列
"""
Expand All @@ -48,6 +49,7 @@ def __init__(self, vocab=None, print_ratio=False):
self.confusiondict = {} # key: pred index, value:target word ocunt
self.predcount = {} # key:pred index, value:count
self.targetcount = {} # key:target index, value:count
self.show_result = show_result
self.vocab = vocab
self.print_ratio = print_ratio

Expand Down Expand Up @@ -153,16 +155,16 @@ def get_aligned_table(self, data, flag="result"):
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx
# namedict key :label idx value: str label name/label idx
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])
for label, idx in zip(totallabel, range(lenth)):
for label, lineidx in zip(totallabel, range(lenth)):
idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
label] = lineidx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
lineidx] = label # 建立一个临时字典,key: 行列index 0,1,2...->1,3,5,...,value:vocab的index,
# 这里打印东西
out = str()
output = []
Expand All @@ -183,13 +185,26 @@ def get_aligned_table(self, data, flag="result"):
for idx in range(len(col_lenths))
]
output.append(l)

tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"]
for n in data[-1]]
col_lenths = [
max(col_lenths[idx], [len(i) for i in tail][idx])
for idx in range(len(col_lenths))
]
output.append(tail)

if self.show_result:
missing_item=[]
missing_item = [i for i in self.show_result if i not in idx2row]
self.show_result = [i for i in self.show_result if i in idx2row]
if missing_item:
print(f"Noticing label(s) which is/are not in target list appeared, final output string will not contain{str(missing_item)}")
if self.show_result:
show_col = [0] + [i + 1 for i in [idx2row[i] for i in self.show_result]]
show_row = [0]+[i+2 for i in [idx2row[i] for i in self.show_result]]
output = [[row[col] for col in show_col] for row in [output[row] for row in show_row]]
output.insert(1,["pred"])
for line in output:
for colidx in range(len(line)):
out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t"
Expand Down Expand Up @@ -217,6 +232,7 @@ def __repr__(self):
return out



class Option(dict):
r"""a dict can treat keys as attributes"""

Expand Down

0 comments on commit 228cca4

Please sign in to comment.