Skip to content

Commit

Permalink
Merge pull request X-LANCE#48 from ddlBoJack/dev-zzasdf
Browse files Browse the repository at this point in the history
Better demonstrate the differences between ref and hyp
  • Loading branch information
ddlBoJack authored Mar 26, 2024
2 parents 2a72707 + 16e5efb commit 1d9f8bb
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/slam_llm/utils/compute_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,39 @@
import numpy as np
import sys

def build_diff(ref, hyp, path):
result = []
ref = list(map(lambda x: x.lower(), ref))
hyp = list(map(lambda x: x.lower(), hyp))
r_record = -1
h_record = -1
# path = path+[(len(ref), len(hyp))]

for rpointer, hpointer in path:
if rpointer!=r_record+1 or hpointer!=h_record+1:
r_buffer = ' '.join(ref[r_record+1:rpointer])
r_buffer = r_buffer if len(r_buffer)>0 else "*"
h_buffer = ' '.join(hyp[h_record+1:hpointer])
h_buffer = h_buffer if len(h_buffer)>0 else "*"
result.append(f"({r_buffer}->{h_buffer})")

result.append(ref[rpointer])
r_record = rpointer
h_record = hpointer

if r_record<len(ref)-1 or h_record<len(hyp)-1:
r_buffer = ' '.join(ref[r_record+1:])
r_buffer = r_buffer if len(r_buffer)>0 else "*"
h_buffer = ' '.join(hyp[h_record+1:])
h_buffer = h_buffer if len(h_buffer)>0 else "*"
result.append(f"({r_buffer}->{h_buffer})")
return ' '.join(result)






def compute_wer(ref_file,
hyp_file,
cer_detail_file):
Expand Down Expand Up @@ -51,6 +84,7 @@ def compute_wer(ref_file,
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
cer_detail_writer.write("diff:" + '\t' + build_diff(ref_dict[hyp_key], hyp_dict[hyp_key], out_item['path']) + '\n')

if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
Expand Down Expand Up @@ -106,7 +140,8 @@ def compute_wer_by_line(hyp,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
'sub': 0,
'path': []
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
Expand Down Expand Up @@ -141,6 +176,7 @@ def compute_wer_by_line(hyp,
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
rst['path'] = match_idx

return rst

Expand Down

0 comments on commit 1d9f8bb

Please sign in to comment.