forked from yeyupiaoling/PaddlePaddle-DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
116 lines (101 loc) · 6.2 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import functools
import time
import paddle
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.error_rate import char_errors, word_errors
from decoders.ctc_greedy_decoder import greedy_decoder_batch
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU评估")
add_arg('batch_size', int, 64, "评估是每一批数据的大小")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('beam_size', int, 100, "集束搜索解码相关参数,搜索大小,范围:[5, 500]")
add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
add_arg('num_proc_bsearch', int, 8, "集束搜索解码相关参数,使用CPU数量")
add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('test_manifest', str, './dataset/manifest.train', "需要评估的测试数据列表")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的字典文件路径")
add_arg('resume_model', str, './models/param/50.pdparams', "恢复模型文件路径")
add_arg('lang_model_path', str, './lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)", choices=['ctc_beam_search', 'ctc_greedy'])
add_arg('error_rate_type', str, 'cer', "评估所使用的错误率方法,有字错率(cer)、词错率(wer)", choices=['wer', 'cer'])
args = parser.parse_args()
# 评估模型
def evaluate():
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
# 获取数据生成器
data_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
keep_transcription_text=True,
place=place,
is_training=False)
# 获取评估数据
batch_reader = data_generator.batch_reader_creator(manifest_path=args.test_manifest,
batch_size=args.batch_size,
shuffle_method=None)
# 获取DeepSpeech2模型,并设置为预测
ds2_model = DeepSpeech2Model(vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
place=place,
resume_model=args.resume_model)
# 读取数据列表
with open(args.test_manifest, 'r', encoding='utf-8') as f_m:
test_len = len(f_m.readlines())
# 集束搜索方法的处理
if args.decoding_method == "ctc_beam_search":
try:
from decoders.beam_search_decoder import BeamSearchDecoder
beam_search_decoder = BeamSearchDecoder(args.alpha, args.beta, args.lang_model_path, data_generator.vocab_list)
except ModuleNotFoundError:
raise Exception('缺少swig_decoders库,请根据文档安装,如果是Windows系统,请使用ctc_greed。')
# 获取评估函数,有字错率和词错率
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("开始评估 ...")
start = time.time()
# 开始评估
for infer_data in batch_reader():
# 获取一批的识别结果
probs_split = ds2_model.infer_batch_data(infer_data=infer_data)
# 执行解码
if args.decoding_method == 'ctc_greedy':
# 贪心解码策略
result_transcripts = greedy_decoder_batch(probs_split=probs_split, vocabulary=data_generator.vocab_list)
else:
# 集束搜索解码策略
result_transcripts = beam_search_decoder.decode_batch_beam_search(probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=data_generator.vocab_list,
num_processes=args.num_proc_bsearch)
# 实际的结果
target_transcripts = infer_data[1]
# 计算字错率
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
print("错误率:[%s] (%d/%d) = %f" % (args.error_rate_type, num_ins, test_len, errors_sum / len_refs))
end = time.time()
print("消耗时间:%ds, 总错误率:[%s] (%d/%d) = %f" % ((end - start), args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("完成评估!")
def main():
print_arguments(args)
evaluate()
if __name__ == '__main__':
main()