forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
latticeforwardbackward.cpp
1437 lines (1297 loc) · 67.4 KB
/
latticeforwardbackward.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// latticearchive.cpp -- managing lattice archives
//
// F. Seide, V-hansu
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#endif
#include "Basics.h"
#include "simple_checked_arrays.h"
#include "latticearchive.h"
#include "simplesenonehmm.h" // the model
#include "ssematrix.h" // the matrices
#include "latticestorage.h"
#include <unordered_map>
#include <list>
#include <stdexcept>
using namespace std;
#define VIRGINLOGZERO (10 * LOGZERO) // used for printing statistics on unseen states
#undef CPU_VERIFICATION
#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif
namespace msra { namespace lattices {
// ---------------------------------------------------------------------------
// helper class for allocation lots of small matrices, no free
// ---------------------------------------------------------------------------
class littlematrixheap
{
static const size_t CHUNKSIZE;
typedef msra::math::ssematrixfrombuffer matrixfrombuffer;
std::list<std::vector<float>> heap;
size_t allocatedinlast; // in last heap element
size_t totalallocated;
std::vector<matrixfrombuffer> matrices;
public:
littlematrixheap(size_t estimatednumentries)
: totalallocated(0), allocatedinlast(0)
{
matrices.reserve(estimatednumentries + 1);
}
msra::math::ssematrixbase &newmatrix(size_t rows, size_t cols)
{
const size_t elementsneeded = matrixfrombuffer::elementsneeded(rows, cols);
if (heap.empty() || (heap.back().size() - allocatedinlast) < elementsneeded)
{
const size_t nelem = max(CHUNKSIZE, elementsneeded + 3 /*+3 for SSE alignment*/);
heap.push_back(std::vector<float>(nelem));
allocatedinlast = 0;
// make sure starting element is SSE-aligned (the constructor demands that)
const size_t offelem = (((size_t) &heap.back()[allocatedinlast]) / sizeof(float)) % 4;
if (offelem != 0)
allocatedinlast += 4 - offelem;
}
auto &buffer = heap.back();
if (elementsneeded > heap.back().size() - allocatedinlast)
LogicError("newmatrix: allocation logic screwed up");
// get our buffer into a handy vector-like thingy
array_ref<float> vecbuffer(&buffer[allocatedinlast], elementsneeded);
// allocate in the current heap location
matrices.resize(matrices.size() + 1);
if (matrices.size() + 1 > matrices.capacity())
LogicError("newmatrix: littlematrixheap cannot grow but was constructed with too small number of eements");
auto &matrix = matrices.back();
matrix = matrixfrombuffer(vecbuffer, rows, cols);
allocatedinlast += elementsneeded;
totalallocated += elementsneeded;
return matrix;
}
};
const size_t littlematrixheap::CHUNKSIZE = 256 * 1024; // 1 MB
// ---------------------------------------------------------------------------
// helpers for log-domain addition
// ---------------------------------------------------------------------------
#ifndef LOGZERO
#define LOGZERO -1e30f
#endif
// logadd (loga, logb) -> a += b, or loga = log [ exp(loga) + exp(logb) ]
static void logaddratio(float &loga, float diff)
{
if (diff < -17.0f)
return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
loga += logf(1.0f + expf(diff));
}
static void logaddratio(double &loga, double diff)
{
if (diff < -37.0f)
return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
loga += log(1.0 + exp(diff));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template <typename FLOAT>
static void logadd(FLOAT &loga, FLOAT logb)
{
if (logb > loga) // we add smaller to bigger
::swap(loga, logb);
if (loga <= LOGZERO) // both are 0
return;
logaddratio(loga, logb - loga);
}
template <typename FLOAT>
static void logmax(FLOAT &loga, FLOAT logb) // for testing (max approx)
{
if (logb > loga)
loga = logb;
}
template <typename FLOAT>
static FLOAT expdiff(FLOAT a, FLOAT b) // for testing
{
if (b > a)
return exp(b) * (exp(a - b) - 1);
else
return exp(a) * (1 - exp(b - a));
}
template <typename FLOAT>
static bool islogzero(FLOAT v)
{
return v < LOGZERO / 2;
} // is this number to be considered 0
// ---------------------------------------------------------------------------
// other helpers go here
// ---------------------------------------------------------------------------
// helper to reconstruct the phonetic transcript
/*static*/ std::string lattice::gettranscript(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset)
{
std::string trans;
foreach_index (k, units) // we exploit that units have fixed boundaries
{
if (k > 0)
trans.push_back(' ');
trans.append(hset.gethmm(units[k].unit).getname());
}
return trans;
}
// ---------------------------------------------------------------------------
// forwardbackwardedge() -- perform state-level forward-backward on a single lattice edge
//
// Results:
// - gammas(j,t) for valid time ranges (remaining areas are not initialized)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::forwardbackwardedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logalphas(loggammas, 1, logLLs.cols()); // shifted views into gammas(,) for alphas and betas
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logbetas(loggammas, 2, logLLs.cols());
// alphas(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// betas(j,t) store the sum of all paths exiting from state j at time t, not including logLL(j,t)
// gammas(j,t) = alphas(j,t) * betas(j,t) / totalLL
// backward pass --token passing
size_t te = logbetas.cols();
size_t je = logbetas.rows();
float bwscore = 0.0f; // backward score
for (size_t k = units.size() - 1; k + 1 > 0; k--)
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t ts = te - units[k].frames; // end time of current unit
const size_t js = je - n; // range of state indices
// pass in the transition score
// t = ts: exit transition (last frame only or tee transition)
float exitscore = 1e30f; // (something impossible)
if (te == ts) // tee transition
{
exitscore = bwscore + transP(-1, n);
}
else // not tee: expand all last states
{
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
logbetas(i, te - 1) = bwscore + transP(from, n);
}
}
// expand from states j at time t (not yet including LL) to time t-1
for (size_t t = te - 1; t + 1 > ts /*note: cannot test t >= ts because t < 0 possible*/; t--)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // source trellis node
const size_t s = hmm.getsenoneid(to); // senone id for state at position 'to' in the HMM
const float acLL = logLLs(s, t);
if (islogzero(acLL))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) ac score(%d,%d) is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te,
(int) s, (int) t,
(int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
const float betajt = logbetas(j, t); // sum over all all path exiting from (j,t) to end
const float betajtpll = betajt + acLL; // incorporate acoustic score
if (t > ts)
for (size_t from = 0 /*no transition from entry state*/; from < n; from++)
{
const size_t i = js + from; // target trellis node
const float pathscore = betajtpll + transP(from, to);
if (to == 0)
logbetas(i, t - 1 /*propagate into preceding frame*/) = pathscore;
else
logadd(logbetas(i, t - 1 /*propagate into preceding frame*/), pathscore);
}
else // transition to entry state
{
const float pathscore = betajtpll + transP(-1, to);
if (to == 0)
exitscore = pathscore;
else
logadd(exitscore, pathscore); // propagate into preceding unit
}
}
}
bwscore = exitscore;
if (islogzero(bwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) bw score is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te, (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
te = ts;
je = js;
}
assert(te == 0 && je == 0);
const float totalbwscore = bwscore;
// forward pass --regular Viterbi
// This also computes the gammas right away.
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit ' k'
float fwscore = 0.0f; // score passed across phone boundaries
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te = ts + units[k].frames; // end time of current unit
const size_t je = js + n; // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // target trellis node
const size_t s = hmm.getsenoneid(to);
const float acLL = logLLs(s, t);
float alphajtnoll = LOGZERO;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
alphajtnoll = pathscore;
}
else
for (size_t from = 0 /*no entering possible*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
logadd(alphajtnoll, pathscore);
}
logalphas(j, t) = alphajtnoll + acLL;
}
// update the gammas --do it here because in next frame, betas get overwritten by alphas (they share memory)
for (size_t j = js; j < je; j++)
{
if (!islogzero(totalbwscore))
loggammas(j, t) = logalphas(j, t) + logbetas(j, t) - totalbwscore;
else // 0/0 problem, can occur if an ac score is so bad that it is 0 after going through softmax
loggammas(j, t) = LOGZERO;
}
}
// t = te: exit transition (last frame only or tee transition)
float exitscore;
if (te == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
}
else // not tee: expand all last states
{
exitscore = LOGZERO;
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, te - 1); // newly computed path score, transiting to t=te
const float pathscore = alphaitm1 + transP(from, n);
logadd(exitscore, pathscore);
}
}
fwscore = exitscore; // score passed on to next unit
js = je;
ts = te;
}
assert(js == logalphas.rows() && ts == logalphas.cols());
const float totalfwscore = fwscore;
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(totalbwscore) ^ islogzero(totalfwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw 0 score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
if (islogzero(totalbwscore))
{
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
if (fabsf(totalfwscore - totalbwscore) / ts > 1e-4f)
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
// we return the full path score
return totalfwscore;
}
// ---------------------------------------------------------------------------
// alignedge() -- perform Viterbi alignment on a single edge
//
// This is an alternative to forwardbackwardedge() that just uses the best path.
// Results:
// - if not returnsenoneids -> 'binary gammas(j,t)' for valid time ranges (remaining areas are not initialized); MMI-compatible
// - if returnsenoneids -> loggammas(0,t) will contain the senone ids directly instead (for sMBR mode)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::alignedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex /*for diagnostic messages*/, const bool returnsenoneids,
array_ref<unsigned short> thisedgealignmentsj)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> backpointers(loggammas, 0, logLLs.cols());
msra::math::ssematrixstriperef<msra::math::ssematrixbase> pathscores(loggammas, 2, logLLs.cols());
// pathscores(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// backpointers(j,t) are the relative states that it came from
// gammas(j,t) <- 1 if on best path, 0 otherwise
const int invalidbp = -2;
// Viterbi alignment
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit 'k'
float fwscore = 0.0f; // score passed across phone boundaries
int fwbackpointer = -1; // bp passed across phone boundaries, -1 means start of utterance
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te = ts + units[k].frames; // end time of current unit
const size_t je = js + hmm.getnumstates(); // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t j = js; j < je; j++)
{
const size_t to = j - js; // relative state
const size_t s = hmm.getsenoneid(to);
pathscores(j, t) = LOGZERO;
backpointers(j, t) = invalidbp;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) fwbackpointer;
}
else
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
if (pathscore > pathscores(j, t))
{
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) i;
}
}
const float acLL = logLLs(s, t);
pathscores(j, t) += acLL;
}
}
// t = te: exit transition (last frame only or tee transition)
float exitscore = LOGZERO;
int exitbackpointer = invalidbp;
if (te == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
exitbackpointer = fwbackpointer;
}
else // not tee: expand all last states
{
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, te - 1); // newly computed path score, transiting to t=te
const float pathscore = alphaitm1 + transP(from, n);
if (pathscore > exitscore)
{
exitscore = pathscore;
exitbackpointer = (int) i;
}
}
}
if (exitbackpointer == invalidbp)
LogicError("exitbackpointer came up empty");
fwscore = exitscore; // score passed on to next unit
fwbackpointer = exitbackpointer; // and accompanying backpointer
js = je;
ts = te;
}
assert(js == pathscores.rows() && ts == pathscores.cols());
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(fwscore))
{
fprintf(stderr, "alignedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
// traceback & gamma update
size_t te = backpointers.cols();
size_t je = backpointers.rows();
int j = fwbackpointer;
for (size_t k = units.size() - 1; k + 1 > 0; k--) // go in units because we also need to clear out the column
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t ts = te - units[k].frames; // end time of current unit
const size_t js = je - hmm.getnumstates(); // range of state indices
for (size_t t = te - 1; t + 1 > ts; t--)
{
if (j < (int) js || j >= (int) je)
LogicError("invalid backpointer resulting in state index out of range");
int bp = (int) backpointers(j, t); // save the backpointer before overwriting it (gammas and backpointers are aliases of each other)
// thisedgealignmentsj[t] = (unsigned short)hmm.getsenoneid(j - js);
if (!returnsenoneids) // return binary gammas (for MMI; this mode is compatible with softalignmode)
for (size_t i = js; i < je; i++)
loggammas(i, t) = ((int) i == j) ? 0.0f : LOGZERO;
else // return senone id (for sMBR; note: NOT compatible with softalignmode; calling code must know this)
thisedgealignmentsj[t] = (unsigned short) hmm.getsenoneid(j - js);
if (bp == invalidbp)
LogicError("deltabackpointer not initialized");
j = bp; // trace back one step
}
te = ts;
je = js;
}
if (j != -1)
LogicError("invalid backpointer resulting in not reaching start of utterance when tracing back");
assert(je == 0 && te == 0);
// we return the full path score
return fwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlattice() -- lattice-level forward/backward
//
// This computes word posteriors, and also returns the per-node alphas and betas.
// Per-edge acoustic scores are passed in via a lambda, as this function is
// intended for use at multiple places with different scores.
// (Specifically, we also use it to determine a pruning threshold, based on
// the original lattice's ac. scores, before even bothering to compute the
// new ac. scores.)
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlattice(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<double> &logpps,
std::vector<double> &logalphas, std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode,
const_array_ref<size_t> &uids, const edgealignments &thisedgealignments,
std::vector<double> &logEframescorrect, std::vector<double> &Eframescorrectbuf, double &logEframescorrecttotal) const
{ // ^^ TODO: remove this
// --- hand off to parallelized (CUDA) implementation if available
if (parallelstate.enabled())
{
double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
return totalfwscore;
}
// if we get here, we have no CUDA, and do it the good ol' way
// allocate return values
logpps.resize(edges.size()); // this is our primary return value
// TODO: these are return values as well, but really shouldn't anymore; only used in some older baseline code we some day may want to compare against
logalphas.assign(nodes.size(), LOGZERO);
logalphas.front() = 0.0f;
logbetas.assign(nodes.size(), LOGZERO);
logbetas.back() = 0.0f;
// --- sMBR version
if (sMBRmode)
{
logEframescorrect.resize(edges.size());
Eframescorrectbuf.resize(edges.size());
std::vector<double> logaccalphas(nodes.size(), LOGZERO); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> logaccbetas(nodes.size(), LOGZERO); // [i] likewise
std::vector<double> logframescorrectedge(edges.size()); // raw counts of correct frames in each edge
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
logframescorrectedge[j] = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
double loginaccs = logaccalphas[e.S] - logalphas[e.S];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logalphas[e.S] + edgescore;
logadd(logaccalphas[e.E], logpathacc);
}
foreach_index (j, logaccalphas)
logaccalphas[j] -= logalphas[j];
const double totalfwscore = logalphas.back();
const double totalfwacc = logaccalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
double loginaccs = logaccbetas[e.E] - logbetas[e.E];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logbetas[e.E] + edgescore;
logadd(logaccbetas[e.S], logpathacc);
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
double tmplogeframecorrect = logframescorrectedge[j];
logadd(tmplogeframecorrect, logaccalphas[e.S]);
logadd(tmplogeframecorrect, logaccbetas[e.E] - logbetas[e.E]);
Eframescorrectbuf[j] = exp(tmplogeframecorrect);
}
foreach_index (j, logaccbetas)
logaccbetas[j] -= logbetas[j];
const double totalbwscore = logbetas.front();
const double totalbwacc = logaccbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
logEframescorrecttotal = totalbwacc;
return totalbwscore;
}
// --- MMI version
// forward pass
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass
// this also computes the word posteriors on the fly, since we are at it
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
// compute lattice posteriors on the fly since we are at it
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
}
const double totalbwscore = logbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
return totalfwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts,
// both the conditioned one (corresponding to c(q) in Dan Povey's thesis)
// and the global one (which is the sMBR criterion to optimize).
//
// Outputs:
// - Eframescorrect[j] == expected frames-correct count conditioned on a state of edge[j].
// We currently assume a hard state alignment. With that, the value turns out
// to be identical for all states of an edge, so we only store it once per edge.
// - return value: expected frames-correct count for entire lattice
//
// Call forwardbackwardlattices() first to compute logalphas/betas.
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlatticesMBR(const std::vector<float> &edgeacscores, const msra::asr::simplesenonehmm &hset,
const std::vector<double> &logalphas, const std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const_array_ref<size_t> &uids,
const edgealignments &thisedgealignments, std::vector<double> &Eframescorrect) const
{
std::vector<double> accalphas(nodes.size(), 0); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> accbetas(nodes.size(), 0); // [i] likewise
std::vector<size_t> maxcorrect(nodes.size(), 0); // [i] max correct frames up to this node (oracle)
std::vector<double> framescorrectedge(edges.size()); // raw counts of correct frames in each edge
std::vector<int> backpointersformaxcorr(nodes.size(), -2); // keep track of backpointer for the max corr
backpointersformaxcorr.front() = -1;
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accalphas[e.S];
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
framescorrectedge[j] = (double) framescorrect; // remember for backward pass
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
// contribution to end node's path acc = start node's plus edge's correct count, weighted by LL, and divided by sum over LLs
double pathacc = (inaccs + framescorrectedge[j]) * exp(logalphas[e.S] + edgescore - logalphas[e.E]);
accalphas[e.E] += pathacc;
// also keep track of max accuracy, so we can find out whether the lattice contains the correct path
size_t oracleframescorrect = maxcorrect[e.S] + framescorrect; // keep track of most correct path up to end of this edge
if (oracleframescorrect > maxcorrect[e.E])
{
maxcorrect[e.E] = oracleframescorrect;
backpointersformaxcorr[size_t(e.E)] = j;
}
}
const double totalfwacc = accalphas.back();
hset; // just for reference
// report on ground-truth path
// TODO: we will later have code that adds this path if needed
size_t oracleframeacc = maxcorrect.back();
if (oracleframeacc != info.numframes)
fprintf(stderr, "forwardbackwardlatticesMBR: ground-truth path missing from lattice (most correct path: %d out of %d frames correct)\n", (unsigned int) oracleframeacc, (int) info.numframes);
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
double pathacc = (inaccs + framescorrectedge[j]) * exp(logbetas[e.E] + edgescore - logbetas[e.S]);
accbetas[e.S] += pathacc;
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
Eframescorrect[j] = (float) (accalphas[e.S] + accbetas[e.E] + framescorrectedge[j]);
}
const double totalbwacc = accbetas.front();
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
return totalbwacc;
}
// ---------------------------------------------------------------------------
// bestpathlattice() -- lattice-level "forward/backward" that only returns the
// best path, but in the form of word posteriors, which are 1 or 0, just like
// a real lattice-level forward/backward would do.
// We don't really use this; this was only for a contrast experiment.
// ---------------------------------------------------------------------------
double lattice::bestpathlattice(const std::vector<float> &edgeacscores, std::vector<double> &logpps,
const float lmf, const float wp, const float amf) const
{
// forward pass --sortnedness => regular Viterbi
std::vector<double> logalphas(nodes.size(), LOGZERO);
std::vector<int> backpointers(nodes.size(), -2);
logalphas.front() = 0.0f;
backpointers.front() = -1;
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
if (pathscore > logalphas[e.E])
{
logalphas[e.E] = pathscore;
backpointers[e.E] = j;
}
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "bestpathlattice: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// traceback
// We encode the result by storing log 1 in edges on the best path, and log 0 else;
// this makes it naturally compatible with softalign mode
logpps.resize(edges.size());
foreach_index (j, edges)
logpps[j] = LOGZERO;
int backpos = backpointers[nodes.size() - 1];
while (backpos >= 0)
{
logpps[backpos] = 0.0f; // edge is on best path -> PP = 1.0
backpos = backpointers[edges[backpos].S];
}
assert(backpos == -1);
return totalfwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardalign() -- compute the statelevel gammas or viterbi alignments
// the first phase of lattice::forwardbackward
//
// Outputs:
// ---------------------------------------------------------------------------
void lattice::forwardbackwardalign(parallelstate ¶llelstate,
const msra::asr::simplesenonehmm &hset, const bool softalignstates,
const double minlogpp, const std::vector<double> &origlogpps,
std::vector<msra::math::ssematrixbase *> &abcs, littlematrixheap &matrixheap,
const bool returnsenoneids,
std::vector<float> &edgeacscores, const msra::math::ssematrixbase &logLLs,
edgealignments &thisedgealignments, backpointers &thisbackpointers, array_ref<size_t> &uids, const_array_ref<size_t> bounds) const
{ // NOTE: this will be removed and replaced by a proper representation of alignments someday
// do forward-backward or alignment on a per-edge basis. This gives us:
// - per-edge gamma[j,t] = P(s(t)==s_j|edge) if forwardbackward, per-edge alignment thisedgealignments[j] if alignment
// - per-edge acoustic scores
const size_t silunitid = hset.gethmmid("sil"); // shall be the same as parallelstate.getsilunitid()
bool parallelsil = true;
bool cpuverification = false;
#ifndef PARALLEL_SIL // we use a define to make this marked
parallelsil = false;
#endif
#ifdef CPU_VERIFICATION
cpuverification = true;
#endif
// Phase 1: abcs allocate
if (!parallelstate.enabled() || !parallelsil || cpuverification) // allocate abcs when 1.parallelstate not enabled (cpu mode); 2. enabled but not PARALLEL_SIL (silence need to be allocate); 3. cpuverfication
{
abcs.resize(edges.size(), NULL); // [edge index] -> alpha/beta/gamma matrices for each edge
size_t countskip = 0; // if pruning: count how many edges are pruned
foreach_index (j, edges)
{
// determine number of frames
// TODO: this is not efficient--we only use a block-diagonal-like structure, rest is empty (exploiting the fixed boundaries)
const size_t edgeframes = nodes[edges[j].E].t - nodes[edges[j].S].t;
if (edgeframes == 0) // dummy !NULL edge at end of lattice
{
if ((size_t) j != edges.size() - 1)
RuntimeError("forwardbackwardalign: unxpected 0-frame edge (only allowed at very end)");
// note: abcs[j] is already initialized to be NULL in this case, which protects us from accidentally using it
}
else
{
// determine the number of states in an edge
const auto &aligntokens = getaligninfo(j); // get alignment tokens
size_t edgestates = 0;
bool edgehassil = false;
foreach_index (i, aligntokens)
if (aligntokens[i].unit == silunitid)
edgehassil = true;
if (!cpuverification && !edgehassil && parallelstate.enabled()) // !cpuverification, parallel & is non sil, we do not allocate
{
abcs[j] = NULL;
continue;
}
foreach_index (k, aligntokens)
edgestates += hset.gethmm(aligntokens[k].unit).getnumstates();
// allocate the matrix
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
countskip++;
else
abcs[j] = &matrixheap.newmatrix(edgestates, edgeframes + 2); // +2 to have one extra column for betas and one for gammas
}
}
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: %d of %d edges pruned\n", (int) countskip, (int) edges.size());
}
// Phase 2: alignment on CPU
if (parallelstate.enabled() && !parallelsil) // silence edge shall be process separately if not cuda and not PARALLEL_SIL
{
if (softalignstates)
LogicError("forwardbackwardalign: parallelized version currently only handles hard alignments");
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: pruning not supported (we won't need it!) :)\n");
edgeacscores.resize(edges.size());
for (size_t j = 0; j < edges.size(); j++)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
if (aligntokens.size() == 0)
continue;
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (!edgehassil) // only process sil
continue;
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, true, thisedgealignments[j]);
}
}
// Phase 3: alignment on GPU
if (parallelstate.enabled())
parallelforwardbackwardalign(parallelstate, hset, logLLs, edgeacscores, thisedgealignments, thisbackpointers);
// zhaorui align to reference mlf
if (bounds.size() > 0)
{
size_t framenum = bounds.size();
msra::math::ssematrixbase *refabcs;
size_t ts, te, t;
ts = te = 0;
vector<aligninfo> refinfo(1);
vector<unsigned short> refalign(framenum);
array_ref<aligninfo> refunits(refinfo.data(), 1);
array_ref<unsigned short> refedgealignmentsj(refalign.data(), framenum);
while (te < framenum)
{
// found one phone's boundary (ts, te)
t = ts + 1;
while (t < framenum && bounds[t] == 0)
t++;
te = t;
// make one phone unit
size_t phoneid = bounds[ts] - 1;
refunits[0].unit = phoneid;
refunits[0].frames = te - ts;
size_t edgestates = hset.gethmm(phoneid).getnumstates();
littlematrixheap refmatrixheap(1); // for abcs
refabcs = &refmatrixheap.newmatrix(edgestates, te - ts + 2);
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
// do alignment
alignedge((const_array_ref<aligninfo>) refunits, hset, edgeLLs, *refabcs, 0, true, refedgealignmentsj);
for (t = ts; t < te; t++)
{
uids[t] = (size_t) refedgealignmentsj[t - ts];
}
ts = te;
}
}
// Phase 4: alignment or forwardbackward on CPU for non parallel mode or verification
if (!parallelstate.enabled() || cpuverification) // non parallel mode or verification
{
edgeacscores.resize(edges.size());
std::vector<float> edgeacscoresgpu;
edgealignments thisedgealignmentsgpu(thisedgealignments);
if (cpuverification)
{
parallelstate.getedgeacscores(edgeacscoresgpu);
parallelstate.copyalignments(thisedgealignmentsgpu);
}
foreach_index (j, edges)
{
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
if (ts == te) // dummy !NULL edge at end
edgeacscores[j] = 0.0f;
else
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
edgeacscores[j] = LOGZERO; // will kill word level forwardbackward hypothesis
else if (softalignstates)
edgeacscores[j] = forwardbackwardedge(aligntokens, hset, edgeLLs, *abcs[j], j);
else
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, returnsenoneids, thisedgealignments[j]);
}
if (cpuverification)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (fabs(edgeacscores[j] - edgeacscoresgpu[j]) > 1e-3)
{
fprintf(stderr, "edge %d, sil ? %d, edgeacscores / edgeacscoresgpu MISMATCH %f v.s. %f, diff %e\n",
j, edgehassil ? 1 : 0, (float) edgeacscores[j], (float) edgeacscoresgpu[j],
(float) (edgeacscores[j] - edgeacscoresgpu[j]));
fprintf(stderr, "aligntokens: ");
foreach_index (i, aligntokens)
fprintf(stderr, "%d %d; ", i, aligntokens[i].unit);
fprintf(stderr, "\n");
}
for (size_t t = ts; t < te; t++)
{
if (thisedgealignments[j][t - ts] != thisedgealignmentsgpu[j][t - ts])
fprintf(stderr, "edge %d, sil ? %d, time %d, alignment / alignmentgpu MISMATCH %d v.s. %d\n", j, edgehassil ? 1 : 0, (int) (t - ts), thisedgealignments[j][t - ts], thisedgealignmentsgpu[j][t - ts]);
}
}
}
}
}