-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtutorial.cpp
613 lines (525 loc) · 17.3 KB
/
tutorial.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cassert>
#include <sstream>
#include "gtn/gtn.h"
using namespace gtn;
std::unordered_map<int, std::string> symbols = {{0, "a"}, {1, "b"}, {2, "c"}};
auto isymbols = symbols;
std::unordered_map<int, std::string> osymbols = {{0, "x"}, {1, "y"}, {2, "z"}};
// Build graph and simple utilities.
void simpleAcceptors() {
// Create a graph, by default a weighted finite-state acceptor (WFSA)
Graph graph;
// Add start node
graph.addNode(true);
// Add an accept node
graph.addNode(false, true);
// Add an internal node
graph.addNode();
// Add an arc from node 0 to 2 with label 0
graph.addArc(0, 2, 0);
// Add an arc from node 0 to 2 with input label 1 and output label 1
graph.addArc(0, 2, 1, 1);
// Add an arc from node 2 to 1 with input label 0, output label 0 and weight 2
graph.addArc(2, 1, 0, 0, 2);
// Print graph to std::out
std::cout << graph;
// Draw the graph in dot format to file
draw(graph, "simple_fsa.dot", symbols);
// Compile to pdf with
// dot -Tpdf graph.dot -o graph.pdf
// One can load a graph from an istream
std::stringstream in(
// First line is space separated start states
"0\n"
// Second line is space separated accept states
"1\n"
// The remaining lines are a list of arcs:
// <source node> <dest node> <ilabel> [olabel] [weight]
// where the olabel defaults to the ilabel if it is not specified
// and the weight defaults to 0.0 if it is not specified.
"0 2 0\n" // olabel = 0, weight = 0.0
"0 2 1 1\n" // olabel = 1, weight = 0.0
"2 1 0 0 2\n"); // olabel = 0, weight = 2.0
Graph other_graph = loadTxt(in);
// Exact match the two graphs, the node indices,
// arc weights and arc labels should all be identical.
assert(equal(graph, other_graph));
// Check that the graphs have the same structure but
// potentially different state indices. In this case,
// only the arc labels and weights must be the same.
assert(isomorphic(graph, other_graph));
}
// A few more interesting graphs
void interestingAcceptors() {
{
Graph graph;
graph.addNode(true);
// Graphs can have multiple start-nodes
graph.addNode(true);
graph.addNode();
graph.addNode(false, true);
// Graphs can also have multiple accept nodes
graph.addNode(false, true);
// Start nodes can have incoming arcs
graph.addArc(0, 1, 1);
graph.addArc(0, 2, 0);
graph.addArc(1, 3, 0);
graph.addArc(2, 3, 1);
graph.addArc(2, 3, 0);
graph.addArc(2, 4, 2);
// Accept nodes can have outgoing arcs
graph.addArc(3, 4, 1);
draw(graph, "multi_start_accept.dot", symbols);
}
{
Graph graph;
graph.addNode(true);
graph.addNode();
graph.addNode(false, true);
// Self loops are allowed
graph.addArc(0, 0, 0);
graph.addArc(0, 1, 1);
graph.addArc(0, 1, 2);
graph.addArc(1, 2, 1);
// Cycles are also allowed
graph.addArc(2, 0, 1);
draw(graph, "cycles.dot", symbols);
}
{
// Epsilon transitions
Graph graph;
graph.addNode(true);
graph.addNode();
graph.addNode(false, true);
graph.addArc(0, 1, 0);
graph.addArc(0, 1, epsilon);
graph.addArc(1, 2, 1);
draw(graph, "epsilons.dot", symbols);
}
}
// Simple operations on WFSAs (and WFSTs)
void simpleOps() {
// The union of a set of graphs accepts any sequence accepted by any
// input graph.
{
// Recognizes "aba*"
Graph g1;
g1.addNode(true);
g1.addNode();
g1.addNode(false, true);
g1.addArc(0, 1, 0);
g1.addArc(1, 2, 1);
g1.addArc(2, 2, 0);
// Recognizes "ba"
Graph g2;
g2.addNode(true);
g2.addNode();
g2.addNode(false, true);
g2.addArc(0, 1, 1);
g2.addArc(1, 2, 0);
// Recognizes "ac"
Graph g3;
g3.addNode(true);
g3.addNode();
g3.addNode(false, true);
g3.addArc(0, 1, 0);
g3.addArc(1, 2, 2);
draw(g1, "union_g1.dot", symbols);
draw(g2, "union_g2.dot", symbols);
draw(g3, "union_g3.dot", symbols);
auto graph = union_({g1, g2, g3});
draw(graph, "union_graph.dot", symbols);
}
// The concatenation of two graphs accepts any sequence xy such that x is
// accepted by the first graph and y is accepted by the second.
{
// Recognizes "ba"
Graph g1;
g1.addNode(true);
g1.addNode();
g1.addNode(false, true);
g1.addArc(0, 1, 1);
g1.addArc(1, 2, 0);
// Recognizes "ac"
Graph g2;
g2.addNode(true);
g2.addNode();
g2.addNode(false, true);
g2.addArc(0, 1, 0);
g2.addArc(1, 2, 2);
draw(g1, "concat_g1.dot", symbols);
draw(g2, "concat_g2.dot", symbols);
auto graph = concat(g1, g2);
draw(graph, "concat_graph.dot", symbols);
}
// The closure of a graph accepts any sequence accepted by the original graph
// repeated 0 or more times (0 repeats is the empty sequence
// "epsilon").
{
// Recognizes "aba"
Graph g;
g.addNode(true);
g.addNode();
g.addNode();
g.addNode(false, true);
g.addArc(0, 1, 0);
g.addArc(1, 2, 1);
g.addArc(2, 3, 0);
draw(g, "closure_input.dot", symbols);
auto graph = closure(g);
draw(graph, "closure_graph.dot", symbols);
}
}
// Intersecting WFSAs
void intersectingAcceptors() {
// The intersection of two acceptors is the graph which represents the set of
// all paths present in both. The score for a path in the intersected graph is
// the sum of the scores for the path in the two input graphs.
Graph g1;
g1.addNode(true);
g1.addNode(false, true);
g1.addArc(0, 0, 0);
g1.addArc(0, 1, 1);
g1.addArc(1, 1, 2);
Graph g2;
g2.addNode(true);
g2.addNode();
g2.addNode();
g2.addNode(false, true);
g2.addArc(0, 1, 0);
g2.addArc(0, 1, 1);
g2.addArc(0, 1, 2);
g2.addArc(1, 2, 0);
g2.addArc(1, 2, 1);
g2.addArc(1, 2, 2);
g2.addArc(2, 3, 0);
g2.addArc(2, 3, 1);
g2.addArc(2, 3, 2);
auto intersected = intersect(g1, g2);
draw(g1, "simple_intersect_g1.dot", symbols);
draw(g2, "simple_intersect_g2.dot", symbols);
draw(intersected, "simple_intersect.dot", symbols);
}
// Forwarding WFSAs
void forwardingAcceptors() {
// The forward algorithm computes the log-sum-exp of the scores for all
// accepting paths in a graph. The graph must not have cycles.
Graph graph;
graph.addNode(true);
graph.addNode(true);
graph.addNode();
graph.addNode(false, true);
graph.addArc(0, 1, 0, 0, 1.1);
graph.addArc(0, 2, 1, 1, 3.2);
graph.addArc(1, 2, 2, 2, 1.4);
graph.addArc(2, 3, 0, 0, 2.1);
// The accepting paths are:
// 0 2 0 (nodes 0 -> 1 -> 2 -> 3 and score = 1.1 + 1.4 + 2.1)
// 1 0 (nodes 0 -> 2 -> 3 and score = 3.2 + 2.1)
// 2 0 (nodes 1 -> 2 -> 3 and score = 1.4 + 2.1)
// The final score is the logadd of the individual path scores.
auto forwarded = forwardScore(graph);
// Use Graph::item() to get the score out of a scalar graph:
float score = forwarded.item();
std::cout << "The forward score is: " << score << std::endl;
draw(graph, "simple_forward.dot", symbols);
// The Viterbi algorithm can be used to compute the highest scoring path and
// it's score in the graph.
auto vscore = viterbiScore(graph);
std::cout << "The Viterbi score is: " << vscore.item() << std::endl;
auto vpath = viterbiPath(graph);
draw(vpath, "simple_viterbi_path.dot", symbols);
}
// Differentiable WFSAs
void differentiableAcceptors() {
// By default a graph will be included in the autograd tape.
auto in = std::stringstream(
"0\n"
"2\n"
"0 1 0\n"
"0 1 1\n"
"1 2 0\n"
"1 2 1");
Graph g1 = loadTxt(in);
// To disable gradient computation for and through a graph, set it's
// calcGrad value to false:
Graph g2(false);
g2.addNode(true);
g2.addNode(false, true);
g2.addArc(0, 0, 0);
g2.addArc(0, 1, 1);
auto a = forwardScore(compose(g1, g2));
auto b = forwardScore(g1);
auto loss = subtract(b, a);
// Differentiate through the computation.
backward(loss);
// Access the graph gradient
Graph grad = g1.grad();
// The gradient with respect to the input graph arcs are the weights on the
// arcs of the gradient graph.
for (auto a = 0; a < grad.numArcs(); ++a) {
grad.weight(a);
}
// The intermediate graphs a and b also have gradients.
a.grad().weight(0);
b.grad().weight(0);
// If gradient computation is disabled, accessing
// the gradient throws.
try {
g2.grad();
} catch (const std::logic_error& e) {
std::cout << e.what() << std::endl;
}
// Zero the gradients before re-using the graphs in
// a new computation, otherwise the gradients will
// simply accumulate.
g1.zeroGrad();
}
// An example: The Auto Segmentation Criterion
// https://arxiv.org/abs/1609.03193
void autoSegCriterion() {
// Consider the ASG alignment graph for the sequence
// [0, 1, 2]
Graph fal;
fal.addNode(true);
fal.addNode();
fal.addNode();
fal.addNode(false, true);
fal.addArc(0, 1, 0);
fal.addArc(1, 1, 0);
fal.addArc(1, 2, 1);
fal.addArc(2, 2, 1);
fal.addArc(2, 3, 2);
fal.addArc(3, 3, 2);
// The fal graph represents all possible alignemnts of the sequence
// where each token ocurrs one or more times.
// Now suppose we have an emission graph for an input with 4 frames.
Graph emissions;
emissions.addNode(true);
emissions.addNode();
emissions.addNode();
emissions.addNode();
emissions.addNode(false, true);
// Loop over time-steps
for (int t = 0; t < 4; t++) {
// Loop over alphabet
for (int i = 0; i < 3; i++) {
emissions.addArc(t, t + 1, i);
}
}
// To limit the alignments to length 4, we can compose the
// alignments graph with the emissions graph.
auto composed = compose(fal, emissions);
draw(fal, "asg_alignments.dot", symbols);
draw(emissions, "asg_emissions.dot", symbols);
draw(composed, "asg_composed.dot", symbols);
// Compute the asg loss which is the negative log likelihood:
// asg = -(fal - fcc)
// where fal (forwardScore(composed)) is the constrained score and
// fcc (forwardScore(emissions) is the unconstrained score (i.e.
// the partition function).
auto loss = subtract(forwardScore(emissions), forwardScore(composed));
// To get gradients:
backward(loss);
// We can also add transitions by making a bigram transition graph:
Graph transitions;
transitions.addNode(true);
transitions.addNode(false, true);
transitions.addNode(false, true);
transitions.addNode(false, true);
for (int i = 1; i <= 3; i++) {
transitions.addArc(0, i, i - 1); // p(i | <s>)
}
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
transitions.addArc(i + 1, j + 1, j); // p(j | i)
}
}
draw(transitions, "asg_transitions.dot", symbols);
// Computing the actual asg loss with transitions is as simple
// as composing with the transition graph:
auto num_graph = compose(compose(fal, transitions), emissions);
auto denom_graph = compose(emissions, transitions);
loss = subtract(forwardScore(denom_graph), forwardScore(num_graph));
// The order of composition won't affect the results as it is an
// associative operation. However, just like multiplying matrices,
// the order of operations can make a big difference in run time.
// For example:
// compose(compose(fal, transitions), emissions)
// will be much faster than
// compose(fal, compose(transitions, emissions))
}
// An example: The CTC Criterion
// https://www.cs.toronto.edu/~graves/icml_2006.pdf
void ctcCriterion() {
std::unordered_map<int, std::string> symbols = {{0, "-"}, {1, "a"}, {2, "b"}};
// Consider the CTC alignment graph for the sequence
// [1, 2] where the blank index is 0.
Graph ctc;
ctc.addNode(true);
ctc.addNode();
ctc.addNode();
ctc.addNode(false, true);
ctc.addNode(false, true);
ctc.addArc(0, 0, 0);
ctc.addArc(0, 1, 1);
ctc.addArc(1, 1, 1);
ctc.addArc(1, 2, 0);
ctc.addArc(1, 3, 2);
ctc.addArc(2, 2, 0);
ctc.addArc(2, 3, 2);
ctc.addArc(3, 3, 2);
ctc.addArc(3, 4, 0);
ctc.addArc(4, 4, 0);
// The ctc graph represents all possible alignemnts of the sequence
// where each token ocurrs one or more times with zero or more blank
// tokens in between.
// Now suppose we have an emission graph for an input with 4 frames.
Graph emissions;
emissions.addNode(true);
emissions.addNode();
emissions.addNode();
emissions.addNode();
emissions.addNode(false, true);
// Loop over time-steps
for (int t = 0; t < 4; t++) {
// Loop over alphabet (including blank)
for (int i = 0; i < 3; i++) {
emissions.addArc(t, t + 1, i);
}
}
// To limit the ctc graph to alignments of length 4, we can compose it
// with the emissions graph.
auto composed = compose(ctc, emissions);
draw(ctc, "ctc_alignments.dot", symbols);
draw(emissions, "ctc_emissions.dot", symbols);
draw(composed, "ctc_composed.dot", symbols);
// Compute the ctc loss
auto loss = subtract(forwardScore(emissions), forwardScore(composed));
// In practice, without transitions, we can
// normalize per frame scores and only compute
// loss = negate(forwardScore(composed));
// We can also add transitions to CTC just like in ASG!
Graph transitions;
transitions.addNode(true);
transitions.addNode(false, true);
transitions.addNode(false, true);
transitions.addNode(false, true);
for (int i = 1; i <= 3; i++) {
transitions.addArc(0, i, i - 1); // p(i | <s>)
}
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
transitions.addArc(i + 1, j + 1, j); // p(j | i)
}
}
// Computing the ctc loss is identical to computing the asg loss,
// the only difference is the alignment graph (ctc instead of asg).
auto num_graph = compose(compose(ctc, transitions), emissions);
auto denom_graph = compose(emissions, transitions);
loss = subtract(forwardScore(denom_graph), forwardScore(num_graph));
}
void simpleTransducers() {
Graph graph;
graph.addNode(true);
graph.addNode();
graph.addNode(false, true);
// Adding an arc with just an input label, the output label defaults to have
// the same value as the input label
graph.addArc(0, 1, 0);
// Add an arc from node 0 to 2 with the same input and output label of 1
graph.addArc(0, 1, 1, 1);
// However, adding an arc with a different input and output label
graph.addArc(1, 2, 1, 2);
// Specify the input and output symbols
draw(graph, "simple_fst.dot", isymbols, osymbols);
}
// Composing WFSTs
void composingTransducers() {
// The composition of two trandsucers is the graph which represents the set
// of all paths such that the output of labelling of a path in the first
// graph matches the input labelling of a path in the second graph. The
// labelling of the path is the input labelling of the path in the first
// graph and the output labelling of the path in the second graph. The score
// of the path in the composed graph is the sum of the scores for the paths
// in the two input graphs.
Graph g1;
g1.addNode(true);
g1.addNode(false, true);
g1.addArc(0, 0, 0, 0);
g1.addArc(0, 1, 1, 1);
g1.addArc(1, 1, 2, 2);
Graph g2;
g2.addNode(true);
g2.addNode();
g2.addNode();
g2.addNode(false, true);
g2.addArc(0, 1, 0, 0);
g2.addArc(0, 1, 0, 1);
g2.addArc(0, 1, 1, 2);
g2.addArc(1, 2, 0, 0);
g2.addArc(1, 2, 1, 1);
g2.addArc(1, 2, 2, 2);
g2.addArc(2, 3, 1, 0);
g2.addArc(2, 3, 2, 1);
g2.addArc(2, 3, 2, 2);
// The output alphabet of the first graph is assumed to be the same as as the
// input alphabet of the second graph. Note also that composing/intersecting
// two acceptors commutes, but composing a transducer with another transducer
// or an acceptor does not.
auto composed = compose(g1, g2);
draw(g1, "transducer_compose_g1.dot", isymbols, osymbols);
draw(g2, "transducer_compose_g2.dot", osymbols, isymbols);
draw(composed, "transducer_compose.dot", isymbols, isymbols);
}
// WFSTs with epsilons
void epsilonTransitions() {
// Transducers or acceptors can have epsilon transitions.
Graph g1;
g1.addNode(true);
g1.addNode();
g1.addNode(false, true);
// Use epsilon to denote an epsilon label (the integer value is -1,
// though you should avoid using that directly to make your code more future
// proof).
g1.addArc(0, 1, 1, epsilon, 1.1);
g1.addArc(1, 2, 0, 0, 2);
// We can forward graphs with epsilons (as long as they don't have any
// cycles).
forwardScore(g1);
g1.addArc(0, 0, 0, epsilon, 0.5);
// Drawing will use a special "ε" token to represent
// `epsilon` when symbols are specified.
draw(g1, "epsilon_graph1.dot", isymbols, osymbols);
Graph g2;
g2.addNode(true);
g2.addNode(false, true);
g2.addArc(0, 1, 0, 0, 1.3);
g2.addArc(1, 1, epsilon, 2, 2.5);
draw(g2, "epsilon_graph2.dot", osymbols, isymbols);
// We can compose graphs with epsilons
auto composed = compose(g1, g2);
draw(composed, "epsilon_composed.dot", isymbols, isymbols);
// For a detailed discussion on composition with epsilon transitions see
// "Weighted Automata Algorithms", Mehryar Mohri,
// https://cs.nyu.edu/~mohri/pub/hwa.pdf, Section 5.1
}
int main() {
simpleAcceptors();
interestingAcceptors();
simpleOps();
intersectingAcceptors();
forwardingAcceptors();
differentiableAcceptors();
autoSegCriterion();
ctcCriterion();
simpleTransducers();
composingTransducers();
epsilonTransitions();
}