Skip to content

Commit

Permalink
Add chapter11-LSTM-CTC (chenyuntc#25)
Browse files Browse the repository at this point in the history
新增由 @Diamondfan  编写的chapter11-语音识别(LSTM-CTC)
  • Loading branch information
范汝超 authored and chenyuntc committed Mar 14, 2018
1 parent a005f9f commit bcc8f48
Show file tree
Hide file tree
Showing 18 changed files with 1,837 additions and 0 deletions.
134 changes: 134 additions & 0 deletions chapter11-语音识别(LSTM-CTC)/BeamSearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#encoding=utf-8

import math

LOG_ZERO = -99999999.0
LOG_ONE = 0.0

class BeamEntry:
"information about one single beam at specific time-step"
def __init__(self):
self.prTotal=LOG_ZERO # blank and non-blank
self.prNonBlank=LOG_ZERO # non-blank
self.prBlank=LOG_ZERO # blank
self.y=() # labelling at current time-step

class BeamState:
"information about beams at specific time-step"
def __init__(self):
self.entries={}

def norm(self):
"length-normalise probabilities to avoid penalising long labellings"
for (k,v) in self.entries.items():
labellingLen=len(self.entries[k].y)
self.entries[k].prTotal=self.entries[k].prTotal*(1.0/(labellingLen if labellingLen else 1))

def sort(self):
"return beams sorted by probability"
u=[v for (k,v) in self.entries.items()]
s=sorted(u, reverse=True, key=lambda x:x.prTotal)
return [x.y for x in s]

class ctcBeamSearch(object):
def __init__(self, classes, beam_width, blank_index=0):
self.classes = classes
self.beamWidth = beam_width
self.blank_index = blank_index

def log_add_prob(self, log_x, log_y):
if log_x <= LOG_ZERO:
return log_y
if log_y <= LOG_ZERO:
return log_x
if (log_y - log_x) > 0.0:
log_y, log_x = log_x, log_y
return log_x + math.log(1 + math.exp(log_y - log_x))

def calcExtPr(self, k, y, t, mat, beamState):
"probability for extending labelling y to y+k"
# optical model (RNN)
if len(y) and y[-1]==k and mat[t-1, self.blank_index] < 0.9:
return math.log(mat[t, k]) + beamState.entries[y].prBlank
else:
return math.log(mat[t, k]) + beamState.entries[y].prTotal

def addLabelling(self, beamState, y):
"adds labelling if it does not exist yet"
if y not in beamState.entries:
beamState.entries[y]=BeamEntry()

def decode(self, inputs, inputs_list):
'''
Args:
inputs(FloatTesnor) : Output of CTC(batch * timesteps * class)
inputs_list(list) : the frames of each sample
Returns:
res(list) : Result of beam search
'''
batches, maxT, maxC = inputs.size()
res = []

for batch in range(batches):
mat = inputs[batch].numpy()
# Initialise beam state
last=BeamState()
y=()
last.entries[y]=BeamEntry()
last.entries[y].prBlank=LOG_ONE
last.entries[y].prTotal=LOG_ONE

# go over all time-steps
for t in range(inputs_list[batch]):
curr=BeamState()

#跳过概率很接近1的blank帧,增加解码速度
if (1 - mat[t, self.blank_index]) < 0.1:
continue

#取前beam个最好的结果
BHat=last.sort()[0:self.beamWidth]
# go over best labellings
for y in BHat:
prNonBlank=LOG_ZERO
# if nonempty labelling
if len(y)>0:
#相同的y两种可能,加入重复或者加入空白,如果之前没有字符,在NonBlank概率为0
prNonBlank=last.entries[y].prNonBlank + math.log(mat[t, y[-1]])

# calc probabilities
prBlank = (last.entries[y].prTotal) + math.log(mat[t, self.blank_index])
# save result
self.addLabelling(curr, y)
curr.entries[y].y=y
curr.entries[y].prNonBlank = self.log_add_prob(curr.entries[y].prNonBlank, prNonBlank)
curr.entries[y].prBlank = self.log_add_prob(curr.entries[y].prBlank, prBlank)
prTotal = self.log_add_prob(prBlank, prNonBlank)
curr.entries[y].prTotal = self.log_add_prob(curr.entries[y].prTotal, prTotal)

#t时刻加入其它的label,此时Blank的概率为0,如果加入的label与最后一个相同,因为不能重复,所以上一个字符一定是blank
for k in range(maxC):
if k != self.blank_index:
newY=y+(k,)
prNonBlank=self.calcExtPr(k, y, t, mat, last)

# save result
self.addLabelling(curr, newY)
curr.entries[newY].y=newY
curr.entries[newY].prNonBlank = self.log_add_prob(curr.entries[newY].prNonBlank, prNonBlank)
curr.entries[newY].prTotal = self.log_add_prob(curr.entries[newY].prTotal, prNonBlank)

# set new beam state
last=curr

# normalise probabilities according to labelling length
last.norm()

# sort by probability
bestLabelling=last.sort()[0] # get most probable labelling

# map labels to chars
res_b =''.join([self.classes[l] for l in bestLabelling])
res.append(res_b)
return res

93 changes: 93 additions & 0 deletions chapter11-语音识别(LSTM-CTC)/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
本章内容是通过pytorch搭建一个LSTM-CTC的语音识别声学模型。

本次实验的数据为TIMIT数据集(可点击[academictorrents](http://academictorrents.com/details/34e2b78745138186976cbc27939b1b34d18bd5b3/tech) 或者 [luojie1987/TIMIT](http://luojie1987.com/index.php/post/110.html) 下载数据集)。

还有很多其他公开的语音相关的数据库可以在这里下载[Open Speech and Language Resources](http://www.openslr.org/resources.php)

本项目的内容大多参考项目[https://github.com/Diamondfan/CTC_pytorch/](https://github.com/Diamondfan/CTC_pytorch/)

## 环境配置

- 安装[PyTorch](http://pytorch.org)
- 安装百度的Warp-CTC并于pytorch绑定,具体参见
[https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding](https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding)
- 安装pytorch audio:
```Bash
sudo apt-get install sox libsox-dev libsox-fmt-all
git clone https://github.com/pytorch/audio.git
cd audio
pip install cffi
python setup.py install
```
- 安装第三方依赖
```Python
pip install -r requirements.txt
```
- 启动visdom
```
python -m visdom.serber
```

## 使用方法:
1、打开顶层脚本run.sh,修改相应的文件路径(TIMIT_dir, CONF_FILE)。
2、打开conf目录下的ctc_model_setting.conf进行网络结构等各项设置。
3、运行顶层脚本,后面带有一个参数stage,0表示从数据开始运行,1表示从训练开始,2表示直接测试
```
- bash run.sh 0 数据处理 + 训练 + 测试
- bash run.sh 1 训练 + 测试
- bash run.sh 2 测试
```

## 说明
### TIMIT数据准备

conf目录下的test\_spk.list和dev\_spk.list是音素识别中的常用验证集和测试集,使用这两个文件选取数据集。

执行数据处理脚本获取数据路径和转写标签:
```Bash
bash timit_data_prep.sh timit_dir
```
执行完成后,在datap_repare目录下会生成wav.scp文件和text文件分别为音频的路径和音频的转写即文本标签.
- train_wav.scp train.text > 3696 sentences
- dev_wav.scp dev.text > 400 sentences
- test_wav.scp test.text > 192 snetences

### 关于rnn_type
ctc_model_setting.conf中的rnn_type可以选择的RNN类型为
- lstm : nn.LSTM
- rnn : nn.RNN
- gru : nn.GRU

### 关于标签
本项目的输出建模单元选择的是字符,即"abcdefghijklmnopqrstuvwxyz'" + " ",空格space也当做一个输出标签。所以总共28 + 1 = 29类。

加的1类为CTC中的空白类,表示该时刻的语音输出为噪声或者没有意义。在model.py中已经加了1,所以配置文件中填入正常的标签类别即可。

选择字符作为标签在小数据集上并不能得到很好的结果,比如在timit上仅有62%左右的正确率。实验发现采用音素作为输出的建模单元更为有效。

### 关于学习率修改
默认修改8次学习率停止训练,每次学习率降低一半。可以根据需要修改train.py(line 274)

### log目录内容
- *.pkl: 保存的模型数据 在model.py(line 132)查看save_package函数
- train.log: 训练时打印的内容都存在train.log中

### 实验结果
<p align="center">
<img src="png/train_loss.png" width="200">
<img src="png/dev_loss.png" width="200">
<img src="png/dev_acc.png" width="200">
</p>

将字符作为标签训练CTC的声学模型在TIMIT上测试集的识别率为:
- Greedy decoder: 61.4831%
- Beam decoder : 62.1029%

本章内容只是构建了一个简单的声学模型,能够真正识别相差甚远,相比于kaldi中复杂的流程。项目内容还是能够对语音识别任务有一个初步的认识。

### 参考文献:
- [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
- [Supervised Sequence Labelling with Recurrent Neural Networks](https://link.springer.com/book/10.1007/978-3-642-24797-2)
- [EESEN: END-TO-END SPEECH RECOGNITION USING DEEP RNN MODELS AND WFST-BASED DECODING](http://www.cs.cmu.edu/afs/cs/Web/People/fmetze/interACT/Publications_files/publications/eesenasru.pdf)


31 changes: 31 additions & 0 deletions chapter11-语音识别(LSTM-CTC)/conf/ctc_model_setting.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[Data]
data_dir = /home/pytorch-book/chapter11-语音识别(LSTM-CTC)/data_prepare
log_dir = /home/pytorch-book/chapter11-语音识别(LSTM-CTC)/log/
log_file = train.log

[Model]
rnn_input_size = 201
rnn_hidden_size = 256
rnn_layers = 4
rnn_type = lstm
bidirectional = True
batch_norm = True
num_class = 28
drop_out = 0
model_file =

[Training]
use_cuda = True
init_lr = 0.001
num_epoches = 2
end_adjust_acc = 2
lr_decay = 0.5
batch_size = 16
weight_decay = 0.05
seed = 1234

[Decode]
decoder_type = Beam
beam_width = 20
eval_dataset = test

50 changes: 50 additions & 0 deletions chapter11-语音识别(LSTM-CTC)/conf/dev_spk.list
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
faks0
fdac1
fjem0
mgwt0
mjar0
mmdb1
mmdm2
mpdf0
fcmh0
fkms0
mbdg0
mbwm0
mcsh0
fadg0
fdms0
fedw0
mgjf0
mglb0
mrtk0
mtaa0
mtdt0
mthc0
mwjg0
fnmr0
frew0
fsem0
mbns0
mmjr0
mdls0
mdlf0
mdvc0
mers0
fmah0
fdrw0
mrcs0
mrjm4
fcal1
mmwh0
fjsj0
majc0
mjsw0
mreb0
fgjd0
fjmg0
mroa0
mteb0
mjfc0
mrjr0
fmml0
mrws1
24 changes: 24 additions & 0 deletions chapter11-语音识别(LSTM-CTC)/conf/test_spk.list
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
mdab0
mwbt0
felc0
mtas1
mwew0
fpas0
mjmp0
mlnt0
fpkt0
mlll0
mtls0
fjlm0
mbpm0
mklt0
fnlp0
mcmj0
mjdh0
fmgd0
mgrt0
mnjm0
fdhc0
mjln0
mpam0
fmld0
Loading

0 comments on commit bcc8f48

Please sign in to comment.