Baidu-warpctc is a CTC implement by Baidu which support GPU. CTC can be used with LSTM to solve lable alignment problems in many areas such as OCR, speech recognition.
cd ~/
git clone
cd warp-ctc
mkdir build
cd build
cmake ..
sudo make install
comment out following lines in make/
WARPCTC_PATH = $(HOME)/warp-ctc
MXNET_PLUGINS += plugin/warpctc/
rebuild mxnet by
make clean && make -j4
I implement two examples, one is just a toy example which can be used to prove ctc integration is right. The second is a OCR example with LSTM+CTC. You can run it by:
cd examples/warpctc
- Please modify
contexts = [mx.context.gpu(0)]
in this file according to your hardware. - Please review the code
. Copy your font to here font/yourfont.ttf. To get a free font from here. - The checkpoint will be auto saved in each epoch. And then you can use this checkpoint to do a predict.
The OCR example is constructed as follows:
- I generate 80x30 image for 4 digits captcha by an python captcha library
- The 80x30 image is used as 80 input for lstm and every input is one column of image (a 30 dim vector)
- The output layer use CTC loss
Following code show detail construction of the net:
def lstm_unroll(num_lstm_layer, seq_len,
num_hidden, num_label):
param_cells = []
last_states = []
for i in range(num_lstm_layer):
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
assert(len(last_states) == num_lstm_layer)
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
#every column of image is an input, there are seq_len inputs
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
hidden_all = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
for i in range(num_lstm_layer):
next_state = lstm(num_hidden, indata=hidden,
seqidx=seqidx, layeridx=i)
hidden = next_state.h
last_states[i] = next_state
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
# here we do NOT need to transpose label as other lstm examples do
label = mx.sym.Reshape(data=label, target_shape=(0,))
#label should be int type, so use cast
label = mx.sym.Cast(data = label, dtype = 'int32')
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
return sm
If you label length is smaller than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b.
Here, 0 is reserved for blank label.
Pelase run:
- Change the code following the name of your params and json file.
- You have to do a
in amalgamation folder.(a will be created in lib folder.)