This repository was archived by the owner on Dec 28, 2017. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathBeamSearchDecoder.lua
124 lines (105 loc) · 3.82 KB
/
BeamSearchDecoder.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
local ctcdecode = require 'ctcdecode.env'
local argcheck = require 'argcheck'
local ffi = require 'ffi'
ffi.cdef[[
typedef struct DefaultBeamScorer DefaultBeamScorer;
typedef struct BeamSearchDecoder BeamSearchDecoder;
void *BeamSearchDecoder_new(int num_labels, int beam_width,
DefaultBeamScorer *scorer, int blank_label,
bool merge_repeated);
void BeamSearchDecoder_delete(BeamSearchDecoder *cdata);
void BeamSearchDecoder_Decode(BeamSearchDecoder *cdata,
THFloatTensor *inputs_tensor,
THIntTensor *seq_len_tensor,
THIntTensor *outputs_tensor,
THIntTensor *alignments_tensor,
THIntTensor *path_len_tensor,
THFloatTensor *scores_tensor);
void BeamSearchDecoder_SetLabelSelectionSize(BeamSearchDecoder *cdata,
int label_selection_size);
void BeamSearchDecoder_SetLabelSelectionMargin(BeamSearchDecoder *cdata,
float label_selection_margin);
]]
local C = ctcdecode.C
local BeamSearchDecoder = torch.class(
'ctcdecode.BeamSearchDecoder', ctcdecode
)
-- TODO: Maybe register all of the C function calls so that we can reuse
-- the Lua-side implementations for these methods?
BeamSearchDecoder.__init = argcheck{
{name = 'self', type = 'ctcdecode.BeamSearchDecoder'},
{name = 'numClasses', type = 'number'},
{name = 'beamWidth', type = 'number'},
{name = 'scorer', type = 'ctcdecode.DefaultBeamScorer'},
{name = 'blankLabel', type = 'number'},
{name = 'mergeRepeated', type = 'boolean', default = false},
call =
function(self, numClasses, beamWidth, scorer, blankLabel, mergeRepeated)
self._scorer = scorer -- avoid having scorer garbage collected prematurely
self._cdata = ffi.gc(
C.BeamSearchDecoder_new(
numClasses,
beamWidth,
self._scorer:cdata(),
blankLabel - 1,
mergeRepeated
),
C.BeamSearchDecoder_delete
)
end
}
BeamSearchDecoder.cdata = argcheck{
{name = 'self', type = 'ctcdecode.BeamSearchDecoder'},
call =
function(self)
return self._cdata
end
}
BeamSearchDecoder.decode = argcheck{
{name = 'self', type = 'ctcdecode.BeamSearchDecoder'},
{name = 'inputs', type = 'torch.FloatTensor'},
{name = 'topPaths', type = 'number'},
{name = 'seqLen', type = 'torch.IntTensor', opt = true},
call =
function(self, inputs, topPaths, seqLen)
local maxSeqLen = inputs:size(1)
local batchSize = inputs:size(2)
local seqLen = seqLen or torch.IntTensor(batchSize):fill(maxSeqLen)
local outputs = torch.IntTensor(batchSize, topPaths, maxSeqLen):zero()
local scores = torch.FloatTensor(batchSize, topPaths):zero()
local alignments = torch.IntTensor(batchSize, topPaths, maxSeqLen):zero()
local pathLen = torch.IntTensor(batchSize, topPaths):zero()
C.BeamSearchDecoder_Decode(
self:cdata(),
inputs:cdata(),
seqLen:cdata(),
outputs:cdata(),
alignments:cdata(),
pathLen:cdata(),
scores:cdata()
)
return outputs, alignments, pathLen, scores
end
}
BeamSearchDecoder.setLabelSelectionSize = argcheck{
{name = 'self', type = 'ctcdecode.BeamSearchDecoder'},
{name = 'labelSelectionSize', type = 'number'},
call =
function(self, labelSelectionSize)
C.BeamSearchDecoder_SetLabelSelectionSize(
self:cdata(),
labelSelectionSize
)
end
}
BeamSearchDecoder.setLabelSelectionMargin = argcheck{
{name = 'self', type = 'ctcdecode.BeamSearchDecoder'},
{name = 'labelSelectionMargin', type = 'number'},
call =
function(self, labelSelectionMargin)
C.BeamSearchDecoder_SetLabelSelectionMargin(
self:cdata(),
labelSelectionMargin
)
end
}