forked from chenyuntc/pytorch-book
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add chapter11-LSTM-CTC (chenyuntc#25)
新增由 @Diamondfan 编写的chapter11-语音识别(LSTM-CTC)
- Loading branch information
Showing
18 changed files
with
1,837 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#encoding=utf-8 | ||
|
||
import math | ||
|
||
LOG_ZERO = -99999999.0 | ||
LOG_ONE = 0.0 | ||
|
||
class BeamEntry: | ||
"information about one single beam at specific time-step" | ||
def __init__(self): | ||
self.prTotal=LOG_ZERO # blank and non-blank | ||
self.prNonBlank=LOG_ZERO # non-blank | ||
self.prBlank=LOG_ZERO # blank | ||
self.y=() # labelling at current time-step | ||
|
||
class BeamState: | ||
"information about beams at specific time-step" | ||
def __init__(self): | ||
self.entries={} | ||
|
||
def norm(self): | ||
"length-normalise probabilities to avoid penalising long labellings" | ||
for (k,v) in self.entries.items(): | ||
labellingLen=len(self.entries[k].y) | ||
self.entries[k].prTotal=self.entries[k].prTotal*(1.0/(labellingLen if labellingLen else 1)) | ||
|
||
def sort(self): | ||
"return beams sorted by probability" | ||
u=[v for (k,v) in self.entries.items()] | ||
s=sorted(u, reverse=True, key=lambda x:x.prTotal) | ||
return [x.y for x in s] | ||
|
||
class ctcBeamSearch(object): | ||
def __init__(self, classes, beam_width, blank_index=0): | ||
self.classes = classes | ||
self.beamWidth = beam_width | ||
self.blank_index = blank_index | ||
|
||
def log_add_prob(self, log_x, log_y): | ||
if log_x <= LOG_ZERO: | ||
return log_y | ||
if log_y <= LOG_ZERO: | ||
return log_x | ||
if (log_y - log_x) > 0.0: | ||
log_y, log_x = log_x, log_y | ||
return log_x + math.log(1 + math.exp(log_y - log_x)) | ||
|
||
def calcExtPr(self, k, y, t, mat, beamState): | ||
"probability for extending labelling y to y+k" | ||
# optical model (RNN) | ||
if len(y) and y[-1]==k and mat[t-1, self.blank_index] < 0.9: | ||
return math.log(mat[t, k]) + beamState.entries[y].prBlank | ||
else: | ||
return math.log(mat[t, k]) + beamState.entries[y].prTotal | ||
|
||
def addLabelling(self, beamState, y): | ||
"adds labelling if it does not exist yet" | ||
if y not in beamState.entries: | ||
beamState.entries[y]=BeamEntry() | ||
|
||
def decode(self, inputs, inputs_list): | ||
''' | ||
Args: | ||
inputs(FloatTesnor) : Output of CTC(batch * timesteps * class) | ||
inputs_list(list) : the frames of each sample | ||
Returns: | ||
res(list) : Result of beam search | ||
''' | ||
batches, maxT, maxC = inputs.size() | ||
res = [] | ||
|
||
for batch in range(batches): | ||
mat = inputs[batch].numpy() | ||
# Initialise beam state | ||
last=BeamState() | ||
y=() | ||
last.entries[y]=BeamEntry() | ||
last.entries[y].prBlank=LOG_ONE | ||
last.entries[y].prTotal=LOG_ONE | ||
|
||
# go over all time-steps | ||
for t in range(inputs_list[batch]): | ||
curr=BeamState() | ||
|
||
#跳过概率很接近1的blank帧,增加解码速度 | ||
if (1 - mat[t, self.blank_index]) < 0.1: | ||
continue | ||
|
||
#取前beam个最好的结果 | ||
BHat=last.sort()[0:self.beamWidth] | ||
# go over best labellings | ||
for y in BHat: | ||
prNonBlank=LOG_ZERO | ||
# if nonempty labelling | ||
if len(y)>0: | ||
#相同的y两种可能,加入重复或者加入空白,如果之前没有字符,在NonBlank概率为0 | ||
prNonBlank=last.entries[y].prNonBlank + math.log(mat[t, y[-1]]) | ||
|
||
# calc probabilities | ||
prBlank = (last.entries[y].prTotal) + math.log(mat[t, self.blank_index]) | ||
# save result | ||
self.addLabelling(curr, y) | ||
curr.entries[y].y=y | ||
curr.entries[y].prNonBlank = self.log_add_prob(curr.entries[y].prNonBlank, prNonBlank) | ||
curr.entries[y].prBlank = self.log_add_prob(curr.entries[y].prBlank, prBlank) | ||
prTotal = self.log_add_prob(prBlank, prNonBlank) | ||
curr.entries[y].prTotal = self.log_add_prob(curr.entries[y].prTotal, prTotal) | ||
|
||
#t时刻加入其它的label,此时Blank的概率为0,如果加入的label与最后一个相同,因为不能重复,所以上一个字符一定是blank | ||
for k in range(maxC): | ||
if k != self.blank_index: | ||
newY=y+(k,) | ||
prNonBlank=self.calcExtPr(k, y, t, mat, last) | ||
|
||
# save result | ||
self.addLabelling(curr, newY) | ||
curr.entries[newY].y=newY | ||
curr.entries[newY].prNonBlank = self.log_add_prob(curr.entries[newY].prNonBlank, prNonBlank) | ||
curr.entries[newY].prTotal = self.log_add_prob(curr.entries[newY].prTotal, prNonBlank) | ||
|
||
# set new beam state | ||
last=curr | ||
|
||
# normalise probabilities according to labelling length | ||
last.norm() | ||
|
||
# sort by probability | ||
bestLabelling=last.sort()[0] # get most probable labelling | ||
|
||
# map labels to chars | ||
res_b =''.join([self.classes[l] for l in bestLabelling]) | ||
res.append(res_b) | ||
return res | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
本章内容是通过pytorch搭建一个LSTM-CTC的语音识别声学模型。 | ||
|
||
本次实验的数据为TIMIT数据集(可点击[academictorrents](http://academictorrents.com/details/34e2b78745138186976cbc27939b1b34d18bd5b3/tech) 或者 [luojie1987/TIMIT](http://luojie1987.com/index.php/post/110.html) 下载数据集)。 | ||
|
||
还有很多其他公开的语音相关的数据库可以在这里下载[Open Speech and Language Resources](http://www.openslr.org/resources.php) | ||
|
||
本项目的内容大多参考项目[https://github.com/Diamondfan/CTC_pytorch/](https://github.com/Diamondfan/CTC_pytorch/) | ||
|
||
## 环境配置 | ||
|
||
- 安装[PyTorch](http://pytorch.org) | ||
- 安装百度的Warp-CTC并于pytorch绑定,具体参见 | ||
[https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding](https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding) | ||
- 安装pytorch audio: | ||
```Bash | ||
sudo apt-get install sox libsox-dev libsox-fmt-all | ||
git clone https://github.com/pytorch/audio.git | ||
cd audio | ||
pip install cffi | ||
python setup.py install | ||
``` | ||
- 安装第三方依赖 | ||
```Python | ||
pip install -r requirements.txt | ||
``` | ||
- 启动visdom | ||
``` | ||
python -m visdom.serber | ||
``` | ||
|
||
## 使用方法: | ||
1、打开顶层脚本run.sh,修改相应的文件路径(TIMIT_dir, CONF_FILE)。 | ||
2、打开conf目录下的ctc_model_setting.conf进行网络结构等各项设置。 | ||
3、运行顶层脚本,后面带有一个参数stage,0表示从数据开始运行,1表示从训练开始,2表示直接测试 | ||
``` | ||
- bash run.sh 0 数据处理 + 训练 + 测试 | ||
- bash run.sh 1 训练 + 测试 | ||
- bash run.sh 2 测试 | ||
``` | ||
|
||
## 说明 | ||
### TIMIT数据准备 | ||
|
||
conf目录下的test\_spk.list和dev\_spk.list是音素识别中的常用验证集和测试集,使用这两个文件选取数据集。 | ||
|
||
执行数据处理脚本获取数据路径和转写标签: | ||
```Bash | ||
bash timit_data_prep.sh timit_dir | ||
``` | ||
执行完成后,在datap_repare目录下会生成wav.scp文件和text文件分别为音频的路径和音频的转写即文本标签. | ||
- train_wav.scp train.text > 3696 sentences | ||
- dev_wav.scp dev.text > 400 sentences | ||
- test_wav.scp test.text > 192 snetences | ||
|
||
### 关于rnn_type | ||
ctc_model_setting.conf中的rnn_type可以选择的RNN类型为 | ||
- lstm : nn.LSTM | ||
- rnn : nn.RNN | ||
- gru : nn.GRU | ||
|
||
### 关于标签 | ||
本项目的输出建模单元选择的是字符,即"abcdefghijklmnopqrstuvwxyz'" + " ",空格space也当做一个输出标签。所以总共28 + 1 = 29类。 | ||
|
||
加的1类为CTC中的空白类,表示该时刻的语音输出为噪声或者没有意义。在model.py中已经加了1,所以配置文件中填入正常的标签类别即可。 | ||
|
||
选择字符作为标签在小数据集上并不能得到很好的结果,比如在timit上仅有62%左右的正确率。实验发现采用音素作为输出的建模单元更为有效。 | ||
|
||
### 关于学习率修改 | ||
默认修改8次学习率停止训练,每次学习率降低一半。可以根据需要修改train.py(line 274) | ||
|
||
### log目录内容 | ||
- *.pkl: 保存的模型数据 在model.py(line 132)查看save_package函数 | ||
- train.log: 训练时打印的内容都存在train.log中 | ||
|
||
### 实验结果 | ||
<p align="center"> | ||
<img src="png/train_loss.png" width="200"> | ||
<img src="png/dev_loss.png" width="200"> | ||
<img src="png/dev_acc.png" width="200"> | ||
</p> | ||
|
||
将字符作为标签训练CTC的声学模型在TIMIT上测试集的识别率为: | ||
- Greedy decoder: 61.4831% | ||
- Beam decoder : 62.1029% | ||
|
||
本章内容只是构建了一个简单的声学模型,能够真正识别相差甚远,相比于kaldi中复杂的流程。项目内容还是能够对语音识别任务有一个初步的认识。 | ||
|
||
### 参考文献: | ||
- [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](http://www.cs.toronto.edu/~graves/icml_2006.pdf) | ||
- [Supervised Sequence Labelling with Recurrent Neural Networks](https://link.springer.com/book/10.1007/978-3-642-24797-2) | ||
- [EESEN: END-TO-END SPEECH RECOGNITION USING DEEP RNN MODELS AND WFST-BASED DECODING](http://www.cs.cmu.edu/afs/cs/Web/People/fmetze/interACT/Publications_files/publications/eesenasru.pdf) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
[Data] | ||
data_dir = /home/pytorch-book/chapter11-语音识别(LSTM-CTC)/data_prepare | ||
log_dir = /home/pytorch-book/chapter11-语音识别(LSTM-CTC)/log/ | ||
log_file = train.log | ||
|
||
[Model] | ||
rnn_input_size = 201 | ||
rnn_hidden_size = 256 | ||
rnn_layers = 4 | ||
rnn_type = lstm | ||
bidirectional = True | ||
batch_norm = True | ||
num_class = 28 | ||
drop_out = 0 | ||
model_file = | ||
|
||
[Training] | ||
use_cuda = True | ||
init_lr = 0.001 | ||
num_epoches = 2 | ||
end_adjust_acc = 2 | ||
lr_decay = 0.5 | ||
batch_size = 16 | ||
weight_decay = 0.05 | ||
seed = 1234 | ||
|
||
[Decode] | ||
decoder_type = Beam | ||
beam_width = 20 | ||
eval_dataset = test | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
faks0 | ||
fdac1 | ||
fjem0 | ||
mgwt0 | ||
mjar0 | ||
mmdb1 | ||
mmdm2 | ||
mpdf0 | ||
fcmh0 | ||
fkms0 | ||
mbdg0 | ||
mbwm0 | ||
mcsh0 | ||
fadg0 | ||
fdms0 | ||
fedw0 | ||
mgjf0 | ||
mglb0 | ||
mrtk0 | ||
mtaa0 | ||
mtdt0 | ||
mthc0 | ||
mwjg0 | ||
fnmr0 | ||
frew0 | ||
fsem0 | ||
mbns0 | ||
mmjr0 | ||
mdls0 | ||
mdlf0 | ||
mdvc0 | ||
mers0 | ||
fmah0 | ||
fdrw0 | ||
mrcs0 | ||
mrjm4 | ||
fcal1 | ||
mmwh0 | ||
fjsj0 | ||
majc0 | ||
mjsw0 | ||
mreb0 | ||
fgjd0 | ||
fjmg0 | ||
mroa0 | ||
mteb0 | ||
mjfc0 | ||
mrjr0 | ||
fmml0 | ||
mrws1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
mdab0 | ||
mwbt0 | ||
felc0 | ||
mtas1 | ||
mwew0 | ||
fpas0 | ||
mjmp0 | ||
mlnt0 | ||
fpkt0 | ||
mlll0 | ||
mtls0 | ||
fjlm0 | ||
mbpm0 | ||
mklt0 | ||
fnlp0 | ||
mcmj0 | ||
mjdh0 | ||
fmgd0 | ||
mgrt0 | ||
mnjm0 | ||
fdhc0 | ||
mjln0 | ||
mpam0 | ||
fmld0 |
Oops, something went wrong.