forked from moses-smt/mosesdecoder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ChartKBestExtractor.cpp
281 lines (255 loc) · 10.5 KB
/
ChartKBestExtractor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
/***********************************************************************
Moses - statistical machine translation system
Copyright (C) 2006-2014 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "ChartKBestExtractor.h"
#include "ChartHypothesis.h"
#include "ScoreComponentCollection.h"
#include "StaticData.h"
#include <boost/scoped_ptr.hpp>
#include <vector>
using namespace std;
namespace Moses
{
// Extract the k-best list from the search graph.
void ChartKBestExtractor::Extract(
const std::vector<const ChartHypothesis*> &topLevelHypos, std::size_t k,
KBestVec &kBestList)
{
kBestList.clear();
if (topLevelHypos.empty()) {
return;
}
// Create a new ChartHypothesis object, supremeHypo, that has the best
// top-level hypothesis as its predecessor and has the same score.
std::vector<const ChartHypothesis*>::const_iterator p = topLevelHypos.begin();
const ChartHypothesis &bestTopLevelHypo = **p;
boost::scoped_ptr<ChartHypothesis> supremeHypo(
new ChartHypothesis(bestTopLevelHypo, *this));
// Do the same for each alternative top-level hypothesis, but add the new
// ChartHypothesis objects as arcs from supremeHypo, as if they had been
// recombined.
for (++p; p != topLevelHypos.end(); ++p) {
// Check that the first item in topLevelHypos really was the best.
UTIL_THROW_IF2((*p)->GetTotalScore() > bestTopLevelHypo.GetTotalScore(),
"top-level hypotheses are not correctly sorted");
// Note: there's no need for a smart pointer here: supremeHypo will take
// ownership of altHypo.
ChartHypothesis *altHypo = new ChartHypothesis(**p, *this);
supremeHypo->AddArc(altHypo);
}
// Create the target vertex then lazily fill its k-best list.
boost::shared_ptr<Vertex> targetVertex = FindOrCreateVertex(*supremeHypo);
LazyKthBest(*targetVertex, k, k);
// Copy the k-best list from the target vertex, but drop the top edge from
// each derivation.
kBestList.reserve(targetVertex->kBestList.size());
for (std::vector<boost::weak_ptr<Derivation> >::const_iterator
q = targetVertex->kBestList.begin();
q != targetVertex->kBestList.end(); ++q) {
const boost::shared_ptr<Derivation> d(*q);
assert(d);
assert(d->subderivations.size() == 1);
kBestList.push_back(d->subderivations[0]);
}
}
// Generate the target-side yield of the derivation d.
Phrase ChartKBestExtractor::GetOutputPhrase(const Derivation &d)
{
FactorType placeholderFactor = StaticData::Instance().GetPlaceholderFactor();
Phrase ret(ARRAY_SIZE_INCR);
const ChartHypothesis &hypo = d.edge.head->hypothesis;
const TargetPhrase &phrase = hypo.GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
phrase.GetAlignNonTerm().GetNonTermIndexMap();
for (std::size_t pos = 0; pos < phrase.GetSize(); ++pos) {
const Word &word = phrase.GetWord(pos);
if (word.IsNonTerminal()) {
std::size_t nonTermInd = nonTermIndexMap[pos];
const Derivation &subderivation = *d.subderivations[nonTermInd];
Phrase subPhrase = GetOutputPhrase(subderivation);
ret.Append(subPhrase);
} else {
ret.AddWord(word);
if (placeholderFactor == NOT_FOUND) {
continue;
}
std::set<std::size_t> sourcePosSet =
phrase.GetAlignTerm().GetAlignmentsForTarget(pos);
if (sourcePosSet.size() == 1) {
const std::vector<const Word*> *ruleSourceFromInputPath =
hypo.GetTranslationOption().GetSourceRuleFromInputPath();
UTIL_THROW_IF2(ruleSourceFromInputPath == NULL,
"Source Words in of the rules hasn't been filled out");
std::size_t sourcePos = *sourcePosSet.begin();
const Word *sourceWord = ruleSourceFromInputPath->at(sourcePos);
UTIL_THROW_IF2(sourceWord == NULL,
"Null source word at position " << sourcePos);
const Factor *factor = sourceWord->GetFactor(placeholderFactor);
if (factor) {
ret.Back()[0] = factor;
}
}
}
}
return ret;
}
// Create an unweighted hyperarc corresponding to the given ChartHypothesis.
ChartKBestExtractor::UnweightedHyperarc ChartKBestExtractor::CreateEdge(
const ChartHypothesis &h)
{
UnweightedHyperarc edge;
edge.head = FindOrCreateVertex(h);
const std::vector<const ChartHypothesis*> &prevHypos = h.GetPrevHypos();
edge.tail.resize(prevHypos.size());
for (std::size_t i = 0; i < prevHypos.size(); ++i) {
const ChartHypothesis *prevHypo = prevHypos[i];
edge.tail[i] = FindOrCreateVertex(*prevHypo);
}
return edge;
}
// Look for the vertex corresponding to a given ChartHypothesis, creating
// a new one if necessary.
boost::shared_ptr<ChartKBestExtractor::Vertex>
ChartKBestExtractor::FindOrCreateVertex(const ChartHypothesis &h)
{
VertexMap::value_type element(&h, boost::shared_ptr<Vertex>());
std::pair<VertexMap::iterator, bool> p = m_vertexMap.insert(element);
boost::shared_ptr<Vertex> &sp = p.first->second;
if (!p.second) {
return sp; // Vertex was already in m_vertexMap.
}
sp.reset(new Vertex(h));
// Create the 1-best derivation and add it to the vertex's kBestList.
UnweightedHyperarc bestEdge;
bestEdge.head = sp;
const std::vector<const ChartHypothesis*> &prevHypos = h.GetPrevHypos();
bestEdge.tail.resize(prevHypos.size());
for (std::size_t i = 0; i < prevHypos.size(); ++i) {
const ChartHypothesis *prevHypo = prevHypos[i];
bestEdge.tail[i] = FindOrCreateVertex(*prevHypo);
}
boost::shared_ptr<Derivation> bestDerivation(new Derivation(bestEdge));
std::pair<DerivationSet::iterator, bool> q =
m_derivations.insert(bestDerivation);
assert(q.second);
sp->kBestList.push_back(bestDerivation);
return sp;
}
// Create the 1-best derivation for each edge in BS(v) (except the best one)
// and add it to v's candidate queue.
void ChartKBestExtractor::GetCandidates(Vertex &v, std::size_t k)
{
// Create derivations for all of v's incoming edges except the best. This
// means everything in v.hypothesis.GetArcList() and not the edge defined
// by v.hypothesis itself. The 1-best derivation for that edge will already
// have been created.
const ChartArcList *arcList = v.hypothesis.GetArcList();
if (arcList) {
for (std::size_t i = 0; i < arcList->size(); ++i) {
const ChartHypothesis &recombinedHypo = *(*arcList)[i];
boost::shared_ptr<Vertex> w = FindOrCreateVertex(recombinedHypo);
assert(w->kBestList.size() == 1);
v.candidates.push(w->kBestList[0]);
}
}
}
// Lazily fill v's k-best list.
void ChartKBestExtractor::LazyKthBest(Vertex &v, std::size_t k,
std::size_t globalK)
{
// If this is the first visit to vertex v then initialize the priority queue.
if (v.visited == false) {
// The 1-best derivation should already be in v's k-best list.
assert(v.kBestList.size() == 1);
// Initialize v's priority queue.
GetCandidates(v, globalK);
v.visited = true;
}
// Add derivations to the k-best list until it contains k or there are none
// left to add.
while (v.kBestList.size() < k) {
assert(!v.kBestList.empty());
// Update the priority queue by adding the successors of the last
// derivation (unless they've been seen before).
boost::shared_ptr<Derivation> d(v.kBestList.back());
LazyNext(v, *d, globalK);
// Check if there are any derivations left in the queue.
if (v.candidates.empty()) {
break;
}
// Get the next best derivation and delete it from the queue.
boost::weak_ptr<Derivation> next = v.candidates.top();
v.candidates.pop();
// Add it to the k-best list.
v.kBestList.push_back(next);
}
}
// Create the neighbours of Derivation d and add them to v's candidate queue.
void ChartKBestExtractor::LazyNext(Vertex &v, const Derivation &d,
std::size_t globalK)
{
for (std::size_t i = 0; i < d.edge.tail.size(); ++i) {
Vertex &pred = *d.edge.tail[i];
// Ensure that pred's k-best list contains enough derivations.
std::size_t k = d.backPointers[i] + 2;
LazyKthBest(pred, k, globalK);
if (pred.kBestList.size() < k) {
// pred's derivations have been exhausted.
continue;
}
// Create the neighbour.
boost::shared_ptr<Derivation> next(new Derivation(d, i));
// Check if it has been created before.
std::pair<DerivationSet::iterator, bool> p = m_derivations.insert(next);
if (p.second) {
v.candidates.push(next); // Haven't previously seen it.
}
}
}
// Construct the 1-best Derivation that ends at edge e.
ChartKBestExtractor::Derivation::Derivation(const UnweightedHyperarc &e)
{
edge = e;
std::size_t arity = edge.tail.size();
backPointers.resize(arity, 0);
subderivations.reserve(arity);
for (std::size_t i = 0; i < arity; ++i) {
const Vertex &pred = *edge.tail[i];
assert(pred.kBestList.size() >= 1);
boost::shared_ptr<Derivation> sub(pred.kBestList[0]);
subderivations.push_back(sub);
}
scoreBreakdown = edge.head->hypothesis.GetScoreBreakdown();
score = edge.head->hypothesis.GetTotalScore();
}
// Construct a Derivation that neighbours an existing Derivation.
ChartKBestExtractor::Derivation::Derivation(const Derivation &d, std::size_t i)
{
edge.head = d.edge.head;
edge.tail = d.edge.tail;
backPointers = d.backPointers;
subderivations = d.subderivations;
std::size_t j = ++backPointers[i];
scoreBreakdown = d.scoreBreakdown;
// Deduct the score of the old subderivation.
scoreBreakdown.MinusEquals(subderivations[i]->scoreBreakdown);
// Update the subderivation pointer.
boost::shared_ptr<Derivation> newSub(edge.tail[i]->kBestList[j]);
subderivations[i] = newSub;
// Add the score of the new subderivation.
scoreBreakdown.PlusEquals(subderivations[i]->scoreBreakdown);
score = scoreBreakdown.GetWeightedScore();
}
} // namespace Moses