This repository was archived by the owner on Dec 28, 2017. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.lua
89 lines (70 loc) · 2.12 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
require 'torch'
local ctcdecode = require 'ctcdecode'
local paths = require 'paths'
torch.setdefaulttensortype('torch.FloatTensor')
local test = torch.TestSuite()
local tester = torch.Tester()
function test.BeamSearchDecoder()
local logProbabilities = torch.Tensor{
{{0, 0, 0.6, 0.4}, {0, 0, 0.4, 0.6}},
{{0, 0, 0.5, 0.5}, {0, 0, 0.5, 0.5}},
{{0, 0, 0.4, 0.6}, {0, 0, 0.6, 0.4}},
{{0, 0, 0.4, 0.6}, {0, 0, 0.6, 0.4}},
{{0, 0, 0.4, 0.6}, {0, 0, 0.6, 0.4}}
}:log()
local nBest = 3
local decoder = ctcdecode.BeamSearchDecoder(
logProbabilities:size(3),
10 * nBest,
ctcdecode.DefaultBeamScorer(),
1,
false
)
local sequenceLength = torch.IntTensor(logProbabilities:size(2))
sequenceLength:fill(logProbabilities:size(1))
-- oof this is a long line
local outputs, alignments, pathLen, scores = decoder:decode(logProbabilities, nBest, sequenceLength)
local expected = {
{
torch.IntTensor{2, 3},
torch.IntTensor{2, 3, 2},
torch.IntTensor{3, 2, 3}
},
{
torch.IntTensor{3, 2},
torch.IntTensor{3, 2, 3},
torch.IntTensor{2, 3, 2}
}
}
for b = 1, outputs:size(1) do
for p = 1, outputs:size(2) do
-- extract the path of labels
local path = outputs[{{b}, {p}, {1, pathLen[b][p]}}]:squeeze()
tester:eq(path, expected[b][p])
end
end
end
function test.NGramDecoder()
local logProbabilities = torch.Tensor{
{{0, 0, 0.4, 0.6}},
{{0, 1, 0, 0}}, -- force a space
{{0, 0, 0.4, 0.6}},
}:log()
local labelMapFilename = paths.thisfile('data/label_map')
local ngramModelFilename = paths.thisfile('data/test.arpa')
local scorer = ctcdecode.NGramBeamScorer(labelMapFilename, ngramModelFilename)
local nBest = 1
local decoder = ctcdecode.NGramDecoder(
logProbabilities:size(3),
10 * nBest,
scorer,
1,
false
)
local outputs, alignments, pathLen, scores = decoder:decode(logProbabilities, nBest)
local path = outputs[{{1}, {1}, {1, pathLen[1][1]}}]:squeeze()
-- 'B' is not in the vocab, so we should prefer 'A A'
tester:eq(path, torch.IntTensor{2, 1, 2})
end
tester:add(test)
tester:run()