forked from tesseract-ocr/tesseract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathctc.cpp
412 lines (392 loc) · 15.6 KB
/
ctc.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
///////////////////////////////////////////////////////////////////////
// File: ctc.cpp
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:50:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "ctc.h"
#include <memory>
#include "genericvector.h"
#include "host.h"
#include "matrix.h"
#include "networkio.h"
#include "network.h"
#include "scrollview.h"
namespace tesseract {
// Magic constants that keep CTC stable.
// Minimum probability limit for softmax input to ctc_loss.
const float CTC::kMinProb_ = 1e-12;
// Maximum absolute argument to exp().
const double CTC::kMaxExpArg_ = 80.0;
// Minimum probability for total prob in time normalization.
const double CTC::kMinTotalTimeProb_ = 1e-8;
// Minimum probability for total prob in final normalization.
const double CTC::kMinTotalFinalProb_ = 1e-6;
// Builds a target using CTC. Slightly improved as follows:
// Includes normalizations and clipping for stability.
// labels should be pre-padded with nulls everywhere.
// labels can be longer than the time sequence, but the total number of
// essential labels (non-null plus nulls between equal labels) must not exceed
// the number of timesteps in outputs.
// outputs is the output of the network, and should have already been
// normalized with NormalizeProbs.
// On return targets is filled with the computed targets.
// Returns false if there is insufficient time for the labels.
/* static */
bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs,
NetworkIO* targets) {
std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
if (!ctc->ComputeLabelLimits()) {
return false; // Not enough time.
}
// Generate simple targets purely from the truth labels by spreading them
// evenly over time.
GENERIC_2D_ARRAY<float> simple_targets;
ctc->ComputeSimpleTargets(&simple_targets);
// Add the simple targets as a starter bias to the network outputs.
float bias_fraction = ctc->CalculateBiasFraction();
simple_targets *= bias_fraction;
ctc->outputs_ += simple_targets;
NormalizeProbs(&ctc->outputs_);
// Run regular CTC on the biased outputs.
// Run forward and backward
GENERIC_2D_ARRAY<double> log_alphas, log_betas;
ctc->Forward(&log_alphas);
ctc->Backward(&log_betas);
// Normalize and come out of log space with a clipped softmax over time.
log_alphas += log_betas;
ctc->NormalizeSequence(&log_alphas);
ctc->LabelsToClasses(log_alphas, targets);
NormalizeProbs(targets);
return true;
}
CTC::CTC(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs)
: labels_(labels), outputs_(outputs), null_char_(null_char) {
num_timesteps_ = outputs.dim1();
num_classes_ = outputs.dim2();
num_labels_ = labels_.size();
}
// Computes vectors of min and max label index for each timestep, based on
// whether skippability of nulls makes it possible to complete a valid path.
bool CTC::ComputeLabelLimits() {
min_labels_.init_to_size(num_timesteps_, 0);
max_labels_.init_to_size(num_timesteps_, 0);
int min_u = num_labels_ - 1;
if (labels_[min_u] == null_char_) --min_u;
for (int t = num_timesteps_ - 1; t >= 0; --t) {
min_labels_[t] = min_u;
if (min_u > 0) {
--min_u;
if (labels_[min_u] == null_char_ && min_u > 0 &&
labels_[min_u + 1] != labels_[min_u - 1]) {
--min_u;
}
}
}
int max_u = labels_[0] == null_char_;
for (int t = 0; t < num_timesteps_; ++t) {
max_labels_[t] = max_u;
if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
if (max_u + 1 < num_labels_) {
++max_u;
if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
labels_[max_u + 1] != labels_[max_u - 1]) {
++max_u;
}
}
}
return true;
}
// Computes targets based purely on the labels by spreading the labels evenly
// over the available timesteps.
void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
// Initialize all targets to zero.
targets->Resize(num_timesteps_, num_classes_, 0.0f);
GenericVector<float> half_widths;
GenericVector<int> means;
ComputeWidthsAndMeans(&half_widths, &means);
for (int l = 0; l < num_labels_; ++l) {
int label = labels_[l];
float left_half_width = half_widths[l];
float right_half_width = left_half_width;
int mean = means[l];
if (label == null_char_) {
if (!NeededNull(l)) {
if ((l > 0 && mean == means[l - 1]) ||
(l + 1 < num_labels_ && mean == means[l + 1])) {
continue; // Drop overlapping null.
}
}
// Make sure that no space is left unoccupied and that non-nulls always
// peak at 1 by stretching nulls to meet their neighbors.
if (l > 0) left_half_width = mean - means[l - 1];
if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
}
if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
float prob = 1.0f - offset / left_half_width;
if (mean - offset < num_timesteps_ &&
prob > targets->get(mean - offset, label)) {
targets->put(mean - offset, label, prob);
}
}
for (int offset = 1;
offset < right_half_width && mean + offset < num_timesteps_;
++offset) {
float prob = 1.0f - offset / right_half_width;
if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
targets->put(mean + offset, label, prob);
}
}
}
}
// Computes mean positions and half widths of the simple targets by spreading
// the labels evenly over the available timesteps.
void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
GenericVector<int>* means) const {
// Count the number of labels of each type, in regexp terms, counts plus
// (non-null or necessary null, which must occur at least once) and star
// (optional null).
int num_plus = 0, num_star = 0;
for (int i = 0; i < num_labels_; ++i) {
if (labels_[i] != null_char_ || NeededNull(i))
++num_plus;
else
++num_star;
}
// Compute the size for each type. If there is enough space for everything
// to have size>=1, then all are equal, otherwise plus_size=1 and star gets
// whatever is left-over.
float plus_size = 1.0f, star_size = 0.0f;
float total_floating = num_plus + num_star;
if (total_floating <= num_timesteps_) {
plus_size = star_size = num_timesteps_ / total_floating;
} else if (num_star > 0) {
star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
}
// Set the width and compute the mean of each.
float mean_pos = 0.0f;
for (int i = 0; i < num_labels_; ++i) {
float half_width;
if (labels_[i] != null_char_ || NeededNull(i)) {
half_width = plus_size / 2.0f;
} else {
half_width = star_size / 2.0f;
}
mean_pos += half_width;
means->push_back(static_cast<int>(mean_pos));
mean_pos += half_width;
half_widths->push_back(half_width);
}
}
// Helper returns the index of the highest probability label at timestep t.
static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
int result = 0;
int num_classes = outputs.dim2();
const float* outputs_t = outputs[t];
for (int c = 1; c < num_classes; ++c) {
if (outputs_t[c] > outputs_t[result]) result = c;
}
return result;
}
// Calculates and returns a suitable fraction of the simple targets to add
// to the network outputs.
float CTC::CalculateBiasFraction() {
// Compute output labels via basic decoding.
GenericVector<int> output_labels;
for (int t = 0; t < num_timesteps_; ++t) {
int label = BestLabel(outputs_, t);
while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
if (label != null_char_) output_labels.push_back(label);
}
// Simple bag of labels error calculation.
GenericVector<int> truth_counts(num_classes_, 0);
GenericVector<int> output_counts(num_classes_, 0);
for (int l = 0; l < num_labels_; ++l) {
++truth_counts[labels_[l]];
}
for (int l = 0; l < output_labels.size(); ++l) {
++output_counts[output_labels[l]];
}
// Count the number of true and false positive non-nulls and truth labels.
int true_pos = 0, false_pos = 0, total_labels = 0;
for (int c = 0; c < num_classes_; ++c) {
if (c == null_char_) continue;
int truth_count = truth_counts[c];
int ocr_count = output_counts[c];
if (truth_count > 0) {
total_labels += truth_count;
if (ocr_count > truth_count) {
true_pos += truth_count;
false_pos += ocr_count - truth_count;
} else {
true_pos += ocr_count;
}
}
// We don't need to count classes that don't exist in the truth as
// false positives, because they don't affect CTC at all.
}
if (total_labels == 0) return 0.0f;
return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
}
// Given ln(x) and ln(y), returns ln(x + y), using:
// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
// bigger number to maximize precision.
static double LogSumExp(double ln_x, double ln_y) {
if (ln_x >= ln_y) {
return ln_x + log1p(exp(ln_y - ln_x));
} else {
return ln_y + log1p(exp(ln_x - ln_y));
}
}
// Runs the forward CTC pass, filling in log_probs.
void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(0, 0, log(outputs_(0, labels_[0])));
if (labels_[0] == null_char_)
log_probs->put(0, 1, log(outputs_(0, labels_[1])));
for (int t = 1; t < num_timesteps_; ++t) {
const float* outputs_t = outputs_[t];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t - 1, u);
// Change from previous label.
if (u > 0) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
}
// Skip the null if allowed.
if (u >= 2 && labels_[u - 1] == null_char_ &&
labels_[u] != labels_[u - 2]) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
}
// Add in the log prob of the current label.
double label_prob = outputs_t[labels_[u]];
log_sum += log(label_prob);
log_probs->put(t, u, log_sum);
}
}
}
// Runs the backward CTC pass, filling in log_probs.
void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
if (labels_[num_labels_ - 1] == null_char_)
log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
for (int t = num_timesteps_ - 2; t >= 0; --t) {
const float* outputs_tp1 = outputs_[t + 1];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
// Change from previous label.
if (u + 1 < num_labels_) {
double prev_prob = outputs_tp1[labels_[u + 1]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
}
// Skip the null if allowed.
if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
labels_[u] != labels_[u + 2]) {
double skip_prob = outputs_tp1[labels_[u + 2]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
}
log_probs->put(t, u, log_sum);
}
}
}
// Normalizes and brings probs out of log space with a softmax over time.
void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
double max_logprob = probs->Max();
for (int u = 0; u < num_labels_; ++u) {
double total = 0.0;
for (int t = 0; t < num_timesteps_; ++t) {
// Separate impossible path from unlikely probs.
double prob = probs->get(t, u);
if (prob > -MAX_FLOAT32)
prob = ClippedExp(prob - max_logprob);
else
prob = 0.0;
total += prob;
probs->put(t, u, prob);
}
// Note that although this is a probability distribution over time and
// therefore should sum to 1, it is important to allow some labels to be
// all zero, (or at least tiny) as it is necessary to skip some blanks.
if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
for (int t = 0; t < num_timesteps_; ++t)
probs->put(t, u, probs->get(t, u) / total);
}
}
// For each timestep computes the max prob for each class over all
// instances of the class in the labels_, and sets the targets to
// the max observed prob.
void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
NetworkIO* targets) const {
// For each timestep compute the max prob for each class over all
// instances of the class in the labels_.
GenericVector<double> class_probs;
for (int t = 0; t < num_timesteps_; ++t) {
float* targets_t = targets->f(t);
class_probs.init_to_size(num_classes_, 0.0);
for (int u = 0; u < num_labels_; ++u) {
double prob = probs(t, u);
// Note that although Graves specifies sum over all labels of the same
// class, we need to allow skipped blanks to go to zero, so they don't
// interfere with the non-blanks, so max is better than sum.
if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
// class_probs[labels_[u]] += prob;
}
int best_class = 0;
for (int c = 0; c < num_classes_; ++c) {
targets_t[c] = class_probs[c];
if (class_probs[c] > class_probs[best_class]) best_class = c;
}
}
}
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
/* static */
void CTC::NormalizeProbs(GENERIC_2D_ARRAY<float>* probs) {
int num_timesteps = probs->dim1();
int num_classes = probs->dim2();
for (int t = 0; t < num_timesteps; ++t) {
float* probs_t = (*probs)[t];
// Compute the total and clip that to prevent amplification of noise.
double total = 0.0;
for (int c = 0; c < num_classes; ++c) total += probs_t[c];
if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
// Compute the increased total as a result of clipping.
double increment = 0.0;
for (int c = 0; c < num_classes; ++c) {
double prob = probs_t[c] / total;
if (prob < kMinProb_) increment += kMinProb_ - prob;
}
// Now normalize with clipping. Any additional clipping is negligible.
total += increment;
for (int c = 0; c < num_classes; ++c) {
float prob = probs_t[c] / total;
probs_t[c] = MAX(prob, kMinProb_);
}
}
}
// Returns true if the label at index is a needed null.
bool CTC::NeededNull(int index) const {
return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
labels_[index + 1] == labels_[index - 1];
}
} // namespace tesseract