From c1c1e426b32794d5e84134ee81bf895ff0228fe5 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Mon, 7 Nov 2016 15:38:07 -0800 Subject: [PATCH] Added new LSTM-based neural network line recognizer --- Makefile.am | 2 +- api/Makefile.am | 7 + api/baseapi.cpp | 63 +- api/baseapi.h | 26 +- api/pdfrenderer.cpp | 7 +- api/renderer.h | 4 +- arch/Makefile.am | 29 + arch/dotproductavx.cpp | 103 ++ arch/dotproductavx.h | 30 + arch/dotproductsse.cpp | 141 +++ arch/dotproductsse.h | 35 + ccmain/Makefile.am | 6 +- ccmain/control.cpp | 51 +- ccmain/linerec.cpp | 332 ++++++ ccmain/ltrresultiterator.cpp | 6 + ccmain/ltrresultiterator.h | 3 + ccmain/tessedit.cpp | 22 +- ccmain/tesseractclass.cpp | 16 +- ccmain/tesseractclass.h | 87 +- ccmain/thresholder.cpp | 20 +- ccstruct/imagedata.cpp | 392 +++++-- ccstruct/imagedata.h | 134 ++- ccstruct/matrix.h | 311 +++++- ccstruct/pageres.cpp | 8 +- ccstruct/pageres.h | 5 +- ccstruct/publictypes.h | 10 +- ccutil/Makefile.am | 4 +- ccutil/genericheap.h | 34 +- ccutil/genericvector.h | 82 +- ccutil/serialis.cpp | 2 +- ccutil/strngs.cpp | 8 + ccutil/strngs.h | 2 + ccutil/tessdatamanager.h | 86 +- ccutil/unicharcompress.cpp | 439 ++++++++ ccutil/unicharcompress.h | 258 +++++ ccutil/unicharset.cpp | 2 + ccutil/unicharset.h | 18 +- configure.ac | 12 +- cutil/oldlist.cpp | 1 - dict/dawg_cache.cpp | 3 + dict/dict.cpp | 67 +- dict/dict.h | 23 +- lstm/Makefile.am | 39 + lstm/convolve.cpp | 124 +++ lstm/convolve.h | 74 ++ lstm/ctc.cpp | 412 ++++++++ lstm/ctc.h | 130 +++ lstm/fullyconnected.cpp | 285 +++++ lstm/fullyconnected.h | 130 +++ lstm/functions.cpp | 26 + lstm/functions.h | 249 +++++ lstm/input.cpp | 154 +++ lstm/input.h | 107 ++ lstm/lstm.cpp | 710 +++++++++++++ lstm/lstm.h | 157 +++ lstm/lstmrecognizer.cpp | 816 +++++++++++++++ lstm/lstmrecognizer.h | 392 +++++++ lstm/lstmtrainer.cpp | 1331 ++++++++++++++++++++++++ lstm/lstmtrainer.h | 477 +++++++++ lstm/maxpool.cpp | 87 ++ lstm/maxpool.h | 71 ++ lstm/network.cpp | 309 ++++++ lstm/network.h | 292 ++++++ lstm/networkbuilder.cpp | 488 +++++++++ lstm/networkbuilder.h | 160 +++ lstm/networkio.cpp | 981 +++++++++++++++++ lstm/networkio.h | 341 ++++++ lstm/networkscratch.h | 257 +++++ lstm/parallel.cpp | 180 ++++ lstm/parallel.h | 87 ++ lstm/plumbing.cpp | 233 +++++ lstm/plumbing.h | 143 +++ lstm/recodebeam.cpp | 759 ++++++++++++++ lstm/recodebeam.h | 304 ++++++ lstm/reconfig.cpp | 128 +++ lstm/reconfig.h | 86 ++ lstm/reversed.cpp | 91 ++ lstm/reversed.h | 89 ++ lstm/series.cpp | 188 ++++ lstm/series.h | 91 ++ lstm/static_shape.h | 80 ++ lstm/stridemap.cpp | 173 +++ lstm/stridemap.h | 137 +++ lstm/tfnetwork.cpp | 146 +++ lstm/tfnetwork.h | 91 ++ lstm/tfnetwork.proto | 61 ++ lstm/weightmatrix.cpp | 443 ++++++++ lstm/weightmatrix.h | 183 ++++ textord/baselinedetect.cpp | 3 +- textord/colpartition.cpp | 4 + textord/colpartition.h | 19 + textord/tordmain.cpp | 2 + textord/tospace.cpp | 7 +- training/Makefile.am | 57 +- training/degradeimage.cpp | 163 +++ training/degradeimage.h | 28 +- training/language-specific.sh | 7 +- training/lstmtraining.cpp | 185 ++++ training/merge_unicharsets.cpp | 52 + training/mftraining.cpp | 3 + training/normstrngs.cpp | 8 +- training/normstrngs.h | 9 +- training/unicharset_training_utils.cpp | 7 +- training/unicharset_training_utils.h | 8 +- viewer/svutil.h | 11 + wordrec/chopper.cpp | 2 +- wordrec/tface.cpp | 6 +- 107 files changed, 15410 insertions(+), 354 deletions(-) create mode 100644 arch/Makefile.am create mode 100644 arch/dotproductavx.cpp create mode 100644 arch/dotproductavx.h create mode 100644 arch/dotproductsse.cpp create mode 100644 arch/dotproductsse.h create mode 100644 ccmain/linerec.cpp create mode 100644 ccutil/unicharcompress.cpp create mode 100644 ccutil/unicharcompress.h create mode 100644 lstm/Makefile.am create mode 100644 lstm/convolve.cpp create mode 100644 lstm/convolve.h create mode 100644 lstm/ctc.cpp create mode 100644 lstm/ctc.h create mode 100644 lstm/fullyconnected.cpp create mode 100644 lstm/fullyconnected.h create mode 100644 lstm/functions.cpp create mode 100644 lstm/functions.h create mode 100644 lstm/input.cpp create mode 100644 lstm/input.h create mode 100644 lstm/lstm.cpp create mode 100644 lstm/lstm.h create mode 100644 lstm/lstmrecognizer.cpp create mode 100644 lstm/lstmrecognizer.h create mode 100644 lstm/lstmtrainer.cpp create mode 100644 lstm/lstmtrainer.h create mode 100644 lstm/maxpool.cpp create mode 100644 lstm/maxpool.h create mode 100644 lstm/network.cpp create mode 100644 lstm/network.h create mode 100644 lstm/networkbuilder.cpp create mode 100644 lstm/networkbuilder.h create mode 100644 lstm/networkio.cpp create mode 100644 lstm/networkio.h create mode 100644 lstm/networkscratch.h create mode 100644 lstm/parallel.cpp create mode 100644 lstm/parallel.h create mode 100644 lstm/plumbing.cpp create mode 100644 lstm/plumbing.h create mode 100644 lstm/recodebeam.cpp create mode 100644 lstm/recodebeam.h create mode 100644 lstm/reconfig.cpp create mode 100644 lstm/reconfig.h create mode 100644 lstm/reversed.cpp create mode 100644 lstm/reversed.h create mode 100644 lstm/series.cpp create mode 100644 lstm/series.h create mode 100644 lstm/static_shape.h create mode 100644 lstm/stridemap.cpp create mode 100644 lstm/stridemap.h create mode 100644 lstm/tfnetwork.cpp create mode 100644 lstm/tfnetwork.h create mode 100644 lstm/tfnetwork.proto create mode 100644 lstm/weightmatrix.cpp create mode 100644 lstm/weightmatrix.h create mode 100644 training/lstmtraining.cpp create mode 100644 training/merge_unicharsets.cpp diff --git a/Makefile.am b/Makefile.am index a4aa1dd915..3d3fc9b79e 100644 --- a/Makefile.am +++ b/Makefile.am @@ -16,7 +16,7 @@ endif .PHONY: install-langs ScrollView.jar install-jars training -SUBDIRS = ccutil viewer cutil opencl ccstruct dict classify wordrec textord +SUBDIRS = arch ccutil viewer cutil opencl ccstruct dict classify wordrec textord lstm if !NO_CUBE_BUILD SUBDIRS += neural_networks/runtime cube endif diff --git a/api/Makefile.am b/api/Makefile.am index 9d20919b2e..140279b5b2 100644 --- a/api/Makefile.am +++ b/api/Makefile.am @@ -1,5 +1,6 @@ AM_CPPFLAGS += -DLOCALEDIR=\"$(localedir)\"\ -DUSE_STD_NAMESPACE \ + -I$(top_srcdir)/arch -I$(top_srcdir)/lstm \ -I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct -I$(top_srcdir)/cube \ -I$(top_srcdir)/viewer \ -I$(top_srcdir)/textord -I$(top_srcdir)/dict \ @@ -27,6 +28,9 @@ libtesseract_api_la_LIBADD = \ ../wordrec/libtesseract_wordrec.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -57,6 +61,9 @@ libtesseract_la_LIBADD = \ ../wordrec/libtesseract_wordrec.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ diff --git a/api/baseapi.cpp b/api/baseapi.cpp index b79e969178..d4da5673e2 100644 --- a/api/baseapi.cpp +++ b/api/baseapi.cpp @@ -121,7 +121,6 @@ TessBaseAPI::TessBaseAPI() block_list_(NULL), page_res_(NULL), input_file_(NULL), - input_image_(NULL), output_file_(NULL), datapath_(NULL), language_(NULL), @@ -515,9 +514,7 @@ void TessBaseAPI::ClearAdaptiveClassifier() { /** * Provide an image for Tesseract to recognize. Format is as - * TesseractRect above. Does not copy the image buffer, or take - * ownership. The source image may be destroyed after Recognize is called, - * either explicitly or implicitly via one of the Get*Text functions. + * TesseractRect above. Copies the image buffer and converts to Pix. * SetImage clears all recognition results, and sets the rectangle to the * full image, so it may be followed immediately by a GetUTF8Text, and it * will automatically perform recognition. @@ -525,9 +522,11 @@ void TessBaseAPI::ClearAdaptiveClassifier() { void TessBaseAPI::SetImage(const unsigned char* imagedata, int width, int height, int bytes_per_pixel, int bytes_per_line) { - if (InternalSetImage()) + if (InternalSetImage()) { thresholder_->SetImage(imagedata, width, height, bytes_per_pixel, bytes_per_line); + SetInputImage(thresholder_->GetPixRect()); + } } void TessBaseAPI::SetSourceResolution(int ppi) { @@ -539,18 +538,17 @@ void TessBaseAPI::SetSourceResolution(int ppi) { /** * Provide an image for Tesseract to recognize. As with SetImage above, - * Tesseract doesn't take a copy or ownership or pixDestroy the image, so - * it must persist until after Recognize. + * Tesseract takes its own copy of the image, so it need not persist until + * after Recognize. * Pix vs raw, which to use? - * Use Pix where possible. A future version of Tesseract may choose to use Pix - * as its internal representation and discard IMAGE altogether. - * Because of that, an implementation that sources and targets Pix may end up - * with less copies than an implementation that does not. + * Use Pix where possible. Tesseract uses Pix as its internal representation + * and it is therefore more efficient to provide a Pix directly. */ void TessBaseAPI::SetImage(Pix* pix) { - if (InternalSetImage()) + if (InternalSetImage()) { thresholder_->SetImage(pix); - SetInputImage(pix); + SetInputImage(thresholder_->GetPixRect()); + } } /** @@ -693,8 +691,8 @@ Boxa* TessBaseAPI::GetComponentImages(PageIteratorLevel level, if (pixa != NULL) { Pix* pix = NULL; if (raw_image) { - pix = page_it->GetImage(level, raw_padding, input_image_, - &left, &top); + pix = page_it->GetImage(level, raw_padding, GetInputImage(), &left, + &top); } else { pix = page_it->GetBinaryImage(level); } @@ -849,13 +847,17 @@ int TessBaseAPI::Recognize(ETEXT_DESC* monitor) { } else if (tesseract_->tessedit_resegment_from_boxes) { page_res_ = tesseract_->ApplyBoxes(*input_file_, false, block_list_); } else { - // TODO(rays) LSTM here. - page_res_ = new PAGE_RES(false, + page_res_ = new PAGE_RES(tesseract_->AnyLSTMLang(), block_list_, &tesseract_->prev_word_best_choice_); } if (page_res_ == NULL) { return -1; } + if (tesseract_->tessedit_train_line_recognizer) { + tesseract_->TrainLineRecognizer(*input_file_, *output_file_, block_list_); + tesseract_->CorrectClassifyWords(page_res_); + return 0; + } if (tesseract_->tessedit_make_boxes_from_boxes) { tesseract_->CorrectClassifyWords(page_res_); return 0; @@ -938,17 +940,10 @@ int TessBaseAPI::RecognizeForChopTest(ETEXT_DESC* monitor) { return 0; } -void TessBaseAPI::SetInputImage(Pix *pix) { - if (input_image_) - pixDestroy(&input_image_); - input_image_ = NULL; - if (pix) - input_image_ = pixCopy(NULL, pix); -} +// Takes ownership of the input pix. +void TessBaseAPI::SetInputImage(Pix* pix) { tesseract_->set_pix_original(pix); } -Pix* TessBaseAPI::GetInputImage() { - return input_image_; -} +Pix* TessBaseAPI::GetInputImage() { return tesseract_->pix_original(); } const char * TessBaseAPI::GetInputName() { if (input_file_) @@ -992,8 +987,7 @@ bool TessBaseAPI::ProcessPagesFileList(FILE *flist, } // Begin producing output - const char* kUnknownTitle = ""; - if (renderer && !renderer->BeginDocument(kUnknownTitle)) { + if (renderer && !renderer->BeginDocument(unknown_title_)) { return false; } @@ -1105,7 +1099,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, const char* retry_config, int timeout_millisec, TessResultRenderer* renderer) { -#ifndef ANDROID_BUILD PERF_COUNT_START("ProcessPages") bool stdInput = !strcmp(filename, "stdin") || !strcmp(filename, "-"); if (stdInput) { @@ -1162,8 +1155,7 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, } // Begin the output - const char* kUnknownTitle = ""; - if (renderer && !renderer->BeginDocument(kUnknownTitle)) { + if (renderer && !renderer->BeginDocument(unknown_title_)) { pixDestroy(&pix); return false; } @@ -1185,9 +1177,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename, } PERF_COUNT_END return true; -#else - return false; -#endif } bool TessBaseAPI::ProcessPage(Pix* pix, int page_index, const char* filename, @@ -2107,10 +2096,6 @@ void TessBaseAPI::End() { delete input_file_; input_file_ = NULL; } - if (input_image_ != NULL) { - pixDestroy(&input_image_); - input_image_ = NULL; - } if (output_file_ != NULL) { delete output_file_; output_file_ = NULL; diff --git a/api/baseapi.h b/api/baseapi.h index d872689eec..d6e532ba81 100644 --- a/api/baseapi.h +++ b/api/baseapi.h @@ -20,8 +20,8 @@ #ifndef TESSERACT_API_BASEAPI_H__ #define TESSERACT_API_BASEAPI_H__ -#define TESSERACT_VERSION_STR "3.05.00dev" -#define TESSERACT_VERSION 0x030500 +#define TESSERACT_VERSION_STR "4.00.00alpha" +#define TESSERACT_VERSION 0x040000 #define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \ (patch)) @@ -142,6 +142,7 @@ class TESS_API TessBaseAPI { * is stored in the PDF so we need that as well. */ const char* GetInputName(); + // Takes ownership of the input pix. void SetInputImage(Pix *pix); Pix* GetInputImage(); int GetSourceYResolution(); @@ -333,9 +334,7 @@ class TESS_API TessBaseAPI { /** * Provide an image for Tesseract to recognize. Format is as - * TesseractRect above. Does not copy the image buffer, or take - * ownership. The source image may be destroyed after Recognize is called, - * either explicitly or implicitly via one of the Get*Text functions. + * TesseractRect above. Copies the image buffer and converts to Pix. * SetImage clears all recognition results, and sets the rectangle to the * full image, so it may be followed immediately by a GetUTF8Text, and it * will automatically perform recognition. @@ -345,13 +344,11 @@ class TESS_API TessBaseAPI { /** * Provide an image for Tesseract to recognize. As with SetImage above, - * Tesseract doesn't take a copy or ownership or pixDestroy the image, so - * it must persist until after Recognize. + * Tesseract takes its own copy of the image, so it need not persist until + * after Recognize. * Pix vs raw, which to use? - * Use Pix where possible. A future version of Tesseract may choose to use Pix - * as its internal representation and discard IMAGE altogether. - * Because of that, an implementation that sources and targets Pix may end up - * with less copies than an implementation that does not. + * Use Pix where possible. Tesseract uses Pix as its internal representation + * and it is therefore more efficient to provide a Pix directly. */ void SetImage(Pix* pix); @@ -866,7 +863,6 @@ class TESS_API TessBaseAPI { BLOCK_LIST* block_list_; ///< The page layout. PAGE_RES* page_res_; ///< The page-level data. STRING* input_file_; ///< Name used by training code. - Pix* input_image_; ///< Image used for searchable PDF STRING* output_file_; ///< Name used by debug code. STRING* datapath_; ///< Current location of tessdata. STRING* language_; ///< Last initialized language. @@ -902,6 +898,12 @@ class TESS_API TessBaseAPI { int timeout_millisec, TessResultRenderer* renderer, int tessedit_page_number); + // There's currently no way to pass a document title from the + // Tesseract command line, and we have multiple places that choose + // to set the title to an empty string. Using a single named + // variable will hopefully reduce confusion if the situation changes + // in the future. + const char *unknown_title_ = ""; }; // class TessBaseAPI. /** Escape a char string - remove &<>"' with HTML codes. */ diff --git a/api/pdfrenderer.cpp b/api/pdfrenderer.cpp index 10c6564ff4..ac651195b4 100644 --- a/api/pdfrenderer.cpp +++ b/api/pdfrenderer.cpp @@ -620,7 +620,6 @@ bool TessPDFRenderer::BeginDocumentHandler() { AppendPDFObject(buf); // FONT DESCRIPTOR - const int kCharHeight = 2; // Effect: highlights are half height n = snprintf(buf, sizeof(buf), "7 0 obj\n" "<<\n" @@ -636,10 +635,10 @@ bool TessPDFRenderer::BeginDocumentHandler() { " /Type /FontDescriptor\n" ">>\n" "endobj\n", - 1000 / kCharHeight, - 1000 / kCharHeight, + 1000, + 1000, 1000 / kCharWidth, - 1000 / kCharHeight, + 1000, 8L // Font data ); if (n >= sizeof(buf)) return false; diff --git a/api/renderer.h b/api/renderer.h index ad9e4d03ba..d868f267fa 100644 --- a/api/renderer.h +++ b/api/renderer.h @@ -77,7 +77,7 @@ class TESS_API TessResultRenderer { bool EndDocument(); const char* file_extension() const { return file_extension_; } - const char* title() const { return title_; } + const char* title() const { return title_.c_str(); } /** * Returns the index of the last image given to AddImage @@ -126,7 +126,7 @@ class TESS_API TessResultRenderer { private: const char* file_extension_; // standard extension for generated output - const char* title_; // title of document being renderered + STRING title_; // title of document being renderered int imagenum_; // index of last image added FILE* fout_; // output file pointer diff --git a/arch/Makefile.am b/arch/Makefile.am new file mode 100644 index 0000000000..bce98e7f52 --- /dev/null +++ b/arch/Makefile.am @@ -0,0 +1,29 @@ +AM_CPPFLAGS += -I$(top_srcdir)/ccutil +AUTOMAKE_OPTIONS = subdir-objects +SUBDIRS = +AM_CXXFLAGS = + +if VISIBILITY +AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden +AM_CPPFLAGS += -DTESS_EXPORTS +endif + +include_HEADERS = \ + dotproductavx.h dotproductsse.h + +noinst_HEADERS = + +if !USING_MULTIPLELIBS +noinst_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la +else +lib_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la +libtesseract_avx_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION) +libtesseract_sse_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION) +endif +libtesseract_avx_la_CXXFLAGS = -mavx +libtesseract_sse_la_CXXFLAGS = -msse4.1 + +libtesseract_avx_la_SOURCES = dotproductavx.cpp + +libtesseract_sse_la_SOURCES = dotproductsse.cpp + diff --git a/arch/dotproductavx.cpp b/arch/dotproductavx.cpp new file mode 100644 index 0000000000..94a806fc65 --- /dev/null +++ b/arch/dotproductavx.cpp @@ -0,0 +1,103 @@ +/////////////////////////////////////////////////////////////////////// +// File: dotproductavx.cpp +// Description: Architecture-specific dot-product function. +// Author: Ray Smith +// Created: Wed Jul 22 10:48:05 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#if !defined(__AVX__) +// Implementation for non-avx archs. + +#include "dotproductavx.h" +#include +#include + +namespace tesseract { +double DotProductAVX(const double* u, const double* v, int n) { + fprintf(stderr, "DotProductAVX can't be used on Android\n"); + abort(); +} +} // namespace tesseract + +#else // !defined(__AVX__) +// Implementation for avx capable archs. +#include +#include +#include "dotproductavx.h" +#include "host.h" + +namespace tesseract { + +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel AVX intrinsics to access the SIMD instruction set. +double DotProductAVX(const double* u, const double* v, int n) { + int max_offset = n - 4; + int offset = 0; + // Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and + // v, and multiplying them together in parallel. + __m256d sum = _mm256_setzero_pd(); + if (offset <= max_offset) { + offset = 4; + // Aligned load is reputedly faster but requires 32 byte aligned input. + if ((reinterpret_cast(u) & 31) == 0 && + (reinterpret_cast(v) & 31) == 0) { + // Use aligned load. + __m256d floats1 = _mm256_load_pd(u); + __m256d floats2 = _mm256_load_pd(v); + // Multiply. + sum = _mm256_mul_pd(floats1, floats2); + while (offset <= max_offset) { + floats1 = _mm256_load_pd(u + offset); + floats2 = _mm256_load_pd(v + offset); + offset += 4; + __m256d product = _mm256_mul_pd(floats1, floats2); + sum = _mm256_add_pd(sum, product); + } + } else { + // Use unaligned load. + __m256d floats1 = _mm256_loadu_pd(u); + __m256d floats2 = _mm256_loadu_pd(v); + // Multiply. + sum = _mm256_mul_pd(floats1, floats2); + while (offset <= max_offset) { + floats1 = _mm256_loadu_pd(u + offset); + floats2 = _mm256_loadu_pd(v + offset); + offset += 4; + __m256d product = _mm256_mul_pd(floats1, floats2); + sum = _mm256_add_pd(sum, product); + } + } + } + // Add the 4 product sums together horizontally. Not so easy as with sse, as + // there is no add across the upper/lower 128 bit boundary, so permute to + // move the upper 128 bits to lower in another register. + __m256d sum2 = _mm256_permute2f128_pd(sum, sum, 1); + sum = _mm256_hadd_pd(sum, sum2); + sum = _mm256_hadd_pd(sum, sum); + double result; + // _mm256_extract_f64 doesn't exist, but resist the temptation to use an sse + // instruction, as that introduces a 70 cycle delay. All this casting is to + // fool the instrinsics into thinking we are extracting the bottom int64. + *(reinterpret_cast(&result)) = + _mm256_extract_epi64(_mm256_castpd_si256(sum), 0); + while (offset < n) { + result += u[offset] * v[offset]; + ++offset; + } + return result; +} + +} // namespace tesseract. + +#endif // ANDROID_BUILD diff --git a/arch/dotproductavx.h b/arch/dotproductavx.h new file mode 100644 index 0000000000..ef00cdfb11 --- /dev/null +++ b/arch/dotproductavx.h @@ -0,0 +1,30 @@ +/////////////////////////////////////////////////////////////////////// +// File: dotproductavx.h +// Description: Architecture-specific dot-product function. +// Author: Ray Smith +// Created: Wed Jul 22 10:51:05 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_ARCH_DOTPRODUCTAVX_H_ +#define TESSERACT_ARCH_DOTPRODUCTAVX_H_ + +namespace tesseract { + +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel AVX intrinsics to access the SIMD instruction set. +double DotProductAVX(const double* u, const double* v, int n); + +} // namespace tesseract. + +#endif // TESSERACT_ARCH_DOTPRODUCTAVX_H_ diff --git a/arch/dotproductsse.cpp b/arch/dotproductsse.cpp new file mode 100644 index 0000000000..cc5c245522 --- /dev/null +++ b/arch/dotproductsse.cpp @@ -0,0 +1,141 @@ +/////////////////////////////////////////////////////////////////////// +// File: dotproductsse.cpp +// Description: Architecture-specific dot-product function. +// Author: Ray Smith +// Created: Wed Jul 22 10:57:45 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#if !defined(__SSE4_1__) +// This code can't compile with "-msse4.1", so use dummy stubs. + +#include "dotproductsse.h" +#include +#include + +namespace tesseract { +double DotProductSSE(const double* u, const double* v, int n) { + fprintf(stderr, "DotProductSSE can't be used on Android\n"); + abort(); +} +inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) { + fprintf(stderr, "IntDotProductSSE can't be used on Android\n"); + abort(); +} +} // namespace tesseract + +#else // !defined(__SSE4_1__) +// Non-Android code here + +#include +#include +#include +#include "dotproductsse.h" +#include "host.h" + +namespace tesseract { + +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel SSE intrinsics to access the SIMD instruction set. +double DotProductSSE(const double* u, const double* v, int n) { + int max_offset = n - 2; + int offset = 0; + // Accumulate a set of 2 sums in sum, by loading pairs of 2 values from u and + // v, and multiplying them together in parallel. + __m128d sum = _mm_setzero_pd(); + if (offset <= max_offset) { + offset = 2; + // Aligned load is reputedly faster but requires 16 byte aligned input. + if ((reinterpret_cast(u) & 15) == 0 && + (reinterpret_cast(v) & 15) == 0) { + // Use aligned load. + sum = _mm_load_pd(u); + __m128d floats2 = _mm_load_pd(v); + // Multiply. + sum = _mm_mul_pd(sum, floats2); + while (offset <= max_offset) { + __m128d floats1 = _mm_load_pd(u + offset); + floats2 = _mm_load_pd(v + offset); + offset += 2; + floats1 = _mm_mul_pd(floats1, floats2); + sum = _mm_add_pd(sum, floats1); + } + } else { + // Use unaligned load. + sum = _mm_loadu_pd(u); + __m128d floats2 = _mm_loadu_pd(v); + // Multiply. + sum = _mm_mul_pd(sum, floats2); + while (offset <= max_offset) { + __m128d floats1 = _mm_loadu_pd(u + offset); + floats2 = _mm_loadu_pd(v + offset); + offset += 2; + floats1 = _mm_mul_pd(floats1, floats2); + sum = _mm_add_pd(sum, floats1); + } + } + } + // Add the 2 sums in sum horizontally. + sum = _mm_hadd_pd(sum, sum); + // Extract the low result. + double result = _mm_cvtsd_f64(sum); + // Add on any left-over products. + while (offset < n) { + result += u[offset] * v[offset]; + ++offset; + } + return result; +} + +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel SSE intrinsics to access the SIMD instruction set. +inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) { + int max_offset = n - 8; + int offset = 0; + // Accumulate a set of 4 32-bit sums in sum, by loading 8 pairs of 8-bit + // values, extending to 16 bit, multiplying to make 32 bit results. + __m128i sum = _mm_setzero_si128(); + if (offset <= max_offset) { + offset = 8; + __m128i packed1 = _mm_loadl_epi64(reinterpret_cast(u)); + __m128i packed2 = _mm_loadl_epi64(reinterpret_cast(v)); + sum = _mm_cvtepi8_epi16(packed1); + packed2 = _mm_cvtepi8_epi16(packed2); + // The magic _mm_add_epi16 is perfect here. It multiplies 8 pairs of 16 bit + // ints to make 32 bit results, which are then horizontally added in pairs + // to make 4 32 bit results that still fit in a 128 bit register. + sum = _mm_madd_epi16(sum, packed2); + while (offset <= max_offset) { + packed1 = _mm_loadl_epi64(reinterpret_cast(u + offset)); + packed2 = _mm_loadl_epi64(reinterpret_cast(v + offset)); + offset += 8; + packed1 = _mm_cvtepi8_epi16(packed1); + packed2 = _mm_cvtepi8_epi16(packed2); + packed1 = _mm_madd_epi16(packed1, packed2); + sum = _mm_add_epi32(sum, packed1); + } + } + // Sum the 4 packed 32 bit sums and extract the low result. + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + inT32 result = _mm_cvtsi128_si32(sum); + while (offset < n) { + result += u[offset] * v[offset]; + ++offset; + } + return result; +} + +} // namespace tesseract. + +#endif // ANDROID_BUILD diff --git a/arch/dotproductsse.h b/arch/dotproductsse.h new file mode 100644 index 0000000000..fa0a744fca --- /dev/null +++ b/arch/dotproductsse.h @@ -0,0 +1,35 @@ +/////////////////////////////////////////////////////////////////////// +// File: dotproductsse.h +// Description: Architecture-specific dot-product function. +// Author: Ray Smith +// Created: Wed Jul 22 10:57:05 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_ARCH_DOTPRODUCTSSE_H_ +#define TESSERACT_ARCH_DOTPRODUCTSSE_H_ + +#include "host.h" + +namespace tesseract { + +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel SSE intrinsics to access the SIMD instruction set. +double DotProductSSE(const double* u, const double* v, int n); +// Computes and returns the dot product of the n-vectors u and v. +// Uses Intel SSE intrinsics to access the SIMD instruction set. +inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n); + +} // namespace tesseract. + +#endif // TESSERACT_ARCH_DOTPRODUCTSSE_H_ diff --git a/ccmain/Makefile.am b/ccmain/Makefile.am index e82c0031a1..38edb2fd18 100644 --- a/ccmain/Makefile.am +++ b/ccmain/Makefile.am @@ -1,6 +1,7 @@ AM_CPPFLAGS += \ -DUSE_STD_NAMESPACE \ -I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \ + -I$(top_srcdir)/arch -I$(top_srcdir)/lstm \ -I$(top_srcdir)/viewer \ -I$(top_srcdir)/classify -I$(top_srcdir)/dict \ -I$(top_srcdir)/wordrec -I$(top_srcdir)/cutil \ @@ -33,6 +34,9 @@ libtesseract_main_la_LIBADD = \ ../ccstruct/libtesseract_ccstruct.la \ ../viewer/libtesseract_viewer.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../classify/libtesseract_classify.la \ ../cutil/libtesseract_cutil.la \ ../opencl/libtesseract_opencl.la @@ -44,7 +48,7 @@ endif libtesseract_main_la_SOURCES = \ adaptions.cpp applybox.cpp control.cpp \ docqual.cpp equationdetect.cpp fixspace.cpp fixxht.cpp \ - ltrresultiterator.cpp \ + linerec.cpp ltrresultiterator.cpp \ osdetect.cpp output.cpp pageiterator.cpp pagesegmain.cpp \ pagewalk.cpp par_control.cpp paragraphs.cpp paramsd.cpp pgedit.cpp recogtraining.cpp \ reject.cpp resultiterator.cpp superscript.cpp \ diff --git a/ccmain/control.cpp b/ccmain/control.cpp index 79a0c27f37..006d34d7be 100644 --- a/ccmain/control.cpp +++ b/ccmain/control.cpp @@ -84,7 +84,12 @@ BOOL8 Tesseract::recog_interactive(PAGE_RES_IT* pr_it) { WordData word_data(*pr_it); SetupWordPassN(2, &word_data); - classify_word_and_language(2, pr_it, &word_data); + // LSTM doesn't run on pass2, but we want to run pass2 for tesseract. + if (lstm_recognizer_ == NULL) { + classify_word_and_language(2, pr_it, &word_data); + } else { + classify_word_and_language(1, pr_it, &word_data); + } if (tessedit_debug_quality_metrics) { WERD_RES* word_res = pr_it->word(); word_char_quality(word_res, pr_it->row()->row, &char_qual, &good_char_qual); @@ -218,16 +223,14 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor, if (pass_n == 1) { monitor->progress = 70 * w / words->size(); if (monitor->progress_callback != NULL) { - TBOX box = pr_it->word()->word->bounding_box(); - (*monitor->progress_callback)(monitor->progress, - box.left(), box.right(), - box.top(), box.bottom()); + TBOX box = pr_it->word()->word->bounding_box(); + (*monitor->progress_callback)(monitor->progress, box.left(), + box.right(), box.top(), box.bottom()); } } else { monitor->progress = 70 + 30 * w / words->size(); - if (monitor->progress_callback!=NULL) { - (*monitor->progress_callback)(monitor->progress, - 0, 0, 0, 0); + if (monitor->progress_callback != NULL) { + (*monitor->progress_callback)(monitor->progress, 0, 0, 0, 0); } } if (monitor->deadline_exceeded() || @@ -252,7 +255,8 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor, pr_it->forward(); ASSERT_HOST(pr_it->word() != NULL); bool make_next_word_fuzzy = false; - if (ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) { + if (!AnyLSTMLang() && + ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) { // Needs to be setup again to see the new outlines in the chopped_word. SetupWordPassN(pass_n, word); } @@ -297,6 +301,16 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res, const TBOX* target_word_box, const char* word_config, int dopasses) { + // PSM_RAW_LINE is a special-case mode in which the layout analysis is + // completely ignored and LSTM is run on the raw image. There is no hope + // of running normal tesseract in this situation or of integrating output. +#ifndef ANDROID_BUILD + if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY && + tessedit_pageseg_mode == PSM_RAW_LINE) { + RecogRawLine(page_res); + return true; + } +#endif PAGE_RES_IT page_res_it(page_res); if (tessedit_minimal_rej_pass1) { @@ -385,7 +399,7 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res, // The next passes can only be run if tesseract has been used, as cube // doesn't set all the necessary outputs in WERD_RES. - if (AnyTessLang()) { + if (AnyTessLang() && !AnyLSTMLang()) { // ****************** Pass 3 ******************* // Fix fuzzy spaces. set_global_loc_code(LOC_FUZZY_SPACE); @@ -1362,6 +1376,19 @@ void Tesseract::classify_word_pass1(const WordData& word_data, cube_word_pass1(block, row, *in_word); return; } + if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) { + if (!(*in_word)->odd_size) { + LSTMRecognizeWord(*block, row, *in_word, out_words); + if (!out_words->empty()) + return; // Successful lstm recognition. + } + // Fall back to tesseract for failed words or odd words. + (*in_word)->SetupForRecognition(unicharset, this, BestPix(), + OEM_TESSERACT_ONLY, NULL, + classify_bln_numeric_mode, + textord_use_cjk_fp_model, + poly_allow_detailed_fx, row, block); + } #endif WERD_RES* word = *in_word; match_word_pass_n(1, word, row, block); @@ -1496,10 +1523,6 @@ void Tesseract::classify_word_pass2(const WordData& word_data, WERD_RES** in_word, PointerVector* out_words) { // Return if we do not want to run Tesseract. - if (tessedit_ocr_engine_mode != OEM_TESSERACT_ONLY && - tessedit_ocr_engine_mode != OEM_TESSERACT_CUBE_COMBINED && - word_data.word->best_choice != NULL) - return; if (tessedit_ocr_engine_mode == OEM_CUBE_ONLY) { return; } diff --git a/ccmain/linerec.cpp b/ccmain/linerec.cpp new file mode 100644 index 0000000000..0a9e8fb89b --- /dev/null +++ b/ccmain/linerec.cpp @@ -0,0 +1,332 @@ +/////////////////////////////////////////////////////////////////////// +// File: linerec.cpp +// Description: Top-level line-based recognition module for Tesseract. +// Author: Ray Smith +// Created: Thu May 02 09:47:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "tesseractclass.h" + +#include "allheaders.h" +#include "boxread.h" +#include "imagedata.h" +#ifndef ANDROID_BUILD +#include "lstmrecognizer.h" +#include "recodebeam.h" +#endif +#include "ndminx.h" +#include "pageres.h" +#include "tprintf.h" + +namespace tesseract { + +// Arbitarary penalty for non-dictionary words. +// TODO(rays) How to learn this? +const float kNonDictionaryPenalty = 5.0f; +// Scale factor to make certainty more comparable to Tesseract. +const float kCertaintyScale = 7.0f; +// Worst acceptable certainty for a dictionary word. +const float kWorstDictCertainty = -25.0f; + +// Generates training data for training a line recognizer, eg LSTM. +// Breaks the page into lines, according to the boxes, and writes them to a +// serialized DocumentData based on output_basename. +void Tesseract::TrainLineRecognizer(const STRING& input_imagename, + const STRING& output_basename, + BLOCK_LIST *block_list) { + STRING lstmf_name = output_basename + ".lstmf"; + DocumentData images(lstmf_name); + if (applybox_page > 0) { + // Load existing document for the previous pages. + if (!images.LoadDocument(lstmf_name.string(), "eng", 0, 0, NULL)) { + tprintf("Failed to read training data from %s!\n", lstmf_name.string()); + return; + } + } + GenericVector boxes; + GenericVector texts; + // Get the boxes for this page, if there are any. + if (!ReadAllBoxes(applybox_page, false, input_imagename, &boxes, &texts, NULL, + NULL) || + boxes.empty()) { + tprintf("Failed to read boxes from %s\n", input_imagename.string()); + return; + } + TrainFromBoxes(boxes, texts, block_list, &images); + if (!images.SaveDocument(lstmf_name.string(), NULL)) { + tprintf("Failed to write training data to %s!\n", lstmf_name.string()); + } +} + +// Generates training data for training a line recognizer, eg LSTM. +// Breaks the boxes into lines, normalizes them, converts to ImageData and +// appends them to the given training_data. +void Tesseract::TrainFromBoxes(const GenericVector& boxes, + const GenericVector& texts, + BLOCK_LIST *block_list, + DocumentData* training_data) { + int box_count = boxes.size(); + // Process all the text lines in this page, as defined by the boxes. + int end_box = 0; + for (int start_box = 0; start_box < box_count; start_box = end_box) { + // Find the textline of boxes starting at start and their bounding box. + TBOX line_box = boxes[start_box]; + STRING line_str = texts[start_box]; + for (end_box = start_box + 1; end_box < box_count && texts[end_box] != "\t"; + ++end_box) { + line_box += boxes[end_box]; + line_str += texts[end_box]; + } + // Find the most overlapping block. + BLOCK* best_block = NULL; + int best_overlap = 0; + BLOCK_IT b_it(block_list); + for (b_it.mark_cycle_pt(); !b_it.cycled_list(); b_it.forward()) { + BLOCK* block = b_it.data(); + if (block->poly_block() != NULL && !block->poly_block()->IsText()) + continue; // Not a text block. + TBOX block_box = block->bounding_box(); + block_box.rotate(block->re_rotation()); + if (block_box.major_overlap(line_box)) { + TBOX overlap_box = line_box.intersection(block_box); + if (overlap_box.area() > best_overlap) { + best_overlap = overlap_box.area(); + best_block = block; + } + } + } + ImageData* imagedata = NULL; + if (best_block == NULL) { + tprintf("No block overlapping textline: %s\n", line_str.string()); + } else { + imagedata = GetLineData(line_box, boxes, texts, start_box, end_box, + *best_block); + } + if (imagedata != NULL) + training_data->AddPageToDocument(imagedata); + if (end_box < texts.size() && texts[end_box] == "\t") ++end_box; + } +} + +// Returns an Imagedata containing the image of the given box, +// and ground truth boxes/truth text if available in the input. +// The image is not normalized in any way. +ImageData* Tesseract::GetLineData(const TBOX& line_box, + const GenericVector& boxes, + const GenericVector& texts, + int start_box, int end_box, + const BLOCK& block) { + TBOX revised_box; + ImageData* image_data = GetRectImage(line_box, block, kImagePadding, + &revised_box); + if (image_data == NULL) return NULL; + image_data->set_page_number(applybox_page); + // Copy the boxes and shift them so they are relative to the image. + FCOORD block_rotation(block.re_rotation().x(), -block.re_rotation().y()); + ICOORD shift = -revised_box.botleft(); + GenericVector line_boxes; + GenericVector line_texts; + for (int b = start_box; b < end_box; ++b) { + TBOX box = boxes[b]; + box.rotate(block_rotation); + box.move(shift); + line_boxes.push_back(box); + line_texts.push_back(texts[b]); + } + GenericVector page_numbers; + page_numbers.init_to_size(line_boxes.size(), applybox_page); + image_data->AddBoxes(line_boxes, line_texts, page_numbers); + return image_data; +} + +// Helper gets the image of a rectangle, using the block.re_rotation() if +// needed to get to the image, and rotating the result back to horizontal +// layout. (CJK characters will be on their left sides) The vertical text flag +// is set in the returned ImageData if the text was originally vertical, which +// can be used to invoke a different CJK recognition engine. The revised_box +// is also returned to enable calculation of output bounding boxes. +ImageData* Tesseract::GetRectImage(const TBOX& box, const BLOCK& block, + int padding, TBOX* revised_box) const { + TBOX wbox = box; + wbox.pad(padding, padding); + *revised_box = wbox; + // Number of clockwise 90 degree rotations needed to get back to tesseract + // coords from the clipped image. + int num_rotations = 0; + if (block.re_rotation().y() > 0.0f) + num_rotations = 1; + else if (block.re_rotation().x() < 0.0f) + num_rotations = 2; + else if (block.re_rotation().y() < 0.0f) + num_rotations = 3; + // Handle two cases automatically: 1 the box came from the block, 2 the box + // came from a box file, and refers to the image, which the block may not. + if (block.bounding_box().major_overlap(*revised_box)) + revised_box->rotate(block.re_rotation()); + // Now revised_box always refers to the image. + // BestPix is never colormapped, but may be of any depth. + Pix* pix = BestPix(); + int width = pixGetWidth(pix); + int height = pixGetHeight(pix); + TBOX image_box(0, 0, width, height); + // Clip to image bounds; + *revised_box &= image_box; + if (revised_box->null_box()) return NULL; + Box* clip_box = boxCreate(revised_box->left(), height - revised_box->top(), + revised_box->width(), revised_box->height()); + Pix* box_pix = pixClipRectangle(pix, clip_box, NULL); + if (box_pix == NULL) return NULL; + boxDestroy(&clip_box); + if (num_rotations > 0) { + Pix* rot_pix = pixRotateOrth(box_pix, num_rotations); + pixDestroy(&box_pix); + box_pix = rot_pix; + } + // Convert sub-8-bit images to 8 bit. + int depth = pixGetDepth(box_pix); + if (depth < 8) { + Pix* grey; + grey = pixConvertTo8(box_pix, false); + pixDestroy(&box_pix); + box_pix = grey; + } + bool vertical_text = false; + if (num_rotations > 0) { + // Rotated the clipped revised box back to internal coordinates. + FCOORD rotation(block.re_rotation().x(), -block.re_rotation().y()); + revised_box->rotate(rotation); + if (num_rotations != 2) + vertical_text = true; + } + return new ImageData(vertical_text, box_pix); +} + +#ifndef ANDROID_BUILD +// Top-level function recognizes a single raw line. +void Tesseract::RecogRawLine(PAGE_RES* page_res) { + PAGE_RES_IT it(page_res); + PointerVector words; + LSTMRecognizeWord(*it.block()->block, it.row()->row, it.word(), &words); + if (getDict().stopper_debug_level >= 1) { + for (int w = 0; w < words.size(); ++w) { + words[w]->DebugWordChoices(true, NULL); + } + } + it.ReplaceCurrentWord(&words); +} + +// Recognizes a word or group of words, converting to WERD_RES in *words. +// Analogous to classify_word_pass1, but can handle a group of words as well. +void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word, + PointerVector* words) { + TBOX word_box = word->word->bounding_box(); + // Get the word image - no frills. + if (tessedit_pageseg_mode == PSM_SINGLE_WORD || + tessedit_pageseg_mode == PSM_RAW_LINE) { + // In single word mode, use the whole image without any other row/word + // interpretation. + word_box = TBOX(0, 0, ImageWidth(), ImageHeight()); + } else { + float baseline = row->base_line((word_box.left() + word_box.right()) / 2); + if (baseline + row->descenders() < word_box.bottom()) + word_box.set_bottom(baseline + row->descenders()); + if (baseline + row->x_height() + row->ascenders() > word_box.top()) + word_box.set_top(baseline + row->x_height() + row->ascenders()); + } + ImageData* im_data = GetRectImage(word_box, block, kImagePadding, &word_box); + if (im_data == NULL) return; + lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0, + kWorstDictCertainty / kCertaintyScale, + lstm_use_matrix, &unicharset, word_box, 2.0, + false, words); + delete im_data; + SearchWords(words); +} + +// Apply segmentation search to the given set of words, within the constraints +// of the existing ratings matrix. If there is already a best_choice on a word +// leaves it untouched and just sets the done/accepted etc flags. +void Tesseract::SearchWords(PointerVector* words) { + // Run the segmentation search on the network outputs and make a BoxWord + // for each of the output words. + // If we drop a word as junk, then there is always a space in front of the + // next. + bool deleted_prev = false; + for (int w = 0; w < words->size(); ++w) { + WERD_RES* word = (*words)[w]; + if (word->best_choice == NULL) { + // If we are using the beam search, the unicharset had better match! + word->SetupWordScript(unicharset); + WordSearch(word); + } else if (word->best_choice->unicharset() == &unicharset && + !lstm_recognizer_->IsRecoding()) { + // We set up the word without using the dictionary, so set the permuter + // now, but we can only do it because the unicharsets match. + word->best_choice->set_permuter( + getDict().valid_word(*word->best_choice, true)); + } + if (word->best_choice == NULL) { + // It is a dud. + words->remove(w); + --w; + deleted_prev = true; + } else { + // Set the best state. + for (int i = 0; i < word->best_choice->length(); ++i) { + int length = word->best_choice->state(i); + word->best_state.push_back(length); + } + word->tess_failed = false; + word->tess_accepted = true; + word->tess_would_adapt = false; + word->done = true; + word->tesseract = this; + float word_certainty = MIN(word->space_certainty, + word->best_choice->certainty()); + word_certainty *= kCertaintyScale; + // Arbitrary ding factor for non-dictionary words. + if (!lstm_recognizer_->IsRecoding() && + !Dict::valid_word_permuter(word->best_choice->permuter(), true)) + word_certainty -= kNonDictionaryPenalty; + if (getDict().stopper_debug_level >= 1) { + tprintf("Best choice certainty=%g, space=%g, scaled=%g, final=%g\n", + word->best_choice->certainty(), word->space_certainty, + MIN(word->space_certainty, word->best_choice->certainty()) * + kCertaintyScale, + word_certainty); + word->best_choice->print(); + } + // Discard words that are impossibly bad, but allow a bit more for + // dictionary words. + if (word_certainty >= RecodeBeamSearch::kMinCertainty || + (word_certainty >= kWorstDictCertainty && + Dict::valid_word_permuter(word->best_choice->permuter(), true))) { + word->best_choice->set_certainty(word_certainty); + if (deleted_prev) word->word->set_blanks(1); + } else { + if (getDict().stopper_debug_level >= 1) { + tprintf("Deleting word with certainty %g\n", word_certainty); + word->best_choice->print(); + } + // It is a dud. + words->remove(w); + --w; + deleted_prev = true; + } + } + } +} +#endif // ANDROID_BUILD + +} // namespace tesseract. diff --git a/ccmain/ltrresultiterator.cpp b/ccmain/ltrresultiterator.cpp index f80e594518..ae582b30c0 100644 --- a/ccmain/ltrresultiterator.cpp +++ b/ccmain/ltrresultiterator.cpp @@ -220,6 +220,12 @@ bool LTRResultIterator::WordIsFromDictionary() const { permuter == USER_DAWG_PERM; } +// Returns the number of blanks before the current word. +int LTRResultIterator::BlanksBeforeWord() const { + if (it_->word() == NULL) return 1; + return it_->word()->word->space(); +} + // Returns true if the current word is numeric. bool LTRResultIterator::WordIsNumeric() const { if (it_->word() == NULL) return false; // Already at the end! diff --git a/ccmain/ltrresultiterator.h b/ccmain/ltrresultiterator.h index f2605b52d2..3ab70c4655 100644 --- a/ccmain/ltrresultiterator.h +++ b/ccmain/ltrresultiterator.h @@ -124,6 +124,9 @@ class TESS_API LTRResultIterator : public PageIterator { // Returns true if the current word was found in a dictionary. bool WordIsFromDictionary() const; + // Returns the number of blanks before the current word. + int BlanksBeforeWord() const; + // Returns true if the current word is numeric. bool WordIsNumeric() const; diff --git a/ccmain/tessedit.cpp b/ccmain/tessedit.cpp index 8c1fb80837..cf6b8b67a4 100644 --- a/ccmain/tessedit.cpp +++ b/ccmain/tessedit.cpp @@ -40,6 +40,9 @@ #include "efio.h" #include "danerror.h" #include "globals.h" +#ifndef ANDROID_BUILD +#include "lstmrecognizer.h" +#endif #include "tesseractclass.h" #include "params.h" @@ -214,6 +217,18 @@ bool Tesseract::init_tesseract_lang_data( ASSERT_HOST(init_cube_objects(true, &tessdata_manager)); if (tessdata_manager_debug_level) tprintf("Loaded Cube with combiner\n"); + } else if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) { + if (tessdata_manager.SeekToStart(TESSDATA_LSTM)) { + lstm_recognizer_ = new LSTMRecognizer; + TFile fp; + fp.Open(tessdata_manager.GetDataFilePtr(), -1); + ASSERT_HOST(lstm_recognizer_->DeSerialize(tessdata_manager.swap(), &fp)); + if (lstm_use_matrix) + lstm_recognizer_->LoadDictionary(tessdata_path.string(), language); + } else { + tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n"); + tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY); + } } #endif // Init ParamsModel. @@ -409,8 +424,7 @@ int Tesseract::init_tesseract_internal( // If only Cube will be used, skip loading Tesseract classifier's // pre-trained templates. bool init_tesseract_classifier = - (tessedit_ocr_engine_mode == OEM_TESSERACT_ONLY || - tessedit_ocr_engine_mode == OEM_TESSERACT_CUBE_COMBINED); + tessedit_ocr_engine_mode != OEM_CUBE_ONLY; // If only Cube will be used and if it has its own Unicharset, // skip initializing permuter and loading Tesseract Dawgs. bool init_dict = @@ -468,7 +482,9 @@ int Tesseract::init_tesseract_lm(const char *arg0, if (!init_tesseract_lang_data(arg0, textbase, language, OEM_TESSERACT_ONLY, NULL, 0, NULL, NULL, false)) return -1; - getDict().Load(Dict::GlobalDawgCache()); + getDict().SetupForLoad(Dict::GlobalDawgCache()); + getDict().Load(tessdata_manager.GetDataFileName().string(), lang); + getDict().FinishLoad(); tessdata_manager.End(); return 0; } diff --git a/ccmain/tesseractclass.cpp b/ccmain/tesseractclass.cpp index f0cc1bfffe..78e4095f47 100644 --- a/ccmain/tesseractclass.cpp +++ b/ccmain/tesseractclass.cpp @@ -49,6 +49,7 @@ #include "equationdetect.h" #include "globals.h" #ifndef NO_CUBE_BUILD +#include "lstmrecognizer.h" #include "tesseract_cube_combiner.h" #endif @@ -65,6 +66,9 @@ Tesseract::Tesseract() "Generate training data from boxed chars", this->params()), BOOL_MEMBER(tessedit_make_boxes_from_boxes, false, "Generate more boxes from boxed chars", this->params()), + BOOL_MEMBER(tessedit_train_line_recognizer, false, + "Break input into lines and remap boxes if present", + this->params()), BOOL_MEMBER(tessedit_dump_pageseg_images, false, "Dump intermediate images made during page segmentation", this->params()), @@ -222,6 +226,8 @@ Tesseract::Tesseract() "(more accurate)", this->params()), INT_MEMBER(cube_debug_level, 0, "Print cube debug info.", this->params()), + BOOL_MEMBER(lstm_use_matrix, 1, + "Use ratings matrix/beam search with lstm", this->params()), STRING_MEMBER(outlines_odd, "%| ", "Non standard number of outlines", this->params()), STRING_MEMBER(outlines_2, "ij!?%\":;", "Non standard number of outlines", @@ -605,6 +611,7 @@ Tesseract::Tesseract() pix_binary_(NULL), cube_binary_(NULL), pix_grey_(NULL), + pix_original_(NULL), pix_thresholds_(NULL), source_resolution_(0), textord_(this), @@ -619,11 +626,16 @@ Tesseract::Tesseract() cube_cntxt_(NULL), tess_cube_combiner_(NULL), #endif - equ_detect_(NULL) { + equ_detect_(NULL), +#ifndef ANDROID_BUILD + lstm_recognizer_(NULL), +#endif + train_line_page_num_(0) { } Tesseract::~Tesseract() { Clear(); + pixDestroy(&pix_original_); end_tesseract(); sub_langs_.delete_data_pointers(); #ifndef NO_CUBE_BUILD @@ -636,6 +648,8 @@ Tesseract::~Tesseract() { delete tess_cube_combiner_; tess_cube_combiner_ = NULL; } + delete lstm_recognizer_; + lstm_recognizer_ = NULL; #endif } diff --git a/ccmain/tesseractclass.h b/ccmain/tesseractclass.h index e05eac5026..8a3b0f9213 100644 --- a/ccmain/tesseractclass.h +++ b/ccmain/tesseractclass.h @@ -102,7 +102,10 @@ class CubeLineObject; class CubeObject; class CubeRecoContext; #endif +class DocumentData; class EquationDetect; +class ImageData; +class LSTMRecognizer; class Tesseract; #ifndef NO_CUBE_BUILD class TesseractCubeCombiner; @@ -189,7 +192,7 @@ class Tesseract : public Wordrec { } // Destroy any existing pix and return a pointer to the pointer. Pix** mutable_pix_binary() { - Clear(); + pixDestroy(&pix_binary_); return &pix_binary_; } Pix* pix_binary() const { @@ -202,16 +205,20 @@ class Tesseract : public Wordrec { pixDestroy(&pix_grey_); pix_grey_ = grey_pix; } - // Returns a pointer to a Pix representing the best available image of the - // page. The image will be 8-bit grey if the input was grey or color. Note - // that in grey 0 is black and 255 is white. If the input was binary, then - // the returned Pix will be binary. Note that here black is 1 and white is 0. - // To tell the difference pixGetDepth() will return 8 or 1. - // In either case, the return value is a borrowed Pix, and should not be - // deleted or pixDestroyed. - Pix* BestPix() const { - return pix_grey_ != NULL ? pix_grey_ : pix_binary_; + Pix* pix_original() const { return pix_original_; } + // Takes ownership of the given original_pix. + void set_pix_original(Pix* original_pix) { + pixDestroy(&pix_original_); + pix_original_ = original_pix; } + // Returns a pointer to a Pix representing the best available (original) image + // of the page. Can be of any bit depth, but never color-mapped, as that has + // always been dealt with. Note that in grey and color, 0 is black and 255 is + // white. If the input was binary, then black is 1 and white is 0. + // To tell the difference pixGetDepth() will return 32, 8 or 1. + // In any case, the return value is a borrowed Pix, and should not be + // deleted or pixDestroyed. + Pix* BestPix() const { return pix_original_; } void set_pix_thresholds(Pix* thresholds) { pixDestroy(&pix_thresholds_); pix_thresholds_ = thresholds; @@ -263,6 +270,15 @@ class Tesseract : public Wordrec { } return false; } + // Returns true if any language uses the LSTM. + bool AnyLSTMLang() const { + if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) return true; + for (int i = 0; i < sub_langs_.size(); ++i) { + if (sub_langs_[i]->tessedit_ocr_engine_mode == OEM_LSTM_ONLY) + return true; + } + return false; + } void SetBlackAndWhitelist(); @@ -293,6 +309,48 @@ class Tesseract : public Wordrec { // par_control.cpp void PrerecAllWordsPar(const GenericVector& words); + //// linerec.cpp + // Generates training data for training a line recognizer, eg LSTM. + // Breaks the page into lines, according to the boxes, and writes them to a + // serialized DocumentData based on output_basename. + void TrainLineRecognizer(const STRING& input_imagename, + const STRING& output_basename, + BLOCK_LIST *block_list); + // Generates training data for training a line recognizer, eg LSTM. + // Breaks the boxes into lines, normalizes them, converts to ImageData and + // appends them to the given training_data. + void TrainFromBoxes(const GenericVector& boxes, + const GenericVector& texts, + BLOCK_LIST *block_list, + DocumentData* training_data); + + // Returns an Imagedata containing the image of the given textline, + // and ground truth boxes/truth text if available in the input. + // The image is not normalized in any way. + ImageData* GetLineData(const TBOX& line_box, + const GenericVector& boxes, + const GenericVector& texts, + int start_box, int end_box, + const BLOCK& block); + // Helper gets the image of a rectangle, using the block.re_rotation() if + // needed to get to the image, and rotating the result back to horizontal + // layout. (CJK characters will be on their left sides) The vertical text flag + // is set in the returned ImageData if the text was originally vertical, which + // can be used to invoke a different CJK recognition engine. The revised_box + // is also returned to enable calculation of output bounding boxes. + ImageData* GetRectImage(const TBOX& box, const BLOCK& block, int padding, + TBOX* revised_box) const; + // Top-level function recognizes a single raw line. + void RecogRawLine(PAGE_RES* page_res); + // Recognizes a word or group of words, converting to WERD_RES in *words. + // Analogous to classify_word_pass1, but can handle a group of words as well. + void LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word, + PointerVector* words); + // Apply segmentation search to the given set of words, within the constraints + // of the existing ratings matrix. If there is already a best_choice on a word + // leaves it untouched and just sets the done/accepted etc flags. + void SearchWords(PointerVector* words); + //// control.h ///////////////////////////////////////////////////////// bool ProcessTargetWord(const TBOX& word_box, const TBOX& target_word_box, const char* word_config, int pass); @@ -783,6 +841,8 @@ class Tesseract : public Wordrec { "Generate training data from boxed chars"); BOOL_VAR_H(tessedit_make_boxes_from_boxes, false, "Generate more boxes from boxed chars"); + BOOL_VAR_H(tessedit_train_line_recognizer, false, + "Break input into lines and remap boxes if present"); BOOL_VAR_H(tessedit_dump_pageseg_images, false, "Dump intermediate images made during page segmentation"); INT_VAR_H(tessedit_pageseg_mode, PSM_SINGLE_BLOCK, @@ -891,6 +951,7 @@ class Tesseract : public Wordrec { "Run paragraph detection on the post-text-recognition " "(more accurate)"); INT_VAR_H(cube_debug_level, 1, "Print cube debug info."); + BOOL_VAR_H(lstm_use_matrix, 1, "Use ratings matrix/beam searct with lstm"); STRING_VAR_H(outlines_odd, "%| ", "Non standard number of outlines"); STRING_VAR_H(outlines_2, "ij!?%\":;", "Non standard number of outlines"); BOOL_VAR_H(docqual_excuse_outline_errs, false, @@ -1174,6 +1235,8 @@ class Tesseract : public Wordrec { Pix* cube_binary_; // Grey-level input image if the input was not binary, otherwise NULL. Pix* pix_grey_; + // Original input image. Color if the input was color. + Pix* pix_original_; // Thresholds that were used to generate the thresholded image from grey. Pix* pix_thresholds_; // Input image resolution after any scaling. The resolution is not well @@ -1205,6 +1268,10 @@ class Tesseract : public Wordrec { #endif // Equation detector. Note: this pointer is NOT owned by the class. EquationDetect* equ_detect_; + // LSTM recognizer, if available. + LSTMRecognizer* lstm_recognizer_; + // Output "page" number (actually line number) using TrainLineRecognizer. + int train_line_page_num_; }; } // namespace tesseract diff --git a/ccmain/thresholder.cpp b/ccmain/thresholder.cpp index df6abd01eb..a9a127eb3b 100644 --- a/ccmain/thresholder.cpp +++ b/ccmain/thresholder.cpp @@ -152,19 +152,27 @@ void ImageThresholder::SetImage(const Pix* pix) { int depth; pixGetDimensions(src, &image_width_, &image_height_, &depth); // Convert the image as necessary so it is one of binary, plain RGB, or - // 8 bit with no colormap. - if (depth > 1 && depth < 8) { + // 8 bit with no colormap. Guarantee that we always end up with our own copy, + // not just a clone of the input. + if (pixGetColormap(src)) { + Pix* tmp = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC); + depth = pixGetDepth(tmp); + if (depth > 1 && depth < 8) { + pix_ = pixConvertTo8(tmp, false); + pixDestroy(&tmp); + } else { + pix_ = tmp; + } + } else if (depth > 1 && depth < 8) { pix_ = pixConvertTo8(src, false); - } else if (pixGetColormap(src)) { - pix_ = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC); } else { - pix_ = pixClone(src); + pix_ = pixCopy(NULL, src); } depth = pixGetDepth(pix_); pix_channels_ = depth / 8; pix_wpl_ = pixGetWpl(pix_); scale_ = 1; - estimated_res_ = yres_ = pixGetYRes(src); + estimated_res_ = yres_ = pixGetYRes(pix_); Init(); } diff --git a/ccstruct/imagedata.cpp b/ccstruct/imagedata.cpp index 3c244c7724..77e4969354 100644 --- a/ccstruct/imagedata.cpp +++ b/ccstruct/imagedata.cpp @@ -24,12 +24,18 @@ #include "imagedata.h" +#include + #include "allheaders.h" #include "boxread.h" #include "callcpp.h" #include "helpers.h" #include "tprintf.h" +// Number of documents to read ahead while training. Doesn't need to be very +// large. +const int kMaxReadAhead = 8; + namespace tesseract { WordFeature::WordFeature() : x_(0), y_(0), dir_(0) { @@ -182,6 +188,19 @@ bool ImageData::DeSerialize(bool swap, TFile* fp) { return true; } +// As DeSerialize, but only seeks past the data - hence a static method. +bool ImageData::SkipDeSerialize(bool swap, TFile* fp) { + if (!STRING::SkipDeSerialize(swap, fp)) return false; + inT32 page_number; + if (fp->FRead(&page_number, sizeof(page_number), 1) != 1) return false; + if (!GenericVector::SkipDeSerialize(swap, fp)) return false; + if (!STRING::SkipDeSerialize(swap, fp)) return false; + if (!GenericVector::SkipDeSerialize(swap, fp)) return false; + if (!GenericVector::SkipDeSerializeClasses(swap, fp)) return false; + inT8 vertical = 0; + return fp->FRead(&vertical, sizeof(vertical), 1) == 1; +} + // Saves the given Pix as a PNG-encoded string and destroys it. void ImageData::SetPix(Pix* pix) { SetPixInternal(pix, &image_data_); @@ -195,11 +214,12 @@ Pix* ImageData::GetPix() const { // Gets anything and everything with a non-NULL pointer, prescaled to a // given target_height (if 0, then the original image height), and aligned. // Also returns (if not NULL) the width and height of the scaled image. -// The return value is the scale factor that was applied to the image to -// achieve the target_height. -float ImageData::PreScale(int target_height, Pix** pix, - int* scaled_width, int* scaled_height, - GenericVector* boxes) const { +// The return value is the scaled Pix, which must be pixDestroyed after use, +// and scale_factor (if not NULL) is set to the scale factor that was applied +// to the image to achieve the target_height. +Pix* ImageData::PreScale(int target_height, float* scale_factor, + int* scaled_width, int* scaled_height, + GenericVector* boxes) const { int input_width = 0; int input_height = 0; Pix* src_pix = GetPix(); @@ -213,19 +233,14 @@ float ImageData::PreScale(int target_height, Pix** pix, *scaled_width = IntCastRounded(im_factor * input_width); if (scaled_height != NULL) *scaled_height = target_height; - if (pix != NULL) { - // Get the scaled image. - pixDestroy(pix); - *pix = pixScale(src_pix, im_factor, im_factor); - if (*pix == NULL) { - tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n", - input_width, input_height, im_factor); - } - if (scaled_width != NULL) - *scaled_width = pixGetWidth(*pix); - if (scaled_height != NULL) - *scaled_height = pixGetHeight(*pix); + // Get the scaled image. + Pix* pix = pixScale(src_pix, im_factor, im_factor); + if (pix == NULL) { + tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n", + input_width, input_height, im_factor); } + if (scaled_width != NULL) *scaled_width = pixGetWidth(pix); + if (scaled_height != NULL) *scaled_height = pixGetHeight(pix); pixDestroy(&src_pix); if (boxes != NULL) { // Get the boxes. @@ -241,7 +256,8 @@ float ImageData::PreScale(int target_height, Pix** pix, boxes->push_back(box); } } - return im_factor; + if (scale_factor != NULL) *scale_factor = im_factor; + return pix; } int ImageData::MemoryUsed() const { @@ -266,19 +282,20 @@ void ImageData::Display() const { // Draw the boxes. win->Pen(ScrollView::RED); win->Brush(ScrollView::NONE); - win->TextAttributes("Arial", kTextSize, false, false, false); - for (int b = 0; b < boxes_.size(); ++b) { - boxes_[b].plot(win); - win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string()); - TBOX scaled(boxes_[b]); - scaled.scale(256.0 / height); - scaled.plot(win); + int text_size = kTextSize; + if (!boxes_.empty() && boxes_[0].height() * 2 < text_size) + text_size = boxes_[0].height() * 2; + win->TextAttributes("Arial", text_size, false, false, false); + if (!boxes_.empty()) { + for (int b = 0; b < boxes_.size(); ++b) { + boxes_[b].plot(win); + win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string()); + } + } else { + // The full transcription. + win->Pen(ScrollView::CYAN); + win->Text(0, height + kTextSize * 2, transcription_.string()); } - // The full transcription. - win->Pen(ScrollView::CYAN); - win->Text(0, height + kTextSize * 2, transcription_.string()); - // Add the features. - win->Pen(ScrollView::GREEN); win->Update(); window_wait(win); #endif @@ -340,27 +357,51 @@ bool ImageData::AddBoxes(const char* box_text) { return false; } +// Thread function to call ReCachePages. +void* ReCachePagesFunc(void* data) { + DocumentData* document_data = reinterpret_cast(data); + document_data->ReCachePages(); + return NULL; +} + DocumentData::DocumentData(const STRING& name) - : document_name_(name), pages_offset_(0), total_pages_(0), - memory_used_(0), max_memory_(0), reader_(NULL) {} + : document_name_(name), + pages_offset_(-1), + total_pages_(-1), + memory_used_(0), + max_memory_(0), + reader_(NULL) {} -DocumentData::~DocumentData() {} +DocumentData::~DocumentData() { + SVAutoLock lock_p(&pages_mutex_); + SVAutoLock lock_g(&general_mutex_); +} // Reads all the pages in the given lstmf filename to the cache. The reader // is used to read the file. bool DocumentData::LoadDocument(const char* filename, const char* lang, int start_page, inT64 max_memory, FileReader reader) { + SetDocument(filename, lang, max_memory, reader); + pages_offset_ = start_page; + return ReCachePages(); +} + +// Sets up the document, without actually loading it. +void DocumentData::SetDocument(const char* filename, const char* lang, + inT64 max_memory, FileReader reader) { + SVAutoLock lock_p(&pages_mutex_); + SVAutoLock lock(&general_mutex_); document_name_ = filename; lang_ = lang; - pages_offset_ = start_page; + pages_offset_ = -1; max_memory_ = max_memory; reader_ = reader; - return ReCachePages(); } // Writes all the pages to the given filename. Returns false on error. bool DocumentData::SaveDocument(const char* filename, FileWriter writer) { + SVAutoLock lock(&pages_mutex_); TFile fp; fp.OpenWrite(NULL); if (!pages_.Serialize(&fp) || !fp.CloseWrite(filename, writer)) { @@ -370,112 +411,166 @@ bool DocumentData::SaveDocument(const char* filename, FileWriter writer) { return true; } bool DocumentData::SaveToBuffer(GenericVector* buffer) { + SVAutoLock lock(&pages_mutex_); TFile fp; fp.OpenWrite(buffer); return pages_.Serialize(&fp); } +// Adds the given page data to this document, counting up memory. +void DocumentData::AddPageToDocument(ImageData* page) { + SVAutoLock lock(&pages_mutex_); + pages_.push_back(page); + set_memory_used(memory_used() + page->MemoryUsed()); +} + +// If the given index is not currently loaded, loads it using a separate +// thread. +void DocumentData::LoadPageInBackground(int index) { + ImageData* page = NULL; + if (IsPageAvailable(index, &page)) return; + SVAutoLock lock(&pages_mutex_); + if (pages_offset_ == index) return; + pages_offset_ = index; + pages_.clear(); + SVSync::StartThread(ReCachePagesFunc, this); +} + // Returns a pointer to the page with the given index, modulo the total -// number of pages, recaching if needed. +// number of pages. Blocks until the background load is completed. const ImageData* DocumentData::GetPage(int index) { - index = Modulo(index, total_pages_); - if (index < pages_offset_ || index >= pages_offset_ + pages_.size()) { - pages_offset_ = index; - if (!ReCachePages()) return NULL; + ImageData* page = NULL; + while (!IsPageAvailable(index, &page)) { + // If there is no background load scheduled, schedule one now. + pages_mutex_.Lock(); + bool needs_loading = pages_offset_ != index; + pages_mutex_.Unlock(); + if (needs_loading) LoadPageInBackground(index); + // We can't directly load the page, or the background load will delete it + // while the caller is using it, so give it a chance to work. + sleep(1); } - return pages_[index - pages_offset_]; + return page; +} + +// Returns true if the requested page is available, and provides a pointer, +// which may be NULL if the document is empty. May block, even though it +// doesn't guarantee to return true. +bool DocumentData::IsPageAvailable(int index, ImageData** page) { + SVAutoLock lock(&pages_mutex_); + int num_pages = NumPages(); + if (num_pages == 0 || index < 0) { + *page = NULL; // Empty Document. + return true; + } + if (num_pages > 0) { + index = Modulo(index, num_pages); + if (pages_offset_ <= index && index < pages_offset_ + pages_.size()) { + *page = pages_[index - pages_offset_]; // Page is available already. + return true; + } + } + return false; } -// Loads as many pages can fit in max_memory_ starting at index pages_offset_. +// Removes all pages from memory and frees the memory, but does not forget +// the document metadata. +inT64 DocumentData::UnCache() { + SVAutoLock lock(&pages_mutex_); + inT64 memory_saved = memory_used(); + pages_.clear(); + pages_offset_ = -1; + set_total_pages(-1); + set_memory_used(0); + tprintf("Unloaded document %s, saving %d memory\n", document_name_.string(), + memory_saved); + return memory_saved; +} + +// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_ +// starting at index pages_offset_. bool DocumentData::ReCachePages() { + SVAutoLock lock(&pages_mutex_); // Read the file. + set_total_pages(0); + set_memory_used(0); + int loaded_pages = 0; + pages_.truncate(0); TFile fp; - if (!fp.Open(document_name_, reader_)) return false; - memory_used_ = 0; - if (!pages_.DeSerialize(false, &fp)) { - tprintf("Deserialize failed: %s\n", document_name_.string()); - pages_.truncate(0); + if (!fp.Open(document_name_, reader_) || + !PointerVector::DeSerializeSize(false, &fp, &loaded_pages) || + loaded_pages <= 0) { + tprintf("Deserialize header failed: %s\n", document_name_.string()); return false; } - total_pages_ = pages_.size(); - pages_offset_ %= total_pages_; - // Delete pages before the first one we want, and relocate the rest. + pages_offset_ %= loaded_pages; + // Skip pages before the first one we want, and load the rest until max + // memory and skip the rest after that. int page; - for (page = 0; page < pages_.size(); ++page) { - if (page < pages_offset_) { - delete pages_[page]; - pages_[page] = NULL; + for (page = 0; page < loaded_pages; ++page) { + if (page < pages_offset_ || + (max_memory_ > 0 && memory_used() > max_memory_)) { + if (!PointerVector::DeSerializeSkip(false, &fp)) break; } else { - ImageData* image_data = pages_[page]; - if (max_memory_ > 0 && page > pages_offset_ && - memory_used_ + image_data->MemoryUsed() > max_memory_) - break; // Don't go over memory quota unless the first image. + if (!pages_.DeSerializeElement(false, &fp)) break; + ImageData* image_data = pages_.back(); if (image_data->imagefilename().length() == 0) { image_data->set_imagefilename(document_name_); image_data->set_page_number(page); } image_data->set_language(lang_); - memory_used_ += image_data->MemoryUsed(); - if (pages_offset_ != 0) { - pages_[page - pages_offset_] = image_data; - pages_[page] = NULL; - } + set_memory_used(memory_used() + image_data->MemoryUsed()); } } - pages_.truncate(page - pages_offset_); - tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", - pages_.size(), total_pages_, pages_offset_, - pages_offset_ + pages_.size(), document_name_.string()); + if (page < loaded_pages) { + tprintf("Deserialize failed: %s read %d/%d pages\n", + document_name_.string(), page, loaded_pages); + pages_.truncate(0); + } else { + tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", pages_.size(), + loaded_pages, pages_offset_, pages_offset_ + pages_.size(), + document_name_.string()); + } + set_total_pages(loaded_pages); return !pages_.empty(); } -// Adds the given page data to this document, counting up memory. -void DocumentData::AddPageToDocument(ImageData* page) { - pages_.push_back(page); - memory_used_ += page->MemoryUsed(); -} - // A collection of DocumentData that knows roughly how much memory it is using. DocumentCache::DocumentCache(inT64 max_memory) - : total_pages_(0), memory_used_(0), max_memory_(max_memory) {} + : num_pages_per_doc_(0), max_memory_(max_memory) {} DocumentCache::~DocumentCache() {} // Adds all the documents in the list of filenames, counting memory. // The reader is used to read the files. bool DocumentCache::LoadDocuments(const GenericVector& filenames, - const char* lang, FileReader reader) { - inT64 fair_share_memory = max_memory_ / filenames.size(); + const char* lang, + CachingStrategy cache_strategy, + FileReader reader) { + cache_strategy_ = cache_strategy; + inT64 fair_share_memory = 0; + // In the round-robin case, each DocumentData handles restricting its content + // to its fair share of memory. In the sequential case, DocumentCache + // determines which DocumentDatas are held entirely in memory. + if (cache_strategy_ == CS_ROUND_ROBIN) + fair_share_memory = max_memory_ / filenames.size(); for (int arg = 0; arg < filenames.size(); ++arg) { STRING filename = filenames[arg]; DocumentData* document = new DocumentData(filename); - if (document->LoadDocument(filename.string(), lang, 0, - fair_share_memory, reader)) { - AddToCache(document); - } else { - tprintf("Failed to load image %s!\n", filename.string()); - delete document; - } + document->SetDocument(filename.string(), lang, fair_share_memory, reader); + AddToCache(document); + } + if (!documents_.empty()) { + // Try to get the first page now to verify the list of filenames. + if (GetPageBySerial(0) != NULL) return true; + tprintf("Load of page 0 failed!\n"); } - tprintf("Loaded %d pages, total %gMB\n", - total_pages_, memory_used_ / 1048576.0); - return total_pages_ > 0; + return false; } -// Adds document to the cache, throwing out other documents if needed. +// Adds document to the cache. bool DocumentCache::AddToCache(DocumentData* data) { inT64 new_memory = data->memory_used(); - memory_used_ += new_memory; documents_.push_back(data); - total_pages_ += data->NumPages(); - // Delete the first item in the array, and other pages of the same name - // while memory is full. - while (memory_used_ >= max_memory_ && max_memory_ > 0) { - tprintf("Memory used=%lld vs max=%lld, discarding doc of size %lld\n", - memory_used_ , max_memory_, documents_[0]->memory_used()); - memory_used_ -= documents_[0]->memory_used(); - total_pages_ -= documents_[0]->NumPages(); - documents_.remove(0); - } return true; } @@ -488,11 +583,104 @@ DocumentData* DocumentCache::FindDocument(const STRING& document_name) const { return NULL; } +// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache +// strategy, could take a long time. +int DocumentCache::TotalPages() { + if (cache_strategy_ == CS_SEQUENTIAL) { + // In sequential mode, we assume each doc has the same number of pages + // whether it is true or not. + if (num_pages_per_doc_ == 0) GetPageSequential(0); + return num_pages_per_doc_ * documents_.size(); + } + int total_pages = 0; + int num_docs = documents_.size(); + for (int d = 0; d < num_docs; ++d) { + // We have to load a page to make NumPages() valid. + documents_[d]->GetPage(0); + total_pages += documents_[d]->NumPages(); + } + return total_pages; +} + // Returns a page by serial number, selecting them in a round-robin fashion -// from all the documents. -const ImageData* DocumentCache::GetPageBySerial(int serial) { - int document_index = serial % documents_.size(); - return documents_[document_index]->GetPage(serial / documents_.size()); +// from all the documents. Highly disk-intensive, but doesn't need samples +// to be shuffled between files to begin with. +const ImageData* DocumentCache::GetPageRoundRobin(int serial) { + int num_docs = documents_.size(); + int doc_index = serial % num_docs; + const ImageData* doc = documents_[doc_index]->GetPage(serial / num_docs); + for (int offset = 1; offset <= kMaxReadAhead && offset < num_docs; ++offset) { + doc_index = (serial + offset) % num_docs; + int page = (serial + offset) / num_docs; + documents_[doc_index]->LoadPageInBackground(page); + } + return doc; +} + +// Returns a page by serial number, selecting them in sequence from each file. +// Requires the samples to be shuffled between the files to give a random or +// uniform distribution of data. Less disk-intensive than GetPageRoundRobin. +const ImageData* DocumentCache::GetPageSequential(int serial) { + int num_docs = documents_.size(); + ASSERT_HOST(num_docs > 0); + if (num_pages_per_doc_ == 0) { + // Use the pages in the first doc as the number of pages in each doc. + documents_[0]->GetPage(0); + num_pages_per_doc_ = documents_[0]->NumPages(); + if (num_pages_per_doc_ == 0) { + tprintf("First document cannot be empty!!\n"); + ASSERT_HOST(num_pages_per_doc_ > 0); + } + // Get rid of zero now if we don't need it. + if (serial / num_pages_per_doc_ % num_docs > 0) documents_[0]->UnCache(); + } + int doc_index = serial / num_pages_per_doc_ % num_docs; + const ImageData* doc = + documents_[doc_index]->GetPage(serial % num_pages_per_doc_); + // Count up total memory. Background loading makes it more complicated to + // keep a running count. + inT64 total_memory = 0; + for (int d = 0; d < num_docs; ++d) { + total_memory += documents_[d]->memory_used(); + } + if (total_memory >= max_memory_) { + // Find something to un-cache. + // If there are more than 3 in front, then serial is from the back reader + // of a pair of readers. If we un-cache from in-front-2 to 2-ahead, then + // we create a hole between them and then un-caching the backmost occupied + // will work for both. + int num_in_front = CountNeighbourDocs(doc_index, 1); + for (int offset = num_in_front - 2; + offset > 1 && total_memory >= max_memory_; --offset) { + int next_index = (doc_index + offset) % num_docs; + total_memory -= documents_[next_index]->UnCache(); + } + // If that didn't work, the best solution is to un-cache from the back. If + // we take away the document that a 2nd reader is using, it will put it + // back and make a hole between. + int num_behind = CountNeighbourDocs(doc_index, -1); + for (int offset = num_behind; offset < 0 && total_memory >= max_memory_; + ++offset) { + int next_index = (doc_index + offset + num_docs) % num_docs; + total_memory -= documents_[next_index]->UnCache(); + } + } + int next_index = (doc_index + 1) % num_docs; + if (!documents_[next_index]->IsCached() && total_memory < max_memory_) { + documents_[next_index]->LoadPageInBackground(0); + } + return doc; +} + +// Helper counts the number of adjacent cached neighbours of index looking in +// direction dir, ie index+dir, index+2*dir etc. +int DocumentCache::CountNeighbourDocs(int index, int dir) { + int num_docs = documents_.size(); + for (int offset = dir; abs(offset) < num_docs; offset += dir) { + int offset_index = (index + offset + num_docs) % num_docs; + if (!documents_[offset_index]->IsCached()) return offset - dir; + } + return num_docs; } } // namespace tesseract. diff --git a/ccstruct/imagedata.h b/ccstruct/imagedata.h index 6321f121b1..7ffca76f83 100644 --- a/ccstruct/imagedata.h +++ b/ccstruct/imagedata.h @@ -25,6 +25,7 @@ #include "normalis.h" #include "rect.h" #include "strngs.h" +#include "svutil.h" struct Pix; @@ -34,8 +35,22 @@ namespace tesseract { const int kFeaturePadding = 2; // Number of pixels to pad around text boxes. const int kImagePadding = 4; -// Number of training images to combine into a mini-batch for training. -const int kNumPagesPerMiniBatch = 100; + +// Enum to determine the caching and data sequencing strategy. +enum CachingStrategy { + // Reads all of one file before moving on to the next. Requires samples to be + // shuffled across files. Uses the count of samples in the first file as + // the count in all the files to achieve high-speed random access. As a + // consequence, if subsequent files are smaller, they get entries used more + // than once, and if subsequent files are larger, some entries are not used. + // Best for larger data sets that don't fit in memory. + CS_SEQUENTIAL, + // Reads one sample from each file in rotation. Does not require shuffled + // samples, but is extremely disk-intensive. Samples in smaller files also + // get used more often than samples in larger files. + // Best for smaller data sets that mostly fit in memory. + CS_ROUND_ROBIN, +}; class WordFeature { public: @@ -103,6 +118,8 @@ class ImageData { // Reads from the given file. Returns false in case of error. // If swap is true, assumes a big/little-endian swap is needed. bool DeSerialize(bool swap, TFile* fp); + // As DeSerialize, but only seeks past the data - hence a static method. + static bool SkipDeSerialize(bool swap, tesseract::TFile* fp); // Other accessors. const STRING& imagefilename() const { @@ -145,11 +162,11 @@ class ImageData { // Gets anything and everything with a non-NULL pointer, prescaled to a // given target_height (if 0, then the original image height), and aligned. // Also returns (if not NULL) the width and height of the scaled image. - // The return value is the scale factor that was applied to the image to - // achieve the target_height. - float PreScale(int target_height, Pix** pix, - int* scaled_width, int* scaled_height, - GenericVector* boxes) const; + // The return value is the scaled Pix, which must be pixDestroyed after use, + // and scale_factor (if not NULL) is set to the scale factor that was applied + // to the image to achieve the target_height. + Pix* PreScale(int target_height, float* scale_factor, int* scaled_width, + int* scaled_height, GenericVector* boxes) const; int MemoryUsed() const; @@ -184,6 +201,8 @@ class ImageData { // A collection of ImageData that knows roughly how much memory it is using. class DocumentData { + friend void* ReCachePagesFunc(void* data); + public: explicit DocumentData(const STRING& name); ~DocumentData(); @@ -192,6 +211,9 @@ class DocumentData { // is used to read the file. bool LoadDocument(const char* filename, const char* lang, int start_page, inT64 max_memory, FileReader reader); + // Sets up the document, without actually loading it. + void SetDocument(const char* filename, const char* lang, inT64 max_memory, + FileReader reader); // Writes all the pages to the given filename. Returns false on error. bool SaveDocument(const char* filename, FileWriter writer); bool SaveToBuffer(GenericVector* buffer); @@ -200,26 +222,62 @@ class DocumentData { void AddPageToDocument(ImageData* page); const STRING& document_name() const { + SVAutoLock lock(&general_mutex_); return document_name_; } int NumPages() const { + SVAutoLock lock(&general_mutex_); return total_pages_; } inT64 memory_used() const { + SVAutoLock lock(&general_mutex_); return memory_used_; } + // If the given index is not currently loaded, loads it using a separate + // thread. Note: there are 4 cases: + // Document uncached: IsCached() returns false, total_pages_ < 0. + // Required page is available: IsPageAvailable returns true. In this case, + // total_pages_ > 0 and + // pages_offset_ <= index%total_pages_ <= pages_offset_+pages_.size() + // Pages are loaded, but the required one is not. + // The requested page is being loaded by LoadPageInBackground. In this case, + // index == pages_offset_. Once the loading starts, the pages lock is held + // until it completes, at which point IsPageAvailable will unblock and return + // true. + void LoadPageInBackground(int index); // Returns a pointer to the page with the given index, modulo the total - // number of pages, recaching if needed. + // number of pages. Blocks until the background load is completed. const ImageData* GetPage(int index); + // Returns true if the requested page is available, and provides a pointer, + // which may be NULL if the document is empty. May block, even though it + // doesn't guarantee to return true. + bool IsPageAvailable(int index, ImageData** page); // Takes ownership of the given page index. The page is made NULL in *this. ImageData* TakePage(int index) { + SVAutoLock lock(&pages_mutex_); ImageData* page = pages_[index]; pages_[index] = NULL; return page; } + // Returns true if the document is currently loaded or in the process of + // loading. + bool IsCached() const { return NumPages() >= 0; } + // Removes all pages from memory and frees the memory, but does not forget + // the document metadata. Returns the memory saved. + inT64 UnCache(); private: - // Loads as many pages can fit in max_memory_ starting at index pages_offset_. + // Sets the value of total_pages_ behind a mutex. + void set_total_pages(int total) { + SVAutoLock lock(&general_mutex_); + total_pages_ = total; + } + void set_memory_used(inT64 memory_used) { + SVAutoLock lock(&general_mutex_); + memory_used_ = memory_used; + } + // Locks the pages_mutex_ and Loads as many pages can fit in max_memory_ + // starting at index pages_offset_. bool ReCachePages(); private: @@ -239,43 +297,77 @@ class DocumentData { inT64 max_memory_; // Saved reader from LoadDocument to allow re-caching. FileReader reader_; + // Mutex that protects pages_ and pages_offset_ against multiple parallel + // loads, and provides a wait for page. + SVMutex pages_mutex_; + // Mutex that protects other data members that callers want to access without + // waiting for a load operation. + mutable SVMutex general_mutex_; }; // A collection of DocumentData that knows roughly how much memory it is using. +// Note that while it supports background read-ahead, it assumes that a single +// thread is accessing documents, ie it is not safe for multiple threads to +// access different documents in parallel, as one may de-cache the other's +// content. class DocumentCache { public: explicit DocumentCache(inT64 max_memory); ~DocumentCache(); + // Deletes all existing documents from the cache. + void Clear() { + documents_.clear(); + num_pages_per_doc_ = 0; + } // Adds all the documents in the list of filenames, counting memory. // The reader is used to read the files. bool LoadDocuments(const GenericVector& filenames, const char* lang, - FileReader reader); + CachingStrategy cache_strategy, FileReader reader); - // Adds document to the cache, throwing out other documents if needed. + // Adds document to the cache. bool AddToCache(DocumentData* data); // Finds and returns a document by name. DocumentData* FindDocument(const STRING& document_name) const; - // Returns a page by serial number, selecting them in a round-robin fashion - // from all the documents. - const ImageData* GetPageBySerial(int serial); + // Returns a page by serial number using the current cache_strategy_ to + // determine the mapping from serial number to page. + const ImageData* GetPageBySerial(int serial) { + if (cache_strategy_ == CS_SEQUENTIAL) + return GetPageSequential(serial); + else + return GetPageRoundRobin(serial); + } const PointerVector& documents() const { return documents_; } - int total_pages() const { - return total_pages_; - } + // Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache + // strategy, could take a long time. + int TotalPages(); private: + // Returns a page by serial number, selecting them in a round-robin fashion + // from all the documents. Highly disk-intensive, but doesn't need samples + // to be shuffled between files to begin with. + const ImageData* GetPageRoundRobin(int serial); + // Returns a page by serial number, selecting them in sequence from each file. + // Requires the samples to be shuffled between the files to give a random or + // uniform distribution of data. Less disk-intensive than GetPageRoundRobin. + const ImageData* GetPageSequential(int serial); + + // Helper counts the number of adjacent cached neighbour documents_ of index + // looking in direction dir, ie index+dir, index+2*dir etc. + int CountNeighbourDocs(int index, int dir); + // A group of pages that corresponds in some loose way to a document. PointerVector documents_; - // Total of all pages. - int total_pages_; - // Total of all memory used by the cache. - inT64 memory_used_; + // Strategy to use for caching and serializing data samples. + CachingStrategy cache_strategy_; + // Number of pages in the first document, used as a divisor in + // GetPageSequential to determine the document index. + int num_pages_per_doc_; // Max memory allowed in this cache. inT64 max_memory_; }; diff --git a/ccstruct/matrix.h b/ccstruct/matrix.h index e13ef31899..4b5b242a43 100644 --- a/ccstruct/matrix.h +++ b/ccstruct/matrix.h @@ -1,8 +1,12 @@ /* -*-C-*- ****************************************************************************** + * File: matrix.h (Formerly matrix.h) + * Description: Generic 2-d array/matrix and banded triangular matrix class. + * Author: Ray Smith + * TODO(rays) Separate from ratings matrix, which it also contains: * - * File: matrix.h (Formerly matrix.h) - * Description: Ratings matrix code. (Used by associator) + * Descrition: Ratings matrix class (specialization of banded matrix). + * Segmentation search matrix of lists of BLOB_CHOICE. * Author: Mark Seaman, OCR Technology * Created: Wed May 16 13:22:06 1990 * Modified: Tue Mar 19 16:00:20 1991 (Mark Seaman) marks@hpgrlt @@ -25,9 +29,13 @@ #ifndef TESSERACT_CCSTRUCT_MATRIX_H__ #define TESSERACT_CCSTRUCT_MATRIX_H__ +#include #include "kdpair.h" +#include "points.h" +#include "serialis.h" #include "unicharset.h" +class BLOB_CHOICE; class BLOB_CHOICE_LIST; #define NOT_CLASSIFIED reinterpret_cast(0) @@ -44,34 +52,60 @@ class GENERIC_2D_ARRAY { // either pass the memory in, or allocate after by calling Resize(). GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty, T* array) : empty_(empty), dim1_(dim1), dim2_(dim2), array_(array) { + size_allocated_ = dim1 * dim2; } // Original constructor for a full rectangular matrix DOES allocate memory // and initialize it to empty. GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty) : empty_(empty), dim1_(dim1), dim2_(dim2) { - array_ = new T[dim1_ * dim2_]; - for (int x = 0; x < dim1_; x++) - for (int y = 0; y < dim2_; y++) - this->put(x, y, empty_); + int new_size = dim1 * dim2; + array_ = new T[new_size]; + size_allocated_ = new_size; + for (int i = 0; i < size_allocated_; ++i) + array_[i] = empty_; + } + // Default constructor for array allocation. Use Resize to set the size. + GENERIC_2D_ARRAY() + : array_(NULL), empty_(static_cast(0)), dim1_(0), dim2_(0), + size_allocated_(0) { + } + GENERIC_2D_ARRAY(const GENERIC_2D_ARRAY& src) + : array_(NULL), empty_(static_cast(0)), dim1_(0), dim2_(0), + size_allocated_(0) { + *this = src; } virtual ~GENERIC_2D_ARRAY() { delete[] array_; } + void operator=(const GENERIC_2D_ARRAY& src) { + ResizeNoInit(src.dim1(), src.dim2()); + memcpy(array_, src.array_, num_elements() * sizeof(array_[0])); + } + + // Reallocate the array to the given size. Does not keep old data, but does + // not initialize the array either. + void ResizeNoInit(int size1, int size2) { + int new_size = size1 * size2; + if (new_size > size_allocated_) { + delete [] array_; + array_ = new T[new_size]; + size_allocated_ = new_size; + } + dim1_ = size1; + dim2_ = size2; + } + // Reallocate the array to the given size. Does not keep old data. void Resize(int size1, int size2, const T& empty) { empty_ = empty; - if (size1 != dim1_ || size2 != dim2_) { - dim1_ = size1; - dim2_ = size2; - delete [] array_; - array_ = new T[dim1_ * dim2_]; - } + ResizeNoInit(size1, size2); Clear(); } // Reallocate the array to the given size, keeping old data. void ResizeWithCopy(int size1, int size2) { if (size1 != dim1_ || size2 != dim2_) { - T* new_array = new T[size1 * size2]; + int new_size = size1 * size2; + T* new_array = new T[new_size]; for (int col = 0; col < size1; ++col) { for (int row = 0; row < size2; ++row) { int old_index = col * dim2() + row; @@ -87,6 +121,7 @@ class GENERIC_2D_ARRAY { array_ = new_array; dim1_ = size1; dim2_ = size2; + size_allocated_ = new_size; } } @@ -106,9 +141,16 @@ class GENERIC_2D_ARRAY { if (fwrite(array_, sizeof(*array_), size, fp) != size) return false; return true; } + bool Serialize(tesseract::TFile* fp) const { + if (!SerializeSize(fp)) return false; + if (fp->FWrite(&empty_, sizeof(empty_), 1) != 1) return false; + int size = num_elements(); + if (fp->FWrite(array_, sizeof(*array_), size) != size) return false; + return true; + } // Reads from the given file. Returns false in case of error. - // Only works with bitwise-serializeable typ + // Only works with bitwise-serializeable types! // If swap is true, assumes a big/little-endian swap is needed. bool DeSerialize(bool swap, FILE* fp) { if (!DeSerializeSize(swap, fp)) return false; @@ -122,6 +164,18 @@ class GENERIC_2D_ARRAY { } return true; } + bool DeSerialize(bool swap, tesseract::TFile* fp) { + if (!DeSerializeSize(swap, fp)) return false; + if (fp->FRead(&empty_, sizeof(empty_), 1) != 1) return false; + if (swap) ReverseN(&empty_, sizeof(empty_)); + int size = num_elements(); + if (fp->FRead(array_, sizeof(*array_), size) != size) return false; + if (swap) { + for (int i = 0; i < size; ++i) + ReverseN(&array_[i], sizeof(array_[i])); + } + return true; + } // Writes to the given file. Returns false in case of error. // Assumes a T::Serialize(FILE*) const function. @@ -163,11 +217,17 @@ class GENERIC_2D_ARRAY { } // Put a list element into the matrix at a specific location. + void put(ICOORD pos, const T& thing) { + array_[this->index(pos.x(), pos.y())] = thing; + } void put(int column, int row, const T& thing) { array_[this->index(column, row)] = thing; } // Get the item at a specified location from the matrix. + T get(ICOORD pos) const { + return array_[this->index(pos.x(), pos.y())]; + } T get(int column, int row) const { return array_[this->index(column, row)]; } @@ -187,6 +247,207 @@ class GENERIC_2D_ARRAY { return &array_[this->index(column, 0)]; } + // Adds addend to *this, element-by-element. + void operator+=(const GENERIC_2D_ARRAY& addend) { + if (dim2_ == addend.dim2_) { + // Faster if equal size in the major dimension. + int size = MIN(num_elements(), addend.num_elements()); + for (int i = 0; i < size; ++i) { + array_[i] += addend.array_[i]; + } + } else { + for (int x = 0; x < dim1_; x++) { + for (int y = 0; y < dim2_; y++) { + (*this)(x, y) += addend(x, y); + } + } + } + } + // Subtracts minuend from *this, element-by-element. + void operator-=(const GENERIC_2D_ARRAY& minuend) { + if (dim2_ == minuend.dim2_) { + // Faster if equal size in the major dimension. + int size = MIN(num_elements(), minuend.num_elements()); + for (int i = 0; i < size; ++i) { + array_[i] -= minuend.array_[i]; + } + } else { + for (int x = 0; x < dim1_; x++) { + for (int y = 0; y < dim2_; y++) { + (*this)(x, y) -= minuend(x, y); + } + } + } + } + // Adds addend to all elements. + void operator+=(const T& addend) { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + array_[i] += addend; + } + } + // Multiplies *this by factor, element-by-element. + void operator*=(const T& factor) { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + array_[i] *= factor; + } + } + // Clips *this to the given range. + void Clip(const T& rangemin, const T& rangemax) { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + array_[i] = ClipToRange(array_[i], rangemin, rangemax); + } + } + // Returns true if all elements of *this are within the given range. + // Only uses operator< + bool WithinBounds(const T& rangemin, const T& rangemax) const { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + const T& value = array_[i]; + if (value < rangemin || rangemax < value) + return false; + } + return true; + } + // Normalize the whole array. + double Normalize() { + int size = num_elements(); + if (size <= 0) return 0.0; + // Compute the mean. + double mean = 0.0; + for (int i = 0; i < size; ++i) { + mean += array_[i]; + } + mean /= size; + // Subtract the mean and compute the standard deviation. + double sd = 0.0; + for (int i = 0; i < size; ++i) { + double normed = array_[i] - mean; + array_[i] = normed; + sd += normed * normed; + } + sd = sqrt(sd / size); + if (sd > 0.0) { + // Divide by the sd. + for (int i = 0; i < size; ++i) { + array_[i] /= sd; + } + } + return sd; + } + + // Returns the maximum value of the array. + T Max() const { + int size = num_elements(); + if (size <= 0) return empty_; + // Compute the max. + T max_value = array_[0]; + for (int i = 1; i < size; ++i) { + const T& value = array_[i]; + if (value > max_value) max_value = value; + } + return max_value; + } + + // Returns the maximum absolute value of the array. + T MaxAbs() const { + int size = num_elements(); + if (size <= 0) return empty_; + // Compute the max. + T max_abs = static_cast(0); + for (int i = 0; i < size; ++i) { + T value = static_cast(fabs(array_[i])); + if (value > max_abs) max_abs = value; + } + return max_abs; + } + + // Accumulates the element-wise sums of squares of src into *this. + void SumSquares(const GENERIC_2D_ARRAY& src) { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + array_[i] += src.array_[i] * src.array_[i]; + } + } + + // Scales each element using the ada-grad algorithm, ie array_[i] by + // sqrt(num_samples/max(1,sqsum[i])). + void AdaGradScaling(const GENERIC_2D_ARRAY& sqsum, int num_samples) { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + array_[i] *= sqrt(num_samples / MAX(1.0, sqsum.array_[i])); + } + } + + void AssertFinite() const { + int size = num_elements(); + for (int i = 0; i < size; ++i) { + ASSERT_HOST(isfinite(array_[i])); + } + } + + // REGARDLESS OF THE CURRENT DIMENSIONS, treats the data as a + // num_dims-dimensional array/tensor with dimensions given by dims, (ordered + // from most significant to least significant, the same as standard C arrays) + // and moves src_dim to dest_dim, with the initial dest_dim and any dimensions + // in between shifted towards the hole left by src_dim. Example: + // Current data content: array_=[0, 1, 2, ....119] + // perhaps *this may be of dim[40, 3], with values [[0, 1, 2][3, 4, 5]... + // but the current dimensions are irrelevant. + // num_dims = 4, dims=[5, 4, 3, 2] + // src_dim=3, dest_dim=1 + // tensor=[[[[0, 1][2, 3][4, 5]] + // [[6, 7][8, 9][10, 11]] + // [[12, 13][14, 15][16, 17]] + // [[18, 19][20, 21][22, 23]]] + // [[[24, 25]... + // output dims =[5, 2, 4, 3] + // output tensor=[[[[0, 2, 4][6, 8, 10][12, 14, 16][18, 20, 22]] + // [[1, 3, 5][7, 9, 11][13, 15, 17][19, 21, 23]]] + // [[[24, 26, 28]... + // which is stored in the array_ as: + // [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 1, 3, 5, 7, 9, 11, 13...] + // NOTE: the 2 stored matrix dimensions are simply copied from *this. To + // change the dimensions after the transpose, use ResizeNoInit. + // Higher dimensions above 2 are strictly the responsibility of the caller. + void RotatingTranspose(const int* dims, int num_dims, int src_dim, + int dest_dim, GENERIC_2D_ARRAY* result) const { + int max_d = MAX(src_dim, dest_dim); + int min_d = MIN(src_dim, dest_dim); + // In a tensor of shape [d0, d1... min_d, ... max_d, ... dn-2, dn-1], the + // ends outside of min_d and max_d are unaffected, with [max_d +1, dn-1] + // being contiguous blocks of data that will move together, and + // [d0, min_d -1] being replicas of the transpose operation. + // num_replicas represents the large dimensions unchanged by the operation. + // move_size represents the small dimensions unchanged by the operation. + // src_step represents the stride in the src between each adjacent group + // in the destination. + int num_replicas = 1, move_size = 1, src_step = 1; + for (int d = 0; d < min_d; ++d) num_replicas *= dims[d]; + for (int d = max_d + 1; d < num_dims; ++d) move_size *= dims[d]; + for (int d = src_dim + 1; d < num_dims; ++d) src_step *= dims[d]; + if (src_dim > dest_dim) src_step *= dims[src_dim]; + // wrap_size is the size of a single replica, being the amount that is + // handled num_replicas times. + int wrap_size = move_size; + for (int d = min_d; d <= max_d; ++d) wrap_size *= dims[d]; + result->ResizeNoInit(dim1_, dim2_); + result->empty_ = empty_; + const T* src = array_; + T* dest = result->array_; + for (int replica = 0; replica < num_replicas; ++replica) { + for (int start = 0; start < src_step; start += move_size) { + for (int pos = start; pos < wrap_size; pos += src_step) { + memcpy(dest, src + pos, sizeof(*dest) * move_size); + dest += move_size; + } + } + src += wrap_size; + } + } + // Delete objects pointed to by array_[i]. void delete_matrix_pointers() { int size = num_elements(); @@ -206,6 +467,13 @@ class GENERIC_2D_ARRAY { if (fwrite(&size, sizeof(size), 1, fp) != 1) return false; return true; } + bool SerializeSize(tesseract::TFile* fp) const { + inT32 size = dim1_; + if (fp->FWrite(&size, sizeof(size), 1) != 1) return false; + size = dim2_; + if (fp->FWrite(&size, sizeof(size), 1) != 1) return false; + return true; + } // Factored helper to deserialize the size. // If swap is true, assumes a big/little-endian swap is needed. bool DeSerializeSize(bool swap, FILE* fp) { @@ -219,11 +487,26 @@ class GENERIC_2D_ARRAY { Resize(size1, size2, empty_); return true; } + bool DeSerializeSize(bool swap, tesseract::TFile* fp) { + inT32 size1, size2; + if (fp->FRead(&size1, sizeof(size1), 1) != 1) return false; + if (fp->FRead(&size2, sizeof(size2), 1) != 1) return false; + if (swap) { + ReverseN(&size1, sizeof(size1)); + ReverseN(&size2, sizeof(size2)); + } + Resize(size1, size2, empty_); + return true; + } T* array_; T empty_; // The unused cell. int dim1_; // Size of the 1st dimension in indexing functions. int dim2_; // Size of the 2nd dimension in indexing functions. + // The total size to which the array can be expanded before a realloc is + // needed. If Resize is used, memory is retained so it can be re-expanded + // without a further alloc, and this stores the allocated size. + int size_allocated_; }; // A generic class to store a banded triangular matrix with entries of type T. diff --git a/ccstruct/pageres.cpp b/ccstruct/pageres.cpp index b66e5636ff..330dd22915 100644 --- a/ccstruct/pageres.cpp +++ b/ccstruct/pageres.cpp @@ -304,6 +304,7 @@ bool WERD_RES::SetupForRecognition(const UNICHARSET& unicharset_in, tesseract = tess; POLY_BLOCK* pb = block != NULL ? block->poly_block() : NULL; if ((norm_mode_hint != tesseract::OEM_CUBE_ONLY && + norm_mode_hint != tesseract::OEM_LSTM_ONLY && word->cblob_list()->empty()) || (pb != NULL && !pb->IsText())) { // Empty words occur when all the blobs have been moved to the rej_blobs // list, which seems to occur frequently in junk. @@ -882,17 +883,17 @@ void WERD_RES::FakeClassifyWord(int blob_count, BLOB_CHOICE** choices) { choice_it.add_after_then_move(choices[c]); ratings->put(c, c, choice_list); } - FakeWordFromRatings(); + FakeWordFromRatings(TOP_CHOICE_PERM); reject_map.initialise(blob_count); done = true; } // Creates a WERD_CHOICE for the word using the top choices from the leading // diagonal of the ratings matrix. -void WERD_RES::FakeWordFromRatings() { +void WERD_RES::FakeWordFromRatings(PermuterType permuter) { int num_blobs = ratings->dimension(); WERD_CHOICE* word_choice = new WERD_CHOICE(uch_set, num_blobs); - word_choice->set_permuter(TOP_CHOICE_PERM); + word_choice->set_permuter(permuter); for (int b = 0; b < num_blobs; ++b) { UNICHAR_ID unichar_id = UNICHAR_SPACE; float rating = MAX_INT32; @@ -1105,6 +1106,7 @@ void WERD_RES::InitNonPointers() { x_height = 0.0; caps_height = 0.0; baseline_shift = 0.0f; + space_certainty = 0.0f; guessed_x_ht = TRUE; guessed_caps_ht = TRUE; combination = FALSE; diff --git a/ccstruct/pageres.h b/ccstruct/pageres.h index 22c5ccb774..33b9f4cb35 100644 --- a/ccstruct/pageres.h +++ b/ccstruct/pageres.h @@ -295,6 +295,9 @@ class WERD_RES : public ELIST_LINK { float x_height; // post match estimate float caps_height; // post match estimate float baseline_shift; // post match estimate. + // Certainty score for the spaces either side of this word (LSTM mode). + // MIN this value with the actual word certainty. + float space_certainty; /* To deal with fuzzy spaces we need to be able to combine "words" to form @@ -590,7 +593,7 @@ class WERD_RES : public ELIST_LINK { // Creates a WERD_CHOICE for the word using the top choices from the leading // diagonal of the ratings matrix. - void FakeWordFromRatings(); + void FakeWordFromRatings(PermuterType permuter); // Copies the best_choice strings to the correct_text for adaption/training. void BestChoiceToCorrectText(); diff --git a/ccstruct/publictypes.h b/ccstruct/publictypes.h index 6cb9f3ba13..97c9ee8b9d 100644 --- a/ccstruct/publictypes.h +++ b/ccstruct/publictypes.h @@ -257,13 +257,21 @@ enum OcrEngineMode { OEM_TESSERACT_ONLY, // Run Tesseract only - fastest OEM_CUBE_ONLY, // Run Cube only - better accuracy, but slower OEM_TESSERACT_CUBE_COMBINED, // Run both and combine results - best accuracy - OEM_DEFAULT // Specify this mode when calling init_*(), + OEM_DEFAULT, // Specify this mode when calling init_*(), // to indicate that any of the above modes // should be automatically inferred from the // variables in the language-specific config, // command-line configs, or if not specified // in any of the above should be set to the // default OEM_TESSERACT_ONLY. + // OEM_LSTM_ONLY will fall back (with a warning) to OEM_TESSERACT_ONLY where + // there is no network model available. This allows use of a mix of languages, + // some of which contain a network model, and some of which do not. Since the + // tesseract model is required for the LSTM to fall back to for "difficult" + // words anyway, this seems like a reasonable approach, but leaves the danger + // of not noticing that it is using the wrong engine if the warning is + // ignored. + OEM_LSTM_ONLY, // Run just the LSTM line recognizer. }; } // namespace tesseract. diff --git a/ccutil/Makefile.am b/ccutil/Makefile.am index 76012006c4..53980e9313 100644 --- a/ccutil/Makefile.am +++ b/ccutil/Makefile.am @@ -14,7 +14,7 @@ endif include_HEADERS = \ basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \ ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \ - tesscallback.h unichar.h unicharmap.h unicharset.h + tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h noinst_HEADERS = \ ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \ @@ -38,7 +38,7 @@ libtesseract_ccutil_la_SOURCES = \ mainblk.cpp memry.cpp \ serialis.cpp strngs.cpp scanutils.cpp \ tessdatamanager.cpp tprintf.cpp \ - unichar.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \ + unichar.cpp unicharcompress.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \ params.cpp universalambigs.cpp if T_WIN diff --git a/ccutil/genericheap.h b/ccutil/genericheap.h index bb5f8ddc79..ccf273b33a 100644 --- a/ccutil/genericheap.h +++ b/ccutil/genericheap.h @@ -108,6 +108,8 @@ class GenericHeap { const Pair& PeekTop() const { return heap_[0]; } + // Get the value of the worst (largest, defined by operator< ) element. + const Pair& PeekWorst() const { return heap_[IndexOfWorst()]; } // Removes the top element of the heap. If entry is not NULL, the element // is copied into *entry, otherwise it is discarded. @@ -136,22 +138,12 @@ class GenericHeap { // not NULL, the element is copied into *entry, otherwise it is discarded. // Time = O(n). Returns false if the heap was already empty. bool PopWorst(Pair* entry) { - int heap_size = heap_.size(); - if (heap_size == 0) return false; // It cannot be empty! - - // Find the maximum element. Its index is guaranteed to be greater than - // the index of the parent of the last element, since by the heap invariant - // the parent must be less than or equal to the children. - int worst_index = heap_size - 1; - int end_parent = ParentNode(worst_index); - for (int i = worst_index - 1; i > end_parent; --i) { - if (heap_[worst_index] < heap_[i]) - worst_index = i; - } + int worst_index = IndexOfWorst(); + if (worst_index < 0) return false; // It cannot be empty! // Extract the worst element from the heap, leaving a hole at worst_index. if (entry != NULL) *entry = heap_[worst_index]; - --heap_size; + int heap_size = heap_.size() - 1; if (heap_size > 0) { // Sift the hole upwards to match the last element of the heap_ Pair hole_pair = heap_[heap_size]; @@ -162,6 +154,22 @@ class GenericHeap { return true; } + // Returns the index of the worst element. Time = O(n/2). + int IndexOfWorst() const { + int heap_size = heap_.size(); + if (heap_size == 0) return -1; // It cannot be empty! + + // Find the maximum element. Its index is guaranteed to be greater than + // the index of the parent of the last element, since by the heap invariant + // the parent must be less than or equal to the children. + int worst_index = heap_size - 1; + int end_parent = ParentNode(worst_index); + for (int i = worst_index - 1; i > end_parent; --i) { + if (heap_[worst_index] < heap_[i]) worst_index = i; + } + return worst_index; + } + // The pointed-to Pair has changed its key value, so the location of pair // is reshuffled to maintain the heap invariant. // Must be a valid pointer to an element of the heap_! diff --git a/ccutil/genericvector.h b/ccutil/genericvector.h index d867d8929b..3a70e21ce0 100644 --- a/ccutil/genericvector.h +++ b/ccutil/genericvector.h @@ -174,6 +174,8 @@ class GenericVector { // If swap is true, assumes a big/little-endian swap is needed. bool DeSerialize(bool swap, FILE* fp); bool DeSerialize(bool swap, tesseract::TFile* fp); + // Skips the deserialization of the vector. + static bool SkipDeSerialize(bool swap, tesseract::TFile* fp); // Writes a vector of classes to the given file. Assumes the existence of // bool T::Serialize(FILE* fp) const that returns false in case of error. // Returns false in case of error. @@ -186,6 +188,8 @@ class GenericVector { // If swap is true, assumes a big/little-endian swap is needed. bool DeSerializeClasses(bool swap, FILE* fp); bool DeSerializeClasses(bool swap, tesseract::TFile* fp); + // Calls SkipDeSerialize on the elements of the vector. + static bool SkipDeSerializeClasses(bool swap, tesseract::TFile* fp); // Allocates a new array of double the current_size, copies over the // information from data to the new location, deletes data and returns @@ -238,14 +242,13 @@ class GenericVector { int binary_search(const T& target) const { int bottom = 0; int top = size_used_; - do { + while (top - bottom > 1) { int middle = (bottom + top) / 2; if (data_[middle] > target) top = middle; else bottom = middle; } - while (top - bottom > 1); return bottom; } @@ -361,7 +364,7 @@ inline bool LoadDataFromFile(const STRING& filename, size_t size = ftell(fp); fseek(fp, 0, SEEK_SET); // Pad with a 0, just in case we treat the result as a string. - data->init_to_size((int)size + 1, 0); + data->init_to_size(static_cast(size) + 1, 0); bool result = fread(&(*data)[0], 1, size, fp) == size; fclose(fp); return result; @@ -556,34 +559,54 @@ class PointerVector : public GenericVector { } bool DeSerialize(bool swap, TFile* fp) { inT32 reserved; - if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false; - if (swap) Reverse32(&reserved); + if (!DeSerializeSize(swap, fp, &reserved)) return false; GenericVector::reserve(reserved); truncate(0); for (int i = 0; i < reserved; ++i) { - inT8 non_null; - if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false; - T* item = NULL; - if (non_null) { - item = new T; - if (!item->DeSerialize(swap, fp)) { - delete item; - return false; - } - this->push_back(item); - } else { - // Null elements should keep their place in the vector. - this->push_back(NULL); + if (!DeSerializeElement(swap, fp)) return false; + } + return true; + } + // Enables deserialization of a selection of elements. Note that in order to + // retain the integrity of the stream, the caller must call some combination + // of DeSerializeElement and DeSerializeSkip of the exact number returned in + // *size, assuming a true return. + static bool DeSerializeSize(bool swap, TFile* fp, inT32* size) { + if (fp->FRead(size, sizeof(*size), 1) != 1) return false; + if (swap) Reverse32(size); + return true; + } + // Reads and appends to the vector the next element of the serialization. + bool DeSerializeElement(bool swap, TFile* fp) { + inT8 non_null; + if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false; + T* item = NULL; + if (non_null) { + item = new T; + if (!item->DeSerialize(swap, fp)) { + delete item; + return false; } + this->push_back(item); + } else { + // Null elements should keep their place in the vector. + this->push_back(NULL); + } + return true; + } + // Skips the next element of the serialization. + static bool DeSerializeSkip(bool swap, TFile* fp) { + inT8 non_null; + if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false; + if (non_null) { + if (!T::SkipDeSerialize(swap, fp)) return false; } return true; } // Sorts the items pointed to by the members of this vector using // t::operator<(). - void sort() { - sort(&sort_ptr_cmp); - } + void sort() { this->GenericVector::sort(&sort_ptr_cmp); } }; } // namespace tesseract @@ -926,6 +949,13 @@ bool GenericVector::DeSerialize(bool swap, tesseract::TFile* fp) { } return true; } +template +bool GenericVector::SkipDeSerialize(bool swap, tesseract::TFile* fp) { + inT32 reserved; + if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false; + if (swap) Reverse32(&reserved); + return fp->FRead(NULL, sizeof(T), reserved) == reserved; +} // Writes a vector of classes to the given file. Assumes the existence of // bool T::Serialize(FILE* fp) const that returns false in case of error. @@ -976,6 +1006,16 @@ bool GenericVector::DeSerializeClasses(bool swap, tesseract::TFile* fp) { } return true; } +template +bool GenericVector::SkipDeSerializeClasses(bool swap, tesseract::TFile* fp) { + uinT32 reserved; + if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false; + if (swap) Reverse32(&reserved); + for (int i = 0; i < reserved; ++i) { + if (!T::SkipDeSerialize(swap, fp)) return false; + } + return true; +} // This method clear the current object, then, does a shallow copy of // its argument, and finally invalidates its argument. diff --git a/ccutil/serialis.cpp b/ccutil/serialis.cpp index ff3b278a7e..d1eed58465 100644 --- a/ccutil/serialis.cpp +++ b/ccutil/serialis.cpp @@ -95,7 +95,7 @@ int TFile::FRead(void* buffer, int size, int count) { char* char_buffer = reinterpret_cast(buffer); if (data_->size() - offset_ < required_size) required_size = data_->size() - offset_; - if (required_size > 0) + if (required_size > 0 && char_buffer != NULL) memcpy(char_buffer, &(*data_)[offset_], required_size); offset_ += required_size; return required_size / size; diff --git a/ccutil/strngs.cpp b/ccutil/strngs.cpp index 0760852e90..5a9cfd0d48 100644 --- a/ccutil/strngs.cpp +++ b/ccutil/strngs.cpp @@ -181,6 +181,14 @@ bool STRING::DeSerialize(bool swap, TFile* fp) { return true; } +// As DeSerialize, but only seeks past the data - hence a static method. +bool STRING::SkipDeSerialize(bool swap, tesseract::TFile* fp) { + inT32 len; + if (fp->FRead(&len, sizeof(len), 1) != 1) return false; + if (swap) ReverseN(&len, sizeof(len)); + return fp->FRead(NULL, 1, len) == len; +} + BOOL8 STRING::contains(const char c) const { return (c != '\0') && (strchr (GetCStr(), c) != NULL); } diff --git a/ccutil/strngs.h b/ccutil/strngs.h index 9308cc67c8..1fe42b6076 100644 --- a/ccutil/strngs.h +++ b/ccutil/strngs.h @@ -60,6 +60,8 @@ class TESS_API STRING // Reads from the given file. Returns false in case of error. // If swap is true, assumes a big/little-endian swap is needed. bool DeSerialize(bool swap, tesseract::TFile* fp); + // As DeSerialize, but only seeks past the data - hence a static method. + static bool SkipDeSerialize(bool swap, tesseract::TFile* fp); BOOL8 contains(const char c) const; inT32 length() const; diff --git a/ccutil/tessdatamanager.h b/ccutil/tessdatamanager.h index fd2685a1d8..9ff9440de2 100644 --- a/ccutil/tessdatamanager.h +++ b/ccutil/tessdatamanager.h @@ -47,6 +47,10 @@ static const char kShapeTableFileSuffix[] = "shapetable"; static const char kBigramDawgFileSuffix[] = "bigram-dawg"; static const char kUnambigDawgFileSuffix[] = "unambig-dawg"; static const char kParamsModelFileSuffix[] = "params-model"; +static const char kLSTMModelFileSuffix[] = "lstm"; +static const char kLSTMPuncDawgFileSuffix[] = "lstm-punc-dawg"; +static const char kLSTMSystemDawgFileSuffix[] = "lstm-word-dawg"; +static const char kLSTMNumberDawgFileSuffix[] = "lstm-number-dawg"; namespace tesseract { @@ -68,6 +72,10 @@ enum TessdataType { TESSDATA_BIGRAM_DAWG, // 14 TESSDATA_UNAMBIG_DAWG, // 15 TESSDATA_PARAMS_MODEL, // 16 + TESSDATA_LSTM, // 17 + TESSDATA_LSTM_PUNC_DAWG, // 18 + TESSDATA_LSTM_SYSTEM_DAWG, // 19 + TESSDATA_LSTM_NUMBER_DAWG, // 20 TESSDATA_NUM_ENTRIES }; @@ -76,24 +84,28 @@ enum TessdataType { * kTessdataFileSuffixes[i] indicates the file suffix for * tessdata of type i (from TessdataType enum). */ -static const char * const kTessdataFileSuffixes[] = { - kLangConfigFileSuffix, // 0 - kUnicharsetFileSuffix, // 1 - kAmbigsFileSuffix, // 2 - kBuiltInTemplatesFileSuffix, // 3 - kBuiltInCutoffsFileSuffix, // 4 - kNormProtoFileSuffix, // 5 - kPuncDawgFileSuffix, // 6 - kSystemDawgFileSuffix, // 7 - kNumberDawgFileSuffix, // 8 - kFreqDawgFileSuffix, // 9 - kFixedLengthDawgsFileSuffix, // 10 // deprecated - kCubeUnicharsetFileSuffix, // 11 - kCubeSystemDawgFileSuffix, // 12 - kShapeTableFileSuffix, // 13 - kBigramDawgFileSuffix, // 14 - kUnambigDawgFileSuffix, // 15 - kParamsModelFileSuffix, // 16 +static const char *const kTessdataFileSuffixes[] = { + kLangConfigFileSuffix, // 0 + kUnicharsetFileSuffix, // 1 + kAmbigsFileSuffix, // 2 + kBuiltInTemplatesFileSuffix, // 3 + kBuiltInCutoffsFileSuffix, // 4 + kNormProtoFileSuffix, // 5 + kPuncDawgFileSuffix, // 6 + kSystemDawgFileSuffix, // 7 + kNumberDawgFileSuffix, // 8 + kFreqDawgFileSuffix, // 9 + kFixedLengthDawgsFileSuffix, // 10 // deprecated + kCubeUnicharsetFileSuffix, // 11 + kCubeSystemDawgFileSuffix, // 12 + kShapeTableFileSuffix, // 13 + kBigramDawgFileSuffix, // 14 + kUnambigDawgFileSuffix, // 15 + kParamsModelFileSuffix, // 16 + kLSTMModelFileSuffix, // 17 + kLSTMPuncDawgFileSuffix, // 18 + kLSTMSystemDawgFileSuffix, // 19 + kLSTMNumberDawgFileSuffix, // 20 }; /** @@ -101,23 +113,27 @@ static const char * const kTessdataFileSuffixes[] = { * of type i (from TessdataType enum) is text, and is binary otherwise. */ static const bool kTessdataFileIsText[] = { - true, // 0 - true, // 1 - true, // 2 - false, // 3 - true, // 4 - true, // 5 - false, // 6 - false, // 7 - false, // 8 - false, // 9 - false, // 10 // deprecated - true, // 11 - false, // 12 - false, // 13 - false, // 14 - false, // 15 - true, // 16 + true, // 0 + true, // 1 + true, // 2 + false, // 3 + true, // 4 + true, // 5 + false, // 6 + false, // 7 + false, // 8 + false, // 9 + false, // 10 // deprecated + true, // 11 + false, // 12 + false, // 13 + false, // 14 + false, // 15 + true, // 16 + false, // 17 + false, // 18 + false, // 19 + false, // 20 }; /** diff --git a/ccutil/unicharcompress.cpp b/ccutil/unicharcompress.cpp new file mode 100644 index 0000000000..a9437ed4cf --- /dev/null +++ b/ccutil/unicharcompress.cpp @@ -0,0 +1,439 @@ +/////////////////////////////////////////////////////////////////////// +// File: unicharcompress.cpp +// Description: Unicode re-encoding using a sequence of smaller numbers in +// place of a single large code for CJK, similarly for Indic, +// and dissection of ligatures for other scripts. +// Author: Ray Smith +// Created: Wed Mar 04 14:45:01 PST 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "unicharcompress.h" +#include "tprintf.h" + +namespace tesseract { + +// String used to represent the null_id in direct_set. +const char* kNullChar = ""; + +// Local struct used only for processing the radical-stroke table. +struct RadicalStroke { + RadicalStroke() : num_strokes(0) {} + RadicalStroke(const STRING& r, int s) : radical(r), num_strokes(s) {} + + bool operator==(const RadicalStroke& other) const { + return radical == other.radical && num_strokes == other.num_strokes; + } + + // The radical is encoded as a string because its format is of an int with + // an optional ' mark to indicate a simplified shape. To treat these as + // distinct, we use a string and a UNICHARSET to do the integer mapping. + STRING radical; + // The number of strokes we treat as dense and just take the face value from + // the table. + int num_strokes; +}; + +// Hash functor for RadicalStroke. +struct RadicalStrokedHash { + size_t operator()(const RadicalStroke& rs) const { + size_t result = rs.num_strokes; + for (int i = 0; i < rs.radical.length(); ++i) { + result ^= rs.radical[i] << (6 * i + 8); + } + return result; + } +}; + +// A hash map to convert unicodes to radical,stroke pair. +typedef TessHashMap RSMap; +// A hash map to count occurrences of each radical,stroke pair. +typedef TessHashMap RSCounts; + +// Helper function builds the RSMap from the radical-stroke file, which has +// already been read into a STRING. Returns false on error. +// The radical_stroke_table is non-const because it gets split and the caller +// is unlikely to want to use it again. +static bool DecodeRadicalStrokeTable(STRING* radical_stroke_table, + RSMap* radical_map) { + GenericVector lines; + radical_stroke_table->split('\n', &lines); + for (int i = 0; i < lines.size(); ++i) { + if (lines[i].length() == 0 || lines[i][0] == '#') continue; + int unicode, radical, strokes; + STRING str_radical; + if (sscanf(lines[i].string(), "%x\t%d.%d", &unicode, &radical, &strokes) == + 3) { + str_radical.add_str_int("", radical); + } else if (sscanf(lines[i].string(), "%x\t%d'.%d", &unicode, &radical, + &strokes) == 3) { + str_radical.add_str_int("'", radical); + } else { + tprintf("Invalid format in radical stroke table at line %d: %s\n", i, + lines[i].string()); + return false; + } + (*radical_map)[unicode] = RadicalStroke(str_radical, strokes); + } + return true; +} + +UnicharCompress::UnicharCompress() : code_range_(0) {} +UnicharCompress::UnicharCompress(const UnicharCompress& src) { *this = src; } +UnicharCompress::~UnicharCompress() { Cleanup(); } +UnicharCompress& UnicharCompress::operator=(const UnicharCompress& src) { + Cleanup(); + encoder_ = src.encoder_; + code_range_ = src.code_range_; + SetupDecoder(); + return *this; +} + +// Computes the encoding for the given unicharset. It is a requirement that +// the file training/langdata/radical-stroke.txt have been read into the +// input string radical_stroke_table. +// Returns false if the encoding cannot be constructed. +bool UnicharCompress::ComputeEncoding(const UNICHARSET& unicharset, int null_id, + STRING* radical_stroke_table) { + RSMap radical_map; + if (!DecodeRadicalStrokeTable(radical_stroke_table, &radical_map)) + return false; + encoder_.clear(); + UNICHARSET direct_set; + UNICHARSET radicals; + // To avoid unused codes, clear the special codes from the unicharsets. + direct_set.clear(); + radicals.clear(); + // Always keep space as 0; + direct_set.unichar_insert(" "); + // Null char is next if we have one. + if (null_id >= 0) { + direct_set.unichar_insert(kNullChar); + } + RSCounts radical_counts; + // In the initial map, codes [0, unicharset.size()) are + // reserved for non-han/hangul sequences of 1 or more unicodes. + int hangul_offset = unicharset.size(); + // Hangul takes the next range [hangul_offset, hangul_offset + kTotalJamos). + const int kTotalJamos = kLCount + kVCount + kTCount; + // Han takes the codes beyond hangul_offset + kTotalJamos. Since it is hard + // to measure the number of radicals and strokes, initially we use the same + // code range for all 3 Han code positions, and fix them after. + int han_offset = hangul_offset + kTotalJamos; + int max_num_strokes = -1; + for (int u = 0; u <= unicharset.size(); ++u) { + bool self_normalized = false; + // We special-case allow null_id to be equal to unicharset.size() in case + // there is no space in unicharset for it. + if (u == unicharset.size()) { + if (u == null_id) { + self_normalized = true; + } else { + break; // Finished. + } + } else { + self_normalized = strcmp(unicharset.id_to_unichar(u), + unicharset.get_normed_unichar(u)) == 0; + } + RecodedCharID code; + // Convert to unicodes. + GenericVector unicodes; + if (u < unicharset.size() && + UNICHAR::UTF8ToUnicode(unicharset.get_normed_unichar(u), &unicodes) && + unicodes.size() == 1) { + // Check single unicodes for Hangul/Han and encode if so. + int unicode = unicodes[0]; + int leading, vowel, trailing; + auto it = radical_map.find(unicode); + if (it != radical_map.end()) { + // This is Han. Convert to radical, stroke, index. + if (!radicals.contains_unichar(it->second.radical.string())) { + radicals.unichar_insert(it->second.radical.string()); + } + int radical = radicals.unichar_to_id(it->second.radical.string()); + int num_strokes = it->second.num_strokes; + int num_samples = radical_counts[it->second]++; + if (num_strokes > max_num_strokes) max_num_strokes = num_strokes; + code.Set3(radical + han_offset, num_strokes + han_offset, + num_samples + han_offset); + } else if (DecomposeHangul(unicode, &leading, &vowel, &trailing)) { + // This is Hangul. Since we know the exact size of each part at compile + // time, it gets the bottom set of codes. + code.Set3(leading + hangul_offset, vowel + kLCount + hangul_offset, + trailing + kLCount + kVCount + hangul_offset); + } + } + // If the code is still empty, it wasn't Han or Hangul. + if (code.length() == 0) { + // Special cases. + if (u == UNICHAR_SPACE) { + code.Set(0, 0); // Space. + } else if (u == null_id || (unicharset.has_special_codes() && + u < SPECIAL_UNICHAR_CODES_COUNT)) { + code.Set(0, direct_set.unichar_to_id(kNullChar)); + } else { + // Add the direct_set unichar-ids of the unicodes in sequence to the + // code. + for (int i = 0; i < unicodes.size(); ++i) { + int position = code.length(); + if (position >= RecodedCharID::kMaxCodeLen) { + tprintf("Unichar %d=%s->%s is too long to encode!!\n", u, + unicharset.id_to_unichar(u), + unicharset.get_normed_unichar(u)); + return false; + } + int uni = unicodes[i]; + UNICHAR unichar(uni); + char* utf8 = unichar.utf8_str(); + if (!direct_set.contains_unichar(utf8)) + direct_set.unichar_insert(utf8); + code.Set(position, direct_set.unichar_to_id(utf8)); + delete[] utf8; + if (direct_set.size() > unicharset.size()) { + // Code space got bigger! + tprintf("Code space expanded from original unicharset!!\n"); + return false; + } + } + } + } + code.set_self_normalized(self_normalized); + encoder_.push_back(code); + } + // Now renumber Han to make all codes unique. We already added han_offset to + // all Han. Now separate out the radical, stroke, and count codes for Han. + // In the uniqued Han encoding, the 1st code uses the next radical_map.size() + // values, the 2nd code uses the next max_num_strokes+1 values, and the 3rd + // code uses the rest for the max number of duplicated radical/stroke combos. + int num_radicals = radicals.size(); + for (int u = 0; u < unicharset.size(); ++u) { + RecodedCharID* code = &encoder_[u]; + if ((*code)(0) >= han_offset) { + code->Set(1, (*code)(1) + num_radicals); + code->Set(2, (*code)(2) + num_radicals + max_num_strokes + 1); + } + } + DefragmentCodeValues(null_id >= 0 ? 1 : -1); + SetupDecoder(); + return true; +} + +// Sets up an encoder that doesn't change the unichars at all, so it just +// passes them through unchanged. +void UnicharCompress::SetupPassThrough(const UNICHARSET& unicharset) { + GenericVector codes; + for (int u = 0; u < unicharset.size(); ++u) { + RecodedCharID code; + code.Set(0, u); + codes.push_back(code); + } + SetupDirect(codes); +} + +// Sets up an encoder directly using the given encoding vector, which maps +// unichar_ids to the given codes. +void UnicharCompress::SetupDirect(const GenericVector& codes) { + encoder_ = codes; + ComputeCodeRange(); + SetupDecoder(); +} + +// Renumbers codes to eliminate unused values. +void UnicharCompress::DefragmentCodeValues(int encoded_null) { + // There may not be any Hangul, but even if there is, it is possible that not + // all codes are used. Likewise with the Han encoding, it is possible that not + // all numbers of strokes are used. + ComputeCodeRange(); + GenericVector offsets; + offsets.init_to_size(code_range_, 0); + // Find which codes are used + for (int c = 0; c < encoder_.size(); ++c) { + const RecodedCharID& code = encoder_[c]; + for (int i = 0; i < code.length(); ++i) { + offsets[code(i)] = 1; + } + } + // Compute offsets based on code use. + int offset = 0; + for (int i = 0; i < offsets.size(); ++i) { + // If not used, decrement everything above here. + // We are moving encoded_null to the end, so it is not "used". + if (offsets[i] == 0 || i == encoded_null) { + --offset; + } else { + offsets[i] = offset; + } + } + if (encoded_null >= 0) { + // The encoded_null is moving to the end, for the benefit of TensorFlow, + // which is offsets.size() + offsets.back(). + offsets[encoded_null] = offsets.size() + offsets.back() - encoded_null; + } + // Now apply the offsets. + for (int c = 0; c < encoder_.size(); ++c) { + RecodedCharID* code = &encoder_[c]; + for (int i = 0; i < code->length(); ++i) { + int value = (*code)(i); + code->Set(i, value + offsets[value]); + } + } + ComputeCodeRange(); +} + +// Encodes a single unichar_id. Returns the length of the code, or zero if +// invalid input, and the encoding itself +int UnicharCompress::EncodeUnichar(int unichar_id, RecodedCharID* code) const { + if (unichar_id < 0 || unichar_id >= encoder_.size()) return 0; + *code = encoder_[unichar_id]; + return code->length(); +} + +// Decodes code, returning the original unichar-id, or +// INVALID_UNICHAR_ID if the input is invalid. +int UnicharCompress::DecodeUnichar(const RecodedCharID& code) const { + int len = code.length(); + if (len <= 0 || len > RecodedCharID::kMaxCodeLen) return INVALID_UNICHAR_ID; + auto it = decoder_.find(code); + if (it == decoder_.end()) return INVALID_UNICHAR_ID; + return it->second; +} + +// Writes to the given file. Returns false in case of error. +bool UnicharCompress::Serialize(TFile* fp) const { + return encoder_.SerializeClasses(fp); +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool UnicharCompress::DeSerialize(bool swap, TFile* fp) { + if (!encoder_.DeSerializeClasses(swap, fp)) return false; + ComputeCodeRange(); + SetupDecoder(); + return true; +} + +// Returns a STRING containing a text file that describes the encoding thus: +// [,]* +// In words, a comma-separated list of one or more indices, followed by a tab +// and the UTF-8 string that the code represents per line. Most simple scripts +// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean +// and the Indic scripts will contain a many-to-many mapping. +// See the class comment above for details. +STRING UnicharCompress::GetEncodingAsString( + const UNICHARSET& unicharset) const { + STRING encoding; + for (int c = 0; c < encoder_.size(); ++c) { + const RecodedCharID& code = encoder_[c]; + if (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT && code == encoder_[c - 1]) { + // Don't show the duplicate entry. + continue; + } + encoding.add_str_int("", code(0)); + for (int i = 1; i < code.length(); ++i) { + encoding.add_str_int(",", code(i)); + } + encoding += "\t"; + if (c >= unicharset.size() || (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT && + unicharset.has_special_codes())) { + encoding += kNullChar; + } else { + encoding += unicharset.id_to_unichar(c); + } + encoding += "\n"; + } + return encoding; +} + +// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing. +// Note that the returned values are 0-based indices, NOT unicode Jamo. +// Returns false if the input is not in the Hangul unicode range. +/* static */ +bool UnicharCompress::DecomposeHangul(int unicode, int* leading, int* vowel, + int* trailing) { + if (unicode < kFirstHangul) return false; + int offset = unicode - kFirstHangul; + if (offset >= kNumHangul) return false; + const int kNCount = kVCount * kTCount; + *leading = offset / kNCount; + *vowel = (offset % kNCount) / kTCount; + *trailing = offset % kTCount; + return true; +} + +// Computes the value of code_range_ from the encoder_. +void UnicharCompress::ComputeCodeRange() { + code_range_ = -1; + for (int c = 0; c < encoder_.size(); ++c) { + const RecodedCharID& code = encoder_[c]; + for (int i = 0; i < code.length(); ++i) { + if (code(i) > code_range_) code_range_ = code(i); + } + } + ++code_range_; +} + +// Initializes the decoding hash_map from the encoding array. +void UnicharCompress::SetupDecoder() { + Cleanup(); + is_valid_start_.init_to_size(code_range_, false); + for (int c = 0; c < encoder_.size(); ++c) { + const RecodedCharID& code = encoder_[c]; + if (code.self_normalized() || decoder_.find(code) == decoder_.end()) + decoder_[code] = c; + is_valid_start_[code(0)] = true; + RecodedCharID prefix = code; + int len = code.length() - 1; + prefix.Truncate(len); + auto final_it = final_codes_.find(prefix); + if (final_it == final_codes_.end()) { + GenericVectorEqEq* code_list = new GenericVectorEqEq; + code_list->push_back(code(len)); + final_codes_[prefix] = code_list; + while (--len >= 0) { + prefix.Truncate(len); + auto next_it = next_codes_.find(prefix); + if (next_it == next_codes_.end()) { + GenericVectorEqEq* code_list = new GenericVectorEqEq; + code_list->push_back(code(len)); + next_codes_[prefix] = code_list; + } else { + // We still have to search the list as we may get here via multiple + // lengths of code. + if (!next_it->second->contains(code(len))) + next_it->second->push_back(code(len)); + break; // This prefix has been processed. + } + } + } else { + if (!final_it->second->contains(code(len))) + final_it->second->push_back(code(len)); + } + } +} + +// Frees allocated memory. +void UnicharCompress::Cleanup() { + decoder_.clear(); + is_valid_start_.clear(); + for (auto it = next_codes_.begin(); it != next_codes_.end(); ++it) { + delete it->second; + } + for (auto it = final_codes_.begin(); it != final_codes_.end(); ++it) { + delete it->second; + } + next_codes_.clear(); + final_codes_.clear(); +} + +} // namespace tesseract. diff --git a/ccutil/unicharcompress.h b/ccutil/unicharcompress.h new file mode 100644 index 0000000000..6efc46fdc7 --- /dev/null +++ b/ccutil/unicharcompress.h @@ -0,0 +1,258 @@ +/////////////////////////////////////////////////////////////////////// +// File: unicharcompress.h +// Description: Unicode re-encoding using a sequence of smaller numbers in +// place of a single large code for CJK, similarly for Indic, +// and dissection of ligatures for other scripts. +// Author: Ray Smith +// Created: Wed Mar 04 14:45:01 PST 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CCUTIL_UNICHARCOMPRESS_H_ +#define TESSERACT_CCUTIL_UNICHARCOMPRESS_H_ + +#include "hashfn.h" +#include "serialis.h" +#include "strngs.h" +#include "unicharset.h" + +namespace tesseract { + +// Trivial class to hold the code for a recoded unichar-id. +class RecodedCharID { + public: + // The maximum length of a code. + static const int kMaxCodeLen = 9; + + RecodedCharID() : self_normalized_(0), length_(0) { + memset(code_, 0, sizeof(code_)); + } + void Truncate(int length) { length_ = length; } + // Sets the code value at the given index in the code. + void Set(int index, int value) { + code_[index] = value; + if (length_ <= index) length_ = index + 1; + } + // Shorthand for setting codes of length 3, as all Hangul and Han codes are + // length 3. + void Set3(int code0, int code1, int code2) { + length_ = 3; + code_[0] = code0; + code_[1] = code1; + code_[2] = code2; + } + // Accessors + bool self_normalized() const { return self_normalized_ != 0; } + void set_self_normalized(bool value) { self_normalized_ = value; } + int length() const { return length_; } + int operator()(int index) const { return code_[index]; } + + // Writes to the given file. Returns false in case of error. + bool Serialize(TFile* fp) const { + if (fp->FWrite(&self_normalized_, sizeof(self_normalized_), 1) != 1) + return false; + if (fp->FWrite(&length_, sizeof(length_), 1) != 1) return false; + if (fp->FWrite(code_, sizeof(code_[0]), length_) != length_) return false; + return true; + } + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, TFile* fp) { + if (fp->FRead(&self_normalized_, sizeof(self_normalized_), 1) != 1) + return false; + if (fp->FRead(&length_, sizeof(length_), 1) != 1) return false; + if (swap) ReverseN(&length_, sizeof(length_)); + if (fp->FRead(code_, sizeof(code_[0]), length_) != length_) return false; + if (swap) { + for (int i = 0; i < length_; ++i) { + ReverseN(&code_[i], sizeof(code_[i])); + } + } + return true; + } + bool operator==(const RecodedCharID& other) const { + if (length_ != other.length_) return false; + for (int i = 0; i < length_; ++i) { + if (code_[i] != other.code_[i]) return false; + } + return true; + } + // Hash functor for RecodedCharID. + struct RecodedCharIDHash { + size_t operator()(const RecodedCharID& code) const { + size_t result = 0; + for (int i = 0; i < code.length_; ++i) { + result ^= code(i) << (7 * i); + } + return result; + } + }; + + private: + // True if this code is self-normalizing, ie is the master entry for indices + // that map to the same code. Has boolean value, but inT8 for serialization. + inT8 self_normalized_; + // The number of elements in use in code_; + inT32 length_; + // The re-encoded form of the unichar-id to which this RecodedCharID relates. + inT32 code_[kMaxCodeLen]; +}; + +// Class holds a "compression" of a unicharset to simplify the learning problem +// for a neural-network-based classifier. +// Objectives: +// 1 (CJK): Ids of a unicharset with a large number of classes are expressed as +// a sequence of 3 codes with much fewer values. +// This is achieved using the Jamo coding for Hangul and the Unicode +// Radical-Stroke-index for Han. +// 2 (Indic): Instead of thousands of codes with one for each grapheme, re-code +// as the unicode sequence (but coded in a more compact space). +// 3 (the rest): Eliminate multi-path problems with ligatures and fold confusing +// and not significantly distinct shapes (quotes) togther, ie +// represent the fi ligature as the f-i pair, and fold u+2019 and +// friends all onto ascii single ' +// 4 The null character and mapping to target activations: +// To save horizontal coding space, the compressed codes are generally mapped +// to target network activations without intervening null characters, BUT +// in the case of ligatures, such as ff, null characters have to be included +// so existence of repeated codes is detected at codebook-building time, and +// null characters are embedded directly into the codes, so the rest of the +// system doesn't need to worry about the problem (much). There is still an +// effect on the range of ways in which the target activations can be +// generated. +// +// The computed code values are compact (no unused values), and, for CJK, +// unique (each code position uses a disjoint set of values from each other code +// position). For non-CJK, the same code value CAN be used in multiple +// positions, eg the ff ligature is converted to , where +// is the same code as is used for the single f. +// NOTE that an intended consequence of using the normalized text from the +// unicharset is that the fancy quotes all map to a single code, so round-trip +// conversion doesn't work for all unichar-ids. +class UnicharCompress { + public: + UnicharCompress(); + UnicharCompress(const UnicharCompress& src); + ~UnicharCompress(); + UnicharCompress& operator=(const UnicharCompress& src); + + // The 1st Hangul unicode. + static const int kFirstHangul = 0xac00; + // The number of Hangul unicodes. + static const int kNumHangul = 11172; + // The number of Jamos for each of the 3 parts of a Hangul character, being + // the Leading consonant, Vowel and Trailing consonant. + static const int kLCount = 19; + static const int kVCount = 21; + static const int kTCount = 28; + + // Computes the encoding for the given unicharset. It is a requirement that + // the file training/langdata/radical-stroke.txt have been read into the + // input string radical_stroke_table. + // Returns false if the encoding cannot be constructed. + bool ComputeEncoding(const UNICHARSET& unicharset, int null_id, + STRING* radical_stroke_table); + // Sets up an encoder that doesn't change the unichars at all, so it just + // passes them through unchanged. + void SetupPassThrough(const UNICHARSET& unicharset); + // Sets up an encoder directly using the given encoding vector, which maps + // unichar_ids to the given codes. + void SetupDirect(const GenericVector& codes); + + // Returns the number of different values that can be used in a code, ie + // 1 + the maximum value that will ever be used by an RecodedCharID code in + // any position in its array. + int code_range() const { return code_range_; } + + // Encodes a single unichar_id. Returns the length of the code, (or zero if + // invalid input), and the encoding itself in code. + int EncodeUnichar(int unichar_id, RecodedCharID* code) const; + // Decodes code, returning the original unichar-id, or + // INVALID_UNICHAR_ID if the input is invalid. Note that this is not a perfect + // inverse of EncodeUnichar, since the unichar-id of U+2019 (curly single + // quote), for example, will have the same encoding as the unichar-id of + // U+0027 (ascii '). The foldings are obtained from the input unicharset, + // which in turn obtains them from NormalizeUTF8String in normstrngs.cpp, + // and include NFKC normalization plus others like quote and dash folding. + int DecodeUnichar(const RecodedCharID& code) const; + // Returns true if the given code is a valid start or single code. + bool IsValidFirstCode(int code) const { return is_valid_start_[code]; } + // Returns a list of valid non-final next codes for a given prefix code, + // which may be empty. + const GenericVector* GetNextCodes(const RecodedCharID& code) const { + auto it = next_codes_.find(code); + return it == next_codes_.end() ? NULL : it->second; + } + // Returns a list of valid final codes for a given prefix code, which may + // be empty. + const GenericVector* GetFinalCodes(const RecodedCharID& code) const { + auto it = final_codes_.find(code); + return it == final_codes_.end() ? NULL : it->second; + } + + // Writes to the given file. Returns false in case of error. + bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, TFile* fp); + + // Returns a STRING containing a text file that describes the encoding thus: + // [,]* + // In words, a comma-separated list of one or more indices, followed by a tab + // and the UTF-8 string that the code represents per line. Most simple scripts + // will encode a single index to a UTF8-string, but Chinese, Japanese, Korean + // and the Indic scripts will contain a many-to-many mapping. + // See the class comment above for details. + STRING GetEncodingAsString(const UNICHARSET& unicharset) const; + + // Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing. + // Note that the returned values are 0-based indices, NOT unicode Jamo. + // Returns false if the input is not in the Hangul unicode range. + static bool DecomposeHangul(int unicode, int* leading, int* vowel, + int* trailing); + + private: + // Renumbers codes to eliminate unused values. + void DefragmentCodeValues(int encoded_null); + // Computes the value of code_range_ from the encoder_. + void ComputeCodeRange(); + // Initializes the decoding hash_map from the encoder_ array. + void SetupDecoder(); + // Frees allocated memory. + void Cleanup(); + + // The encoder that maps a unichar-id to a sequence of small codes. + // encoder_ is the only part that is serialized. The rest is computed on load. + GenericVector encoder_; + // Decoder converts the output of encoder back to a unichar-id. + TessHashMap decoder_; + // True if the index is a valid single or start code. + GenericVector is_valid_start_; + // Maps a prefix code to a list of valid next codes. + // The map owns the vectors. + TessHashMap*, + RecodedCharID::RecodedCharIDHash> + next_codes_; + // Maps a prefix code to a list of valid final codes. + // The map owns the vectors. + TessHashMap*, + RecodedCharID::RecodedCharIDHash> + final_codes_; + // Max of any value in encoder_ + 1. + int code_range_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_CCUTIL_UNICHARCOMPRESS_H_ diff --git a/ccutil/unicharset.cpp b/ccutil/unicharset.cpp index f7e4842162..380c74101f 100644 --- a/ccutil/unicharset.cpp +++ b/ccutil/unicharset.cpp @@ -906,6 +906,8 @@ void UNICHARSET::post_load_setup() { han_sid_ = get_script_id_from_name("Han"); hiragana_sid_ = get_script_id_from_name("Hiragana"); katakana_sid_ = get_script_id_from_name("Katakana"); + thai_sid_ = get_script_id_from_name("Thai"); + hangul_sid_ = get_script_id_from_name("Hangul"); // Compute default script. Use the highest-counting alpha script, that is // not the common script, as that still contains some "alphas". diff --git a/ccutil/unicharset.h b/ccutil/unicharset.h index 023e84d5b6..eb1b4e5f4b 100644 --- a/ccutil/unicharset.h +++ b/ccutil/unicharset.h @@ -290,6 +290,8 @@ class UNICHARSET { han_sid_ = 0; hiragana_sid_ = 0; katakana_sid_ = 0; + thai_sid_ = 0; + hangul_sid_ = 0; } // Return the size of the set (the number of different UNICHAR it holds). @@ -604,6 +606,16 @@ class UNICHARSET { return unichars[unichar_id].properties.AnyRangeEmpty(); } + // Returns true if the script of the given id is space delimited. + // Returns false for Han and Thai scripts. + bool IsSpaceDelimited(UNICHAR_ID unichar_id) const { + if (INVALID_UNICHAR_ID == unichar_id) return true; + int script_id = get_script(unichar_id); + return script_id != han_sid_ && script_id != thai_sid_ && + script_id != hangul_sid_ && script_id != hiragana_sid_ && + script_id != katakana_sid_; + } + // Return the script name of the given unichar. // The returned pointer will always be the same for the same script, it's // managed by unicharset and thus MUST NOT be deleted @@ -773,7 +785,7 @@ class UNICHARSET { // Returns normalized version of unichar with the given unichar_id. const char *get_normed_unichar(UNICHAR_ID unichar_id) const { - if (unichar_id == UNICHAR_SPACE && has_special_codes()) return " "; + if (unichar_id == UNICHAR_SPACE) return " "; return unichars[unichar_id].properties.normed.string(); } // Returns a vector of UNICHAR_IDs that represent the ids of the normalized @@ -835,6 +847,8 @@ class UNICHARSET { int han_sid() const { return han_sid_; } int hiragana_sid() const { return hiragana_sid_; } int katakana_sid() const { return katakana_sid_; } + int thai_sid() const { return thai_sid_; } + int hangul_sid() const { return hangul_sid_; } int default_sid() const { return default_sid_; } // Returns true if the unicharset has the concept of upper/lower case. @@ -977,6 +991,8 @@ class UNICHARSET { int han_sid_; int hiragana_sid_; int katakana_sid_; + int thai_sid_; + int hangul_sid_; // The most frequently occurring script in the charset. int default_sid_; }; diff --git a/configure.ac b/configure.ac index a775e4fc28..34cafadc00 100644 --- a/configure.ac +++ b/configure.ac @@ -6,7 +6,7 @@ # Initialization # ---------------------------------------- AC_PREREQ([2.50]) -AC_INIT([tesseract], [3.05.00dev], [https://github.com/tesseract-ocr/tesseract/issues]) +AC_INIT([tesseract], [4.00.00dev], [https://github.com/tesseract-ocr/tesseract/issues]) AC_PROG_CXX([g++ clang++]) AC_LANG([C++]) AC_LANG_COMPILER_REQUIRE @@ -18,8 +18,8 @@ AC_PREFIX_DEFAULT([/usr/local]) # Define date of package, etc. Could be useful in auto-generated # documentation. -PACKAGE_YEAR=2015 -PACKAGE_DATE="07/11" +PACKAGE_YEAR=2016 +PACKAGE_DATE="11/11" abs_top_srcdir=`AS_DIRNAME([$0])` gitrev="`git --git-dir=${abs_top_srcdir}/.git --work-tree=${abs_top_srcdir} describe --always --tags`" @@ -42,8 +42,8 @@ AC_SUBST([PACKAGE_DATE]) GENERIC_LIBRARY_NAME=tesseract # Release versioning -GENERIC_MAJOR_VERSION=3 -GENERIC_MINOR_VERSION=4 +GENERIC_MAJOR_VERSION=4 +GENERIC_MINOR_VERSION=0 GENERIC_MICRO_VERSION=0 # API version (often = GENERIC_MAJOR_VERSION.GENERIC_MINOR_VERSION) @@ -520,6 +520,7 @@ fi # Output files AC_CONFIG_FILES([Makefile tesseract.pc]) AC_CONFIG_FILES([api/Makefile]) +AC_CONFIG_FILES([arch/Makefile]) AC_CONFIG_FILES([ccmain/Makefile]) AC_CONFIG_FILES([opencl/Makefile]) AC_CONFIG_FILES([ccstruct/Makefile]) @@ -528,6 +529,7 @@ AC_CONFIG_FILES([classify/Makefile]) AC_CONFIG_FILES([cube/Makefile]) AC_CONFIG_FILES([cutil/Makefile]) AC_CONFIG_FILES([dict/Makefile]) +AC_CONFIG_FILES([lstm/Makefile]) AC_CONFIG_FILES([neural_networks/runtime/Makefile]) AC_CONFIG_FILES([textord/Makefile]) AC_CONFIG_FILES([viewer/Makefile]) diff --git a/cutil/oldlist.cpp b/cutil/oldlist.cpp index d966deec3b..9e3f6f4c06 100644 --- a/cutil/oldlist.cpp +++ b/cutil/oldlist.cpp @@ -401,7 +401,6 @@ LIST s_adjoin(LIST var_list, void *variable, int_compare compare) { return (push_last (var_list, variable)); } - /********************************************************************** * s e a r c h * diff --git a/dict/dawg_cache.cpp b/dict/dawg_cache.cpp index 2d21b01809..aea500a132 100644 --- a/dict/dawg_cache.cpp +++ b/dict/dawg_cache.cpp @@ -69,14 +69,17 @@ Dawg *DawgLoader::Load() { PermuterType perm_type; switch (tessdata_dawg_type_) { case TESSDATA_PUNC_DAWG: + case TESSDATA_LSTM_PUNC_DAWG: dawg_type = DAWG_TYPE_PUNCTUATION; perm_type = PUNC_PERM; break; case TESSDATA_SYSTEM_DAWG: + case TESSDATA_LSTM_SYSTEM_DAWG: dawg_type = DAWG_TYPE_WORD; perm_type = SYSTEM_DAWG_PERM; break; case TESSDATA_NUMBER_DAWG: + case TESSDATA_LSTM_NUMBER_DAWG: dawg_type = DAWG_TYPE_NUMBER; perm_type = NUMBER_PERM; break; diff --git a/dict/dict.cpp b/dict/dict.cpp index 918e34aeae..fec9fcce19 100644 --- a/dict/dict.cpp +++ b/dict/dict.cpp @@ -202,10 +202,8 @@ DawgCache *Dict::GlobalDawgCache() { return &cache; } -void Dict::Load(DawgCache *dawg_cache) { - STRING name; - STRING &lang = getCCUtil()->lang; - +// Sets up ready for a Load or LoadLSTM. +void Dict::SetupForLoad(DawgCache *dawg_cache) { if (dawgs_.length() != 0) this->End(); apostrophe_unichar_id_ = getUnicharset().unichar_to_id(kApostropheSymbol); @@ -220,10 +218,10 @@ void Dict::Load(DawgCache *dawg_cache) { dawg_cache_ = new DawgCache(); dawg_cache_is_ours_ = true; } +} - TessdataManager &tessdata_manager = getCCUtil()->tessdata_manager; - const char *data_file_name = tessdata_manager.GetDataFileName().string(); - +// Loads the dawgs needed by Tesseract. Call FinishLoad() after. +void Dict::Load(const char *data_file_name, const STRING &lang) { // Load dawgs_. if (load_punc_dawg) { punc_dawg_ = dawg_cache_->GetSquishedDawg( @@ -255,6 +253,7 @@ void Dict::Load(DawgCache *dawg_cache) { if (unambig_dawg_) dawgs_ += unambig_dawg_; } + STRING name; if (((STRING &)user_words_suffix).length() > 0 || ((STRING &)user_words_file).length() > 0) { Trie *trie_ptr = new Trie(DAWG_TYPE_WORD, lang, USER_DAWG_PERM, @@ -300,8 +299,33 @@ void Dict::Load(DawgCache *dawg_cache) { // This dawg is temporary and should not be searched by letter_is_ok. pending_words_ = new Trie(DAWG_TYPE_WORD, lang, NO_PERM, getUnicharset().size(), dawg_debug_level); +} - // Construct a list of corresponding successors for each dawg. Each entry i +// Loads the dawgs needed by the LSTM model. Call FinishLoad() after. +void Dict::LoadLSTM(const char *data_file_name, const STRING &lang) { + // Load dawgs_. + if (load_punc_dawg) { + punc_dawg_ = dawg_cache_->GetSquishedDawg( + lang, data_file_name, TESSDATA_LSTM_PUNC_DAWG, dawg_debug_level); + if (punc_dawg_) dawgs_ += punc_dawg_; + } + if (load_system_dawg) { + Dawg *system_dawg = dawg_cache_->GetSquishedDawg( + lang, data_file_name, TESSDATA_LSTM_SYSTEM_DAWG, dawg_debug_level); + if (system_dawg) dawgs_ += system_dawg; + } + if (load_number_dawg) { + Dawg *number_dawg = dawg_cache_->GetSquishedDawg( + lang, data_file_name, TESSDATA_LSTM_NUMBER_DAWG, dawg_debug_level); + if (number_dawg) dawgs_ += number_dawg; + } +} + +// Completes the loading process after Load() and/or LoadLSTM(). +// Returns false if no dictionaries were loaded. +bool Dict::FinishLoad() { + if (dawgs_.empty()) return false; + // Construct a list of corresponding successors for each dawg. Each entry, i, // in the successors_ vector is a vector of integers that represent the // indices into the dawgs_ vector of the successors for dawg i. successors_.reserve(dawgs_.length()); @@ -316,6 +340,7 @@ void Dict::Load(DawgCache *dawg_cache) { } successors_ += lst; } + return true; } void Dict::End() { @@ -368,6 +393,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, // Initialization. PermuterType curr_perm = NO_PERM; dawg_args->updated_dawgs->clear(); + dawg_args->valid_end = false; // Go over the active_dawgs vector and insert DawgPosition records // with the updated ref (an edge with the corresponding unichar id) into @@ -405,6 +431,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args, dawg_debug_level > 0, "Append transition from punc dawg to current dawgs: "); if (sdawg->permuter() > curr_perm) curr_perm = sdawg->permuter(); + if (sdawg->end_of_word(dawg_edge) && + punc_dawg->end_of_word(punc_transition_edge)) + dawg_args->valid_end = true; } } } @@ -419,6 +448,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, dawg_debug_level > 0, "Extend punctuation dawg: "); if (PUNC_PERM > curr_perm) curr_perm = PUNC_PERM; + if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true; } continue; } @@ -436,6 +466,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, dawg_debug_level > 0, "Return to punctuation dawg: "); if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter(); + if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true; } } @@ -445,8 +476,8 @@ int Dict::def_letter_is_okay(void* void_dawg_args, // possible edges, not only for the exact unichar_id, but also // for all its character classes (alpha, digit, etc). if (dawg->type() == DAWG_TYPE_PATTERN) { - ProcessPatternEdges(dawg, pos, unichar_id, word_end, - dawg_args->updated_dawgs, &curr_perm); + ProcessPatternEdges(dawg, pos, unichar_id, word_end, dawg_args, + &curr_perm); // There can't be any successors to dawg that is of type // DAWG_TYPE_PATTERN, so we are done examining this DawgPosition. continue; @@ -473,6 +504,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args, continue; } if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter(); + if (dawg->end_of_word(edge) && + (punc_dawg == NULL || punc_dawg->end_of_word(pos.punc_ref))) + dawg_args->valid_end = true; dawg_args->updated_dawgs->add_unique( DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref, false), @@ -497,7 +531,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args, void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos, UNICHAR_ID unichar_id, bool word_end, - DawgPositionVector *updated_dawgs, + DawgArgs *dawg_args, PermuterType *curr_perm) const { NODE_REF node = GetStartingNode(dawg, pos.dawg_ref); // Try to find the edge corresponding to the exact unichar_id and to all the @@ -520,7 +554,8 @@ void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos, tprintf("Letter found in pattern dawg %d\n", pos.dawg_index); } if (dawg->permuter() > *curr_perm) *curr_perm = dawg->permuter(); - updated_dawgs->add_unique( + if (dawg->end_of_word(edge)) dawg_args->valid_end = true; + dawg_args->updated_dawgs->add_unique( DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref, pos.back_to_punc), dawg_debug_level > 0, @@ -816,5 +851,13 @@ bool Dict::valid_punctuation(const WERD_CHOICE &word) { return false; } +/// Returns true if the language is space-delimited (not CJ, or T). +bool Dict::IsSpaceDelimitedLang() const { + const UNICHARSET &u_set = getUnicharset(); + if (u_set.han_sid() > 0) return false; + if (u_set.katakana_sid() > 0) return false; + if (u_set.thai_sid() > 0) return false; + return true; +} } // namespace tesseract diff --git a/dict/dict.h b/dict/dict.h index 326f1235d5..9b6ca19e54 100644 --- a/dict/dict.h +++ b/dict/dict.h @@ -23,7 +23,6 @@ #include "dawg.h" #include "dawg_cache.h" #include "host.h" -#include "oldlist.h" #include "ratngs.h" #include "stopper.h" #include "trie.h" @@ -76,11 +75,13 @@ enum XHeightConsistencyEnum {XH_GOOD, XH_SUBNORMAL, XH_INCONSISTENT}; struct DawgArgs { DawgArgs(DawgPositionVector *d, DawgPositionVector *up, PermuterType p) - : active_dawgs(d), updated_dawgs(up), permuter(p) {} + : active_dawgs(d), updated_dawgs(up), permuter(p), valid_end(false) {} DawgPositionVector *active_dawgs; DawgPositionVector *updated_dawgs; PermuterType permuter; + // True if the current position is a valid word end. + bool valid_end; }; class Dict { @@ -294,7 +295,15 @@ class Dict { /// Initialize Dict class - load dawgs from [lang].traineddata and /// user-specified wordlist and parttern list. static DawgCache *GlobalDawgCache(); - void Load(DawgCache *dawg_cache); + // Sets up ready for a Load or LoadLSTM. + void SetupForLoad(DawgCache *dawg_cache); + // Loads the dawgs needed by Tesseract. Call FinishLoad() after. + void Load(const char *data_file_name, const STRING &lang); + // Loads the dawgs needed by the LSTM model. Call FinishLoad() after. + void LoadLSTM(const char *data_file_name, const STRING &lang); + // Completes the loading process after Load() and/or LoadLSTM(). + // Returns false if no dictionaries were loaded. + bool FinishLoad(); void End(); // Resets the document dictionary analogous to ResetAdaptiveClassifier. @@ -397,9 +406,7 @@ class Dict { } inline void SetWildcardID(UNICHAR_ID id) { wildcard_unichar_id_ = id; } - inline UNICHAR_ID WildcardID() const { - return wildcard_unichar_id_; - } + inline UNICHAR_ID WildcardID() const { return wildcard_unichar_id_; } /// Return the number of dawgs in the dawgs_ vector. inline int NumDawgs() const { return dawgs_.size(); } /// Return i-th dawg pointer recorded in the dawgs_ vector. @@ -436,7 +443,7 @@ class Dict { /// edges were found. void ProcessPatternEdges(const Dawg *dawg, const DawgPosition &info, UNICHAR_ID unichar_id, bool word_end, - DawgPositionVector *updated_dawgs, + DawgArgs *dawg_args, PermuterType *current_permuter) const; /// Read/Write/Access special purpose dawgs which contain words @@ -483,6 +490,8 @@ class Dict { inline void SetWordsegRatingAdjustFactor(float f) { wordseg_rating_adjust_factor_ = f; } + /// Returns true if the language is space-delimited (not CJ, or T). + bool IsSpaceDelimitedLang() const; private: /** Private member variables. */ diff --git a/lstm/Makefile.am b/lstm/Makefile.am new file mode 100644 index 0000000000..fddd6230ba --- /dev/null +++ b/lstm/Makefile.am @@ -0,0 +1,39 @@ +AM_CPPFLAGS += \ + -I$(top_srcdir)/ccutil -I$(top_srcdir)/cutil -I$(top_srcdir)/ccstruct \ + -I$(top_srcdir)/arch -I$(top_srcdir)/viewer -I$(top_srcdir)/classify \ + -I$(top_srcdir)/dict -I$(top_srcdir)/lstm +AUTOMAKE_OPTIONS = subdir-objects +SUBDIRS = +AM_CXXFLAGS = -fopenmp + +if !NO_TESSDATA_PREFIX +AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@/ +endif + +if VISIBILITY +AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden +AM_CPPFLAGS += -DTESS_EXPORTS +endif + +include_HEADERS = \ + convolve.h ctc.h fullyconnected.h functions.h input.h \ + lstm.h lstmrecognizer.h lstmtrainer.h maxpool.h \ + networkbuilder.h network.h networkio.h networkscratch.h \ + parallel.h plumbing.h recodebeam.h reconfig.h reversed.h \ + series.h static_shape.h stridemap.h tfnetwork.h weightmatrix.h + +noinst_HEADERS = + +if !USING_MULTIPLELIBS +noinst_LTLIBRARIES = libtesseract_lstm.la +else +lib_LTLIBRARIES = libtesseract_lstm.la +libtesseract_lstm_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION) +endif + +libtesseract_lstm_la_SOURCES = \ + convolve.cpp ctc.cpp fullyconnected.cpp functions.cpp input.cpp \ + lstm.cpp lstmrecognizer.cpp lstmtrainer.cpp maxpool.cpp \ + networkbuilder.cpp network.cpp networkio.cpp \ + parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \ + series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp diff --git a/lstm/convolve.cpp b/lstm/convolve.cpp new file mode 100644 index 0000000000..f89ca3bae3 --- /dev/null +++ b/lstm/convolve.cpp @@ -0,0 +1,124 @@ +/////////////////////////////////////////////////////////////////////// +// File: convolve.cpp +// Description: Convolutional layer that stacks the inputs over its rectangle +// and pulls in random data to fill out-of-input inputs. +// Output is therefore same size as its input, but deeper. +// Author: Ray Smith +// Created: Tue Mar 18 16:56:06 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "convolve.h" + +#include "networkscratch.h" +#include "serialis.h" + +namespace tesseract { + +Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y) + : Network(NT_CONVOLVE, name, ni, ni * (2*half_x + 1) * (2*half_y + 1)), + half_x_(half_x), half_y_(half_y) { +} + +Convolve::~Convolve() { +} + +// Writes to the given file. Returns false in case of error. +bool Convolve::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false; + if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Convolve::DeSerialize(bool swap, TFile* fp) { + if (fp->FRead(&half_x_, sizeof(half_x_), 1) != 1) return false; + if (fp->FRead(&half_y_, sizeof(half_y_), 1) != 1) return false; + if (swap) { + ReverseN(&half_x_, sizeof(half_x_)); + ReverseN(&half_y_, sizeof(half_y_)); + } + no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1); + return true; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Convolve::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + output->Resize(input, no_); + int y_scale = 2 * half_y_ + 1; + StrideMap::Index dest_index(output->stride_map()); + do { + // Stack x_scale groups of y_scale * ni_ inputs together. + int t = dest_index.t(); + int out_ix = 0; + for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { + StrideMap::Index x_index(dest_index); + if (!x_index.AddOffset(x, FD_WIDTH)) { + // This x is outside the image. + output->Randomize(t, out_ix, y_scale * ni_, randomizer_); + } else { + int out_iy = out_ix; + for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { + StrideMap::Index y_index(x_index); + if (!y_index.AddOffset(y, FD_HEIGHT)) { + // This y is outside the image. + output->Randomize(t, out_iy, ni_, randomizer_); + } else { + output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0); + } + } + } + } + } while (dest_index.Increment()); + if (debug) DisplayForward(*output); +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + back_deltas->Resize(fwd_deltas, ni_); + NetworkScratch::IO delta_sum; + delta_sum.ResizeFloat(fwd_deltas, ni_, scratch); + delta_sum->Zero(); + int y_scale = 2 * half_y_ + 1; + StrideMap::Index src_index(fwd_deltas.stride_map()); + do { + // Stack x_scale groups of y_scale * ni_ inputs together. + int t = src_index.t(); + int out_ix = 0; + for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { + StrideMap::Index x_index(src_index); + if (x_index.AddOffset(x, FD_WIDTH)) { + int out_iy = out_ix; + for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { + StrideMap::Index y_index(x_index); + if (y_index.AddOffset(y, FD_HEIGHT)) { + fwd_deltas.AddTimeStepPart(t, out_iy, ni_, + delta_sum->f(y_index.t())); + } + } + } + } + } while (src_index.Increment()); + back_deltas->CopyWithNormalization(*delta_sum, fwd_deltas); + return true; +} + +} // namespace tesseract. diff --git a/lstm/convolve.h b/lstm/convolve.h new file mode 100644 index 0000000000..a05dc1d850 --- /dev/null +++ b/lstm/convolve.h @@ -0,0 +1,74 @@ +/////////////////////////////////////////////////////////////////////// +// File: convolve.h +// Description: Convolutional layer that stacks the inputs over its rectangle +// and pulls in random data to fill out-of-input inputs. +// Output is therefore same size as its input, but deeper. +// Author: Ray Smith +// Created: Tue Mar 18 16:45:34 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_CONVOLVE_H_ +#define TESSERACT_LSTM_CONVOLVE_H_ + +#include "genericvector.h" +#include "matrix.h" +#include "network.h" + +namespace tesseract { + +// Makes each time-step deeper by stacking inputs over its rectangle. Does not +// affect the size of its input. Achieves this by bringing in random values in +// out-of-input areas. +class Convolve : public Network { + public: + // The area of convolution is 2*half_x + 1 by 2*half_y + 1, forcing it to + // always be odd, so the center is the current pixel. + Convolve(const STRING& name, int ni, int half_x, int half_y); + virtual ~Convolve(); + + virtual STRING spec() const { + STRING spec; + spec.add_str_int("C", half_x_ * 2 + 1); + spec.add_str_int(",", half_y_ * 2 + 1); + return spec; + } + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + protected: + // Serialized data. + inT32 half_x_; + inT32 half_y_; +}; + +} // namespace tesseract. + + +#endif // TESSERACT_LSTM_SUBSAMPLE_H_ diff --git a/lstm/ctc.cpp b/lstm/ctc.cpp new file mode 100644 index 0000000000..7a84108822 --- /dev/null +++ b/lstm/ctc.cpp @@ -0,0 +1,412 @@ +/////////////////////////////////////////////////////////////////////// +// File: ctc.cpp +// Description: Slightly improved standard CTC to compute the targets. +// Author: Ray Smith +// Created: Wed Jul 13 15:50:06 PDT 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#include "ctc.h" + +#include + +#include "genericvector.h" +#include "host.h" +#include "matrix.h" +#include "networkio.h" + +#include "network.h" +#include "scrollview.h" + +namespace tesseract { + +// Magic constants that keep CTC stable. +// Minimum probability limit for softmax input to ctc_loss. +const float CTC::kMinProb_ = 1e-12; +// Maximum absolute argument to exp(). +const double CTC::kMaxExpArg_ = 80.0; +// Minimum probability for total prob in time normalization. +const double CTC::kMinTotalTimeProb_ = 1e-8; +// Minimum probability for total prob in final normalization. +const double CTC::kMinTotalFinalProb_ = 1e-6; + +// Builds a target using CTC. Slightly improved as follows: +// Includes normalizations and clipping for stability. +// labels should be pre-padded with nulls everywhere. +// labels can be longer than the time sequence, but the total number of +// essential labels (non-null plus nulls between equal labels) must not exceed +// the number of timesteps in outputs. +// outputs is the output of the network, and should have already been +// normalized with NormalizeProbs. +// On return targets is filled with the computed targets. +// Returns false if there is insufficient time for the labels. +/* static */ +bool CTC::ComputeCTCTargets(const GenericVector& labels, int null_char, + const GENERIC_2D_ARRAY& outputs, + NetworkIO* targets) { + std::unique_ptr ctc(new CTC(labels, null_char, outputs)); + if (!ctc->ComputeLabelLimits()) { + return false; // Not enough time. + } + // Generate simple targets purely from the truth labels by spreading them + // evenly over time. + GENERIC_2D_ARRAY simple_targets; + ctc->ComputeSimpleTargets(&simple_targets); + // Add the simple targets as a starter bias to the network outputs. + float bias_fraction = ctc->CalculateBiasFraction(); + simple_targets *= bias_fraction; + ctc->outputs_ += simple_targets; + NormalizeProbs(&ctc->outputs_); + // Run regular CTC on the biased outputs. + // Run forward and backward + GENERIC_2D_ARRAY log_alphas, log_betas; + ctc->Forward(&log_alphas); + ctc->Backward(&log_betas); + // Normalize and come out of log space with a clipped softmax over time. + log_alphas += log_betas; + ctc->NormalizeSequence(&log_alphas); + ctc->LabelsToClasses(log_alphas, targets); + NormalizeProbs(targets); + return true; +} + +CTC::CTC(const GenericVector& labels, int null_char, + const GENERIC_2D_ARRAY& outputs) + : labels_(labels), outputs_(outputs), null_char_(null_char) { + num_timesteps_ = outputs.dim1(); + num_classes_ = outputs.dim2(); + num_labels_ = labels_.size(); +} + +// Computes vectors of min and max label index for each timestep, based on +// whether skippability of nulls makes it possible to complete a valid path. +bool CTC::ComputeLabelLimits() { + min_labels_.init_to_size(num_timesteps_, 0); + max_labels_.init_to_size(num_timesteps_, 0); + int min_u = num_labels_ - 1; + if (labels_[min_u] == null_char_) --min_u; + for (int t = num_timesteps_ - 1; t >= 0; --t) { + min_labels_[t] = min_u; + if (min_u > 0) { + --min_u; + if (labels_[min_u] == null_char_ && min_u > 0 && + labels_[min_u + 1] != labels_[min_u - 1]) { + --min_u; + } + } + } + int max_u = labels_[0] == null_char_; + for (int t = 0; t < num_timesteps_; ++t) { + max_labels_[t] = max_u; + if (max_labels_[t] < min_labels_[t]) return false; // Not enough room. + if (max_u + 1 < num_labels_) { + ++max_u; + if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ && + labels_[max_u + 1] != labels_[max_u - 1]) { + ++max_u; + } + } + } + return true; +} + +// Computes targets based purely on the labels by spreading the labels evenly +// over the available timesteps. +void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY* targets) const { + // Initialize all targets to zero. + targets->Resize(num_timesteps_, num_classes_, 0.0f); + GenericVector half_widths; + GenericVector means; + ComputeWidthsAndMeans(&half_widths, &means); + for (int l = 0; l < num_labels_; ++l) { + int label = labels_[l]; + float left_half_width = half_widths[l]; + float right_half_width = left_half_width; + int mean = means[l]; + if (label == null_char_) { + if (!NeededNull(l)) { + if ((l > 0 && mean == means[l - 1]) || + (l + 1 < num_labels_ && mean == means[l + 1])) { + continue; // Drop overlapping null. + } + } + // Make sure that no space is left unoccupied and that non-nulls always + // peak at 1 by stretching nulls to meet their neighbors. + if (l > 0) left_half_width = mean - means[l - 1]; + if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean; + } + if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f); + for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) { + float prob = 1.0f - offset / left_half_width; + if (mean - offset < num_timesteps_ && + prob > targets->get(mean - offset, label)) { + targets->put(mean - offset, label, prob); + } + } + for (int offset = 1; + offset < right_half_width && mean + offset < num_timesteps_; + ++offset) { + float prob = 1.0f - offset / right_half_width; + if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) { + targets->put(mean + offset, label, prob); + } + } + } +} + +// Computes mean positions and half widths of the simple targets by spreading +// the labels evenly over the available timesteps. +void CTC::ComputeWidthsAndMeans(GenericVector* half_widths, + GenericVector* means) const { + // Count the number of labels of each type, in regexp terms, counts plus + // (non-null or necessary null, which must occur at least once) and star + // (optional null). + int num_plus = 0, num_star = 0; + for (int i = 0; i < num_labels_; ++i) { + if (labels_[i] != null_char_ || NeededNull(i)) + ++num_plus; + else + ++num_star; + } + // Compute the size for each type. If there is enough space for everything + // to have size>=1, then all are equal, otherwise plus_size=1 and star gets + // whatever is left-over. + float plus_size = 1.0f, star_size = 0.0f; + float total_floating = num_plus + num_star; + if (total_floating <= num_timesteps_) { + plus_size = star_size = num_timesteps_ / total_floating; + } else if (num_star > 0) { + star_size = static_cast(num_timesteps_ - num_plus) / num_star; + } + // Set the width and compute the mean of each. + float mean_pos = 0.0f; + for (int i = 0; i < num_labels_; ++i) { + float half_width; + if (labels_[i] != null_char_ || NeededNull(i)) { + half_width = plus_size / 2.0f; + } else { + half_width = star_size / 2.0f; + } + mean_pos += half_width; + means->push_back(static_cast(mean_pos)); + mean_pos += half_width; + half_widths->push_back(half_width); + } +} + +// Helper returns the index of the highest probability label at timestep t. +static int BestLabel(const GENERIC_2D_ARRAY& outputs, int t) { + int result = 0; + int num_classes = outputs.dim2(); + const float* outputs_t = outputs[t]; + for (int c = 1; c < num_classes; ++c) { + if (outputs_t[c] > outputs_t[result]) result = c; + } + return result; +} + +// Calculates and returns a suitable fraction of the simple targets to add +// to the network outputs. +float CTC::CalculateBiasFraction() { + // Compute output labels via basic decoding. + GenericVector output_labels; + for (int t = 0; t < num_timesteps_; ++t) { + int label = BestLabel(outputs_, t); + while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t; + if (label != null_char_) output_labels.push_back(label); + } + // Simple bag of labels error calculation. + GenericVector truth_counts(num_classes_, 0); + GenericVector output_counts(num_classes_, 0); + for (int l = 0; l < num_labels_; ++l) { + ++truth_counts[labels_[l]]; + } + for (int l = 0; l < output_labels.size(); ++l) { + ++output_counts[output_labels[l]]; + } + // Count the number of true and false positive non-nulls and truth labels. + int true_pos = 0, false_pos = 0, total_labels = 0; + for (int c = 0; c < num_classes_; ++c) { + if (c == null_char_) continue; + int truth_count = truth_counts[c]; + int ocr_count = output_counts[c]; + if (truth_count > 0) { + total_labels += truth_count; + if (ocr_count > truth_count) { + true_pos += truth_count; + false_pos += ocr_count - truth_count; + } else { + true_pos += ocr_count; + } + } + // We don't need to count classes that don't exist in the truth as + // false positives, because they don't affect CTC at all. + } + if (total_labels == 0) return 0.0f; + return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels); +} + +// Given ln(x) and ln(y), returns ln(x + y), using: +// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the +// bigger number to maximize precision. +static double LogSumExp(double ln_x, double ln_y) { + if (ln_x >= ln_y) { + return ln_x + log1p(exp(ln_y - ln_x)); + } else { + return ln_y + log1p(exp(ln_x - ln_y)); + } +} + +// Runs the forward CTC pass, filling in log_probs. +void CTC::Forward(GENERIC_2D_ARRAY* log_probs) const { + log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32); + log_probs->put(0, 0, log(outputs_(0, labels_[0]))); + if (labels_[0] == null_char_) + log_probs->put(0, 1, log(outputs_(0, labels_[1]))); + for (int t = 1; t < num_timesteps_; ++t) { + const float* outputs_t = outputs_[t]; + for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) { + // Continuing the same label. + double log_sum = log_probs->get(t - 1, u); + // Change from previous label. + if (u > 0) { + log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1)); + } + // Skip the null if allowed. + if (u >= 2 && labels_[u - 1] == null_char_ && + labels_[u] != labels_[u - 2]) { + log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2)); + } + // Add in the log prob of the current label. + double label_prob = outputs_t[labels_[u]]; + log_sum += log(label_prob); + log_probs->put(t, u, log_sum); + } + } +} + +// Runs the backward CTC pass, filling in log_probs. +void CTC::Backward(GENERIC_2D_ARRAY* log_probs) const { + log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32); + log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0); + if (labels_[num_labels_ - 1] == null_char_) + log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0); + for (int t = num_timesteps_ - 2; t >= 0; --t) { + const float* outputs_tp1 = outputs_[t + 1]; + for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) { + // Continuing the same label. + double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]); + // Change from previous label. + if (u + 1 < num_labels_) { + double prev_prob = outputs_tp1[labels_[u + 1]]; + log_sum = + LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob)); + } + // Skip the null if allowed. + if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ && + labels_[u] != labels_[u + 2]) { + double skip_prob = outputs_tp1[labels_[u + 2]]; + log_sum = + LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob)); + } + log_probs->put(t, u, log_sum); + } + } +} + +// Normalizes and brings probs out of log space with a softmax over time. +void CTC::NormalizeSequence(GENERIC_2D_ARRAY* probs) const { + double max_logprob = probs->Max(); + for (int u = 0; u < num_labels_; ++u) { + double total = 0.0; + for (int t = 0; t < num_timesteps_; ++t) { + // Separate impossible path from unlikely probs. + double prob = probs->get(t, u); + if (prob > -MAX_FLOAT32) + prob = ClippedExp(prob - max_logprob); + else + prob = 0.0; + total += prob; + probs->put(t, u, prob); + } + // Note that although this is a probability distribution over time and + // therefore should sum to 1, it is important to allow some labels to be + // all zero, (or at least tiny) as it is necessary to skip some blanks. + if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_; + for (int t = 0; t < num_timesteps_; ++t) + probs->put(t, u, probs->get(t, u) / total); + } +} + +// For each timestep computes the max prob for each class over all +// instances of the class in the labels_, and sets the targets to +// the max observed prob. +void CTC::LabelsToClasses(const GENERIC_2D_ARRAY& probs, + NetworkIO* targets) const { + // For each timestep compute the max prob for each class over all + // instances of the class in the labels_. + GenericVector class_probs; + for (int t = 0; t < num_timesteps_; ++t) { + float* targets_t = targets->f(t); + class_probs.init_to_size(num_classes_, 0.0); + for (int u = 0; u < num_labels_; ++u) { + double prob = probs(t, u); + // Note that although Graves specifies sum over all labels of the same + // class, we need to allow skipped blanks to go to zero, so they don't + // interfere with the non-blanks, so max is better than sum. + if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob; + // class_probs[labels_[u]] += prob; + } + int best_class = 0; + for (int c = 0; c < num_classes_; ++c) { + targets_t[c] = class_probs[c]; + if (class_probs[c] > class_probs[best_class]) best_class = c; + } + } +} + +// Normalizes the probabilities such that no target has a prob below min_prob, +// and, provided that the initial total is at least min_total_prob, then all +// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output +// probability is thus 1 - (num_classes-1)*min_prob. +/* static */ +void CTC::NormalizeProbs(GENERIC_2D_ARRAY* probs) { + int num_timesteps = probs->dim1(); + int num_classes = probs->dim2(); + for (int t = 0; t < num_timesteps; ++t) { + float* probs_t = (*probs)[t]; + // Compute the total and clip that to prevent amplification of noise. + double total = 0.0; + for (int c = 0; c < num_classes; ++c) total += probs_t[c]; + if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_; + // Compute the increased total as a result of clipping. + double increment = 0.0; + for (int c = 0; c < num_classes; ++c) { + double prob = probs_t[c] / total; + if (prob < kMinProb_) increment += kMinProb_ - prob; + } + // Now normalize with clipping. Any additional clipping is negligible. + total += increment; + for (int c = 0; c < num_classes; ++c) { + float prob = probs_t[c] / total; + probs_t[c] = MAX(prob, kMinProb_); + } + } +} + +// Returns true if the label at index is a needed null. +bool CTC::NeededNull(int index) const { + return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ && + labels_[index + 1] == labels_[index - 1]; +} + +} // namespace tesseract diff --git a/lstm/ctc.h b/lstm/ctc.h new file mode 100644 index 0000000000..47fba67479 --- /dev/null +++ b/lstm/ctc.h @@ -0,0 +1,130 @@ +/////////////////////////////////////////////////////////////////////// +// File: ctc.h +// Description: Slightly improved standard CTC to compute the targets. +// Author: Ray Smith +// Created: Wed Jul 13 15:17:06 PDT 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_CTC_H_ +#define TESSERACT_LSTM_CTC_H_ + +#include "genericvector.h" +#include "network.h" +#include "networkio.h" +#include "scrollview.h" + +namespace tesseract { + +// Class to encapsulate CTC and simple target generation. +class CTC { + public: + // Normalizes the probabilities such that no target has a prob below min_prob, + // and, provided that the initial total is at least min_total_prob, then all + // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output + // probability is thus 1 - (num_classes-1)*min_prob. + static void NormalizeProbs(NetworkIO* probs) { + NormalizeProbs(probs->mutable_float_array()); + } + + // Builds a target using CTC. Slightly improved as follows: + // Includes normalizations and clipping for stability. + // labels should be pre-padded with nulls wherever desired, but they don't + // have to be between all labels. Allows for multi-label codes with no + // nulls between. + // labels can be longer than the time sequence, but the total number of + // essential labels (non-null plus nulls between equal labels) must not exceed + // the number of timesteps in outputs. + // outputs is the output of the network, and should have already been + // normalized with NormalizeProbs. + // On return targets is filled with the computed targets. + // Returns false if there is insufficient time for the labels. + static bool ComputeCTCTargets(const GenericVector& truth_labels, + int null_char, + const GENERIC_2D_ARRAY& outputs, + NetworkIO* targets); + + private: + // Constructor is private as the instance only holds information specific to + // the current labels, outputs etc, and is built by the static function. + CTC(const GenericVector& labels, int null_char, + const GENERIC_2D_ARRAY& outputs); + + // Computes vectors of min and max label index for each timestep, based on + // whether skippability of nulls makes it possible to complete a valid path. + bool ComputeLabelLimits(); + // Computes targets based purely on the labels by spreading the labels evenly + // over the available timesteps. + void ComputeSimpleTargets(GENERIC_2D_ARRAY* targets) const; + // Computes mean positions and half widths of the simple targets by spreading + // the labels even over the available timesteps. + void ComputeWidthsAndMeans(GenericVector* half_widths, + GenericVector* means) const; + // Calculates and returns a suitable fraction of the simple targets to add + // to the network outputs. + float CalculateBiasFraction(); + // Runs the forward CTC pass, filling in log_probs. + void Forward(GENERIC_2D_ARRAY* log_probs) const; + // Runs the backward CTC pass, filling in log_probs. + void Backward(GENERIC_2D_ARRAY* log_probs) const; + // Normalizes and brings probs out of log space with a softmax over time. + void NormalizeSequence(GENERIC_2D_ARRAY* probs) const; + // For each timestep computes the max prob for each class over all + // instances of the class in the labels_, and sets the targets to + // the max observed prob. + void LabelsToClasses(const GENERIC_2D_ARRAY& probs, + NetworkIO* targets) const; + // Normalizes the probabilities such that no target has a prob below min_prob, + // and, provided that the initial total is at least min_total_prob, then all + // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output + // probability is thus 1 - (num_classes-1)*min_prob. + static void NormalizeProbs(GENERIC_2D_ARRAY* probs); + // Returns true if the label at index is a needed null. + bool NeededNull(int index) const; + // Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/ + // underflow. + static double ClippedExp(double x) { + if (x < -kMaxExpArg_) return exp(-kMaxExpArg_); + if (x > kMaxExpArg_) return exp(kMaxExpArg_); + return exp(x); + } + + // Minimum probability limit for softmax input to ctc_loss. + static const float kMinProb_; + // Maximum absolute argument to exp(). + static const double kMaxExpArg_; + // Minimum probability for total prob in time normalization. + static const double kMinTotalTimeProb_; + // Minimum probability for total prob in final normalization. + static const double kMinTotalFinalProb_; + + // The truth label indices that are to be matched to outputs_. + const GenericVector& labels_; + // The network outputs. + GENERIC_2D_ARRAY outputs_; + // The null or "blank" label. + int null_char_; + // Number of timesteps in outputs_. + int num_timesteps_; + // Number of classes in outputs_. + int num_classes_; + // Number of labels in labels_. + int num_labels_; + // Min and max valid label indices for each timestep. + GenericVector min_labels_; + GenericVector max_labels_; +}; + +} // namespace tesseract + +#endif // TESSERACT_LSTM_CTC_H_ diff --git a/lstm/fullyconnected.cpp b/lstm/fullyconnected.cpp new file mode 100644 index 0000000000..77406b6208 --- /dev/null +++ b/lstm/fullyconnected.cpp @@ -0,0 +1,285 @@ +/////////////////////////////////////////////////////////////////////// +// File: fullyconnected.cpp +// Description: Simple feed-forward layer with various non-linearities. +// Author: Ray Smith +// Created: Wed Feb 26 14:49:15 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "fullyconnected.h" + +#ifdef _OPENMP +#include +#endif +#include +#include + +#include "functions.h" +#include "networkscratch.h" + +// Number of threads to use for parallel calculation of Forward and Backward. +const int kNumThreads = 4; + +namespace tesseract { + +FullyConnected::FullyConnected(const STRING& name, int ni, int no, + NetworkType type) + : Network(type, name, ni, no), external_source_(NULL), int_mode_(false) { +} + +FullyConnected::~FullyConnected() { +} + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const { + LossType loss_type = LT_NONE; + if (type_ == NT_SOFTMAX) + loss_type = LT_CTC; + else if (type_ == NT_SOFTMAX_NO_CTC) + loss_type = LT_SOFTMAX; + else if (type_ == NT_LOGISTIC) + loss_type = LT_LOGISTIC; + StaticShape result(input_shape); + result.set_depth(no_); + result.set_loss_type(loss_type); + return result; +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +int FullyConnected::InitWeights(float range, TRand* randomizer) { + Network::SetRandomizer(randomizer); + num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADA_GRAD), + range, randomizer); + return num_weights_; +} + +// Converts a float network to an int network. +void FullyConnected::ConvertToInt() { + weights_.ConvertToInt(); +} + +// Provides debug output on the weights. +void FullyConnected::DebugWeights() { + weights_.Debug2D(name_.string()); +} + +// Writes to the given file. Returns false in case of error. +bool FullyConnected::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + if (!weights_.Serialize(training_, fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool FullyConnected::DeSerialize(bool swap, TFile* fp) { + if (!weights_.DeSerialize(training_, swap, fp)) return false; + return true; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void FullyConnected::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + int width = input.Width(); + if (type_ == NT_SOFTMAX) + output->ResizeFloat(input, no_); + else + output->Resize(input, no_); + SetupForward(input, input_transpose); + GenericVector temp_lines; + temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec()); + GenericVector curr_input; + curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec()); + for (int i = 0; i < temp_lines.size(); ++i) { + temp_lines[i].Init(no_, scratch); + curr_input[i].Init(ni_, scratch); + } +#ifdef _OPENMP +#pragma omp parallel for num_threads(kNumThreads) + for (int t = 0; t < width; ++t) { + // Thread-local pointer to temporary storage. + int thread_id = omp_get_thread_num(); +#else + for (int t = 0; t < width; ++t) { + // Thread-local pointer to temporary storage. + int thread_id = 0; +#endif + double* temp_line = temp_lines[thread_id]; + const double* d_input = NULL; + const inT8* i_input = NULL; + if (input.int_mode()) { + i_input = input.i(t); + } else { + input.ReadTimeStep(t, curr_input[thread_id]); + d_input = curr_input[thread_id]; + } + ForwardTimeStep(d_input, i_input, t, temp_line); + output->WriteTimeStep(t, temp_line); + if (training() && type_ != NT_SOFTMAX) { + acts_.CopyTimeStepFrom(t, *output, t); + } + } + // Zero all the elements that are in the padding around images that allows + // multiple different-sized images to exist in a single array. + // acts_ is only used if this is not a softmax op. + if (training() && type_ != NT_SOFTMAX) { + acts_.ZeroInvalidElements(); + } + output->ZeroInvalidElements(); +#if DEBUG_DETAIL > 0 + tprintf("F Output:%s\n", name_.string()); + output->Print(10); +#endif + if (debug) DisplayForward(*output); +} + +// Components of Forward so FullyConnected can be reused inside LSTM. +void FullyConnected::SetupForward(const NetworkIO& input, + const TransposedArray* input_transpose) { + // Softmax output is always float, so save the input type. + int_mode_ = input.int_mode(); + if (training()) { + acts_.Resize(input, no_); + // Source_ is a transposed copy of input. It isn't needed if provided. + external_source_ = input_transpose; + if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width()); + } +} + +void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input, + int t, double* output_line) { + // input is copied to source_ line-by-line for cache coherency. + if (training() && external_source_ == NULL && d_input != NULL) + source_t_.WriteStrided(t, d_input); + if (d_input != NULL) + weights_.MatrixDotVector(d_input, output_line); + else + weights_.MatrixDotVector(i_input, output_line); + if (type_ == NT_TANH) { + FuncInplace(no_, output_line); + } else if (type_ == NT_LOGISTIC) { + FuncInplace(no_, output_line); + } else if (type_ == NT_POSCLIP) { + FuncInplace(no_, output_line); + } else if (type_ == NT_SYMCLIP) { + FuncInplace(no_, output_line); + } else if (type_ == NT_RELU) { + FuncInplace(no_, output_line); + } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) { + SoftmaxInPlace(no_, output_line); + } else if (type_ != NT_LINEAR) { + ASSERT_HOST("Invalid fully-connected type!" == NULL); + } +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + if (debug) DisplayBackward(fwd_deltas); + back_deltas->Resize(fwd_deltas, ni_); + GenericVector errors; + errors.init_to_size(kNumThreads, NetworkScratch::FloatVec()); + for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch); + GenericVector temp_backprops; + if (needs_to_backprop_) { + temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec()); + for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch); + } + int width = fwd_deltas.Width(); + NetworkScratch::GradientStore errors_t; + errors_t.Init(no_, width, scratch); +#ifdef _OPENMP +#pragma omp parallel for num_threads(kNumThreads) + for (int t = 0; t < width; ++t) { + int thread_id = omp_get_thread_num(); +#else + for (int t = 0; t < width; ++t) { + int thread_id = 0; +#endif + double* backprop = NULL; + if (needs_to_backprop_) backprop = temp_backprops[thread_id]; + double* curr_errors = errors[thread_id]; + BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop); + if (backprop != NULL) { + back_deltas->WriteTimeStep(t, backprop); + } + } + FinishBackward(*errors_t.get()); + if (needs_to_backprop_) { + back_deltas->ZeroInvalidElements(); + back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas); +#if DEBUG_DETAIL > 0 + tprintf("F Backprop:%s\n", name_.string()); + back_deltas->Print(10); +#endif + return true; + } + return false; // No point going further back. +} + +void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t, + double* curr_errors, + TransposedArray* errors_t, + double* backprop) { + if (type_ == NT_TANH) + acts_.FuncMultiply(fwd_deltas, t, curr_errors); + else if (type_ == NT_LOGISTIC) + acts_.FuncMultiply(fwd_deltas, t, curr_errors); + else if (type_ == NT_POSCLIP) + acts_.FuncMultiply(fwd_deltas, t, curr_errors); + else if (type_ == NT_SYMCLIP) + acts_.FuncMultiply(fwd_deltas, t, curr_errors); + else if (type_ == NT_RELU) + acts_.FuncMultiply(fwd_deltas, t, curr_errors); + else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || + type_ == NT_LINEAR) + fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors. + else + ASSERT_HOST("Invalid fully-connected type!" == NULL); + // Generate backprop only if needed by the lower layer. + if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop); + errors_t->WriteStrided(t, curr_errors); +} + +void FullyConnected::FinishBackward(const TransposedArray& errors_t) { + if (external_source_ == NULL) + weights_.SumOuterTransposed(errors_t, source_t_, true); + else + weights_.SumOuterTransposed(errors_t, *external_source_, true); +} + +// Updates the weights using the given learning rate and momentum. +// num_samples is the quotient to be used in the adagrad computation iff +// use_ada_grad_ is true. +void FullyConnected::Update(float learning_rate, float momentum, + int num_samples) { + weights_.Update(learning_rate, momentum, num_samples); +} + +// Sums the products of weight updates in *this and other, splitting into +// positive (same direction) in *same and negative (different direction) in +// *changed. +void FullyConnected::CountAlternators(const Network& other, double* same, + double* changed) const { + ASSERT_HOST(other.type() == type_); + const FullyConnected* fc = reinterpret_cast(&other); + weights_.CountAlternators(fc->weights_, same, changed); +} + +} // namespace tesseract. diff --git a/lstm/fullyconnected.h b/lstm/fullyconnected.h new file mode 100644 index 0000000000..d2d2b73ae8 --- /dev/null +++ b/lstm/fullyconnected.h @@ -0,0 +1,130 @@ +/////////////////////////////////////////////////////////////////////// +// File: fullyconnected.h +// Description: Simple feed-forward layer with various non-linearities. +// Author: Ray Smith +// Created: Wed Feb 26 14:46:06 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_FULLYCONNECTED_H_ +#define TESSERACT_LSTM_FULLYCONNECTED_H_ + +#include "network.h" +#include "networkscratch.h" + +namespace tesseract { + +// C++ Implementation of the Softmax (output) class from lstm.py. +class FullyConnected : public Network { + public: + FullyConnected(const STRING& name, int ni, int no, NetworkType type); + virtual ~FullyConnected(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec; + if (type_ == NT_TANH) + spec.add_str_int("Ft", no_); + else if (type_ == NT_LOGISTIC) + spec.add_str_int("Fs", no_); + else if (type_ == NT_RELU) + spec.add_str_int("Fr", no_); + else if (type_ == NT_LINEAR) + spec.add_str_int("Fl", no_); + else if (type_ == NT_POSCLIP) + spec.add_str_int("Fp", no_); + else if (type_ == NT_SYMCLIP) + spec.add_str_int("Fs", no_); + else if (type_ == NT_SOFTMAX) + spec.add_str_int("Fc", no_); + else + spec.add_str_int("Fm", no_); + return spec; + } + + // Changes the type to the given type. Used to commute a softmax to a + // non-output type for adding on other networks. + void ChangeType(NetworkType type) { + type_ = type; + } + + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + virtual int InitWeights(float range, TRand* randomizer); + + // Converts a float network to an int network. + virtual void ConvertToInt(); + + // Provides debug output on the weights. + virtual void DebugWeights(); + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + // Components of Forward so FullyConnected can be reused inside LSTM. + void SetupForward(const NetworkIO& input, + const TransposedArray* input_transpose); + void ForwardTimeStep(const double* d_input, const inT8* i_input, int t, + double* output_line); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + // Components of Backward so FullyConnected can be reused inside LSTM. + void BackwardTimeStep(const NetworkIO& fwd_deltas, int t, double* curr_errors, + TransposedArray* errors_t, double* backprop); + void FinishBackward(const TransposedArray& errors_t); + + // Updates the weights using the given learning rate and momentum. + // num_samples is the quotient to be used in the adagrad computation iff + // use_ada_grad_ is true. + virtual void Update(float learning_rate, float momentum, int num_samples); + // Sums the products of weight updates in *this and other, splitting into + // positive (same direction) in *same and negative (different direction) in + // *changed. + virtual void CountAlternators(const Network& other, double* same, + double* changed) const; + + protected: + // Weight arrays of size [no, ni + 1]. + WeightMatrix weights_; + // Transposed copy of input used during training of size [ni, width]. + TransposedArray source_t_; + // Pointer to transposed input stored elsewhere. If not null, this is used + // in preference to calculating the transpose and storing it in source_t_. + const TransposedArray* external_source_; + // Activations from forward pass of size [width, no]. + NetworkIO acts_; + // Memory of the integer mode input to forward as softmax always outputs + // float, so the information is otherwise lost. + bool int_mode_; +}; + +} // namespace tesseract. + + + +#endif // TESSERACT_LSTM_FULLYCONNECTED_H_ diff --git a/lstm/functions.cpp b/lstm/functions.cpp new file mode 100644 index 0000000000..644530c340 --- /dev/null +++ b/lstm/functions.cpp @@ -0,0 +1,26 @@ +/////////////////////////////////////////////////////////////////////// +// File: functions.cpp +// Description: Static initialize-on-first-use non-linearity functions. +// Author: Ray Smith +// Created: Tue Jul 17 14:02:59 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "functions.h" + +namespace tesseract { + +double TanhTable[kTableSize]; +double LogisticTable[kTableSize]; + +} // namespace tesseract. diff --git a/lstm/functions.h b/lstm/functions.h new file mode 100644 index 0000000000..d633e6bf7a --- /dev/null +++ b/lstm/functions.h @@ -0,0 +1,249 @@ +/////////////////////////////////////////////////////////////////////// +// File: functions.h +// Description: Collection of function-objects used by the network layers. +// Author: Ray Smith +// Created: Fri Jun 20 10:45:37 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_FUNCTIONS_H_ +#define TESSERACT_LSTM_FUNCTIONS_H_ + +#include +#include "helpers.h" +#include "tprintf.h" + +// Setting this to 1 or more causes massive dumps of debug data: weights, +// updates, internal calculations etc, and reduces the number of test iterations +// to a small number, so outputs can be diffed. +#define DEBUG_DETAIL 0 +#if DEBUG_DETAIL > 0 +#undef _OPENMP // Disable open mp to get the outputs in sync. +#endif + +namespace tesseract { + +// Size of static tables. +const int kTableSize = 4096; +// Scale factor for float arg to int index. +const double kScaleFactor = 256.0; + +extern double TanhTable[]; +extern double LogisticTable[]; + +// Non-linearity (sigmoid) functions with cache tables and clipping. +inline double Tanh(double x) { + if (x < 0.0) return -Tanh(-x); + if (x >= (kTableSize - 1) / kScaleFactor) return 1.0; + x *= kScaleFactor; + int index = static_cast(floor(x)); + if (TanhTable[index] == 0.0 && index > 0) { + // Generate the entry. + TanhTable[index] = tanh(index / kScaleFactor); + } + if (index == kTableSize - 1) return TanhTable[kTableSize - 1]; + if (TanhTable[index + 1] == 0.0) { + // Generate the entry. + TanhTable[index + 1] = tanh((index + 1) / kScaleFactor); + } + double offset = x - index; + return TanhTable[index] * (1.0 - offset) + TanhTable[index + 1] * offset; +} + +inline double Logistic(double x) { + if (x < 0.0) return 1.0 - Logistic(-x); + if (x >= (kTableSize - 1) / kScaleFactor) return 1.0; + x *= kScaleFactor; + int index = static_cast(floor(x)); + if (LogisticTable[index] == 0.0) { + // Generate the entry. + LogisticTable[index] = 1.0 / (1.0 + exp(-index / kScaleFactor)); + } + if (index == kTableSize - 1) return LogisticTable[kTableSize - 1]; + if (LogisticTable[index + 1] == 0.0) { + // Generate the entry. + LogisticTable[index + 1] = 1.0 / (1.0 + exp(-(index + 1) / kScaleFactor)); + } + double offset = x - index; + return LogisticTable[index] * (1.0 - offset) + + LogisticTable[index + 1] * offset; +} + +// Non-linearity (sigmoid) functions and their derivatives. +struct FFunc { + inline double operator()(double x) const { return Logistic(x); } +}; +struct FPrime { + inline double operator()(double y) const { return y * (1.0 - y); } +}; +struct ClipFFunc { + inline double operator()(double x) const { + if (x <= 0.0) return 0.0; + if (x >= 1.0) return 1.0; + return x; + } +}; +struct ClipFPrime { + inline double operator()(double y) const { + return 0.0 < y && y < 1.0 ? 1.0 : 0.0; + } +}; +struct Relu { + inline double operator()(double x) const { + if (x <= 0.0) return 0.0; + return x; + } +}; +struct ReluPrime { + inline double operator()(double y) const { return 0.0 < y ? 1.0 : 0.0; } +}; +struct GFunc { + inline double operator()(double x) const { return Tanh(x); } +}; +struct GPrime { + inline double operator()(double y) const { return 1.0 - y * y; } +}; +struct ClipGFunc { + inline double operator()(double x) const { + if (x <= -1.0) return -1.0; + if (x >= 1.0) return 1.0; + return x; + } +}; +struct ClipGPrime { + inline double operator()(double y) const { + return -1.0 < y && y < 1.0 ? 1.0 : 0.0; + } +}; +struct HFunc { + inline double operator()(double x) const { return Tanh(x); } +}; +struct HPrime { + inline double operator()(double y) const { + double u = Tanh(y); + return 1.0 - u * u; + } +}; +struct UnityFunc { + inline double operator()(double x) const { return 1.0; } +}; +struct IdentityFunc { + inline double operator()(double x) const { return x; } +}; + +// Applies Func in-place to inout, of size n. +template +inline void FuncInplace(int n, double* inout) { + Func f; + for (int i = 0; i < n; ++i) { + inout[i] = f(inout[i]); + } +} +// Applies Func to u and multiplies the result by v component-wise, +// putting the product in out, all of size n. +template +inline void FuncMultiply(const double* u, const double* v, int n, double* out) { + Func f; + for (int i = 0; i < n; ++i) { + out[i] = f(u[i]) * v[i]; + } +} +// Applies the Softmax function in-place to inout, of size n. +template +inline void SoftmaxInPlace(int n, T* inout) { + if (n <= 0) return; + // A limit on the negative range input to exp to guarantee non-zero output. + const T kMaxSoftmaxActivation = 86.0f; + + T max_output = inout[0]; + for (int i = 1; i < n; i++) { + T output = inout[i]; + if (output > max_output) max_output = output; + } + T prob_total = 0.0; + for (int i = 0; i < n; i++) { + T prob = inout[i] - max_output; + prob = exp(ClipToRange(prob, -kMaxSoftmaxActivation, static_cast(0))); + prob_total += prob; + inout[i] = prob; + } + if (prob_total > 0.0) { + for (int i = 0; i < n; i++) inout[i] /= prob_total; + } +} + +// Copies n values of the given src vector to dest. +inline void CopyVector(int n, const double* src, double* dest) { + memcpy(dest, src, n * sizeof(dest[0])); +} + +// Adds n values of the given src vector to dest. +inline void AccumulateVector(int n, const double* src, double* dest) { + for (int i = 0; i < n; ++i) dest[i] += src[i]; +} + +// Multiplies n values of inout in-place element-wise by the given src vector. +inline void MultiplyVectorsInPlace(int n, const double* src, double* inout) { + for (int i = 0; i < n; ++i) inout[i] *= src[i]; +} + +// Multiplies n values of u by v, element-wise, accumulating to out. +inline void MultiplyAccumulate(int n, const double* u, const double* v, + double* out) { + for (int i = 0; i < n; i++) { + out[i] += u[i] * v[i]; + } +} + +// Sums the given 5 n-vectors putting the result into sum. +inline void SumVectors(int n, const double* v1, const double* v2, + const double* v3, const double* v4, const double* v5, + double* sum) { + for (int i = 0; i < n; ++i) { + sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i]; + } +} + +// Sets the given n-vector vec to 0. +template +inline void ZeroVector(int n, T* vec) { + memset(vec, 0, n * sizeof(*vec)); +} + +// Clips the given vector vec, of size n to [lower, upper]. +template +inline void ClipVector(int n, T lower, T upper, T* vec) { + for (int i = 0; i < n; ++i) vec[i] = ClipToRange(vec[i], lower, upper); +} + +// Converts the given n-vector to a binary encoding of the maximum value, +// encoded as vector of nf binary values. +inline void CodeInBinary(int n, int nf, double* vec) { + if (nf <= 0 || n < nf) return; + int index = 0; + double best_score = vec[0]; + for (int i = 1; i < n; ++i) { + if (vec[i] > best_score) { + best_score = vec[i]; + index = i; + } + } + int mask = 1; + for (int i = 0; i < nf; ++i, mask *= 2) { + vec[i] = (index & mask) ? 1.0 : 0.0; + } +} + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_FUNCTIONS_H_ diff --git a/lstm/input.cpp b/lstm/input.cpp new file mode 100644 index 0000000000..c0f61781ea --- /dev/null +++ b/lstm/input.cpp @@ -0,0 +1,154 @@ +/////////////////////////////////////////////////////////////////////// +// File: input.cpp +// Description: Input layer class for neural network implementations. +// Author: Ray Smith +// Created: Thu Mar 13 09:10:34 PDT 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "input.h" + +#include "allheaders.h" +#include "imagedata.h" +#include "pageres.h" +#include "scrollview.h" + +namespace tesseract { + +Input::Input(const STRING& name, int ni, int no) + : Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {} +Input::Input(const STRING& name, const StaticShape& shape) + : Network(NT_INPUT, name, shape.height(), shape.depth()), + shape_(shape), + cached_x_scale_(1) { + if (shape.height() == 1) ni_ = shape.depth(); +} + +Input::~Input() { +} + +// Writes to the given file. Returns false in case of error. +bool Input::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Input::DeSerialize(bool swap, TFile* fp) { + if (fp->FRead(&shape_, sizeof(shape_), 1) != 1) return false; + // TODO(rays) swaps! + return true; +} + +// Returns an integer reduction factor that the network applies to the +// time sequence. Assumes that any 2-d is already eliminated. Used for +// scaling bounding boxes of truth data. +int Input::XScaleFactor() const { + return 1; +} + +// Provides the (minimum) x scale factor to the network (of interest only to +// input units) so they can determine how to scale bounding boxes. +void Input::CacheXScaleFactor(int factor) { + cached_x_scale_ = factor; +} + +// Runs forward propagation of activations on the input line. +// See Network for a detailed discussion of the arguments. +void Input::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + *output = input; +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Input::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + tprintf("Input::Backward should not be called!!\n"); + return false; +} + +// Creates and returns a Pix of appropriate size for the network from the +// image_data. If non-null, *image_scale returns the image scale factor used. +// Returns nullptr on error. +/* static */ +Pix* Input::PrepareLSTMInputs(const ImageData& image_data, + const Network* network, int min_width, + TRand* randomizer, float* image_scale) { + // Note that NumInputs() is defined as input image height. + int target_height = network->NumInputs(); + int width, height; + Pix* pix = + image_data.PreScale(target_height, image_scale, &width, &height, nullptr); + if (pix == nullptr) { + tprintf("Bad pix from ImageData!\n"); + return nullptr; + } + if (width <= min_width) { + tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width, + height, min_width); + pixDestroy(&pix); + return nullptr; + } + return pix; +} + +// Converts the given pix to a NetworkIO of height and depth appropriate to the +// given StaticShape: +// If depth == 3, convert to 24 bit color, otherwise normalized grey. +// Scale to target height, if the shape's height is > 1, or its depth if the +// height == 1. If height == 0 then no scaling. +// NOTE: It isn't safe for multiple threads to call this on the same pix. +/* static */ +void Input::PreparePixInput(const StaticShape& shape, const Pix* pix, + TRand* randomizer, NetworkIO* input) { + bool color = shape.depth() == 3; + Pix* var_pix = const_cast(pix); + int depth = pixGetDepth(var_pix); + Pix* normed_pix = nullptr; + // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without + // colormap, so we just have to deal with depth conversion here. + if (color) { + // Force RGB. + if (depth == 32) + normed_pix = pixClone(var_pix); + else + normed_pix = pixConvertTo32(var_pix); + } else { + // Convert non-8-bit images to 8 bit. + if (depth == 8) + normed_pix = pixClone(var_pix); + else + normed_pix = pixConvertTo8(var_pix, false); + } + int width = pixGetWidth(normed_pix); + int height = pixGetHeight(normed_pix); + int target_height = shape.height(); + if (target_height == 1) target_height = shape.depth(); + if (target_height == 0) target_height = height; + float im_factor = static_cast(target_height) / height; + if (im_factor != 1.0f) { + // Get the scaled image. + Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor); + pixDestroy(&normed_pix); + normed_pix = scaled_pix; + } + input->FromPix(shape, normed_pix, randomizer); + pixDestroy(&normed_pix); +} + +} // namespace tesseract. diff --git a/lstm/input.h b/lstm/input.h new file mode 100644 index 0000000000..7a750a562a --- /dev/null +++ b/lstm/input.h @@ -0,0 +1,107 @@ +/////////////////////////////////////////////////////////////////////// +// File: input.h +// Description: Input layer class for neural network implementations. +// Author: Ray Smith +// Created: Thu Mar 13 08:56:26 PDT 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_INPUT_H_ +#define TESSERACT_LSTM_INPUT_H_ + +#include "network.h" + +class ScrollView; + +namespace tesseract { + +class Input : public Network { + public: + Input(const STRING& name, int ni, int no); + Input(const STRING& name, const StaticShape& shape); + virtual ~Input(); + + virtual STRING spec() const { + STRING spec; + spec.add_str_int("", shape_.batch()); + spec.add_str_int(",", shape_.height()); + spec.add_str_int(",", shape_.width()); + spec.add_str_int(",", shape_.depth()); + return spec; + } + + // Returns the required shape input to the network. + virtual StaticShape InputShape() const { return shape_; } + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const { + return shape_; + } + // Writes to the given file. Returns false in case of error. + // Should be overridden by subclasses, but called by their Serialize. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + // Should be overridden by subclasses, but NOT called by their DeSerialize. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Returns an integer reduction factor that the network applies to the + // time sequence. Assumes that any 2-d is already eliminated. Used for + // scaling bounding boxes of truth data. + // WARNING: if GlobalMinimax is used to vary the scale, this will return + // the last used scale factor. Call it before any forward, and it will return + // the minimum scale factor of the paths through the GlobalMinimax. + virtual int XScaleFactor() const; + + // Provides the (minimum) x scale factor to the network (of interest only to + // input units) so they can determine how to scale bounding boxes. + virtual void CacheXScaleFactor(int factor); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + // Creates and returns a Pix of appropriate size for the network from the + // image_data. If non-null, *image_scale returns the image scale factor used. + // Returns nullptr on error. + /* static */ + static Pix* PrepareLSTMInputs(const ImageData& image_data, + const Network* network, int min_width, + TRand* randomizer, float* image_scale); + // Converts the given pix to a NetworkIO of height and depth appropriate to + // the given StaticShape: + // If depth == 3, convert to 24 bit color, otherwise normalized grey. + // Scale to target height, if the shape's height is > 1, or its depth if the + // height == 1. If height == 0 then no scaling. + // NOTE: It isn't safe for multiple threads to call this on the same pix. + static void PreparePixInput(const StaticShape& shape, const Pix* pix, + TRand* randomizer, NetworkIO* input); + + private: + // Input shape determines how images are dealt with. + StaticShape shape_; + // Cached total network x scale factor for scaling bounding boxes. + int cached_x_scale_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_INPUT_H_ + diff --git a/lstm/lstm.cpp b/lstm/lstm.cpp new file mode 100644 index 0000000000..cac5f64c93 --- /dev/null +++ b/lstm/lstm.cpp @@ -0,0 +1,710 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstm.cpp +// Description: Long-term-short-term-memory Recurrent neural network. +// Author: Ray Smith +// Created: Wed May 01 17:43:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "lstm.h" + +#ifndef ANDROID_BUILD +#include +#endif +#include +#include + +#include "fullyconnected.h" +#include "functions.h" +#include "networkscratch.h" +#include "tprintf.h" + +// Macros for openmp code if it is available, otherwise empty macros. +#ifdef _OPENMP +#define PARALLEL_IF_OPENMP(__num_threads) \ + PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \ + PRAGMA(omp sections nowait) { \ + PRAGMA(omp section) { +#define SECTION_IF_OPENMP \ + } \ + PRAGMA(omp section) \ + { + +#define END_PARALLEL_IF_OPENMP \ + } \ + } /* end of sections */ \ + } /* end of parallel section */ + +// Define the portable PRAGMA macro. +#ifdef _MSC_VER // Different _Pragma +#define PRAGMA(x) __pragma(x) +#else +#define PRAGMA(x) _Pragma(#x) +#endif // _MSC_VER + +#else // _OPENMP +#define PARALLEL_IF_OPENMP(__num_threads) +#define SECTION_IF_OPENMP +#define END_PARALLEL_IF_OPENMP +#endif // _OPENMP + + +namespace tesseract { + +// Max absolute value of state_. It is reasonably high to enable the state +// to count things. +const double kStateClip = 100.0; +// Max absolute value of gate_errors (the gradients). +const double kErrClip = 1.0f; + +LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional, + NetworkType type) + : Network(type, name, ni, no), + na_(ni + ns), + ns_(ns), + nf_(0), + is_2d_(two_dimensional), + softmax_(NULL) { + if (two_dimensional) na_ += ns_; + if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) { + nf_ = 0; + // networkbuilder ensures this is always true. + ASSERT_HOST(no == ns); + } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) { + nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_))); + softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX); + } else { + tprintf("%d is invalid type of LSTM!\n", type); + ASSERT_HOST(false); + } + na_ += nf_; +} + +LSTM::~LSTM() { delete softmax_; } + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape LSTM::OutputShape(const StaticShape& input_shape) const { + StaticShape result = input_shape; + result.set_depth(no_); + if (type_ == NT_LSTM_SUMMARY) result.set_width(1); + if (softmax_ != NULL) return softmax_->OutputShape(result); + return result; +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +int LSTM::InitWeights(float range, TRand* randomizer) { + Network::SetRandomizer(randomizer); + num_weights_ = 0; + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + num_weights_ += gate_weights_[w].InitWeightsFloat( + ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer); + } + if (softmax_ != NULL) { + num_weights_ += softmax_->InitWeights(range, randomizer); + } + return num_weights_; +} + +// Converts a float network to an int network. +void LSTM::ConvertToInt() { + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + gate_weights_[w].ConvertToInt(); + } + if (softmax_ != NULL) { + softmax_->ConvertToInt(); + } +} + +// Sets up the network for training using the given weight_range. +void LSTM::DebugWeights() { + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + STRING msg = name_; + msg.add_str_int(" Gate weights ", w); + gate_weights_[w].Debug2D(msg.string()); + } + if (softmax_ != NULL) { + softmax_->DebugWeights(); + } +} + +// Writes to the given file. Returns false in case of error. +bool LSTM::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false; + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + if (!gate_weights_[w].Serialize(training_, fp)) return false; + } + if (softmax_ != NULL && !softmax_->Serialize(fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool LSTM::DeSerialize(bool swap, TFile* fp) { + if (fp->FRead(&na_, sizeof(na_), 1) != 1) return false; + if (swap) ReverseN(&na_, sizeof(na_)); + if (type_ == NT_LSTM_SOFTMAX) { + nf_ = no_; + } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) { + nf_ = IntCastRounded(ceil(log2(no_))); + } else { + nf_ = 0; + } + is_2d_ = false; + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false; + if (w == CI) { + ns_ = gate_weights_[CI].NumOutputs(); + is_2d_ = na_ - nf_ == ni_ + 2 * ns_; + } + } + delete softmax_; + if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) { + softmax_ = + reinterpret_cast(Network::CreateFromFile(swap, fp)); + if (softmax_ == NULL) return false; + } else { + softmax_ = NULL; + } + return true; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void LSTM::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + input_map_ = input.stride_map(); + input_width_ = input.Width(); + if (softmax_ != NULL) + output->ResizeFloat(input, no_); + else if (type_ == NT_LSTM_SUMMARY) + output->ResizeXTo1(input, no_); + else + output->Resize(input, no_); + ResizeForward(input); + // Temporary storage of forward computation for each gate. + NetworkScratch::FloatVec temp_lines[WT_COUNT]; + for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch); + // Single timestep buffers for the current/recurrent output and state. + NetworkScratch::FloatVec curr_state, curr_output; + curr_state.Init(ns_, scratch); + ZeroVector(ns_, curr_state); + curr_output.Init(ns_, scratch); + ZeroVector(ns_, curr_output); + // Rotating buffers of width buf_width allow storage of the state and output + // for the other dimension, used only when working in true 2D mode. The width + // is enough to hold an entire strip of the major direction. + int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1; + GenericVector states, outputs; + if (Is2D()) { + states.init_to_size(buf_width, NetworkScratch::FloatVec()); + outputs.init_to_size(buf_width, NetworkScratch::FloatVec()); + for (int i = 0; i < buf_width; ++i) { + states[i].Init(ns_, scratch); + ZeroVector(ns_, states[i]); + outputs[i].Init(ns_, scratch); + ZeroVector(ns_, outputs[i]); + } + } + // Used only if a softmax LSTM. + NetworkScratch::FloatVec softmax_output; + NetworkScratch::IO int_output; + if (softmax_ != NULL) { + softmax_output.Init(no_, scratch); + ZeroVector(no_, softmax_output); + if (input.int_mode()) int_output.Resize2d(true, 1, ns_, scratch); + softmax_->SetupForward(input, NULL); + } + NetworkScratch::FloatVec curr_input; + curr_input.Init(na_, scratch); + StrideMap::Index src_index(input_map_); + // Used only by NT_LSTM_SUMMARY. + StrideMap::Index dest_index(output->stride_map()); + do { + int t = src_index.t(); + // True if there is a valid old state for the 2nd dimension. + bool valid_2d = Is2D(); + if (valid_2d) { + StrideMap::Index dim_index(src_index); + if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false; + } + // Index of the 2-D revolving buffers (outputs, states). + int mod_t = Modulo(t, buf_width); // Current timestep. + // Setup the padded input in source. + source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0); + if (softmax_ != NULL) { + source_.WriteTimeStepPart(t, ni_, nf_, softmax_output); + } + source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output); + if (Is2D()) + source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]); + if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input); + // Matrix multiply the inputs with the source. + PARALLEL_IF_OPENMP(GFS) + // It looks inefficient to create the threads on each t iteration, but the + // alternative of putting the parallel outside the t loop, a single around + // the t-loop and then tasks in place of the sections is a *lot* slower. + // Cell inputs. + if (source_.int_mode()) + gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]); + else + gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]); + FuncInplace(ns_, temp_lines[CI]); + + SECTION_IF_OPENMP + // Input Gates. + if (source_.int_mode()) + gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]); + else + gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]); + FuncInplace(ns_, temp_lines[GI]); + + SECTION_IF_OPENMP + // 1-D forget gates. + if (source_.int_mode()) + gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]); + else + gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]); + FuncInplace(ns_, temp_lines[GF1]); + + // 2-D forget gates. + if (Is2D()) { + if (source_.int_mode()) + gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]); + else + gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]); + FuncInplace(ns_, temp_lines[GFS]); + } + + SECTION_IF_OPENMP + // Output gates. + if (source_.int_mode()) + gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]); + else + gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]); + FuncInplace(ns_, temp_lines[GO]); + END_PARALLEL_IF_OPENMP + + // Apply forget gate to state. + MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state); + if (Is2D()) { + // Max-pool the forget gates (in 2-d) instead of blindly adding. + inT8* which_fg_col = which_fg_[t]; + memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0])); + if (valid_2d) { + const double* stepped_state = states[mod_t]; + for (int i = 0; i < ns_; ++i) { + if (temp_lines[GF1][i] < temp_lines[GFS][i]) { + curr_state[i] = temp_lines[GFS][i] * stepped_state[i]; + which_fg_col[i] = 2; + } + } + } + } + MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state); + // Clip curr_state to a sane range. + ClipVector(ns_, -kStateClip, kStateClip, curr_state); + if (training_) { + // Save the gate node values. + node_values_[CI].WriteTimeStep(t, temp_lines[CI]); + node_values_[GI].WriteTimeStep(t, temp_lines[GI]); + node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]); + node_values_[GO].WriteTimeStep(t, temp_lines[GO]); + if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]); + } + FuncMultiply(curr_state, temp_lines[GO], ns_, curr_output); + if (training_) state_.WriteTimeStep(t, curr_state); + if (softmax_ != NULL) { + if (input.int_mode()) { + int_output->WriteTimeStep(0, curr_output); + softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output); + } else { + softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output); + } + output->WriteTimeStep(t, softmax_output); + if (type_ == NT_LSTM_SOFTMAX_ENCODED) { + CodeInBinary(no_, nf_, softmax_output); + } + } else if (type_ == NT_LSTM_SUMMARY) { + // Output only at the end of a row. + if (src_index.IsLast(FD_WIDTH)) { + output->WriteTimeStep(dest_index.t(), curr_output); + dest_index.Increment(); + } + } else { + output->WriteTimeStep(t, curr_output); + } + // Save states for use by the 2nd dimension only if needed. + if (Is2D()) { + CopyVector(ns_, curr_state, states[mod_t]); + CopyVector(ns_, curr_output, outputs[mod_t]); + } + // Always zero the states at the end of every row, but only for the major + // direction. The 2-D state remains intact. + if (src_index.IsLast(FD_WIDTH)) { + ZeroVector(ns_, curr_state); + ZeroVector(ns_, curr_output); + } + } while (src_index.Increment()); +#if DEBUG_DETAIL > 0 + tprintf("Source:%s\n", name_.string()); + source_.Print(10); + tprintf("State:%s\n", name_.string()); + state_.Print(10); + tprintf("Output:%s\n", name_.string()); + output->Print(10); +#endif + if (debug) DisplayForward(*output); +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + if (debug) DisplayBackward(fwd_deltas); + back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_); + // ======Scratch space.====== + // Output errors from deltas with recurrence from sourceerr. + NetworkScratch::FloatVec outputerr; + outputerr.Init(ns_, scratch); + // Recurrent error in the state/source. + NetworkScratch::FloatVec curr_stateerr, curr_sourceerr; + curr_stateerr.Init(ns_, scratch); + curr_sourceerr.Init(na_, scratch); + ZeroVector(ns_, curr_stateerr); + ZeroVector(na_, curr_sourceerr); + // Errors in the gates. + NetworkScratch::FloatVec gate_errors[WT_COUNT]; + for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch); + // Rotating buffers of width buf_width allow storage of the recurrent time- + // steps used only for true 2-D. Stores one full strip of the major direction. + int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1; + GenericVector stateerr, sourceerr; + if (Is2D()) { + stateerr.init_to_size(buf_width, NetworkScratch::FloatVec()); + sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec()); + for (int t = 0; t < buf_width; ++t) { + stateerr[t].Init(ns_, scratch); + sourceerr[t].Init(na_, scratch); + ZeroVector(ns_, stateerr[t]); + ZeroVector(na_, sourceerr[t]); + } + } + // Parallel-generated sourceerr from each of the gates. + NetworkScratch::FloatVec sourceerr_temps[WT_COUNT]; + for (int w = 0; w < WT_COUNT; ++w) + sourceerr_temps[w].Init(na_, scratch); + int width = input_width_; + // Transposed gate errors stored over all timesteps for sum outer. + NetworkScratch::GradientStore gate_errors_t[WT_COUNT]; + for (int w = 0; w < WT_COUNT; ++w) { + gate_errors_t[w].Init(ns_, width, scratch); + } + // Used only if softmax_ != NULL. + NetworkScratch::FloatVec softmax_errors; + NetworkScratch::GradientStore softmax_errors_t; + if (softmax_ != NULL) { + softmax_errors.Init(no_, scratch); + softmax_errors_t.Init(no_, width, scratch); + } + double state_clip = Is2D() ? 9.0 : 4.0; +#if DEBUG_DETAIL > 1 + tprintf("fwd_deltas:%s\n", name_.string()); + fwd_deltas.Print(10); +#endif + StrideMap::Index dest_index(input_map_); + dest_index.InitToLast(); + // Used only by NT_LSTM_SUMMARY. + StrideMap::Index src_index(fwd_deltas.stride_map()); + src_index.InitToLast(); + do { + int t = dest_index.t(); + bool at_last_x = dest_index.IsLast(FD_WIDTH); + // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only + // valid if >= 0, which is true if 2d and not on the top/bottom. + int up_pos = -1; + int down_pos = -1; + if (Is2D()) { + if (dest_index.index(FD_HEIGHT) > 0) { + StrideMap::Index up_index(dest_index); + if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t(); + } + if (!dest_index.IsLast(FD_HEIGHT)) { + StrideMap::Index down_index(dest_index); + if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t(); + } + } + // Index of the 2-D revolving buffers (sourceerr, stateerr). + int mod_t = Modulo(t, buf_width); // Current timestep. + // Zero the state in the major direction only at the end of every row. + if (at_last_x) { + ZeroVector(na_, curr_sourceerr); + ZeroVector(ns_, curr_stateerr); + } + // Setup the outputerr. + if (type_ == NT_LSTM_SUMMARY) { + if (dest_index.IsLast(FD_WIDTH)) { + fwd_deltas.ReadTimeStep(src_index.t(), outputerr); + src_index.Decrement(); + } else { + ZeroVector(ns_, outputerr); + } + } else if (softmax_ == NULL) { + fwd_deltas.ReadTimeStep(t, outputerr); + } else { + softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, + softmax_errors_t.get(), outputerr); + } + if (!at_last_x) + AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr); + if (down_pos >= 0) + AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr); + // Apply the 1-d forget gates. + if (!at_last_x) { + const float* next_node_gf1 = node_values_[GF1].f(t + 1); + for (int i = 0; i < ns_; ++i) { + curr_stateerr[i] *= next_node_gf1[i]; + } + } + if (Is2D() && t + 1 < width) { + for (int i = 0; i < ns_; ++i) { + if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0; + } + if (down_pos >= 0) { + const float* right_node_gfs = node_values_[GFS].f(down_pos); + const double* right_stateerr = stateerr[mod_t]; + for (int i = 0; i < ns_; ++i) { + if (which_fg_[down_pos][i] == 2) { + curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i]; + } + } + } + } + state_.FuncMultiply3Add(node_values_[GO], t, outputerr, + curr_stateerr); + // Clip stateerr_ to a sane range. + ClipVector(ns_, -state_clip, state_clip, curr_stateerr); +#if DEBUG_DETAIL > 1 + if (t + 10 > width) { + tprintf("t=%d, stateerr=", t); + for (int i = 0; i < ns_; ++i) + tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], + curr_sourceerr[ni_ + nf_ + i]); + tprintf("\n"); + } +#endif + // Matrix multiply to get the source errors. + PARALLEL_IF_OPENMP(GFS) + + // Cell inputs. + node_values_[CI].FuncMultiply3(t, node_values_[GI], t, + curr_stateerr, gate_errors[CI]); + ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get()); + gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]); + gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]); + + SECTION_IF_OPENMP + // Input Gates. + node_values_[GI].FuncMultiply3(t, node_values_[CI], t, + curr_stateerr, gate_errors[GI]); + ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get()); + gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]); + gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]); + + SECTION_IF_OPENMP + // 1-D forget Gates. + if (t > 0) { + node_values_[GF1].FuncMultiply3(t, state_, t - 1, curr_stateerr, + gate_errors[GF1]); + ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get()); + gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], + sourceerr_temps[GF1]); + } else { + memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0])); + memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1])); + } + gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]); + + // 2-D forget Gates. + if (up_pos >= 0) { + node_values_[GFS].FuncMultiply3(t, state_, up_pos, curr_stateerr, + gate_errors[GFS]); + ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get()); + gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], + sourceerr_temps[GFS]); + } else { + memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0])); + memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS])); + } + if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]); + + SECTION_IF_OPENMP + // Output gates. + state_.Func2Multiply3(node_values_[GO], t, outputerr, + gate_errors[GO]); + ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get()); + gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]); + gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]); + END_PARALLEL_IF_OPENMP + + SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], + sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS], + curr_sourceerr); + back_deltas->WriteTimeStep(t, curr_sourceerr); + // Save states for use by the 2nd dimension only if needed. + if (Is2D()) { + CopyVector(ns_, curr_stateerr, stateerr[mod_t]); + CopyVector(na_, curr_sourceerr, sourceerr[mod_t]); + } + } while (dest_index.Decrement()); +#if DEBUG_DETAIL > 2 + for (int w = 0; w < WT_COUNT; ++w) { + tprintf("%s gate errors[%d]\n", name_.string(), w); + gate_errors_t[w].get()->PrintUnTransposed(10); + } +#endif + // Transposed source_ used to speed-up SumOuter. + NetworkScratch::GradientStore source_t, state_t; + source_t.Init(na_, width, scratch); + source_.Transpose(source_t.get()); + state_t.Init(ns_, width, scratch); + state_.Transpose(state_t.get()); +#ifdef _OPENMP +#pragma omp parallel for num_threads(GFS) if (!Is2D()) +#endif + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false); + } + if (softmax_ != NULL) { + softmax_->FinishBackward(*softmax_errors_t); + } + if (needs_to_backprop_) { + // Normalize the inputerr in back_deltas. + back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas); + return true; + } + return false; +} + +// Updates the weights using the given learning rate and momentum. +// num_samples is the quotient to be used in the adagrad computation iff +// use_ada_grad_ is true. +void LSTM::Update(float learning_rate, float momentum, int num_samples) { +#if DEBUG_DETAIL > 3 + PrintW(); +#endif + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + gate_weights_[w].Update(learning_rate, momentum, num_samples); + } + if (softmax_ != NULL) { + softmax_->Update(learning_rate, momentum, num_samples); + } +#if DEBUG_DETAIL > 3 + PrintDW(); +#endif +} + +// Sums the products of weight updates in *this and other, splitting into +// positive (same direction) in *same and negative (different direction) in +// *changed. +void LSTM::CountAlternators(const Network& other, double* same, + double* changed) const { + ASSERT_HOST(other.type() == type_); + const LSTM* lstm = reinterpret_cast(&other); + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed); + } + if (softmax_ != NULL) { + softmax_->CountAlternators(*lstm->softmax_, same, changed); + } +} + +// Prints the weights for debug purposes. +void LSTM::PrintW() { + tprintf("Weight state:%s\n", name_.string()); + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + tprintf("Gate %d, inputs\n", w); + for (int i = 0; i < ni_; ++i) { + tprintf("Row %d:", i); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetWeights(s)[i]); + tprintf("\n"); + } + tprintf("Gate %d, outputs\n", w); + for (int i = ni_; i < ni_ + ns_; ++i) { + tprintf("Row %d:", i - ni_); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetWeights(s)[i]); + tprintf("\n"); + } + tprintf("Gate %d, bias\n", w); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]); + tprintf("\n"); + } +} + +// Prints the weight deltas for debug purposes. +void LSTM::PrintDW() { + tprintf("Delta state:%s\n", name_.string()); + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + tprintf("Gate %d, inputs\n", w); + for (int i = 0; i < ni_; ++i) { + tprintf("Row %d:", i); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetDW(s, i)); + tprintf("\n"); + } + tprintf("Gate %d, outputs\n", w); + for (int i = ni_; i < ni_ + ns_; ++i) { + tprintf("Row %d:", i - ni_); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetDW(s, i)); + tprintf("\n"); + } + tprintf("Gate %d, bias\n", w); + for (int s = 0; s < ns_; ++s) + tprintf(" %g", gate_weights_[w].GetDW(s, na_)); + tprintf("\n"); + } +} + +// Resizes forward data to cope with an input image of the given width. +void LSTM::ResizeForward(const NetworkIO& input) { + source_.Resize(input, na_); + which_fg_.ResizeNoInit(input.Width(), ns_); + if (training_) { + state_.ResizeFloat(input, ns_); + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + node_values_[w].ResizeFloat(input, ns_); + } + } +} + + +} // namespace tesseract. diff --git a/lstm/lstm.h b/lstm/lstm.h new file mode 100644 index 0000000000..c62a846013 --- /dev/null +++ b/lstm/lstm.h @@ -0,0 +1,157 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstm.h +// Description: Long-term-short-term-memory Recurrent neural network. +// Author: Ray Smith +// Created: Wed May 01 17:33:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_LSTM_H_ +#define TESSERACT_LSTM_LSTM_H_ + +#include "network.h" +#include "fullyconnected.h" + +namespace tesseract { + +// C++ Implementation of the LSTM class from lstm.py. +class LSTM : public Network { + public: + // Enum for the different weights in LSTM, to reduce some of the I/O and + // setup code to loops. The elements of the enum correspond to elements of an + // array of WeightMatrix or a corresponding array of NetworkIO. + enum WeightType { + CI, // Cell Inputs. + GI, // Gate at the input. + GF1, // Forget gate at the memory (1-d or looking back 1 timestep). + GO, // Gate at the output. + GFS, // Forget gate at the memory, looking back in the other dimension. + + WT_COUNT // Number of WeightTypes. + }; + + // Constructor for NT_LSTM (regular 1 or 2-d LSTM), NT_LSTM_SOFTMAX (LSTM with + // additional softmax layer included and fed back into the input at the next + // timestep), or NT_LSTM_SOFTMAX_ENCODED (as LSTM_SOFTMAX, but the feedback + // is binary encoded instead of categorical) only. + // 2-d and bidi softmax LSTMs are not rejected, but are impossible to build + // in the conventional way because the output feedback both forwards and + // backwards in time does become impossible. + LSTM(const STRING& name, int num_inputs, int num_states, int num_outputs, + bool two_dimensional, NetworkType type); + virtual ~LSTM(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec; + if (type_ == NT_LSTM) + spec.add_str_int("Lfx", ns_); + else if (type_ == NT_LSTM_SUMMARY) + spec.add_str_int("Lfxs", ns_); + else if (type_ == NT_LSTM_SOFTMAX) + spec.add_str_int("LS", ns_); + else if (type_ == NT_LSTM_SOFTMAX_ENCODED) + spec.add_str_int("LE", ns_); + if (softmax_ != NULL) spec += softmax_->spec(); + return spec; + } + + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + virtual int InitWeights(float range, TRand* randomizer); + + // Converts a float network to an int network. + virtual void ConvertToInt(); + + // Provides debug output on the weights. + virtual void DebugWeights(); + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + // Updates the weights using the given learning rate and momentum. + // num_samples is the quotient to be used in the adagrad computation iff + // use_ada_grad_ is true. + virtual void Update(float learning_rate, float momentum, int num_samples); + // Sums the products of weight updates in *this and other, splitting into + // positive (same direction) in *same and negative (different direction) in + // *changed. + virtual void CountAlternators(const Network& other, double* same, + double* changed) const; + // Prints the weights for debug purposes. + void PrintW(); + // Prints the weight deltas for debug purposes. + void PrintDW(); + + // Returns true of this is a 2-d lstm. + bool Is2D() const { + return is_2d_; + } + + private: + // Resizes forward data to cope with an input image of the given width. + void ResizeForward(const NetworkIO& input); + + private: + // Size of padded input to weight matrices = ni_ + no_ for 1-D operation + // and ni_ + 2 * no_ for 2-D operation. Note that there is a phantom 1 input + // for the bias that makes the weight matrices of size [na + 1][no]. + inT32 na_; + // Number of internal states. Equal to no_ except for a softmax LSTM. + // ns_ is NOT serialized, but is calculated from gate_weights_. + inT32 ns_; + // Number of additional feedback states. The softmax types feed back + // additional output information on top of the ns_ internal states. + // In the case of a binary-coded (EMBEDDED) softmax, nf_ < no_. + inT32 nf_; + // Flag indicating 2-D operation. + bool is_2d_; + + // Gate weight arrays of size [na + 1, no]. + WeightMatrix gate_weights_[WT_COUNT]; + // Used only if this is a softmax LSTM. + FullyConnected* softmax_; + // Input padded with previous output of size [width, na]. + NetworkIO source_; + // Internal state used during forward operation, of size [width, ns]. + NetworkIO state_; + // State of the 2-d maxpool, generated during forward, used during backward. + GENERIC_2D_ARRAY which_fg_; + // Internal state saved from forward, but used only during backward. + NetworkIO node_values_[WT_COUNT]; + // Preserved input stride_map used for Backward when NT_LSTM_SQUASHED. + StrideMap input_map_; + int input_width_; +}; + +} // namespace tesseract. + + +#endif // TESSERACT_LSTM_LSTM_H_ diff --git a/lstm/lstmrecognizer.cpp b/lstm/lstmrecognizer.cpp new file mode 100644 index 0000000000..f648e160da --- /dev/null +++ b/lstm/lstmrecognizer.cpp @@ -0,0 +1,816 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmrecognizer.cpp +// Description: Top-level line recognizer class for LSTM-based networks. +// Author: Ray Smith +// Created: Thu May 02 10:59:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "lstmrecognizer.h" + +#include "allheaders.h" +#include "callcpp.h" +#include "dict.h" +#include "genericheap.h" +#include "helpers.h" +#include "imagedata.h" +#include "input.h" +#include "lstm.h" +#include "normalis.h" +#include "pageres.h" +#include "ratngs.h" +#include "recodebeam.h" +#include "scrollview.h" +#include "shapetable.h" +#include "statistc.h" +#include "tprintf.h" + +namespace tesseract { + +// Max number of blob choices to return in any given position. +const int kMaxChoices = 4; +// Default ratio between dict and non-dict words. +const double kDictRatio = 2.25; +// Default certainty offset to give the dictionary a chance. +const double kCertOffset = -0.085; + +LSTMRecognizer::LSTMRecognizer() + : network_(NULL), + training_flags_(0), + training_iteration_(0), + sample_iteration_(0), + null_char_(UNICHAR_BROKEN), + weight_range_(0.0f), + learning_rate_(0.0f), + momentum_(0.0f), + dict_(NULL), + search_(NULL), + debug_win_(NULL) {} + +LSTMRecognizer::~LSTMRecognizer() { + delete network_; + delete dict_; + delete search_; +} + +// Writes to the given file. Returns false in case of error. +bool LSTMRecognizer::Serialize(TFile* fp) const { + if (!network_->Serialize(fp)) return false; + if (!GetUnicharset().save_to_file(fp)) return false; + if (!network_str_.Serialize(fp)) return false; + if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1) + return false; + if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1) + return false; + if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) + return false; + if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false; + if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false; + if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; + if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; + if (IsRecoding() && !recoder_.Serialize(fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool LSTMRecognizer::DeSerialize(bool swap, TFile* fp) { + delete network_; + network_ = Network::CreateFromFile(swap, fp); + if (network_ == NULL) return false; + if (!ccutil_.unicharset.load_from_file(fp, false)) return false; + if (!network_str_.DeSerialize(swap, fp)) return false; + if (fp->FRead(&training_flags_, sizeof(training_flags_), 1) != 1) + return false; + if (fp->FRead(&training_iteration_, sizeof(training_iteration_), 1) != 1) + return false; + if (fp->FRead(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) + return false; + if (fp->FRead(&null_char_, sizeof(null_char_), 1) != 1) return false; + if (fp->FRead(&weight_range_, sizeof(weight_range_), 1) != 1) return false; + if (fp->FRead(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; + if (fp->FRead(&momentum_, sizeof(momentum_), 1) != 1) return false; + if (IsRecoding()) { + if (!recoder_.DeSerialize(swap, fp)) return false; + RecodedCharID code; + recoder_.EncodeUnichar(UNICHAR_SPACE, &code); + if (code(0) != UNICHAR_SPACE) { + tprintf("Space was garbled in recoding!!\n"); + return false; + } + } + // TODO(rays) swaps! + network_->SetRandomizer(&randomizer_); + network_->CacheXScaleFactor(network_->XScaleFactor()); + return true; +} + +// Loads the dictionary if possible from the traineddata file. +// Prints a warning message, and returns false but otherwise fails silently +// and continues to work without it if loading fails. +// Note that dictionary load is independent from DeSerialize, but dependent +// on the unicharset matching. This enables training to deserialize a model +// from checkpoint or restore without having to go back and reload the +// dictionary. +bool LSTMRecognizer::LoadDictionary(const char* data_file_name, + const char* lang) { + delete dict_; + dict_ = new Dict(&ccutil_); + dict_->SetupForLoad(Dict::GlobalDawgCache()); + dict_->LoadLSTM(data_file_name, lang); + if (dict_->FinishLoad()) return true; // Success. + tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", + lang); + delete dict_; + dict_ = NULL; + return false; +} + +// Recognizes the line image, contained within image_data, returning the +// ratings matrix and matching box_word for each WERD_RES in the output. +void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, + bool debug, double worst_dict_cert, + bool use_alternates, + const UNICHARSET* target_unicharset, + const TBOX& line_box, float score_ratio, + bool one_word, + PointerVector* words) { + NetworkIO outputs; + float label_threshold = use_alternates ? 0.75f : 0.0f; + float scale_factor; + NetworkIO inputs; + if (!RecognizeLine(image_data, invert, debug, false, label_threshold, + &scale_factor, &inputs, &outputs)) + return; + if (IsRecoding()) { + if (search_ == NULL) { + search_ = + new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); + } + search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL); + search_->ExtractBestPathAsWords(line_box, scale_factor, debug, + &GetUnicharset(), words); + } else { + GenericVector label_coords; + GenericVector labels; + LabelsFromOutputs(outputs, label_threshold, &labels, &label_coords); + WordsFromOutputs(outputs, labels, label_coords, line_box, debug, + use_alternates, one_word, score_ratio, scale_factor, + target_unicharset, words); + } +} + +// Builds a set of tesseract-compatible WERD_RESs aligned to line_box, +// corresponding to the network output in outputs, labels, label_coords. +// one_word generates a single word output, that may include spaces inside. +// use_alternates generates alternative BLOB_CHOICEs and segmentation paths. +// If not NULL, we attempt to translate the output to target_unicharset, but do +// not guarantee success, due to mismatches. In that case the output words are +// marked with our UNICHARSET, not the caller's. +void LSTMRecognizer::WordsFromOutputs( + const NetworkIO& outputs, const GenericVector& labels, + const GenericVector label_coords, const TBOX& line_box, bool debug, + bool use_alternates, bool one_word, float score_ratio, float scale_factor, + const UNICHARSET* target_unicharset, PointerVector* words) { + // Convert labels to unichar-ids. + int word_end = 0; + float prev_space_cert = 0.0f; + for (int i = 0; i < labels.size(); i = word_end) { + word_end = i + 1; + if (labels[i] == null_char_ || labels[i] == UNICHAR_SPACE) { + continue; + } + float space_cert = 0.0f; + if (one_word) { + word_end = labels.size(); + } else { + // Find the end of the word at the first null_char_ that leads to the + // first UNICHAR_SPACE. + while (word_end < labels.size() && labels[word_end] != UNICHAR_SPACE) + ++word_end; + if (word_end < labels.size()) { + float rating; + outputs.ScoresOverRange(label_coords[word_end], + label_coords[word_end] + 1, UNICHAR_SPACE, + null_char_, &rating, &space_cert); + } + while (word_end > i && labels[word_end - 1] == null_char_) --word_end; + } + ASSERT_HOST(word_end > i); + // Create a WERD_RES for the output word. + if (debug) + tprintf("Creating word from outputs over [%d,%d)\n", i, word_end); + WERD_RES* word = + WordFromOutput(line_box, outputs, i, word_end, score_ratio, + MIN(prev_space_cert, space_cert), debug, + use_alternates && !SimpleTextOutput(), target_unicharset, + labels, label_coords, scale_factor); + if (word == NULL && target_unicharset != NULL) { + // Unicharset translation failed - use decoder_ instead, and disable + // the segmentation search on output, as it won't understand the encoding. + word = WordFromOutput(line_box, outputs, i, word_end, score_ratio, + MIN(prev_space_cert, space_cert), debug, false, + NULL, labels, label_coords, scale_factor); + } + prev_space_cert = space_cert; + words->push_back(word); + } +} + +// Helper computes min and mean best results in the output. +void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output, + float* mean_output, float* sd) { + const int kOutputScale = MAX_INT8; + STATS stats(0, kOutputScale + 1); + for (int t = 0; t < outputs.Width(); ++t) { + int best_label = outputs.BestLabel(t, NULL); + if (best_label != null_char_ || t == 0) { + float best_output = outputs.f(t)[best_label]; + stats.add(static_cast(kOutputScale * best_output), 1); + } + } + *min_output = static_cast(stats.min_bucket()) / kOutputScale; + *mean_output = stats.mean() / kOutputScale; + *sd = stats.sd() / kOutputScale; +} + +// Recognizes the image_data, returning the labels, +// scores, and corresponding pairs of start, end x-coords in coords. +// If label_threshold is positive, uses it for making the labels, otherwise +// uses standard ctc. +bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, + bool debug, bool re_invert, + float label_threshold, float* scale_factor, + NetworkIO* inputs, NetworkIO* outputs) { + // Maximum width of image to train on. + const int kMaxImageWidth = 2048; + // This ensures consistent recognition results. + SetRandomSeed(); + int min_width = network_->XScaleFactor(); + Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width, + &randomizer_, scale_factor); + if (pix == NULL) { + tprintf("Line cannot be recognized!!\n"); + return false; + } + if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) { + tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), + pixGetHeight(pix)); + pixDestroy(&pix); + return false; + } + // Reduction factor from image to coords. + *scale_factor = min_width / *scale_factor; + inputs->set_int_mode(IsIntMode()); + SetRandomSeed(); + Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs); + network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs); + // Check for auto inversion. + float pos_min, pos_mean, pos_sd; + OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd); + if (invert && pos_min < 0.5) { + // Run again inverted and see if it is any better. + float inv_scale; + NetworkIO inv_inputs, inv_outputs; + inv_inputs.set_int_mode(IsIntMode()); + SetRandomSeed(); + pixInvert(pix, pix); + Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, + &inv_inputs); + network_->Forward(debug, inv_inputs, NULL, &scratch_space_, &inv_outputs); + float inv_min, inv_mean, inv_sd; + OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd); + if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) { + // Inverted did better. Use inverted data. + if (debug) { + tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", + pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd); + } + *outputs = inv_outputs; + *inputs = inv_inputs; + } else if (re_invert) { + // Inverting was not an improvement, so undo and run again, so the + // outputs match the best forward result. + SetRandomSeed(); + network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs); + } + } + pixDestroy(&pix); + if (debug) { + GenericVector labels, coords; + LabelsFromOutputs(*outputs, label_threshold, &labels, &coords); + DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_); + DebugActivationPath(*outputs, labels, coords); + } + return true; +} + +// Returns a tesseract-compatible WERD_RES from the line recognizer outputs. +// line_box should be the bounding box of the line image in the main image, +// outputs the output of the network, +// [word_start, word_end) the interval over which to convert, +// score_ratio for choosing alternate classifier choices, +// use_alternates to control generation of alternative segmentations, +// labels, label_coords, scale_factor from RecognizeLine above. +// If target_unicharset is not NULL, attempts to translate the internal +// unichar_ids to the target_unicharset, but falls back to untranslated ids +// if the translation should fail. +WERD_RES* LSTMRecognizer::WordFromOutput( + const TBOX& line_box, const NetworkIO& outputs, int word_start, + int word_end, float score_ratio, float space_certainty, bool debug, + bool use_alternates, const UNICHARSET* target_unicharset, + const GenericVector& labels, const GenericVector& label_coords, + float scale_factor) { + WERD_RES* word_res = InitializeWord( + line_box, word_start, word_end, space_certainty, use_alternates, + target_unicharset, labels, label_coords, scale_factor); + int max_blob_run = word_res->ratings->bandwidth(); + for (int width = 1; width <= max_blob_run; ++width) { + int col = 0; + for (int i = word_start; i + width <= word_end; ++i) { + if (labels[i] != null_char_) { + // Starting at i, use width labels, but stop at the next null_char_. + // This forms all combinations of blobs between regions of null_char_. + int j = i + 1; + while (j - i < width && labels[j] != null_char_) ++j; + if (j - i == width) { + // Make the blob choices. + int end_coord = label_coords[j]; + if (j < word_end && labels[j] == null_char_) + end_coord = label_coords[j + 1]; + BLOB_CHOICE_LIST* choices = GetBlobChoices( + col, col + width - 1, debug, outputs, target_unicharset, + label_coords[i], end_coord, score_ratio); + if (choices == NULL) { + delete word_res; + return NULL; + } + word_res->ratings->put(col, col + width - 1, choices); + } + ++col; + } + } + } + if (use_alternates) { + // Merge adjacent single results over null_char boundaries. + int col = 0; + for (int i = word_start; i + 2 < word_end; ++i) { + if (labels[i] != null_char_ && labels[i + 1] == null_char_ && + labels[i + 2] != null_char_ && + (i == word_start || labels[i - 1] == null_char_) && + (i + 3 == word_end || labels[i + 3] == null_char_)) { + int end_coord = label_coords[i + 3]; + if (i + 3 < word_end && labels[i + 3] == null_char_) + end_coord = label_coords[i + 4]; + BLOB_CHOICE_LIST* choices = + GetBlobChoices(col, col + 1, debug, outputs, target_unicharset, + label_coords[i], end_coord, score_ratio); + if (choices == NULL) { + delete word_res; + return NULL; + } + word_res->ratings->put(col, col + 1, choices); + } + if (labels[i] != null_char_) ++col; + } + } else { + word_res->FakeWordFromRatings(TOP_CHOICE_PERM); + } + return word_res; +} + +// Sets up a word with the ratings matrix and fake blobs with boxes in the +// right places. +WERD_RES* LSTMRecognizer::InitializeWord(const TBOX& line_box, int word_start, + int word_end, float space_certainty, + bool use_alternates, + const UNICHARSET* target_unicharset, + const GenericVector& labels, + const GenericVector& label_coords, + float scale_factor) { + // Make a fake blob for each non-zero label. + C_BLOB_LIST blobs; + C_BLOB_IT b_it(&blobs); + // num_blobs is the length of the diagonal of the ratings matrix. + int num_blobs = 0; + // max_blob_run is the diagonal width of the ratings matrix + int max_blob_run = 0; + int blob_run = 0; + for (int i = word_start; i < word_end; ++i) { + if (IsRecoding() && !recoder_.IsValidFirstCode(labels[i])) continue; + if (labels[i] != null_char_) { + // Make a fake blob. + TBOX box(label_coords[i], 0, label_coords[i + 1], line_box.height()); + box.scale(scale_factor); + box.move(ICOORD(line_box.left(), line_box.bottom())); + box.set_top(line_box.top()); + b_it.add_after_then_move(C_BLOB::FakeBlob(box)); + ++num_blobs; + ++blob_run; + } + if (labels[i] == null_char_ || i + 1 == word_end) { + if (blob_run > max_blob_run) + max_blob_run = blob_run; + } + } + if (!use_alternates) max_blob_run = 1; + ASSERT_HOST(label_coords.size() >= word_end); + // Make a fake word from the blobs. + WERD* word = new WERD(&blobs, word_start > 1 ? 1 : 0, NULL); + // Make a WERD_RES from the word. + WERD_RES* word_res = new WERD_RES(word); + word_res->uch_set = + target_unicharset != NULL ? target_unicharset : &GetUnicharset(); + word_res->combination = true; // Give it ownership of the word. + word_res->space_certainty = space_certainty; + word_res->ratings = new MATRIX(num_blobs, max_blob_run); + return word_res; +} + +// Converts an array of labels to utf-8, whether or not the labels are +// augmented with character boundaries. +STRING LSTMRecognizer::DecodeLabels(const GenericVector& labels) { + STRING result; + int end = 1; + for (int start = 0; start < labels.size(); start = end) { + if (labels[start] == null_char_) { + end = start + 1; + } else { + result += DecodeLabel(labels, start, &end, NULL); + } + } + return result; +} + +// Displays the forward results in a window with the characters and +// boundaries as determined by the labels and label_coords. +void LSTMRecognizer::DisplayForward(const NetworkIO& inputs, + const GenericVector& labels, + const GenericVector& label_coords, + const char* window_name, + ScrollView** window) { +#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics + int x_scale = network_->XScaleFactor(); + Pix* input_pix = inputs.ToPix(); + Network::ClearWindow(false, window_name, pixGetWidth(input_pix), + pixGetHeight(input_pix), window); + int line_height = Network::DisplayImage(input_pix, *window); + DisplayLSTMOutput(labels, label_coords, line_height, *window); +#endif // GRAPHICS_DISABLED +} + +// Displays the labels and cuts at the corresponding xcoords. +// Size of labels should match xcoords. +void LSTMRecognizer::DisplayLSTMOutput(const GenericVector& labels, + const GenericVector& xcoords, + int height, ScrollView* window) { +#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics + int x_scale = network_->XScaleFactor(); + window->TextAttributes("Arial", height / 4, false, false, false); + int end = 1; + for (int start = 0; start < labels.size(); start = end) { + int xpos = xcoords[start] * x_scale; + if (labels[start] == null_char_) { + end = start + 1; + window->Pen(ScrollView::RED); + } else { + window->Pen(ScrollView::GREEN); + const char* str = DecodeLabel(labels, start, &end, NULL); + if (*str == '\\') str = "\\\\"; + xpos = xcoords[(start + end) / 2] * x_scale; + window->Text(xpos, height, str); + } + window->Line(xpos, 0, xpos, height * 3 / 2); + } + window->Update(); +#endif // GRAPHICS_DISABLED +} + +// Prints debug output detailing the activation path that is implied by the +// label_coords. +void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs, + const GenericVector& labels, + const GenericVector& xcoords) { + if (xcoords[0] > 0) + DebugActivationRange(outputs, "", null_char_, 0, xcoords[0]); + int end = 1; + for (int start = 0; start < labels.size(); start = end) { + if (labels[start] == null_char_) { + end = start + 1; + DebugActivationRange(outputs, "", null_char_, xcoords[start], + xcoords[end]); + continue; + } else { + int decoded; + const char* label = DecodeLabel(labels, start, &end, &decoded); + DebugActivationRange(outputs, label, labels[start], xcoords[start], + xcoords[start + 1]); + for (int i = start + 1; i < end; ++i) { + DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], + xcoords[i], xcoords[i + 1]); + } + } + } +} + +// Prints debug output detailing activations and 2nd choice over a range +// of positions. +void LSTMRecognizer::DebugActivationRange(const NetworkIO& outputs, + const char* label, int best_choice, + int x_start, int x_end) { + tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end); + double max_score = 0.0; + double mean_score = 0.0; + int width = x_end - x_start; + for (int x = x_start; x < x_end; ++x) { + const float* line = outputs.f(x); + double score = line[best_choice] * 100.0; + if (score > max_score) max_score = score; + mean_score += score / width; + int best_c = 0; + double best_score = 0.0; + for (int c = 0; c < outputs.NumFeatures(); ++c) { + if (c != best_choice && line[c] > best_score) { + best_c = c; + best_score = line[c]; + } + } + tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, + best_score * 100.0); + } + tprintf(", Mean=%g, max=%g\n", mean_score, max_score); +} + +// Helper returns true if the null_char is the winner at t, and it beats the +// null_threshold, or the next choice is space, in which case we will use the +// null anyway. +static bool NullIsBest(const NetworkIO& output, float null_thr, + int null_char, int t) { + if (output.f(t)[null_char] >= null_thr) return true; + if (output.BestLabel(t, null_char, null_char, NULL) != UNICHAR_SPACE) + return false; + return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE]; +} + +// Converts the network output to a sequence of labels. Outputs labels, scores +// and start xcoords of each char, and each null_char_, with an additional +// final xcoord for the end of the output. +// The conversion method is determined by internal state. +void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, float null_thr, + GenericVector* labels, + GenericVector* xcoords) { + if (SimpleTextOutput()) { + LabelsViaSimpleText(outputs, labels, xcoords); + } else if (IsRecoding()) { + LabelsViaReEncode(outputs, labels, xcoords); + } else if (null_thr <= 0.0) { + LabelsViaCTC(outputs, labels, xcoords); + } else { + LabelsViaThreshold(outputs, null_thr, labels, xcoords); + } +} + +// Converts the network output to a sequence of labels, using a threshold +// on the null_char_ to determine character boundaries. Outputs labels, scores +// and start xcoords of each char, and each null_char_, with an additional +// final xcoord for the end of the output. +// The label output is the one with the highest score in the interval between +// null_chars_. +void LSTMRecognizer::LabelsViaThreshold(const NetworkIO& output, + float null_thr, + GenericVector* labels, + GenericVector* xcoords) { + labels->truncate(0); + xcoords->truncate(0); + int width = output.Width(); + int t = 0; + // Skip any initial non-char. + int label = null_char_; + while (t < width && NullIsBest(output, null_thr, null_char_, t)) { + ++t; + } + while (t < width) { + ASSERT_HOST(!isnan(output.f(t)[null_char_])); + int label = output.BestLabel(t, null_char_, null_char_, NULL); + int char_start = t++; + while (t < width && !NullIsBest(output, null_thr, null_char_, t) && + label == output.BestLabel(t, null_char_, null_char_, NULL)) { + ++t; + } + int char_end = t; + labels->push_back(label); + xcoords->push_back(char_start); + // Find the end of the non-char, and compute its score. + while (t < width && NullIsBest(output, null_thr, null_char_, t)) { + ++t; + } + if (t > char_end) { + labels->push_back(null_char_); + xcoords->push_back(char_end); + } + } + xcoords->push_back(width); +} + +// Converts the network output to a sequence of labels, with scores and +// start x-coords of the character labels. Retains the null_char_ as the +// end x-coord, where already present, otherwise the start of the next +// character is the end. +// The number of labels, scores, and xcoords is always matched, except that +// there is always an additional xcoord for the last end position. +void LSTMRecognizer::LabelsViaCTC(const NetworkIO& output, + GenericVector* labels, + GenericVector* xcoords) { + labels->truncate(0); + xcoords->truncate(0); + int width = output.Width(); + int t = 0; + while (t < width) { + float score = 0.0f; + int label = output.BestLabel(t, &score); + labels->push_back(label); + xcoords->push_back(t); + while (++t < width && output.BestLabel(t, NULL) == label) { + } + } + xcoords->push_back(width); +} + +// As LabelsViaCTC except that this function constructs the best path that +// contains only legal sequences of subcodes for CJK. +void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output, + GenericVector* labels, + GenericVector* xcoords) { + if (search_ == NULL) { + search_ = + new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); + } + search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, NULL); + search_->ExtractBestPathAsLabels(labels, xcoords); +} + +// Converts the network output to a sequence of labels, with scores, using +// the simple character model (each position is a char, and the null_char_ is +// mainly intended for tail padding.) +void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output, + GenericVector* labels, + GenericVector* xcoords) { + labels->truncate(0); + xcoords->truncate(0); + int width = output.Width(); + for (int t = 0; t < width; ++t) { + float score = 0.0f; + int label = output.BestLabel(t, &score); + if (label != null_char_) { + labels->push_back(label); + xcoords->push_back(t); + } + } + xcoords->push_back(width); +} + +// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range. +// Handles either LSTM labels or direct unichar-ids. +// Score ratio determines the worst ratio between top choice and remainder. +// If target_unicharset is not NULL, attempts to translate to the target +// unicharset, returning NULL on failure. +BLOB_CHOICE_LIST* LSTMRecognizer::GetBlobChoices( + int col, int row, bool debug, const NetworkIO& output, + const UNICHARSET* target_unicharset, int x_start, int x_end, + float score_ratio) { + int width = x_end - x_start; + float rating = 0.0f, certainty = 0.0f; + int label = output.BestChoiceOverRange(x_start, x_end, UNICHAR_SPACE, + null_char_, &rating, &certainty); + int unichar_id = label == null_char_ ? UNICHAR_SPACE : label; + if (debug) { + tprintf("Best choice over range %d,%d=unichar%d=%s r = %g, cert=%g\n", + x_start, x_end, unichar_id, DecodeSingleLabel(label), rating, + certainty); + } + BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST; + BLOB_CHOICE_IT bc_it(choices); + if (!AddBlobChoices(unichar_id, rating, certainty, col, row, + target_unicharset, &bc_it)) { + delete choices; + return NULL; + } + // Get the other choices. + double best_cert = certainty; + for (int c = 0; c < output.NumFeatures(); ++c) { + if (c == label || c == UNICHAR_SPACE || c == null_char_) continue; + // Compute the score over the range. + output.ScoresOverRange(x_start, x_end, c, null_char_, &rating, &certainty); + int unichar_id = c == null_char_ ? UNICHAR_SPACE : c; + if (certainty >= best_cert - score_ratio && + !AddBlobChoices(unichar_id, rating, certainty, col, row, + target_unicharset, &bc_it)) { + delete choices; + return NULL; + } + } + choices->sort(&BLOB_CHOICE::SortByRating); + if (bc_it.length() > kMaxChoices) { + bc_it.move_to_first(); + for (int i = 0; i < kMaxChoices; ++i) + bc_it.forward(); + while (!bc_it.at_first()) { + delete bc_it.extract(); + bc_it.forward(); + } + } + return choices; +} + +// Adds to the given iterator, the blob choices for the target_unicharset +// that correspond to the given LSTM unichar_id. +// Returns false if unicharset translation failed. +bool LSTMRecognizer::AddBlobChoices(int unichar_id, float rating, + float certainty, int col, int row, + const UNICHARSET* target_unicharset, + BLOB_CHOICE_IT* bc_it) { + int target_id = unichar_id; + if (target_unicharset != NULL) { + const char* utf8 = GetUnicharset().id_to_unichar(unichar_id); + if (target_unicharset->contains_unichar(utf8)) { + target_id = target_unicharset->unichar_to_id(utf8); + } else { + return false; + } + } + BLOB_CHOICE* choice = new BLOB_CHOICE(target_id, rating, certainty, -1, 1.0f, + static_cast(MAX_INT16), 0.0f, + BCC_STATIC_CLASSIFIER); + choice->set_matrix_cell(col, row); + bc_it->add_after_then_move(choice); + return true; +} + +// Returns a string corresponding to the label starting at start. Sets *end +// to the next start and if non-null, *decoded to the unichar id. +const char* LSTMRecognizer::DecodeLabel(const GenericVector& labels, + int start, int* end, int* decoded) { + *end = start + 1; + if (IsRecoding()) { + // Decode labels via recoder_. + RecodedCharID code; + if (labels[start] == null_char_) { + if (decoded != NULL) { + code.Set(0, null_char_); + *decoded = recoder_.DecodeUnichar(code); + } + return ""; + } + int index = start; + while (index < labels.size() && + code.length() < RecodedCharID::kMaxCodeLen) { + code.Set(code.length(), labels[index++]); + while (index < labels.size() && labels[index] == null_char_) ++index; + int uni_id = recoder_.DecodeUnichar(code); + // If the next label isn't a valid first code, then we need to continue + // extending even if we have a valid uni_id from this prefix. + if (uni_id != INVALID_UNICHAR_ID && + (index == labels.size() || + code.length() == RecodedCharID::kMaxCodeLen || + recoder_.IsValidFirstCode(labels[index]))) { + *end = index; + if (decoded != NULL) *decoded = uni_id; + if (uni_id == UNICHAR_SPACE) return " "; + return GetUnicharset().get_normed_unichar(uni_id); + } + } + return ""; + } else { + if (decoded != NULL) *decoded = labels[start]; + if (labels[start] == null_char_) return ""; + if (labels[start] == UNICHAR_SPACE) return " "; + return GetUnicharset().get_normed_unichar(labels[start]); + } +} + +// Returns a string corresponding to a given single label id, falling back to +// a default of ".." for part of a multi-label unichar-id. +const char* LSTMRecognizer::DecodeSingleLabel(int label) { + if (label == null_char_) return ""; + if (IsRecoding()) { + // Decode label via recoder_. + RecodedCharID code; + code.Set(0, label); + label = recoder_.DecodeUnichar(code); + if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code. + } + if (label == UNICHAR_SPACE) return " "; + return GetUnicharset().get_normed_unichar(label); +} + +} // namespace tesseract. diff --git a/lstm/lstmrecognizer.h b/lstm/lstmrecognizer.h new file mode 100644 index 0000000000..d99439dbdb --- /dev/null +++ b/lstm/lstmrecognizer.h @@ -0,0 +1,392 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmrecognizer.h +// Description: Top-level line recognizer class for LSTM-based networks. +// Author: Ray Smith +// Created: Thu May 02 08:57:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_ +#define TESSERACT_LSTM_LSTMRECOGNIZER_H_ + +#include "ccutil.h" +#include "helpers.h" +#include "imagedata.h" +#include "matrix.h" +#include "network.h" +#include "networkscratch.h" +#include "recodebeam.h" +#include "series.h" +#include "strngs.h" +#include "unicharcompress.h" + +class BLOB_CHOICE_IT; +struct Pix; +class ROW_RES; +class ScrollView; +class TBOX; +class WERD_RES; + +namespace tesseract { + +class Dict; +class ImageData; + +// Enum indicating training mode control flags. +enum TrainingFlags { + TF_INT_MODE = 1, + TF_AUTO_HARDEN = 2, + TF_ROUND_ROBIN_TRAINING = 16, + TF_COMPRESS_UNICHARSET = 64, +}; + +// Top-level line recognizer class for LSTM-based networks. +// Note that a sub-class, LSTMTrainer is used for training. +class LSTMRecognizer { + public: + LSTMRecognizer(); + ~LSTMRecognizer(); + + int NumOutputs() const { + return network_->NumOutputs(); + } + int training_iteration() const { + return training_iteration_; + } + int sample_iteration() const { + return sample_iteration_; + } + double learning_rate() const { + return learning_rate_; + } + bool IsHardening() const { + return (training_flags_ & TF_AUTO_HARDEN) != 0; + } + LossType OutputLossType() const { + if (network_ == nullptr) return LT_NONE; + StaticShape shape; + shape = network_->OutputShape(shape); + return shape.loss_type(); + } + bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; } + bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; } + // True if recoder_ is active to re-encode text to a smaller space. + bool IsRecoding() const { + return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0; + } + // Returns the cache strategy for the DocumentCache. + CachingStrategy CacheStrategy() const { + return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN + : CS_SEQUENTIAL; + } + // Returns true if the network is a TensorFlow network. + bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; } + // Returns a vector of layer ids that can be passed to other layer functions + // to access a specific layer. + GenericVector EnumerateLayers() const { + ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); + Series* series = reinterpret_cast(network_); + GenericVector layers; + series->EnumerateLayers(NULL, &layers); + return layers; + } + // Returns a specific layer from its id (from EnumerateLayers). + Network* GetLayer(const STRING& id) const { + ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); + ASSERT_HOST(id.length() > 1 && id[0] == ':'); + Series* series = reinterpret_cast(network_); + return series->GetLayer(&id[1]); + } + // Returns the learning rate of the layer from its id. + float GetLayerLearningRate(const STRING& id) const { + ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); + if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { + ASSERT_HOST(id.length() > 1 && id[0] == ':'); + Series* series = reinterpret_cast(network_); + return series->LayerLearningRate(&id[1]); + } else { + return learning_rate_; + } + } + // Multiplies the all the learning rate(s) by the given factor. + void ScaleLearningRate(double factor) { + ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); + learning_rate_ *= factor; + if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { + GenericVector layers = EnumerateLayers(); + for (int i = 0; i < layers.size(); ++i) { + ScaleLayerLearningRate(layers[i], factor); + } + } + } + // Multiplies the learning rate of the layer with id, by the given factor. + void ScaleLayerLearningRate(const STRING& id, double factor) { + ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); + ASSERT_HOST(id.length() > 1 && id[0] == ':'); + Series* series = reinterpret_cast(network_); + series->ScaleLayerLearningRate(&id[1], factor); + } + + // True if the network is using adagrad to train. + bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); } + // Provides access to the UNICHARSET that this classifier works with. + const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; } + // Sets the sample iteration to the given value. The sample_iteration_ + // determines the seed for the random number generator. The training + // iteration is incremented only by a successful training iteration. + void SetIteration(int iteration) { + sample_iteration_ = iteration; + } + // Accessors for textline image normalization. + int NumInputs() const { + return network_->NumInputs(); + } + int null_char() const { return null_char_; } + + // Writes to the given file. Returns false in case of error. + bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, TFile* fp); + // Loads the dictionary if possible from the traineddata file. + // Prints a warning message, and returns false but otherwise fails silently + // and continues to work without it if loading fails. + // Note that dictionary load is independent from DeSerialize, but dependent + // on the unicharset matching. This enables training to deserialize a model + // from checkpoint or restore without having to go back and reload the + // dictionary. + bool LoadDictionary(const char* data_file_name, const char* lang); + + // Recognizes the line image, contained within image_data, returning the + // ratings matrix and matching box_word for each WERD_RES in the output. + // If invert, tries inverted as well if the normal interpretation doesn't + // produce a good enough result. If use_alternates, the ratings matrix is + // filled with segmentation and classifier alternatives that may be searched + // using the standard beam search, otherwise, just a diagonal and prebuilt + // best_choice. The line_box is used for computing the box_word in the + // output words. Score_ratio is used to determine the classifier alternates. + // If one_word, then a single WERD_RES is formed, regardless of the spaces + // found during recognition. + // If not NULL, we attempt to translate the output to target_unicharset, but + // do not guarantee success, due to mismatches. In that case the output words + // are marked with our UNICHARSET, not the caller's. + void RecognizeLine(const ImageData& image_data, bool invert, bool debug, + double worst_dict_cert, bool use_alternates, + const UNICHARSET* target_unicharset, const TBOX& line_box, + float score_ratio, bool one_word, + PointerVector* words); + // Builds a set of tesseract-compatible WERD_RESs aligned to line_box, + // corresponding to the network output in outputs, labels, label_coords. + // one_word generates a single word output, that may include spaces inside. + // use_alternates generates alternative BLOB_CHOICEs and segmentation paths, + // with cut-offs determined by scale_factor. + // If not NULL, we attempt to translate the output to target_unicharset, but + // do not guarantee success, due to mismatches. In that case the output words + // are marked with our UNICHARSET, not the caller's. + void WordsFromOutputs(const NetworkIO& outputs, + const GenericVector& labels, + const GenericVector label_coords, + const TBOX& line_box, bool debug, bool use_alternates, + bool one_word, float score_ratio, float scale_factor, + const UNICHARSET* target_unicharset, + PointerVector* words); + + // Helper computes min and mean best results in the output. + void OutputStats(const NetworkIO& outputs, + float* min_output, float* mean_output, float* sd); + // Recognizes the image_data, returning the labels, + // scores, and corresponding pairs of start, end x-coords in coords. + // If label_threshold is positive, uses it for making the labels, otherwise + // uses standard ctc. Returned in scale_factor is the reduction factor + // between the image and the output coords, for computing bounding boxes. + // If re_invert is true, the input is inverted back to its orginal + // photometric interpretation if inversion is attempted but fails to + // improve the results. This ensures that outputs contains the correct + // forward outputs for the best photometric interpretation. + // inputs is filled with the used inputs to the network, and if not null, + // target boxes is filled with scaled truth boxes if present in image_data. + bool RecognizeLine(const ImageData& image_data, bool invert, bool debug, + bool re_invert, float label_threshold, float* scale_factor, + NetworkIO* inputs, NetworkIO* outputs); + // Returns a tesseract-compatible WERD_RES from the line recognizer outputs. + // line_box should be the bounding box of the line image in the main image, + // outputs the output of the network, + // [word_start, word_end) the interval over which to convert, + // score_ratio for choosing alternate classifier choices, + // use_alternates to control generation of alternative segmentations, + // labels, label_coords, scale_factor from RecognizeLine above. + // If target_unicharset is not NULL, attempts to translate the internal + // unichar_ids to the target_unicharset, but falls back to untranslated ids + // if the translation should fail. + WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs, + int word_start, int word_end, float score_ratio, + float space_certainty, bool debug, + bool use_alternates, + const UNICHARSET* target_unicharset, + const GenericVector& labels, + const GenericVector& label_coords, + float scale_factor); + // Sets up a word with the ratings matrix and fake blobs with boxes in the + // right places. + WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end, + float space_certainty, bool use_alternates, + const UNICHARSET* target_unicharset, + const GenericVector& labels, + const GenericVector& label_coords, + float scale_factor); + + // Converts an array of labels to utf-8, whether or not the labels are + // augmented with character boundaries. + STRING DecodeLabels(const GenericVector& labels); + + // Displays the forward results in a window with the characters and + // boundaries as determined by the labels and label_coords. + void DisplayForward(const NetworkIO& inputs, + const GenericVector& labels, + const GenericVector& label_coords, + const char* window_name, + ScrollView** window); + + protected: + // Sets the random seed from the sample_iteration_; + void SetRandomSeed() { + inT64 seed = static_cast(sample_iteration_) * 0x10000001; + randomizer_.set_seed(seed); + randomizer_.IntRand(); + } + + // Displays the labels and cuts at the corresponding xcoords. + // Size of labels should match xcoords. + void DisplayLSTMOutput(const GenericVector& labels, + const GenericVector& xcoords, + int height, ScrollView* window); + + // Prints debug output detailing the activation path that is implied by the + // xcoords. + void DebugActivationPath(const NetworkIO& outputs, + const GenericVector& labels, + const GenericVector& xcoords); + + // Prints debug output detailing activations and 2nd choice over a range + // of positions. + void DebugActivationRange(const NetworkIO& outputs, const char* label, + int best_choice, int x_start, int x_end); + + // Converts the network output to a sequence of labels. Outputs labels, scores + // and start xcoords of each char, and each null_char_, with an additional + // final xcoord for the end of the output. + // The conversion method is determined by internal state. + void LabelsFromOutputs(const NetworkIO& outputs, float null_thr, + GenericVector* labels, + GenericVector* xcoords); + // Converts the network output to a sequence of labels, using a threshold + // on the null_char_ to determine character boundaries. Outputs labels, scores + // and start xcoords of each char, and each null_char_, with an additional + // final xcoord for the end of the output. + // The label output is the one with the highest score in the interval between + // null_chars_. + void LabelsViaThreshold(const NetworkIO& output, + float null_threshold, + GenericVector* labels, + GenericVector* xcoords); + // Converts the network output to a sequence of labels, with scores and + // start x-coords of the character labels. Retains the null_char_ character as + // the end x-coord, where already present, otherwise the start of the next + // character is the end. + // The number of labels, scores, and xcoords is always matched, except that + // there is always an additional xcoord for the last end position. + void LabelsViaCTC(const NetworkIO& output, + GenericVector* labels, + GenericVector* xcoords); + // As LabelsViaCTC except that this function constructs the best path that + // contains only legal sequences of subcodes for recoder_. + void LabelsViaReEncode(const NetworkIO& output, GenericVector* labels, + GenericVector* xcoords); + // Converts the network output to a sequence of labels, with scores, using + // the simple character model (each position is a char, and the null_char_ is + // mainly intended for tail padding.) + void LabelsViaSimpleText(const NetworkIO& output, + GenericVector* labels, + GenericVector* xcoords); + + // Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range. + // Handles either LSTM labels or direct unichar-ids. + // Score ratio determines the worst ratio between top choice and remainder. + // If target_unicharset is not NULL, attempts to translate to the target + // unicharset, returning NULL on failure. + BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug, + const NetworkIO& output, + const UNICHARSET* target_unicharset, + int x_start, int x_end, float score_ratio); + + // Adds to the given iterator, the blob choices for the target_unicharset + // that correspond to the given LSTM unichar_id. + // Returns false if unicharset translation failed. + bool AddBlobChoices(int unichar_id, float rating, float certainty, int col, + int row, const UNICHARSET* target_unicharset, + BLOB_CHOICE_IT* bc_it); + + // Returns a string corresponding to the label starting at start. Sets *end + // to the next start and if non-null, *decoded to the unichar id. + const char* DecodeLabel(const GenericVector& labels, int start, int* end, + int* decoded); + + // Returns a string corresponding to a given single label id, falling back to + // a default of ".." for part of a multi-label unichar-id. + const char* DecodeSingleLabel(int label); + + protected: + // The network hierarchy. + Network* network_; + // The unicharset. Only the unicharset element is serialized. + // Has to be a CCUtil, so Dict can point to it. + CCUtil ccutil_; + // For backward compatability, recoder_ is serialized iff + // training_flags_ & TF_COMPRESS_UNICHARSET. + // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset. + UnicharCompress recoder_; + + // ==Training parameters that are serialized to provide a record of them.== + STRING network_str_; + // Flags used to determine the training method of the network. + // See enum TrainingFlags above. + inT32 training_flags_; + // Number of actual backward training steps used. + inT32 training_iteration_; + // Index into training sample set. sample_iteration >= training_iteration_. + inT32 sample_iteration_; + // Index in softmax of null character. May take the value UNICHAR_BROKEN or + // ccutil_.unicharset.size(). + inT32 null_char_; + // Range used for the initial random numbers in the weights. + float weight_range_; + // Learning rate and momentum multipliers of deltas in backprop. + float learning_rate_; + float momentum_; + + // === NOT SERIALIZED. + TRand randomizer_; + NetworkScratch scratch_space_; + // Language model (optional) to use with the beam search. + Dict* dict_; + // Beam search held between uses to optimize memory allocation/use. + RecodeBeamSearch* search_; + + // == Debugging parameters.== + // Recognition debug display window. + ScrollView* debug_win_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_ diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp new file mode 100644 index 0000000000..009aa41363 --- /dev/null +++ b/lstm/lstmtrainer.cpp @@ -0,0 +1,1331 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmtrainer.cpp +// Description: Top-level line trainer class for LSTM-based networks. +// Author: Ray Smith +// Created: Fir May 03 09:14:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "lstmtrainer.h" +#include + +#include "allheaders.h" +#include "boxread.h" +#include "ctc.h" +#include "imagedata.h" +#include "input.h" +#include "networkbuilder.h" +#include "ratngs.h" +#include "recodebeam.h" +#ifdef INCLUDE_TENSORFLOW +#include "tfnetwork.h" +#endif +#include "tprintf.h" + +#include "callcpp.h" + +using std::string; + +namespace tesseract { + +// Min actual error rate increase to constitute divergence. +const double kMinDivergenceRate = 50.0; +// Min iterations since last best before acting on a stall. +const int kMinStallIterations = 10000; +// Fraction of current char error rate that sub_trainer_ has to be ahead +// before we declare the sub_trainer_ a success and switch to it. +const double kSubTrainerMarginFraction = 3.0 / 128; +// Factor to reduce learning rate on divergence. +const double kLearningRateDecay = sqrt(0.5); +// LR adjustment iterations. +const int kNumAdjustmentIterations = 100; +// How often to add data to the error_graph_. +const int kErrorGraphInterval = 1000; +// Number of training images to train between calls to MaintainCheckpoints. +const int kNumPagesPerBatch = 100; +// Min percent error rate to consider start-up phase over. +const int kMinStartedErrorRate = 75; +// Error rate at which to transition to stage 1. +const double kStageTransitionThreshold = 10.0; +// How often to test for flipping. +const int kFlipTestRate = 20; +// Confidence beyond which the truth is more likely wrong than the recognizer. +const double kHighConfidence = 0.9375; // 15/16. +// Fraction of weight sign-changing total to constitute a definite improvement. +const double kImprovementFraction = 15.0 / 16.0; +// Fraction of last written best to make it worth writing another. +const double kBestCheckpointFraction = 31.0 / 32.0; +// Scale factor for display of target activations of CTC. +const int kTargetXScale = 5; +const int kTargetYScale = 100; + +LSTMTrainer::LSTMTrainer() + : training_data_(0), + file_reader_(LoadDataFromFile), + file_writer_(SaveDataToFile), + checkpoint_reader_( + NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)), + checkpoint_writer_( + NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)), + sub_trainer_(NULL) { + EmptyConstructor(); + debug_interval_ = 0; +} + +LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer, + CheckPointReader checkpoint_reader, + CheckPointWriter checkpoint_writer, + const char* model_base, const char* checkpoint_name, + int debug_interval, inT64 max_memory) + : training_data_(max_memory), + file_reader_(file_reader), + file_writer_(file_writer), + checkpoint_reader_(checkpoint_reader), + checkpoint_writer_(checkpoint_writer), + sub_trainer_(NULL) { + EmptyConstructor(); + if (file_reader_ == NULL) file_reader_ = LoadDataFromFile; + if (file_writer_ == NULL) file_writer_ = SaveDataToFile; + if (checkpoint_reader_ == NULL) { + checkpoint_reader_ = + NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump); + } + if (checkpoint_writer_ == NULL) { + checkpoint_writer_ = + NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump); + } + debug_interval_ = debug_interval; + model_base_ = model_base; + checkpoint_name_ = checkpoint_name; +} + +LSTMTrainer::~LSTMTrainer() { + delete align_win_; + delete target_win_; + delete ctc_win_; + delete recon_win_; + delete checkpoint_reader_; + delete checkpoint_writer_; + delete sub_trainer_; +} + +// Tries to deserialize a trainer from the given file and silently returns +// false in case of failure. +bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) { + GenericVector data; + if (!(*file_reader_)(filename, &data)) return false; + tprintf("Loaded file %s, unpacking...\n", filename); + return checkpoint_reader_->Run(data, this); +} + +// Initializes the character set encode/decode mechanism. +// train_flags control training behavior according to the TrainingFlags +// enum, including character set encoding. +// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided, +// fully initializes the unicharset from the universal unicharsets. +// Note: Call before InitNetwork! +void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, + const STRING& script_dir, int train_flags) { + // Call before InitNetwork. + ASSERT_HOST(network_ == NULL); + EmptyConstructor(); + training_flags_ = train_flags; + ccutil_.unicharset.CopyFrom(unicharset); + null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN + : GetUnicharset().size(); + SetUnicharsetProperties(script_dir); +} + +// Initializes the character set encode/decode mechanism directly from a +// previously setup UNICHARSET and UnicharCompress. +// ctc_mode controls how the truth text is mapped to the network targets. +// Note: Call before InitNetwork! +void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, + const UnicharCompress recoder) { + // Call before InitNetwork. + ASSERT_HOST(network_ == NULL); + EmptyConstructor(); + int flags = TF_COMPRESS_UNICHARSET; + training_flags_ = static_cast(flags); + ccutil_.unicharset.CopyFrom(unicharset); + recoder_ = recoder; + null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN + : GetUnicharset().size(); + RecodedCharID code; + recoder_.EncodeUnichar(null_char_, &code); + null_char_ = code(0); + // Space should encode as itself. + recoder_.EncodeUnichar(UNICHAR_SPACE, &code); + ASSERT_HOST(code(0) == UNICHAR_SPACE); +} + +// Initializes the trainer with a network_spec in the network description +// net_flags control network behavior according to the NetworkFlags enum. +// There isn't really much difference between them - only where the effects +// are implemented. +// For other args see NetworkBuilder::InitNetwork. +// Note: Be sure to call InitCharSet before InitNetwork! +bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index, + int net_flags, float weight_range, + float learning_rate, float momentum) { + // Call after InitCharSet. + ASSERT_HOST(GetUnicharset().size() > SPECIAL_UNICHAR_CODES_COUNT); + weight_range_ = weight_range; + learning_rate_ = learning_rate; + momentum_ = momentum; + int num_outputs = null_char_ == GetUnicharset().size() + ? null_char_ + 1 + : GetUnicharset().size(); + if (IsRecoding()) num_outputs = recoder_.code_range(); + if (!NetworkBuilder::InitNetwork(num_outputs, network_spec, append_index, + net_flags, weight_range, &randomizer_, + &network_)) { + return false; + } + network_str_ += network_spec; + tprintf("Built network:%s from request %s\n", + network_->spec().string(), network_spec.string()); + tprintf("Training parameters:\n Debug interval = %d," + " weights = %g, learning rate = %g, momentum=%g\n", + debug_interval_, weight_range_, learning_rate_, momentum_); + return true; +} + +// Initializes a trainer from a serialized TFNetworkModel proto. +// Returns the global step of TensorFlow graph or 0 if failed. +int LSTMTrainer::InitTensorFlowNetwork(const string& tf_proto) { +#ifdef INCLUDE_TENSORFLOW + delete network_; + TFNetwork* tf_net = new TFNetwork("TensorFlow"); + training_iteration_ = tf_net->InitFromProtoStr(tf_proto); + if (training_iteration_ == 0) { + tprintf("InitFromProtoStr failed!!\n"); + return 0; + } + network_ = tf_net; + ASSERT_HOST(recoder_.code_range() == tf_net->num_classes()); + return training_iteration_; +#else + tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n"); + return 0; +#endif +} + +// If the training sample is usable, grid searches for the optimal +// dict_ratio/cert_offset, and returns the results in a string of space- +// separated triplets of ratio,offset=worderr. +Trainability LSTMTrainer::GridSearchDictParams( + const ImageData* trainingdata, int iteration, double min_dict_ratio, + double dict_ratio_step, double max_dict_ratio, double min_cert_offset, + double cert_offset_step, double max_cert_offset, STRING* results) { + sample_iteration_ = iteration; + NetworkIO fwd_outputs, targets; + Trainability result = + PrepareForBackward(trainingdata, &fwd_outputs, &targets); + if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == NULL) + return result; + + // Encode/decode the truth to get the normalization. + GenericVector truth_labels, ocr_labels, xcoords; + ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels)); + // NO-dict error. + RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), NULL); + base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, + NULL); + base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords); + STRING truth_text = DecodeLabels(truth_labels); + STRING ocr_text = DecodeLabels(ocr_labels); + double baseline_error = ComputeWordError(&truth_text, &ocr_text); + results->add_str_double("0,0=", baseline_error); + + RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_); + for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) { + for (double c = min_cert_offset; c < max_cert_offset; + c += cert_offset_step) { + search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, NULL); + search.ExtractBestPathAsLabels(&ocr_labels, &xcoords); + truth_text = DecodeLabels(truth_labels); + ocr_text = DecodeLabels(ocr_labels); + // This is destructive on both strings. + double word_error = ComputeWordError(&truth_text, &ocr_text); + if ((r == min_dict_ratio && c == min_cert_offset) || + !std::isfinite(word_error)) { + STRING t = DecodeLabels(truth_labels); + STRING o = DecodeLabels(ocr_labels); + tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c, + t.string(), o.string(), word_error, truth_labels[0]); + } + results->add_str_double(" ", r); + results->add_str_double(",", c); + results->add_str_double("=", word_error); + } + } + return result; +} + +// Provides output on the distribution of weight values. +void LSTMTrainer::DebugNetwork() { + network_->DebugWeights(); +} + +// Loads a set of lstmf files that were created using the lstm.train config to +// tesseract into memory ready for training. Returns false if nothing was +// loaded. +bool LSTMTrainer::LoadAllTrainingData(const GenericVector& filenames) { + training_data_.Clear(); + return training_data_.LoadDocuments(filenames, "eng", CacheStrategy(), + file_reader_); +} + +// Keeps track of best and locally worst char error_rate and launches tests +// using tester, when a new min or max is reached. +// Writes checkpoints at appropriate times and builds and returns a log message +// to indicate progress. Returns false if nothing interesting happened. +bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) { + PrepareLogMsg(log_msg); + double error_rate = CharError(); + int iteration = learning_iteration(); + if (iteration >= stall_iteration_ && + error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) && + best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) { + // It hasn't got any better in a long while, and is a margin worse than the + // best, so go back to the best model and try a different learning rate. + StartSubtrainer(log_msg); + } + SubTrainerResult sub_trainer_result = STR_NONE; + if (sub_trainer_ != NULL) { + sub_trainer_result = UpdateSubtrainer(log_msg); + if (sub_trainer_result == STR_REPLACED) { + // Reset the inputs, as we have overwritten *this. + error_rate = CharError(); + iteration = learning_iteration(); + PrepareLogMsg(log_msg); + } + } + bool result = true; // Something interesting happened. + GenericVector rec_model_data; + if (error_rate < best_error_rate_) { + SaveRecognitionDump(&rec_model_data); + log_msg->add_str_double(" New best char error = ", error_rate); + *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester); + // If sub_trainer_ is not NULL, either *this beat it to a new best, or it + // just overwrote *this. In either case, we have finished with it. + delete sub_trainer_; + sub_trainer_ = NULL; + stall_iteration_ = learning_iteration() + kMinStallIterations; + if (TransitionTrainingStage(kStageTransitionThreshold)) { + log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage()); + } + checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_); + if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) { + STRING best_model_name = DumpFilename(); + if (!(*file_writer_)(best_trainer_, best_model_name)) { + *log_msg += " failed to write best model:"; + } else { + *log_msg += " wrote best model:"; + error_rate_of_last_saved_best_ = best_error_rate_; + } + *log_msg += best_model_name; + } + } else if (error_rate > worst_error_rate_) { + SaveRecognitionDump(&rec_model_data); + log_msg->add_str_double(" New worst char error = ", error_rate); + *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester); + if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate && + best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) { + // Error rate has ballooned. Go back to the best model. + *log_msg += "\nDivergence! "; + // Copy best_trainer_ before reading it, as it will get overwritten. + GenericVector revert_data(best_trainer_); + if (checkpoint_reader_->Run(revert_data, this)) { + LogIterations("Reverted to", log_msg); + ReduceLearningRates(this, log_msg); + } else { + LogIterations("Failed to Revert at", log_msg); + } + // If it fails again, we will wait twice as long before reverting again. + stall_iteration_ = iteration + 2 * (iteration - learning_iteration()); + // Re-save the best trainer with the new learning rates and stall + // iteration. + checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_); + } + } else { + // Something interesting happened only if the sub_trainer_ was trained. + result = sub_trainer_result != STR_NONE; + } + if (checkpoint_writer_ != NULL && file_writer_ != NULL && + checkpoint_name_.length() > 0) { + // Write a current checkpoint. + GenericVector checkpoint; + if (!checkpoint_writer_->Run(FULL, this, &checkpoint) || + !(*file_writer_)(checkpoint, checkpoint_name_)) { + *log_msg += " failed to write checkpoint."; + } else { + *log_msg += " wrote checkpoint."; + } + } + *log_msg += "\n"; + return result; +} + +// Builds a string containing a progress message with current error rates. +void LSTMTrainer::PrepareLogMsg(STRING* log_msg) const { + LogIterations("At", log_msg); + log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]); + log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]); + log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]); + log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]); + log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]); + *log_msg += "%, "; +} + +// Appends iteration learning_iteration()/training_iteration()/ +// sample_iteration() to the log_msg. +void LSTMTrainer::LogIterations(const char* intro_str, STRING* log_msg) const { + *log_msg += intro_str; + log_msg->add_str_int(" iteration ", learning_iteration()); + log_msg->add_str_int("/", training_iteration()); + log_msg->add_str_int("/", sample_iteration()); +} + +// Returns true and increments the training_stage_ if the error rate has just +// passed through the given threshold for the first time. +bool LSTMTrainer::TransitionTrainingStage(float error_threshold) { + if (best_error_rate_ < error_threshold && + training_stage_ + 1 < num_training_stages_) { + ++training_stage_; + return true; + } + return false; +} + +// Writes to the given file. Returns false in case of error. +bool LSTMTrainer::Serialize(TFile* fp) const { + if (!LSTMRecognizer::Serialize(fp)) return false; + if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) + return false; + if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) != + 1) + return false; + if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false; + if (fp->FWrite(&last_perfect_training_iteration_, + sizeof(last_perfect_training_iteration_), 1) != 1) + return false; + for (int i = 0; i < ET_COUNT; ++i) { + if (!error_buffers_[i].Serialize(fp)) return false; + } + if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false; + if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1) + return false; + uinT8 amount = serialize_amount_; + if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false; + if (amount == LIGHT) return true; // We are done. + if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1) + return false; + if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1) + return false; + if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1) + return false; + if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1) + return false; + if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1) + return false; + if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1) + return false; + if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1) + return false; + if (!best_model_data_.Serialize(fp)) return false; + if (!worst_model_data_.Serialize(fp)) return false; + if (amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) return false; + GenericVector sub_data; + if (sub_trainer_ != NULL && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data)) + return false; + if (!sub_data.Serialize(fp)) return false; + if (!best_error_history_.Serialize(fp)) return false; + if (!best_error_iterations_.Serialize(fp)) return false; + if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1) + return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool LSTMTrainer::DeSerialize(bool swap, TFile* fp) { + if (!LSTMRecognizer::DeSerialize(swap, fp)) return false; + if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) + return false; + if (fp->FRead(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) != + 1) + return false; + if (fp->FRead(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false; + if (fp->FRead(&last_perfect_training_iteration_, + sizeof(last_perfect_training_iteration_), 1) != 1) + return false; + for (int i = 0; i < ET_COUNT; ++i) { + if (!error_buffers_[i].DeSerialize(swap, fp)) return false; + } + if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false; + if (fp->FRead(&training_stage_, sizeof(training_stage_), 1) != 1) + return false; + uinT8 amount; + if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false; + if (amount == LIGHT) return true; // Don't read the rest. + if (fp->FRead(&best_error_rate_, sizeof(best_error_rate_), 1) != 1) + return false; + if (fp->FRead(&best_error_rates_, sizeof(best_error_rates_), 1) != 1) + return false; + if (fp->FRead(&best_iteration_, sizeof(best_iteration_), 1) != 1) + return false; + if (fp->FRead(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1) + return false; + if (fp->FRead(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1) + return false; + if (fp->FRead(&worst_iteration_, sizeof(worst_iteration_), 1) != 1) + return false; + if (fp->FRead(&stall_iteration_, sizeof(stall_iteration_), 1) != 1) + return false; + if (!best_model_data_.DeSerialize(swap, fp)) return false; + if (!worst_model_data_.DeSerialize(swap, fp)) return false; + if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(swap, fp)) + return false; + GenericVector sub_data; + if (!sub_data.DeSerialize(swap, fp)) return false; + delete sub_trainer_; + if (sub_data.empty()) { + sub_trainer_ = NULL; + } else { + sub_trainer_ = new LSTMTrainer(); + if (!ReadTrainingDump(sub_data, sub_trainer_)) return false; + } + if (!best_error_history_.DeSerialize(swap, fp)) return false; + if (!best_error_iterations_.DeSerialize(swap, fp)) return false; + if (fp->FRead(&improvement_steps_, sizeof(improvement_steps_), 1) != 1) + return false; + return true; +} + +// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the +// learning rates (by scaling reduction, or layer specific, according to +// NF_LAYER_SPECIFIC_LR). +void LSTMTrainer::StartSubtrainer(STRING* log_msg) { + delete sub_trainer_; + sub_trainer_ = new LSTMTrainer(); + if (!checkpoint_reader_->Run(best_trainer_, sub_trainer_)) { + *log_msg += " Failed to revert to previous best for trial!"; + delete sub_trainer_; + sub_trainer_ = NULL; + } + log_msg->add_str_int(" Trial sub_trainer_ from iteration ", + sub_trainer_->training_iteration()); + // Reduce learning rate so it doesn't diverge this time. + sub_trainer_->ReduceLearningRates(this, log_msg); + // If it fails again, we will wait twice as long before reverting again. + int stall_offset = learning_iteration() - sub_trainer_->learning_iteration(); + stall_iteration_ = learning_iteration() + 2 * stall_offset; + sub_trainer_->stall_iteration_ = stall_iteration_; + // Re-save the best trainer with the new learning rates and stall iteration. + checkpoint_writer_->Run(NO_BEST_TRAINER, sub_trainer_, &best_trainer_); +} + +// While the sub_trainer_ is behind the current training iteration and its +// training error is at least kSubTrainerMarginFraction better than the +// current training error, trains the sub_trainer_, and returns STR_UPDATED if +// it did anything. If it catches up, and has a better error rate than the +// current best, as well as a margin over the current error rate, then the +// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is +// returned. STR_NONE is returned if the subtrainer wasn't good enough to +// receive any training iterations. +SubTrainerResult LSTMTrainer::UpdateSubtrainer(STRING* log_msg) { + double training_error = CharError(); + double sub_error = sub_trainer_->CharError(); + double sub_margin = (training_error - sub_error) / sub_error; + if (sub_margin >= kSubTrainerMarginFraction) { + log_msg->add_str_double(" sub_trainer=", sub_error); + log_msg->add_str_double(" margin=", 100.0 * sub_margin); + *log_msg += "\n"; + // Catch up to current iteration. + int end_iteration = training_iteration(); + while (sub_trainer_->training_iteration() < end_iteration && + sub_margin >= kSubTrainerMarginFraction) { + int target_iteration = + sub_trainer_->training_iteration() + kNumPagesPerBatch; + while (sub_trainer_->training_iteration() < target_iteration) { + sub_trainer_->TrainOnLine(this, false); + } + STRING batch_log = "Sub:"; + sub_trainer_->PrepareLogMsg(&batch_log); + batch_log += "\n"; + tprintf("UpdateSubtrainer:%s", batch_log.string()); + *log_msg += batch_log; + sub_error = sub_trainer_->CharError(); + sub_margin = (training_error - sub_error) / sub_error; + } + if (sub_error < best_error_rate_ && + sub_margin >= kSubTrainerMarginFraction) { + // The sub_trainer_ has won the race to a new best. Switch to it. + GenericVector updated_trainer; + SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer); + ReadTrainingDump(updated_trainer, this); + log_msg->add_str_int(" Sub trainer wins at iteration ", + training_iteration()); + *log_msg += "\n"; + return STR_REPLACED; + } + return STR_UPDATED; + } + return STR_NONE; +} + +// Reduces network learning rates, either for everything, or for layers +// independently, according to NF_LAYER_SPECIFIC_LR. +void LSTMTrainer::ReduceLearningRates(LSTMTrainer* samples_trainer, + STRING* log_msg) { + if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { + int num_reduced = ReduceLayerLearningRates( + kLearningRateDecay, kNumAdjustmentIterations, samples_trainer); + log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced); + } else { + ScaleLearningRate(kLearningRateDecay); + log_msg->add_str_double("\nReduced learning rate to :", learning_rate_); + } + *log_msg += "\n"; +} + +// Considers reducing the learning rate independently for each layer down by +// factor(<1), or leaving it the same, by double-training the given number of +// samples and minimizing the amount of changing of sign of weight updates. +// Even if it looks like all weights should remain the same, an adjustment +// will be made to guarantee a different result when reverting to an old best. +// Returns the number of layer learning rates that were reduced. +int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, + LSTMTrainer* samples_trainer) { + enum WhichWay { + LR_DOWN, // Learning rate will go down by factor. + LR_SAME, // Learning rate will stay the same. + LR_COUNT // Size of arrays. + }; + // Epsilon is so small that it may as well be zero, but still positive. + const double kEpsilon = 1.0e-30; + GenericVector layers = EnumerateLayers(); + int num_layers = layers.size(); + GenericVector num_weights; + num_weights.init_to_size(num_layers, 0); + GenericVector bad_sums[LR_COUNT]; + GenericVector ok_sums[LR_COUNT]; + for (int i = 0; i < LR_COUNT; ++i) { + bad_sums[i].init_to_size(num_layers, 0.0); + ok_sums[i].init_to_size(num_layers, 0.0); + } + double momentum_factor = 1.0 / (1.0 - momentum_); + GenericVector orig_trainer; + SaveTrainingDump(LIGHT, this, &orig_trainer); + for (int i = 0; i < num_layers; ++i) { + Network* layer = GetLayer(layers[i]); + num_weights[i] = layer->training() ? layer->num_weights() : 0; + } + int iteration = sample_iteration(); + for (int s = 0; s < num_samples; ++s) { + // Which way will we modify the learning rate? + for (int ww = 0; ww < LR_COUNT; ++ww) { + // Transfer momentum to learning rate and adjust by the ww factor. + float ww_factor = momentum_factor; + if (ww == LR_DOWN) ww_factor *= factor; + // Make a copy of *this, so we can mess about without damaging anything. + LSTMTrainer copy_trainer; + copy_trainer.ReadTrainingDump(orig_trainer, ©_trainer); + // Clear the updates, doing nothing else. + copy_trainer.network_->Update(0.0, 0.0, 0); + // Adjust the learning rate in each layer. + for (int i = 0; i < num_layers; ++i) { + if (num_weights[i] == 0) continue; + copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor); + } + copy_trainer.SetIteration(iteration); + // Train on the sample, but keep the update in updates_ instead of + // applying to the weights. + const ImageData* trainingdata = + copy_trainer.TrainOnLine(samples_trainer, true); + if (trainingdata == NULL) continue; + // We'll now use this trainer again for each layer. + GenericVector updated_trainer; + SaveTrainingDump(LIGHT, ©_trainer, &updated_trainer); + for (int i = 0; i < num_layers; ++i) { + if (num_weights[i] == 0) continue; + LSTMTrainer layer_trainer; + layer_trainer.ReadTrainingDump(updated_trainer, &layer_trainer); + Network* layer = layer_trainer.GetLayer(layers[i]); + // Update the weights in just the layer, and also zero the updates + // matrix (to epsilon). + layer->Update(0.0, kEpsilon, 0); + // Train again on the same sample, again holding back the updates. + layer_trainer.TrainOnLine(trainingdata, true); + // Count the sign changes in the updates in layer vs in copy_trainer. + float before_bad = bad_sums[ww][i]; + float before_ok = ok_sums[ww][i]; + layer->CountAlternators(*copy_trainer.GetLayer(layers[i]), + &ok_sums[ww][i], &bad_sums[ww][i]); + float bad_frac = + bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok; + if (bad_frac > 0.0f) + bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac; + } + } + ++iteration; + } + int num_lowered = 0; + for (int i = 0; i < num_layers; ++i) { + if (num_weights[i] == 0) continue; + Network* layer = GetLayer(layers[i]); + float lr = GetLayerLearningRate(layers[i]); + double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i]; + double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i]; + double frac_down = bad_sums[LR_DOWN][i] / total_down; + double frac_same = bad_sums[LR_SAME][i] / total_same; + tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(), + lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same); + if (frac_down < frac_same * kImprovementFraction) { + tprintf(" REDUCED\n"); + ScaleLayerLearningRate(layers[i], factor); + ++num_lowered; + } else { + tprintf(" SAME\n"); + } + } + if (num_lowered == 0) { + // Just lower everything to make sure. + for (int i = 0; i < num_layers; ++i) { + if (num_weights[i] > 0) { + ScaleLayerLearningRate(layers[i], factor); + ++num_lowered; + } + } + } + return num_lowered; +} + +// Converts the string to integer class labels, with appropriate null_char_s +// in between if not in SimpleTextOutput mode. Returns false on failure. +/* static */ +bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset, + const UnicharCompress* recoder, bool simple_text, + int null_char, GenericVector* labels) { + if (str.string() == NULL || str.length() <= 0) { + tprintf("Empty truth string!\n"); + return false; + } + int err_index; + GenericVector internal_labels; + labels->truncate(0); + if (!simple_text) labels->push_back(null_char); + if (unicharset.encode_string(str.string(), true, &internal_labels, NULL, + &err_index)) { + bool success = true; + for (int i = 0; i < internal_labels.size(); ++i) { + if (recoder != NULL) { + // Re-encode labels via recoder. + RecodedCharID code; + int len = recoder->EncodeUnichar(internal_labels[i], &code); + if (len > 0) { + for (int j = 0; j < len; ++j) { + labels->push_back(code(j)); + if (!simple_text) labels->push_back(null_char); + } + } else { + success = false; + err_index = 0; + break; + } + } else { + labels->push_back(internal_labels[i]); + if (!simple_text) labels->push_back(null_char); + } + } + if (success) return true; + } + tprintf("Encoding of string failed! Failure bytes:"); + while (err_index < str.length()) { + tprintf(" %x", str[err_index++]); + } + tprintf("\n"); + return false; +} + +// Performs forward-backward on the given trainingdata. +// Returns a Trainability enum to indicate the suitability of the sample. +Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata, + bool batch) { + NetworkIO fwd_outputs, targets; + Trainability trainable = + PrepareForBackward(trainingdata, &fwd_outputs, &targets); + ++sample_iteration_; + if (trainable == UNENCODABLE || trainable == NOT_BOXED) { + return trainable; // Sample was unusable. + } + bool debug = debug_interval_ > 0 && + training_iteration() % debug_interval_ == 0; + // Run backprop on the output. + NetworkIO bp_deltas; + if (network_->training() && + (trainable != PERFECT || + training_iteration() > + last_perfect_training_iteration_ + perfect_delay_)) { + network_->Backward(debug, targets, &scratch_space_, &bp_deltas); + network_->Update(learning_rate_, batch ? -1.0f : momentum_, + training_iteration_ + 1); + } +#ifndef GRAPHICS_DISABLED + if (debug_interval_ == 1 && debug_win_ != NULL) { + delete debug_win_->AwaitEvent(SVET_CLICK); + } +#endif // GRAPHICS_DISABLED + // Roll the memory of past means. + RollErrorBuffers(); + return trainable; +} + +// Prepares the ground truth, runs forward, and prepares the targets. +// Returns a Trainability enum to indicate the suitability of the sample. +Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata, + NetworkIO* fwd_outputs, + NetworkIO* targets) { + if (trainingdata == NULL) { + tprintf("Null trainingdata.\n"); + return UNENCODABLE; + } + // Ensure repeatability of random elements even across checkpoints. + bool debug = debug_interval_ > 0 && + training_iteration() % debug_interval_ == 0; + GenericVector truth_labels; + if (!EncodeString(trainingdata->transcription(), &truth_labels)) { + tprintf("Can't encode transcription: %s\n", + trainingdata->transcription().string()); + return UNENCODABLE; + } + int w = 0; + while (w < truth_labels.size() && + (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) + ++w; + if (w == truth_labels.size()) { + tprintf("Blank transcription: %s\n", + trainingdata->transcription().string()); + return UNENCODABLE; + } + float image_scale; + NetworkIO inputs; + bool invert = trainingdata->boxes().empty(); + if (!RecognizeLine(*trainingdata, invert, debug, invert, 0.0f, &image_scale, + &inputs, fwd_outputs)) { + tprintf("Image not trainable\n"); + return UNENCODABLE; + } + targets->Resize(*fwd_outputs, network_->NumOutputs()); + double text_error = 100.0; + LossType loss_type = OutputLossType(); + if (loss_type == LT_SOFTMAX) { + if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) { + tprintf("Compute simple targets failed!\n"); + return UNENCODABLE; + } + } else if (loss_type == LT_CTC) { + if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) { + tprintf("Compute CTC targets failed!\n"); + return UNENCODABLE; + } + } else { + tprintf("Logistic outputs not implemented yet!\n"); + return UNENCODABLE; + } + GenericVector ocr_labels; + GenericVector xcoords; + LabelsFromOutputs(*fwd_outputs, 0.0f, &ocr_labels, &xcoords); + // CTC does not produce correct target labels to begin with. + if (loss_type != LT_CTC) { + LabelsFromOutputs(*targets, 0.0f, &truth_labels, &xcoords); + } + if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels, + *targets)) { + tprintf("Input width was %d\n", inputs.Width()); + return UNENCODABLE; + } + STRING ocr_text = DecodeLabels(ocr_labels); + STRING truth_text = DecodeLabels(truth_labels); + targets->SubtractAllFromFloat(*fwd_outputs); + if (debug_interval_ != 0) { + tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(), + ocr_text.string()); + } + double char_error = ComputeCharError(truth_labels, ocr_labels); + double word_error = ComputeWordError(&truth_text, &ocr_text); + double delta_error = ComputeErrorRates(*targets, char_error, word_error); + if (debug_interval_ != 0) { + tprintf("File %s page %d %s:\n", trainingdata->imagefilename().string(), + trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : ""); + } + if (delta_error == 0.0) return PERFECT; + if (targets->AnySuspiciousTruth(kHighConfidence)) return HI_PRECISION_ERR; + return TRAINABLE; +} + +// Writes the trainer to memory, so that the current training state can be +// restored. +bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount, + const LSTMTrainer* trainer, + GenericVector* data) const { + TFile fp; + fp.OpenWrite(data); + trainer->serialize_amount_ = serialize_amount; + return trainer->Serialize(&fp); +} + +// Reads previously saved trainer from memory. +bool LSTMTrainer::ReadTrainingDump(const GenericVector& data, + LSTMTrainer* trainer) { + return trainer->ReadSizedTrainingDump(&data[0], data.size()); +} + +bool LSTMTrainer::ReadSizedTrainingDump(const char* data, int size) { + TFile fp; + fp.Open(data, size); + return DeSerialize(false, &fp); +} + +// Writes the recognizer to memory, so that it can be used for testing later. +void LSTMTrainer::SaveRecognitionDump(GenericVector* data) const { + TFile fp; + fp.OpenWrite(data); + network_->SetEnableTraining(false); + ASSERT_HOST(LSTMRecognizer::Serialize(&fp)); + network_->SetEnableTraining(true); +} + +// Reads and returns a previously saved recognizer from memory. +LSTMRecognizer* LSTMTrainer::ReadRecognitionDump( + const GenericVector& data) { + TFile fp; + fp.Open(&data[0], data.size()); + LSTMRecognizer* recognizer = new LSTMRecognizer; + ASSERT_HOST(recognizer->DeSerialize(false, &fp)); + return recognizer; +} + +// Returns a suitable filename for a training dump, based on the model_base_, +// the iteration and the error rates. +STRING LSTMTrainer::DumpFilename() const { + STRING filename; + filename.add_str_double(model_base_.string(), best_error_rate_); + filename.add_str_int("_", best_iteration_); + filename += ".lstm"; + return filename; +} + +// Fills the whole error buffer of the given type with the given value. +void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) { + for (int i = 0; i < kRollingBufferSize_; ++i) + error_buffers_[type][i] = new_error; + error_rates_[type] = 100.0 * new_error; +} + +// Factored sub-constructor sets up reasonable default values. +void LSTMTrainer::EmptyConstructor() { + align_win_ = NULL; + target_win_ = NULL; + ctc_win_ = NULL; + recon_win_ = NULL; + checkpoint_iteration_ = 0; + serialize_amount_ = FULL; + training_stage_ = 0; + num_training_stages_ = 2; + prev_sample_iteration_ = 0; + best_error_rate_ = 100.0; + best_iteration_ = 0; + worst_error_rate_ = 0.0; + worst_iteration_ = 0; + stall_iteration_ = kMinStallIterations; + learning_iteration_ = 0; + improvement_steps_ = kMinStallIterations; + perfect_delay_ = 0; + last_perfect_training_iteration_ = 0; + for (int i = 0; i < ET_COUNT; ++i) { + best_error_rates_[i] = 100.0; + worst_error_rates_[i] = 0.0; + error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0); + error_rates_[i] = 100.0; + } + sample_iteration_ = 0; + training_iteration_ = 0; + error_rate_of_last_saved_best_ = kMinStartedErrorRate; +} + +// Sets the unicharset properties using the given script_dir as a source of +// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets +// up the recoder_ to simplify the unicharset. +void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) { + tprintf("Setting unichar properties\n"); + for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) { + if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0) + continue; + // Load the unicharset for the script if available. + STRING filename = script_dir + "/" + + GetUnicharset().get_script_from_script_id(s) + + ".unicharset"; + UNICHARSET script_set; + GenericVector data; + if ((*file_reader_)(filename, &data) && + script_set.load_from_inmemory_file(&data[0], data.size())) { + tprintf("Setting properties for script %s\n", + GetUnicharset().get_script_from_script_id(s)); + ccutil_.unicharset.SetPropertiesFromOther(script_set); + } + } + if (IsRecoding()) { + STRING filename = script_dir + "/radical-stroke.txt"; + GenericVector data; + if ((*file_reader_)(filename, &data)) { + data += '\0'; + STRING stroke_table = &data[0]; + if (recoder_.ComputeEncoding(GetUnicharset(), null_char_, + &stroke_table)) { + RecodedCharID code; + recoder_.EncodeUnichar(null_char_, &code); + null_char_ = code(0); + // Space should encode as itself. + recoder_.EncodeUnichar(UNICHAR_SPACE, &code); + ASSERT_HOST(code(0) == UNICHAR_SPACE); + return; + } + } else { + tprintf("Failed to load radical-stroke info from: %s\n", + filename.string()); + } + training_flags_ &= ~TF_COMPRESS_UNICHARSET; + } +} + +// Outputs the string and periodically displays the given network inputs +// as an image in the given window, and the corresponding labels at the +// corresponding x_starts. +// Returns false if the truth string is empty. +bool LSTMTrainer::DebugLSTMTraining(const NetworkIO& inputs, + const ImageData& trainingdata, + const NetworkIO& fwd_outputs, + const GenericVector& truth_labels, + const NetworkIO& outputs) { + const STRING& truth_text = DecodeLabels(truth_labels); + if (truth_text.string() == NULL || truth_text.length() <= 0) { + tprintf("Empty truth string at decode time!\n"); + return false; + } + if (debug_interval_ != 0) { + // Get class labels, xcoords and string. + GenericVector labels; + GenericVector xcoords; + LabelsFromOutputs(outputs, 0.0f, &labels, &xcoords); + STRING text = DecodeLabels(labels); + tprintf("Iteration %d: ALIGNED TRUTH : %s\n", + training_iteration(), text.string()); + if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) { + tprintf("TRAINING activation path for truth string %s\n", + truth_text.string()); + DebugActivationPath(outputs, labels, xcoords); + DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_); + if (OutputLossType() == LT_CTC) { + DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_); + DisplayTargets(outputs, "CTC Targets", &target_win_); + } + } + } + return true; +} + +// Displays the network targets as line a line graph. +void LSTMTrainer::DisplayTargets(const NetworkIO& targets, + const char* window_name, ScrollView** window) { +#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics. + int width = targets.Width(); + int num_features = targets.NumFeatures(); + Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale, + window); + for (int c = 0; c < num_features; ++c) { + int color = c % (ScrollView::GREEN_YELLOW - 1) + 2; + (*window)->Pen(static_cast(color)); + int start_t = -1; + for (int t = 0; t < width; ++t) { + double target = targets.f(t)[c]; + target *= kTargetYScale; + if (target >= 1) { + if (start_t < 0) { + (*window)->SetCursor(t - 1, 0); + start_t = t; + } + (*window)->DrawTo(t, target); + } else if (start_t >= 0) { + (*window)->DrawTo(t, 0); + (*window)->DrawTo(start_t - 1, 0); + start_t = -1; + } + } + if (start_t >= 0) { + (*window)->DrawTo(width, 0); + (*window)->DrawTo(start_t - 1, 0); + } + } + (*window)->Update(); +#endif // GRAPHICS_DISABLED +} + +// Builds a no-compromises target where the first positions should be the +// truth labels and the rest is padded with the null_char_. +bool LSTMTrainer::ComputeTextTargets(const NetworkIO& outputs, + const GenericVector& truth_labels, + NetworkIO* targets) { + if (truth_labels.size() > targets->Width()) { + tprintf("Error: transcription %s too long to fit into target of width %d\n", + DecodeLabels(truth_labels).string(), targets->Width()); + return false; + } + for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) { + targets->SetActivations(i, truth_labels[i], 1.0); + } + for (int i = truth_labels.size(); i < targets->Width(); ++i) { + targets->SetActivations(i, null_char_, 1.0); + } + return true; +} + +// Builds a target using standard CTC. truth_labels should be pre-padded with +// nulls wherever desired. They don't have to be between all labels. +// outputs is input-output, as it gets clipped to minimum probability. +bool LSTMTrainer::ComputeCTCTargets(const GenericVector& truth_labels, + NetworkIO* outputs, NetworkIO* targets) { + // Bottom-clip outputs to a minimum probability. + CTC::NormalizeProbs(outputs); + return CTC::ComputeCTCTargets(truth_labels, null_char_, + outputs->float_array(), targets); +} + +// Computes network errors, and stores the results in the rolling buffers, +// along with the supplied text_error. +// Returns the delta error of the current sample (not running average.) +double LSTMTrainer::ComputeErrorRates(const NetworkIO& deltas, + double char_error, double word_error) { + UpdateErrorBuffer(ComputeRMSError(deltas), ET_RMS); + // Delta error is the fraction of timesteps with >0.5 error in the top choice + // score. If zero, then the top choice characters are guaranteed correct, + // even when there is residue in the RMS error. + double delta_error = ComputeWinnerError(deltas); + UpdateErrorBuffer(delta_error, ET_DELTA); + UpdateErrorBuffer(word_error, ET_WORD_RECERR); + UpdateErrorBuffer(char_error, ET_CHAR_ERROR); + // Skip ratio measures the difference between sample_iteration_ and + // training_iteration_, which reflects the number of unusable samples, + // usually due to unencodable truth text, or the text not fitting in the + // space for the output. + double skip_count = sample_iteration_ - prev_sample_iteration_; + UpdateErrorBuffer(skip_count, ET_SKIP_RATIO); + return delta_error; +} + +// Computes the network activation RMS error rate. +double LSTMTrainer::ComputeRMSError(const NetworkIO& deltas) { + double total_error = 0.0; + int width = deltas.Width(); + int num_classes = deltas.NumFeatures(); + for (int t = 0; t < width; ++t) { + const float* class_errs = deltas.f(t); + for (int c = 0; c < num_classes; ++c) { + double error = class_errs[c]; + total_error += error * error; + } + } + return sqrt(total_error / (width * num_classes)); +} + +// Computes network activation winner error rate. (Number of values that are +// in error by >= 0.5 divided by number of time-steps.) More closely related +// to final character error than RMS, but still directly calculable from +// just the deltas. Because of the binary nature of the targets, zero winner +// error is a sufficient but not necessary condition for zero char error. +double LSTMTrainer::ComputeWinnerError(const NetworkIO& deltas) { + int num_errors = 0; + int width = deltas.Width(); + int num_classes = deltas.NumFeatures(); + for (int t = 0; t < width; ++t) { + const float* class_errs = deltas.f(t); + for (int c = 0; c < num_classes; ++c) { + float abs_delta = fabs(class_errs[c]); + // TODO(rays) Filtering cases where the delta is very large to cut out + // GT errors doesn't work. Find a better way or get better truth. + if (0.5 <= abs_delta) + ++num_errors; + } + } + return static_cast(num_errors) / width; +} + +// Computes a very simple bag of chars char error rate. +double LSTMTrainer::ComputeCharError(const GenericVector& truth_str, + const GenericVector& ocr_str) { + GenericVector label_counts; + label_counts.init_to_size(NumOutputs(), 0); + int truth_size = 0; + for (int i = 0; i < truth_str.size(); ++i) { + if (truth_str[i] != null_char_) { + ++label_counts[truth_str[i]]; + ++truth_size; + } + } + for (int i = 0; i < ocr_str.size(); ++i) { + if (ocr_str[i] != null_char_) { + --label_counts[ocr_str[i]]; + } + } + int char_errors = 0; + for (int i = 0; i < label_counts.size(); ++i) { + char_errors += abs(label_counts[i]); + } + return static_cast(char_errors) / truth_size; +} + +// Computes a very simple bag of words word recall error rate. +// NOTE that this is destructive on both input strings. +double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) { + typedef TessHashMap > StrMap; + GenericVector truth_words, ocr_words; + truth_str->split(' ', &truth_words); + if (truth_words.empty()) return 0.0; + ocr_str->split(' ', &ocr_words); + StrMap word_counts; + for (int i = 0; i < truth_words.size(); ++i) { + string truth_word(truth_words[i].string()); + StrMap::iterator it = word_counts.find(truth_word); + if (it == word_counts.end()) + word_counts.insert(make_pair(truth_word, 1)); + else + ++it->second; + } + for (int i = 0; i < ocr_words.size(); ++i) { + string ocr_word(ocr_words[i].string()); + StrMap::iterator it = word_counts.find(ocr_word); + if (it == word_counts.end()) + word_counts.insert(make_pair(ocr_word, -1)); + else + --it->second; + } + int word_recall_errs = 0; + for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end(); + ++it) { + if (it->second > 0) word_recall_errs += it->second; + } + return static_cast(word_recall_errs) / truth_words.size(); +} + +// Updates the error buffer and corresponding mean of the given type with +// the new_error. +void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) { + int index = training_iteration_ % kRollingBufferSize_; + error_buffers_[type][index] = new_error; + // Compute the mean error. + int mean_count = MIN(training_iteration_ + 1, error_buffers_[type].size()); + double buffer_sum = 0.0; + for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i]; + double mean = buffer_sum / mean_count; + // Trim precision to 1/1000 of 1%. + error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0; +} + +// Rolls error buffers and reports the current means. +void LSTMTrainer::RollErrorBuffers() { + prev_sample_iteration_ = sample_iteration_; + if (NewSingleError(ET_DELTA) > 0.0) + ++learning_iteration_; + else + last_perfect_training_iteration_ = training_iteration_; + ++training_iteration_; + if (debug_interval_ != 0) { + tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n", + error_rates_[ET_RMS], error_rates_[ET_DELTA], + error_rates_[ET_CHAR_ERROR], error_rates_[ET_WORD_RECERR], + error_rates_[ET_SKIP_RATIO]); + } +} + +// Given that error_rate is either a new min or max, updates the best/worst +// error rates, and record of progress. +// Tester is an externally supplied callback function that tests on some +// data set with a given model and records the error rates in a graph. +STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, + const GenericVector& model_data, + TestCallback tester) { + if (error_rate > best_error_rate_ + && iteration < best_iteration_ + kErrorGraphInterval) { + // Too soon to record a new point. + if (tester != NULL) + return tester->Run(worst_iteration_, NULL, worst_model_data_, + CurrentTrainingStage()); + else + return ""; + } + STRING result; + // NOTE: there are 2 asymmetries here: + // 1. We are computing the global minimum, but the local maximum in between. + // 2. If the tester returns an empty string, indicating that it is busy, + // call it repeatedly on new local maxima to test the previous min, but + // not the other way around, as there is little point testing the maxima + // between very frequent minima. + if (error_rate < best_error_rate_) { + // This is a new (global) minimum. + if (tester != NULL) { + result = tester->Run(worst_iteration_, worst_error_rates_, + worst_model_data_, CurrentTrainingStage()); + worst_model_data_.truncate(0); + best_model_data_ = model_data; + } + best_error_rate_ = error_rate; + memcpy(best_error_rates_, error_rates_, sizeof(error_rates_)); + best_iteration_ = iteration; + best_error_history_.push_back(error_rate); + best_error_iterations_.push_back(iteration); + // Compute 2% decay time. + double two_percent_more = error_rate + 2.0; + int i; + for (i = best_error_history_.size() - 1; + i >= 0 && best_error_history_[i] < two_percent_more; --i) { + } + int old_iteration = i >= 0 ? best_error_iterations_[i] : 0; + improvement_steps_ = iteration - old_iteration; + tprintf("2 Percent improvement time=%d, best error was %g @ %d\n", + improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0, + old_iteration); + } else if (error_rate > best_error_rate_) { + // This is a new (local) maximum. + if (tester != NULL) { + if (best_model_data_.empty()) { + // Allow for multiple data points with "worst" error rate. + result = tester->Run(worst_iteration_, worst_error_rates_, + worst_model_data_, CurrentTrainingStage()); + } else { + result = tester->Run(best_iteration_, best_error_rates_, + best_model_data_, CurrentTrainingStage()); + } + if (result.length() > 0) + best_model_data_.truncate(0); + worst_model_data_ = model_data; + } + } + worst_error_rate_ = error_rate; + memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_)); + worst_iteration_ = iteration; + return result; +} + +} // namespace tesseract. diff --git a/lstm/lstmtrainer.h b/lstm/lstmtrainer.h new file mode 100644 index 0000000000..e6a7c43f2e --- /dev/null +++ b/lstm/lstmtrainer.h @@ -0,0 +1,477 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmtrainer.h +// Description: Top-level line trainer class for LSTM-based networks. +// Author: Ray Smith +// Created: Fri May 03 09:07:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_LSTMTRAINER_H_ +#define TESSERACT_LSTM_LSTMTRAINER_H_ + +#include "imagedata.h" +#include "lstmrecognizer.h" +#include "rect.h" +#include "tesscallback.h" + +namespace tesseract { + +class LSTM; +class LSTMTrainer; +class Parallel; +class Reversed; +class Softmax; +class Series; + +// Enum for the types of errors that are counted. +enum ErrorTypes { + ET_RMS, // RMS activation error. + ET_DELTA, // Number of big errors in deltas. + ET_WORD_RECERR, // Output text string word recall error. + ET_CHAR_ERROR, // Output text string total char error. + ET_SKIP_RATIO, // Fraction of samples skipped. + ET_COUNT // For array sizing. +}; + +// Enum for the trainability_ flags. +enum Trainability { + TRAINABLE, // Non-zero delta error. + PERFECT, // Zero delta error. + UNENCODABLE, // Not trainable due to coding/alignment trouble. + HI_PRECISION_ERR, // Hi confidence disagreement. + NOT_BOXED, // Early in training and has no character boxes. +}; + +// Enum to define the amount of data to get serialized. +enum SerializeAmount { + LIGHT, // Minimal data for remote training. + NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_. + FULL, // All data including best_trainer_. +}; + +// Enum to indicate how the sub_trainer_ training went. +enum SubTrainerResult { + STR_NONE, // Did nothing as not good enough. + STR_UPDATED, // Subtrainer was updated, but didn't replace *this. + STR_REPLACED // Subtrainer replaced *this. +}; + +class LSTMTrainer; +// Function to restore the trainer state from a given checkpoint. +// Returns false on failure. +typedef TessResultCallback2&, LSTMTrainer*>* + CheckPointReader; +// Function to save a checkpoint of the current trainer state. +// Returns false on failure. SerializeAmount determines the amount of the +// trainer to serialize, typically used for saving the best state. +typedef TessResultCallback3*>* CheckPointWriter; +// Function to compute and record error rates on some external test set(s). +// Args are: iteration, mean errors, model, training stage. +// Returns a STRING containing logging information about the tests. +typedef TessResultCallback4&, int>* TestCallback; + +// Trainer class for LSTM networks. Most of the effort is in creating the +// ideal target outputs from the transcription. A box file is used if it is +// available, otherwise estimates of the char widths from the unicharset are +// used to guide a DP search for the best fit to the transcription. +class LSTMTrainer : public LSTMRecognizer { + public: + LSTMTrainer(); + // Callbacks may be null, in which case defaults are used. + LSTMTrainer(FileReader file_reader, FileWriter file_writer, + CheckPointReader checkpoint_reader, + CheckPointWriter checkpoint_writer, + const char* model_base, const char* checkpoint_name, + int debug_interval, inT64 max_memory); + virtual ~LSTMTrainer(); + + // Tries to deserialize a trainer from the given file and silently returns + // false in case of failure. + bool TryLoadingCheckpoint(const char* filename); + + // Initializes the character set encode/decode mechanism. + // train_flags control training behavior according to the TrainingFlags + // enum, including character set encoding. + // script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided, + // fully initializes the unicharset from the universal unicharsets. + // Note: Call before InitNetwork! + void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir, + int train_flags); + // Initializes the character set encode/decode mechanism directly from a + // previously setup UNICHARSET and UnicharCompress. + // ctc_mode controls how the truth text is mapped to the network targets. + // Note: Call before InitNetwork! + void InitCharSet(const UNICHARSET& unicharset, const UnicharCompress recoder); + + // Initializes the trainer with a network_spec in the network description + // net_flags control network behavior according to the NetworkFlags enum. + // There isn't really much difference between them - only where the effects + // are implemented. + // For other args see NetworkBuilder::InitNetwork. + // Note: Be sure to call InitCharSet before InitNetwork! + bool InitNetwork(const STRING& network_spec, int append_index, int net_flags, + float weight_range, float learning_rate, float momentum); + // Initializes a trainer from a serialized TFNetworkModel proto. + // Returns the global step of TensorFlow graph or 0 if failed. + // Building a compatible TF graph: See tfnetwork.proto. + int InitTensorFlowNetwork(const std::string& tf_proto); + + // Accessors. + double ActivationError() const { + return error_rates_[ET_DELTA]; + } + double CharError() const { return error_rates_[ET_CHAR_ERROR]; } + const double* error_rates() const { + return error_rates_; + } + double best_error_rate() const { + return best_error_rate_; + } + int best_iteration() const { + return best_iteration_; + } + int learning_iteration() const { return learning_iteration_; } + int improvement_steps() const { return improvement_steps_; } + void set_perfect_delay(int delay) { perfect_delay_ = delay; } + const GenericVector& best_trainer() const { return best_trainer_; } + // Returns the error that was just calculated by PrepareForBackward. + double NewSingleError(ErrorTypes type) const { + return error_buffers_[type][training_iteration() % kRollingBufferSize_]; + } + // Returns the error that was just calculated by TrainOnLine. Since + // TrainOnLine rolls the error buffers, this is one further back than + // NewSingleError. + double LastSingleError(ErrorTypes type) const { + return error_buffers_[type] + [(training_iteration() + kRollingBufferSize_ - 1) % + kRollingBufferSize_]; + } + const DocumentCache& training_data() const { + return training_data_; + } + DocumentCache* mutable_training_data() { return &training_data_; } + + // If the training sample is usable, grid searches for the optimal + // dict_ratio/cert_offset, and returns the results in a string of space- + // separated triplets of ratio,offset=worderr. + Trainability GridSearchDictParams( + const ImageData* trainingdata, int iteration, double min_dict_ratio, + double dict_ratio_step, double max_dict_ratio, double min_cert_offset, + double cert_offset_step, double max_cert_offset, STRING* results); + + void SetSerializeMode(SerializeAmount serialize_amount) const { + serialize_amount_ = serialize_amount; + } + + // Provides output on the distribution of weight values. + void DebugNetwork(); + + // Loads a set of lstmf files that were created using the lstm.train config to + // tesseract into memory ready for training. Returns false if nothing was + // loaded. + bool LoadAllTrainingData(const GenericVector& filenames); + + // Keeps track of best and locally worst error rate, using internally computed + // values. See MaintainCheckpointsSpecific for more detail. + bool MaintainCheckpoints(TestCallback tester, STRING* log_msg); + // Keeps track of best and locally worst error_rate (whatever it is) and + // launches tests using rec_model, when a new min or max is reached. + // Writes checkpoints using train_model at appropriate times and builds and + // returns a log message to indicate progress. Returns false if nothing + // interesting happened. + bool MaintainCheckpointsSpecific(int iteration, + const GenericVector* train_model, + const GenericVector* rec_model, + TestCallback tester, STRING* log_msg); + // Builds a string containing a progress message with current error rates. + void PrepareLogMsg(STRING* log_msg) const; + // Appends iteration learning_iteration()/training_iteration()/ + // sample_iteration() to the log_msg. + void LogIterations(const char* intro_str, STRING* log_msg) const; + + // TODO(rays) Add curriculum learning. + // Returns true and increments the training_stage_ if the error rate has just + // passed through the given threshold for the first time. + bool TransitionTrainingStage(float error_threshold); + // Returns the current training stage. + int CurrentTrainingStage() const { return training_stage_; } + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the + // learning rates (by scaling reduction, or layer specific, according to + // NF_LAYER_SPECIFIC_LR). + void StartSubtrainer(STRING* log_msg); + // While the sub_trainer_ is behind the current training iteration and its + // training error is at least kSubTrainerMarginFraction better than the + // current training error, trains the sub_trainer_, and returns STR_UPDATED if + // it did anything. If it catches up, and has a better error rate than the + // current best, as well as a margin over the current error rate, then the + // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is + // returned. STR_NONE is returned if the subtrainer wasn't good enough to + // receive any training iterations. + SubTrainerResult UpdateSubtrainer(STRING* log_msg); + // Reduces network learning rates, either for everything, or for layers + // independently, according to NF_LAYER_SPECIFIC_LR. + void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg); + // Considers reducing the learning rate independently for each layer down by + // factor(<1), or leaving it the same, by double-training the given number of + // samples and minimizing the amount of changing of sign of weight updates. + // Even if it looks like all weights should remain the same, an adjustment + // will be made to guarantee a different result when reverting to an old best. + // Returns the number of layer learning rates that were reduced. + int ReduceLayerLearningRates(double factor, int num_samples, + LSTMTrainer* samples_trainer); + + // Converts the string to integer class labels, with appropriate null_char_s + // in between if not in SimpleTextOutput mode. Returns false on failure. + bool EncodeString(const STRING& str, GenericVector* labels) const { + return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL, + SimpleTextOutput(), null_char_, labels); + } + // Static version operates on supplied unicharset, encoder, simple_text. + static bool EncodeString(const STRING& str, const UNICHARSET& unicharset, + const UnicharCompress* recoder, bool simple_text, + int null_char, GenericVector* labels); + + // Converts the network to int if not already. + void ConvertToInt() { + if ((training_flags_ & TF_INT_MODE) == 0) { + network_->ConvertToInt(); + training_flags_ |= TF_INT_MODE; + } + } + + // Performs forward-backward on the given trainingdata. + // Returns the sample that was used or NULL if the next sample was deemed + // unusable. samples_trainer could be this or an alternative trainer that + // holds the training samples. + const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) { + int sample_index = sample_iteration(); + const ImageData* image = + samples_trainer->training_data_.GetPageBySerial(sample_index); + if (image != NULL) { + Trainability trainable = TrainOnLine(image, batch); + if (trainable == UNENCODABLE || trainable == NOT_BOXED) { + return NULL; // Sample was unusable. + } + } else { + ++sample_iteration_; + } + return image; + } + Trainability TrainOnLine(const ImageData* trainingdata, bool batch); + + // Prepares the ground truth, runs forward, and prepares the targets. + // Returns a Trainability enum to indicate the suitability of the sample. + Trainability PrepareForBackward(const ImageData* trainingdata, + NetworkIO* fwd_outputs, NetworkIO* targets); + + // Writes the trainer to memory, so that the current training state can be + // restored. + bool SaveTrainingDump(SerializeAmount serialize_amount, + const LSTMTrainer* trainer, + GenericVector* data) const; + + // Reads previously saved trainer from memory. + bool ReadTrainingDump(const GenericVector& data, LSTMTrainer* trainer); + bool ReadSizedTrainingDump(const char* data, int size); + + // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump. + void SetupCheckpointInfo(); + + // Writes the recognizer to memory, so that it can be used for testing later. + void SaveRecognitionDump(GenericVector* data) const; + + // Reads and returns a previously saved recognizer from memory. + static LSTMRecognizer* ReadRecognitionDump(const GenericVector& data); + + // Writes current best model to a file, unless it has already been written. + bool SaveBestModel(FileWriter writer) const; + + // Returns a suitable filename for a training dump, based on the model_base_, + // the iteration and the error rates. + STRING DumpFilename() const; + + // Fills the whole error buffer of the given type with the given value. + void FillErrorBuffer(double new_error, ErrorTypes type); + + protected: + // Factored sub-constructor sets up reasonable default values. + void EmptyConstructor(); + + // Sets the unicharset properties using the given script_dir as a source of + // script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets + // up the recoder_ to simplify the unicharset. + void SetUnicharsetProperties(const STRING& script_dir); + + // Outputs the string and periodically displays the given network inputs + // as an image in the given window, and the corresponding labels at the + // corresponding x_starts. + // Returns false if the truth string is empty. + bool DebugLSTMTraining(const NetworkIO& inputs, + const ImageData& trainingdata, + const NetworkIO& fwd_outputs, + const GenericVector& truth_labels, + const NetworkIO& outputs); + // Displays the network targets as line a line graph. + void DisplayTargets(const NetworkIO& targets, const char* window_name, + ScrollView** window); + + // Builds a no-compromises target where the first positions should be the + // truth labels and the rest is padded with the null_char_. + bool ComputeTextTargets(const NetworkIO& outputs, + const GenericVector& truth_labels, + NetworkIO* targets); + + // Builds a target using standard CTC. truth_labels should be pre-padded with + // nulls wherever desired. They don't have to be between all labels. + // outputs is input-output, as it gets clipped to minimum probability. + bool ComputeCTCTargets(const GenericVector& truth_labels, + NetworkIO* outputs, NetworkIO* targets); + + // Computes network errors, and stores the results in the rolling buffers, + // along with the supplied text_error. + // Returns the delta error of the current sample (not running average.) + double ComputeErrorRates(const NetworkIO& deltas, double char_error, + double word_error); + + // Computes the network activation RMS error rate. + double ComputeRMSError(const NetworkIO& deltas); + + // Computes network activation winner error rate. (Number of values that are + // in error by >= 0.5 divided by number of time-steps.) More closely related + // to final character error than RMS, but still directly calculable from + // just the deltas. Because of the binary nature of the targets, zero winner + // error is a sufficient but not necessary condition for zero char error. + double ComputeWinnerError(const NetworkIO& deltas); + + // Computes a very simple bag of chars char error rate. + double ComputeCharError(const GenericVector& truth_str, + const GenericVector& ocr_str); + // Computes a very simple bag of words word recall error rate. + // NOTE that this is destructive on both input strings. + double ComputeWordError(STRING* truth_str, STRING* ocr_str); + + // Updates the error buffer and corresponding mean of the given type with + // the new_error. + void UpdateErrorBuffer(double new_error, ErrorTypes type); + + // Rolls error buffers and reports the current means. + void RollErrorBuffers(); + + // Given that error_rate is either a new min or max, updates the best/worst + // error rates, and record of progress. + STRING UpdateErrorGraph(int iteration, double error_rate, + const GenericVector& model_data, + TestCallback tester); + + protected: + // Alignment display window. + ScrollView* align_win_; + // CTC target display window. + ScrollView* target_win_; + // CTC output display window. + ScrollView* ctc_win_; + // Reconstructed image window. + ScrollView* recon_win_; + // How often to display a debug image. + int debug_interval_; + // Iteration at which the last checkpoint was dumped. + int checkpoint_iteration_; + // Basename of files to save best models to. + STRING model_base_; + // Checkpoint filename. + STRING checkpoint_name_; + // Training data. + DocumentCache training_data_; + // A hack to serialize less data for batch training and record file version. + mutable SerializeAmount serialize_amount_; + // Name to use when saving best_trainer_. + STRING best_model_name_; + // Number of available training stages. + int num_training_stages_; + // Checkpointing callbacks. + FileReader file_reader_; + FileWriter file_writer_; + // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr + // when we can commit to c++11. + CheckPointReader checkpoint_reader_; + CheckPointWriter checkpoint_writer_; + + // ===Serialized data to ensure that a restart produces the same results.=== + // These members are only serialized when serialize_amount_ != LIGHT. + // Best error rate so far. + double best_error_rate_; + // Snapshot of all error rates at best_iteration_. + double best_error_rates_[ET_COUNT]; + // Iteration of best_error_rate_. + int best_iteration_; + // Worst error rate since best_error_rate_. + double worst_error_rate_; + // Snapshot of all error rates at worst_iteration_. + double worst_error_rates_[ET_COUNT]; + // Iteration of worst_error_rate_. + int worst_iteration_; + // Iteration at which the process will be thought stalled. + int stall_iteration_; + // Saved recognition models for computing test error for graph points. + GenericVector best_model_data_; + GenericVector worst_model_data_; + // Saved trainer for reverting back to last known best. + GenericVector best_trainer_; + // A subsidiary trainer running with a different learning rate until either + // *this or sub_trainer_ hits a new best. + LSTMTrainer* sub_trainer_; + // Error rate at which last best model was dumped. + float error_rate_of_last_saved_best_; + // Current stage of training. + int training_stage_; + // History of best error rate against iteration. Used for computing the + // number of steps to each 2% improvement. + GenericVector best_error_history_; + GenericVector best_error_iterations_; + // Number of iterations since the best_error_rate_ was 2% more than it is now. + int improvement_steps_; + // Number of iterations that yielded a non-zero delta error and thus provided + // significant learning. learning_iteration_ <= training_iteration_. + // learning_iteration_ is used to measure rate of learning progress. + int learning_iteration_; + // Saved value of sample_iteration_ before looking for the the next sample. + int prev_sample_iteration_; + // How often to include a PERFECT training sample in backprop. + // A PERFECT training sample is used if the current + // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_, + // so with perfect_delay_ == 0, all samples are used, and with + // perfect_delay_ == 4, at most 1 in 5 samples will be perfect. + int perfect_delay_; + // Value of training_iteration_ at which the last PERFECT training sample + // was used in back prop. + int last_perfect_training_iteration_; + // Rolling buffers storing recent training errors are indexed by + // training_iteration % kRollingBufferSize_. + static const int kRollingBufferSize_ = 1000; + GenericVector error_buffers_[ET_COUNT]; + // Rounded mean percent trailing training errors in the buffers. + double error_rates_[ET_COUNT]; // RMS training error. +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_LSTMTRAINER_H_ diff --git a/lstm/maxpool.cpp b/lstm/maxpool.cpp new file mode 100644 index 0000000000..2164aaf5e3 --- /dev/null +++ b/lstm/maxpool.cpp @@ -0,0 +1,87 @@ +/////////////////////////////////////////////////////////////////////// +// File: maxpool.h +// Description: Standard Max-Pooling layer. +// Author: Ray Smith +// Created: Tue Mar 18 16:28:18 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "maxpool.h" +#include "tprintf.h" + +namespace tesseract { + +Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale) + : Reconfig(name, ni, x_scale, y_scale) { + type_ = NT_MAXPOOL; + no_ = ni; +} + +Maxpool::~Maxpool() { +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Maxpool::DeSerialize(bool swap, TFile* fp) { + bool result = Reconfig::DeSerialize(swap, fp); + no_ = ni_; + return result; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Maxpool::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + output->ResizeScaled(input, x_scale_, y_scale_, no_); + maxes_.ResizeNoInit(output->Width(), ni_); + back_map_ = input.stride_map(); + + StrideMap::Index dest_index(output->stride_map()); + do { + int out_t = dest_index.t(); + StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH), + dest_index.index(FD_HEIGHT) * y_scale_, + dest_index.index(FD_WIDTH) * x_scale_); + // Find the max input out of x_scale_ groups of y_scale_ inputs. + // Do it independently for each input dimension. + int* max_line = maxes_[out_t]; + int in_t = src_index.t(); + output->CopyTimeStepFrom(out_t, input, in_t); + for (int i = 0; i < ni_; ++i) { + max_line[i] = in_t; + } + for (int x = 0; x < x_scale_; ++x) { + for (int y = 0; y < y_scale_; ++y) { + StrideMap::Index src_xy(src_index); + if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) { + output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line); + } + } + } + } while (dest_index.Increment()); +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_); + back_deltas->MaxpoolBackward(fwd_deltas, maxes_); + return true; +} + + +} // namespace tesseract. + diff --git a/lstm/maxpool.h b/lstm/maxpool.h new file mode 100644 index 0000000000..1f742a9d3c --- /dev/null +++ b/lstm/maxpool.h @@ -0,0 +1,71 @@ +/////////////////////////////////////////////////////////////////////// +// File: maxpool.h +// Description: Standard Max-Pooling layer. +// Author: Ray Smith +// Created: Tue Mar 18 16:28:18 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_MAXPOOL_H_ +#define TESSERACT_LSTM_MAXPOOL_H_ + +#include "reconfig.h" + +namespace tesseract { + +// Maxpooling reduction. Independently for each input, selects the location +// in the rectangle that contains the max value. +// Backprop propagates only to the position that was the max. +class Maxpool : public Reconfig { + public: + Maxpool(const STRING& name, int ni, int x_scale, int y_scale); + virtual ~Maxpool(); + + // Accessors. + virtual STRING spec() const { + STRING spec; + spec.add_str_int("Mp", y_scale_); + spec.add_str_int(",", x_scale_); + return spec; + } + + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + private: + // Memory of which input was the max. + GENERIC_2D_ARRAY maxes_; +}; + + +} // namespace tesseract. + + + + + +#endif // TESSERACT_LSTM_MAXPOOL_H_ + diff --git a/lstm/network.cpp b/lstm/network.cpp new file mode 100644 index 0000000000..3120a3f70a --- /dev/null +++ b/lstm/network.cpp @@ -0,0 +1,309 @@ +/////////////////////////////////////////////////////////////////////// +// File: network.cpp +// Description: Base class for neural network implementations. +// Author: Ray Smith +// Created: Wed May 01 17:25:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "network.h" + +#include + +// This base class needs to know about all its sub-classes because of the +// factory deserializing method: CreateFromFile. +#include "allheaders.h" +#include "convolve.h" +#include "fullyconnected.h" +#include "input.h" +#include "lstm.h" +#include "maxpool.h" +#include "parallel.h" +#include "reconfig.h" +#include "reversed.h" +#include "scrollview.h" +#include "series.h" +#include "statistc.h" +#ifdef INCLUDE_TENSORFLOW +#include "tfnetwork.h" +#endif +#include "tprintf.h" + +namespace tesseract { + +// Min and max window sizes. +const int kMinWinSize = 500; +const int kMaxWinSize = 2000; +// Window frame sizes need adding on to make the content fit. +const int kXWinFrameSize = 30; +const int kYWinFrameSize = 80; + +// String names corresponding to the NetworkType enum. Keep in sync. +// Names used in Serialization to allow re-ordering/addition/deletion of +// layer types in NetworkType without invalidating existing network files. +char const* const Network::kTypeNames[NT_COUNT] = { + "Invalid", "Input", + "Convolve", "Maxpool", + "Parallel", "Replicated", + "ParBidiLSTM", "DepParUDLSTM", + "Par2dLSTM", "Series", + "Reconfig", "RTLReversed", + "TTBReversed", "XYTranspose", + "LSTM", "SummLSTM", + "Logistic", "LinLogistic", + "LinTanh", "Tanh", + "Relu", "Linear", + "Softmax", "SoftmaxNoCTC", + "LSTMSoftmax", "LSTMBinarySoftmax", + "TensorFlow", +}; + +Network::Network() + : type_(NT_NONE), training_(true), needs_to_backprop_(true), + network_flags_(0), ni_(0), no_(0), num_weights_(0), + forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) { +} +Network::Network(NetworkType type, const STRING& name, int ni, int no) + : type_(type), training_(true), needs_to_backprop_(true), + network_flags_(0), ni_(ni), no_(no), num_weights_(0), + name_(name), forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) { +} + +Network::~Network() { +} + +// Ends training by setting the training_ flag to false. Serialize and +// DeSerialize will now only operate on the run-time data. +void Network::SetEnableTraining(bool state) { + training_ = state; +} + +// Sets flags that control the action of the network. See NetworkFlags enum +// for bit values. +void Network::SetNetworkFlags(uinT32 flags) { + network_flags_ = flags; +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +int Network::InitWeights(float range, TRand* randomizer) { + randomizer_ = randomizer; + return 0; +} + +// Provides a pointer to a TRand for any networks that care to use it. +// Note that randomizer is a borrowed pointer that should outlive the network +// and should not be deleted by any of the networks. +void Network::SetRandomizer(TRand* randomizer) { + randomizer_ = randomizer; +} + +// Sets needs_to_backprop_ to needs_backprop and returns true if +// needs_backprop || any weights in this network so the next layer forward +// can be told to produce backprop for this layer if needed. +bool Network::SetupNeedsBackprop(bool needs_backprop) { + needs_to_backprop_ = needs_backprop; + return needs_backprop || num_weights_ > 0; +} + +// Writes to the given file. Returns false in case of error. +bool Network::Serialize(TFile* fp) const { + inT8 data = NT_NONE; + if (fp->FWrite(&data, sizeof(data), 1) != 1) return false; + STRING type_name = kTypeNames[type_]; + if (!type_name.Serialize(fp)) return false; + data = training_; + if (fp->FWrite(&data, sizeof(data), 1) != 1) return false; + data = needs_to_backprop_; + if (fp->FWrite(&data, sizeof(data), 1) != 1) return false; + if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false; + if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false; + if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false; + if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false; + if (!name_.Serialize(fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +// Should be overridden by subclasses, but NOT called by their DeSerialize. +bool Network::DeSerialize(bool swap, TFile* fp) { + inT8 data = 0; + if (fp->FRead(&data, sizeof(data), 1) != 1) return false; + if (data == NT_NONE) { + STRING type_name; + if (!type_name.DeSerialize(swap, fp)) return false; + for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) { + } + if (data == NT_COUNT) { + tprintf("Invalid network layer type:%s\n", type_name.string()); + return false; + } + } + type_ = static_cast(data); + if (fp->FRead(&data, sizeof(data), 1) != 1) return false; + training_ = data != 0; + if (fp->FRead(&data, sizeof(data), 1) != 1) return false; + needs_to_backprop_ = data != 0; + if (fp->FRead(&network_flags_, sizeof(network_flags_), 1) != 1) return false; + if (fp->FRead(&ni_, sizeof(ni_), 1) != 1) return false; + if (fp->FRead(&no_, sizeof(no_), 1) != 1) return false; + if (fp->FRead(&num_weights_, sizeof(num_weights_), 1) != 1) return false; + if (!name_.DeSerialize(swap, fp)) return false; + if (swap) { + ReverseN(&network_flags_, sizeof(network_flags_)); + ReverseN(&ni_, sizeof(ni_)); + ReverseN(&no_, sizeof(no_)); + ReverseN(&num_weights_, sizeof(num_weights_)); + } + return true; +} + +// Reads from the given file. Returns NULL in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +// Determines the type of the serialized class and calls its DeSerialize +// on a new object of the appropriate type, which is returned. +Network* Network::CreateFromFile(bool swap, TFile* fp) { + Network stub; + if (!stub.DeSerialize(swap, fp)) return NULL; + Network* network = NULL; + switch (stub.type_) { + case NT_CONVOLVE: + network = new Convolve(stub.name_, stub.ni_, 0, 0); + break; + case NT_INPUT: + network = new Input(stub.name_, stub.ni_, stub.no_); + break; + case NT_LSTM: + case NT_LSTM_SOFTMAX: + case NT_LSTM_SOFTMAX_ENCODED: + case NT_LSTM_SUMMARY: + network = + new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_); + break; + case NT_MAXPOOL: + network = new Maxpool(stub.name_, stub.ni_, 0, 0); + break; + // All variants of Parallel. + case NT_PARALLEL: + case NT_REPLICATED: + case NT_PAR_RL_LSTM: + case NT_PAR_UD_LSTM: + case NT_PAR_2D_LSTM: + network = new Parallel(stub.name_, stub.type_); + break; + case NT_RECONFIG: + network = new Reconfig(stub.name_, stub.ni_, 0, 0); + break; + // All variants of reversed. + case NT_XREVERSED: + case NT_YREVERSED: + case NT_XYTRANSPOSE: + network = new Reversed(stub.name_, stub.type_); + break; + case NT_SERIES: + network = new Series(stub.name_); + break; + case NT_TENSORFLOW: +#ifdef INCLUDE_TENSORFLOW + network = new TFNetwork(stub.name_); +#else + tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n"); + return NULL; +#endif + break; + // All variants of FullyConnected. + case NT_SOFTMAX: + case NT_SOFTMAX_NO_CTC: + case NT_RELU: + case NT_TANH: + case NT_LINEAR: + case NT_LOGISTIC: + case NT_POSCLIP: + case NT_SYMCLIP: + network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_); + break; + default: + return NULL; + } + network->training_ = stub.training_; + network->needs_to_backprop_ = stub.needs_to_backprop_; + network->network_flags_ = stub.network_flags_; + network->num_weights_ = stub.num_weights_; + if (!network->DeSerialize(swap, fp)) { + delete network; + return NULL; + } + return network; +} + +// Returns a random number in [-range, range]. +double Network::Random(double range) { + ASSERT_HOST(randomizer_ != NULL); + return randomizer_->SignedRand(range); +} + +#ifndef GRAPHICS_DISABLED +// === Debug image display methods. === +// Displays the image of the matrix to the forward window. +void Network::DisplayForward(const NetworkIO& matrix) { + Pix* image = matrix.ToPix(); + ClearWindow(false, name_.string(), pixGetWidth(image), + pixGetHeight(image), &forward_win_); + DisplayImage(image, forward_win_); + forward_win_->Update(); +} + +// Displays the image of the matrix to the backward window. +void Network::DisplayBackward(const NetworkIO& matrix) { + Pix* image = matrix.ToPix(); + STRING window_name = name_ + "-back"; + ClearWindow(false, window_name.string(), pixGetWidth(image), + pixGetHeight(image), &backward_win_); + DisplayImage(image, backward_win_); + backward_win_->Update(); +} + +// Creates the window if needed, otherwise clears it. +void Network::ClearWindow(bool tess_coords, const char* window_name, + int width, int height, ScrollView** window) { + if (*window == NULL) { + int min_size = MIN(width, height); + if (min_size < kMinWinSize) { + if (min_size < 1) min_size = 1; + width = width * kMinWinSize / min_size; + height = height * kMinWinSize / min_size; + } + width += kXWinFrameSize; + height += kYWinFrameSize; + if (width > kMaxWinSize) width = kMaxWinSize; + if (height > kMaxWinSize) height = kMaxWinSize; + *window = new ScrollView(window_name, 80, 100, width, height, width, height, + tess_coords); + tprintf("Created window %s of size %d, %d\n", window_name, width, height); + } else { + (*window)->Clear(); + } +} + +// Displays the pix in the given window. and returns the height of the pix. +// The pix is pixDestroyed. +int Network::DisplayImage(Pix* pix, ScrollView* window) { + int height = pixGetHeight(pix); + window->Image(pix, 0, 0); + pixDestroy(&pix); + return height; +} +#endif // GRAPHICS_DISABLED + +} // namespace tesseract. diff --git a/lstm/network.h b/lstm/network.h new file mode 100644 index 0000000000..edd04b4f6d --- /dev/null +++ b/lstm/network.h @@ -0,0 +1,292 @@ +/////////////////////////////////////////////////////////////////////// +// File: network.h +// Description: Base class for neural network implementations. +// Author: Ray Smith +// Created: Wed May 01 16:38:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_NETWORK_H_ +#define TESSERACT_LSTM_NETWORK_H_ + +#include +#include + +#include "genericvector.h" +#include "helpers.h" +#include "matrix.h" +#include "networkio.h" +#include "serialis.h" +#include "static_shape.h" +#include "tprintf.h" + +struct Pix; +class ScrollView; +class TBOX; + +namespace tesseract { + +class ImageData; +class NetworkScratch; + +// Enum to store the run-time type of a Network. Keep in sync with kTypeNames. +enum NetworkType { + NT_NONE, // The naked base class. + NT_INPUT, // Inputs from an image. + // Plumbing networks combine other networks or rearrange the inputs. + NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood. + NT_MAXPOOL, // Chooses the max result from a rectangle. + NT_PARALLEL, // Runs networks in parallel. + NT_REPLICATED, // Runs identical networks in parallel. + NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel. + NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel. + NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel. + NT_SERIES, // Executes a sequence of layers. + NT_RECONFIG, // Scales the time/y size but makes the output deeper. + NT_XREVERSED, // Reverses the x direction of the inputs/outputs. + NT_YREVERSED, // Reverses the y-direction of the inputs/outputs. + NT_XYTRANSPOSE, // Transposes x and y (for just a single op). + // Functional networks actually calculate stuff. + NT_LSTM, // Long-Short-Term-Memory block. + NT_LSTM_SUMMARY, // LSTM that only keeps its last output. + NT_LOGISTIC, // Fully connected logistic nonlinearity. + NT_POSCLIP, // Fully connected rect lin version of logistic. + NT_SYMCLIP, // Fully connected rect lin version of tanh. + NT_TANH, // Fully connected with tanh nonlinearity. + NT_RELU, // Fully connected with rectifier nonlinearity. + NT_LINEAR, // Fully connected with no nonlinearity. + NT_SOFTMAX, // Softmax uses exponential normalization, with CTC. + NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC. + // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with + // the outputs fed back to the input of the LSTM at the next timestep. + // The ENCODED version binary encodes the softmax outputs, providing log2 of + // the number of outputs as additional inputs, and the other version just + // provides all the softmax outputs as additional inputs. + NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax. + NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax. + // A TensorFlow graph encapsulated as a Tesseract network. + NT_TENSORFLOW, + + NT_COUNT // Array size. +}; + +// Enum of Network behavior flags. Can in theory be set for each individual +// network element. +enum NetworkFlags { + // Network forward/backprop behavior. + NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer. + NF_ADA_GRAD = 128, // Weight-specific learning rate. +}; + +// Base class for network types. Not quite an abstract base class, but almost. +// Most of the time no isolated Network exists, except prior to +// deserialization. +class Network { + public: + Network(); + Network(NetworkType type, const STRING& name, int ni, int no); + virtual ~Network(); + + // Accessors. + NetworkType type() const { + return type_; + } + bool training() const { + return training_; + } + bool needs_to_backprop() const { + return needs_to_backprop_; + } + int num_weights() const { return num_weights_; } + int NumInputs() const { + return ni_; + } + int NumOutputs() const { + return no_; + } + // Returns the required shape input to the network. + virtual StaticShape InputShape() const { + StaticShape result; + return result; + } + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const { + StaticShape result(input_shape); + result.set_depth(no_); + return result; + } + const STRING& name() const { + return name_; + } + virtual STRING spec() const { + return "?"; + } + bool TestFlag(NetworkFlags flag) const { + return (network_flags_ & flag) != 0; + } + + // Initialization and administrative functions that are mostly provided + // by Plumbing. + // Returns true if the given type is derived from Plumbing, and thus contains + // multiple sub-networks that can have their own learning rate. + virtual bool IsPlumbingType() const { return false; } + + // Suspends/Enables training by setting the training_ flag. Serialize and + // DeSerialize only operate on the run-time data if state is false. + virtual void SetEnableTraining(bool state); + + // Sets flags that control the action of the network. See NetworkFlags enum + // for bit values. + virtual void SetNetworkFlags(uinT32 flags); + + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + // Note that randomizer is a borrowed pointer that should outlive the network + // and should not be deleted by any of the networks. + // Returns the number of weights initialized. + virtual int InitWeights(float range, TRand* randomizer); + + // Converts a float network to an int network. + virtual void ConvertToInt() {} + + // Provides a pointer to a TRand for any networks that care to use it. + // Note that randomizer is a borrowed pointer that should outlive the network + // and should not be deleted by any of the networks. + virtual void SetRandomizer(TRand* randomizer); + + // Sets needs_to_backprop_ to needs_backprop and returns true if + // needs_backprop || any weights in this network so the next layer forward + // can be told to produce backprop for this layer if needed. + virtual bool SetupNeedsBackprop(bool needs_backprop); + + // Returns the most recent reduction factor that the network applied to the + // time sequence. Assumes that any 2-d is already eliminated. Used for + // scaling bounding boxes of truth data and calculating result bounding boxes. + // WARNING: if GlobalMinimax is used to vary the scale, this will return + // the last used scale factor. Call it before any forward, and it will return + // the minimum scale factor of the paths through the GlobalMinimax. + virtual int XScaleFactor() const { + return 1; + } + + // Provides the (minimum) x scale factor to the network (of interest only to + // input units) so they can determine how to scale bounding boxes. + virtual void CacheXScaleFactor(int factor) {} + + // Provides debug output on the weights. + virtual void DebugWeights() { + tprintf("Must override Network::DebugWeights for type %d\n", type_); + } + + // Writes to the given file. Returns false in case of error. + // Should be overridden by subclasses, but called by their Serialize. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + // Should be overridden by subclasses, but NOT called by their DeSerialize. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Updates the weights using the given learning rate and momentum. + // num_samples is the quotient to be used in the adagrad computation iff + // use_ada_grad_ is true. + virtual void Update(float learning_rate, float momentum, int num_samples) {} + // Sums the products of weight updates in *this and other, splitting into + // positive (same direction) in *same and negative (different direction) in + // *changed. + virtual void CountAlternators(const Network& other, double* same, + double* changed) const {} + + // Reads from the given file. Returns NULL in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + // Determines the type of the serialized class and calls its DeSerialize + // on a new object of the appropriate type, which is returned. + static Network* CreateFromFile(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // Note that input and output are both 2-d arrays. + // The 1st index is the time element. In a 1-d network, it might be the pixel + // position on the textline. In a 2-d network, the linearization is defined + // by the stride_map. (See networkio.h). + // The 2nd index of input is the network inputs/outputs, and the dimension + // of the input must match NumInputs() of this network. + // The output array will be resized as needed so that its 1st dimension is + // always equal to the number of output values, and its second dimension is + // always NumOutputs(). Note that all this detail is encapsulated away inside + // NetworkIO, as are the internals of the scratch memory space used by the + // network. See networkscratch.h for that. + // If input_transpose is not NULL, then it contains the transpose of input, + // and the caller guarantees that it will still be valid on the next call to + // backward. The callee is therefore at liberty to save the pointer and + // reference it on a call to backward. This is a bit ugly, but it makes it + // possible for a replicating parallel to calculate the input transpose once + // instead of all the replicated networks having to do it. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + tprintf("Must override Network::Forward for type %d\n", type_); + } + + // Runs backward propagation of errors on fwdX_deltas. + // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward. + // Returns false if back_deltas was not set, due to there being no point in + // propagating further backwards. Thus most complete networks will always + // return false from Backward! + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + tprintf("Must override Network::Backward for type %d\n", type_); + return false; + } + + // === Debug image display methods. === + // Displays the image of the matrix to the forward window. + void DisplayForward(const NetworkIO& matrix); + // Displays the image of the matrix to the backward window. + void DisplayBackward(const NetworkIO& matrix); + + // Creates the window if needed, otherwise clears it. + static void ClearWindow(bool tess_coords, const char* window_name, + int width, int height, ScrollView** window); + + // Displays the pix in the given window. and returns the height of the pix. + // The pix is pixDestroyed. + static int DisplayImage(Pix* pix, ScrollView* window); + + protected: + // Returns a random number in [-range, range]. + double Random(double range); + + protected: + NetworkType type_; // Type of the derived network class. + bool training_; // Are we currently training? + bool needs_to_backprop_; // This network needs to output back_deltas. + inT32 network_flags_; // Behavior control flags in NetworkFlags. + inT32 ni_; // Number of input values. + inT32 no_; // Number of output values. + inT32 num_weights_; // Number of weights in this and sub-network. + STRING name_; // A unique name for this layer. + + // NOT-serialized debug data. + ScrollView* forward_win_; // Recognition debug display window. + ScrollView* backward_win_; // Training debug display window. + TRand* randomizer_; // Random number generator. + + // Static serialized name/type_ mapping. Keep in sync with NetworkType. + static char const* const kTypeNames[NT_COUNT]; +}; + + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_NETWORK_H_ diff --git a/lstm/networkbuilder.cpp b/lstm/networkbuilder.cpp new file mode 100644 index 0000000000..053e092ba6 --- /dev/null +++ b/lstm/networkbuilder.cpp @@ -0,0 +1,488 @@ +/////////////////////////////////////////////////////////////////////// +// File: networkbuilder.h +// Description: Class to parse the network description language and +// build a corresponding network. +// Author: Ray Smith +// Created: Wed Jul 16 18:35:38 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "networkbuilder.h" +#include "convolve.h" +#include "fullyconnected.h" +#include "input.h" +#include "lstm.h" +#include "maxpool.h" +#include "network.h" +#include "parallel.h" +#include "reconfig.h" +#include "reversed.h" +#include "series.h" +#include "unicharset.h" + +namespace tesseract { + +// Builds a network with a network_spec in the network description +// language, to recognize a character set of num_outputs size. +// If append_index is non-negative, then *network must be non-null and the +// given network_spec will be appended to *network AFTER append_index, with +// the top of the input *network discarded. +// Note that network_spec is call by value to allow a non-const char* pointer +// into the string for BuildFromString. +// net_flags control network behavior according to the NetworkFlags enum. +// The resulting network is returned via **network. +// Returns false if something failed. +bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec, + int append_index, int net_flags, + float weight_range, TRand* randomizer, + Network** network) { + NetworkBuilder builder(num_outputs); + Series* bottom_series = NULL; + StaticShape input_shape; + if (append_index >= 0) { + // Split the current network after the given append_index. + ASSERT_HOST(*network != NULL && (*network)->type() == NT_SERIES); + Series* series = reinterpret_cast(*network); + Series* top_series = NULL; + series->SplitAt(append_index, &bottom_series, &top_series); + if (bottom_series == NULL || top_series == NULL) { + tprintf("Yikes! Splitting current network failed!!\n"); + return false; + } + input_shape = bottom_series->OutputShape(input_shape); + delete top_series; + } + char* str_ptr = &network_spec[0]; + *network = builder.BuildFromString(input_shape, &str_ptr); + if (*network == NULL) return false; + (*network)->SetNetworkFlags(net_flags); + (*network)->InitWeights(weight_range, randomizer); + (*network)->SetupNeedsBackprop(false); + if (bottom_series != NULL) { + bottom_series->AppendSeries(*network); + *network = bottom_series; + } + (*network)->CacheXScaleFactor((*network)->XScaleFactor()); + return true; +} + +// Helper skips whitespace. +static void SkipWhitespace(char** str) { + while (**str == ' ' || **str == '\t' || **str == '\n') ++*str; +} + +// Parses the given string and returns a network according to the network +// description language in networkbuilder.h +Network* NetworkBuilder::BuildFromString(const StaticShape& input_shape, + char** str) { + SkipWhitespace(str); + char code_ch = **str; + if (code_ch == '[') { + return ParseSeries(input_shape, nullptr, str); + } + if (input_shape.depth() == 0) { + // There must be an input at this point. + return ParseInput(str); + } + switch (code_ch) { + case '(': + return ParseParallel(input_shape, str); + case 'R': + return ParseR(input_shape, str); + case 'S': + return ParseS(input_shape, str); + case 'C': + return ParseC(input_shape, str); + case 'M': + return ParseM(input_shape, str); + case 'L': + return ParseLSTM(input_shape, str); + case 'F': + return ParseFullyConnected(input_shape, str); + case 'O': + return ParseOutput(input_shape, str); + default: + tprintf("Invalid network spec:%s\n", *str); + return nullptr; + } + return nullptr; +} + +// Parses an input specification and returns the result, which may include a +// series. +Network* NetworkBuilder::ParseInput(char** str) { + // There must be an input at this point. + int length = 0; + int batch, height, width, depth; + int num_converted = + sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length); + StaticShape shape; + shape.SetShape(batch, height, width, depth); + // num_converted may or may not include the length. + if (num_converted != 4 && num_converted != 5) { + tprintf("Must specify an input layer as the first layer, not %s!!\n", *str); + return nullptr; + } + *str += length; + Input* input = new Input("Input", shape); + // We want to allow [rest of net... or [rest of net... so we + // have to check explicitly for '[' here. + SkipWhitespace(str); + if (**str == '[') return ParseSeries(shape, input, str); + return input; +} + +// Parses a sequential series of networks, defined by [...]. +Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape, + Input* input_layer, char** str) { + StaticShape shape = input_shape; + Series* series = new Series("Series"); + ++*str; + if (input_layer != nullptr) { + series->AddToStack(input_layer); + shape = input_layer->OutputShape(shape); + } + Network* network = NULL; + while (**str != '\0' && **str != ']' && + (network = BuildFromString(shape, str)) != NULL) { + shape = network->OutputShape(shape); + series->AddToStack(network); + } + if (**str != ']') { + tprintf("Missing ] at end of [Series]!\n"); + delete series; + return NULL; + } + ++*str; + return series; +} + +// Parses a parallel set of networks, defined by (...). +Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape, + char** str) { + Parallel* parallel = new Parallel("Parallel", NT_PARALLEL); + ++*str; + Network* network = NULL; + while (**str != '\0' && **str != ')' && + (network = BuildFromString(input_shape, str)) != NULL) { + parallel->AddToStack(network); + } + if (**str != ')') { + tprintf("Missing ) at end of (Parallel)!\n"); + delete parallel; + return nullptr; + } + ++*str; + return parallel; +} + +// Parses a network that begins with 'R'. +Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) { + char dir = (*str)[1]; + if (dir == 'x' || dir == 'y') { + STRING name = "Reverse"; + name += dir; + *str += 2; + Network* network = BuildFromString(input_shape, str); + if (network == nullptr) return nullptr; + Reversed* rev = + new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED); + rev->SetNetwork(network); + return rev; + } + int replicas = strtol(*str + 1, str, 10); + if (replicas <= 0) { + tprintf("Invalid R spec!:%s\n", *str); + return nullptr; + } + Parallel* parallel = new Parallel("Replicated", NT_REPLICATED); + char* str_copy = *str; + for (int i = 0; i < replicas; ++i) { + str_copy = *str; + Network* network = BuildFromString(input_shape, &str_copy); + if (network == NULL) { + tprintf("Invalid replicated network!\n"); + delete parallel; + return nullptr; + } + parallel->AddToStack(network); + } + *str = str_copy; + return parallel; +} + +// Parses a network that begins with 'S'. +Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) { + int y = strtol(*str + 1, str, 10); + if (**str == ',') { + int x = strtol(*str + 1, str, 10); + if (y <= 0 || x <= 0) { + tprintf("Invalid S spec!:%s\n", *str); + return nullptr; + } + return new Reconfig("Reconfig", input_shape.depth(), x, y); + } else if (**str == '(') { + // TODO(rays) Add Generic reshape. + tprintf("Generic reshape not yet implemented!!\n"); + return nullptr; + } + tprintf("Invalid S spec!:%s\n", *str); + return nullptr; +} + +// Helper returns the fully-connected type for the character code. +static NetworkType NonLinearity(char func) { + switch (func) { + case 's': + return NT_LOGISTIC; + case 't': + return NT_TANH; + case 'r': + return NT_RELU; + case 'l': + return NT_LINEAR; + case 'm': + return NT_SOFTMAX; + case 'p': + return NT_POSCLIP; + case 'n': + return NT_SYMCLIP; + default: + return NT_NONE; + } +} + +// Parses a network that begins with 'C'. +Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) { + NetworkType type = NonLinearity((*str)[1]); + if (type == NT_NONE) { + tprintf("Invalid nonlinearity on C-spec!: %s\n", *str); + return nullptr; + } + int y = 0, x = 0, d = 0; + if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' || + (x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' || + (d = strtol(*str + 1, str, 10)) <= 0) { + tprintf("Invalid C spec!:%s\n", *str); + return nullptr; + } + if (x == 1 && y == 1) { + // No actual convolution. Just a FullyConnected on the current depth, to + // be slid over all batch,y,x. + return new FullyConnected("Conv1x1", input_shape.depth(), d, type); + } + Series* series = new Series("ConvSeries"); + Convolve* convolve = + new Convolve("Convolve", input_shape.depth(), x / 2, y / 2); + series->AddToStack(convolve); + StaticShape fc_input = convolve->OutputShape(input_shape); + series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type)); + return series; +} + +// Parses a network that begins with 'M'. +Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) { + int y = 0, x = 0; + if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 || + **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) { + tprintf("Invalid Mp spec!:%s\n", *str); + return nullptr; + } + return new Maxpool("Maxpool", input_shape.depth(), x, y); +} + +// Parses an LSTM network, either individual, bi- or quad-directional. +Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) { + bool two_d = false; + NetworkType type = NT_LSTM; + char* spec_start = *str; + int chars_consumed = 1; + int num_outputs = 0; + char key = (*str)[chars_consumed], dir = 'f', dim = 'x'; + if (key == 'S') { + type = NT_LSTM_SOFTMAX; + num_outputs = num_softmax_outputs_; + ++chars_consumed; + } else if (key == 'E') { + type = NT_LSTM_SOFTMAX_ENCODED; + num_outputs = num_softmax_outputs_; + ++chars_consumed; + } else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') || + ((*str)[2] == 'y' && (*str)[3] == 'x'))) { + chars_consumed = 4; + dim = (*str)[3]; + two_d = true; + } else if (key == 'f' || key == 'r' || key == 'b') { + dir = key; + dim = (*str)[2]; + if (dim != 'x' && dim != 'y') { + tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str); + return nullptr; + } + chars_consumed = 3; + if ((*str)[chars_consumed] == 's') { + ++chars_consumed; + type = NT_LSTM_SUMMARY; + } + } else { + tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str); + return nullptr; + } + int num_states = strtol(*str + chars_consumed, str, 10); + if (num_states <= 0) { + tprintf("Invalid number of states in L Spec!:%s\n", *str); + return nullptr; + } + Network* lstm = nullptr; + if (two_d) { + lstm = BuildLSTMXYQuad(input_shape.depth(), num_states); + } else { + if (num_outputs == 0) num_outputs = num_states; + STRING name(spec_start, *str - spec_start); + lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, + type); + if (dir != 'f') { + Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED); + rev->SetNetwork(lstm); + lstm = rev; + } + if (dir == 'b') { + name += "LTR"; + Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM); + parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states, + num_outputs, false, type)); + parallel->AddToStack(lstm); + lstm = parallel; + } + } + if (dim == 'y') { + Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE); + rev->SetNetwork(lstm); + lstm = rev; + } + return lstm; +} + +// Builds a set of 4 lstms with x and y reversal, running in true parallel. +Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) { + Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM); + parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, + num_states, true, NT_LSTM)); + Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED); + rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, + true, NT_LSTM)); + parallel->AddToStack(rev); + rev = new Reversed("L2DRTLYRev", NT_YREVERSED); + rev->SetNetwork( + new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM)); + Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED); + rev2->SetNetwork(rev); + parallel->AddToStack(rev2); + rev = new Reversed("L2DXRevY", NT_YREVERSED); + rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, + true, NT_LSTM)); + parallel->AddToStack(rev); + return parallel; +} + +// Helper builds a truly (0-d) fully connected layer of the given type. +static Network* BuildFullyConnected(const StaticShape& input_shape, + NetworkType type, const STRING& name, + int depth) { + if (input_shape.height() == 0 || input_shape.width() == 0) { + tprintf("Fully connected requires positive height and width, had %d,%d\n", + input_shape.height(), input_shape.width()); + return nullptr; + } + int input_size = input_shape.height() * input_shape.width(); + int input_depth = input_size * input_shape.depth(); + Network* fc = new FullyConnected(name, input_depth, depth, type); + if (input_size > 1) { + Series* series = new Series("FCSeries"); + series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), + input_shape.width(), input_shape.height())); + series->AddToStack(fc); + fc = series; + } + return fc; +} + +// Parses a Fully connected network. +Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape, + char** str) { + char* spec_start = *str; + NetworkType type = NonLinearity((*str)[1]); + if (type == NT_NONE) { + tprintf("Invalid nonlinearity on F-spec!: %s\n", *str); + return nullptr; + } + int depth = strtol(*str + 1, str, 10); + if (depth <= 0) { + tprintf("Invalid F spec!:%s\n", *str); + return nullptr; + } + STRING name(spec_start, *str - spec_start); + return BuildFullyConnected(input_shape, type, name, depth); +} + +// Parses an Output spec. +Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape, + char** str) { + char dims_ch = (*str)[1]; + if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') { + tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str); + return nullptr; + } + char type_ch = (*str)[2]; + if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') { + tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str); + return nullptr; + } + int depth = strtol(*str + 3, str, 10); + if (depth != num_softmax_outputs_) { + tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth, + num_softmax_outputs_); + depth = num_softmax_outputs_; + } + NetworkType type = NT_SOFTMAX; + if (type_ch == 'l') + type = NT_LOGISTIC; + else if (type_ch == 's') + type = NT_SOFTMAX_NO_CTC; + if (dims_ch == '0') { + // Same as standard fully connected. + return BuildFullyConnected(input_shape, type, "Output", depth); + } else if (dims_ch == '2') { + // We don't care if x and/or y are variable. + return new FullyConnected("Output2d", input_shape.depth(), depth, type); + } + // For 1-d y has to be fixed, and if not 1, moved to depth. + if (input_shape.height() == 0) { + tprintf("Fully connected requires fixed height!\n"); + return nullptr; + } + int input_size = input_shape.height(); + int input_depth = input_size * input_shape.depth(); + Network* fc = new FullyConnected("Output", input_depth, depth, type); + if (input_size > 1) { + Series* series = new Series("FCSeries"); + series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, + input_shape.height())); + series->AddToStack(fc); + fc = series; + } + return fc; +} + +} // namespace tesseract. + diff --git a/lstm/networkbuilder.h b/lstm/networkbuilder.h new file mode 100644 index 0000000000..a405fc52b7 --- /dev/null +++ b/lstm/networkbuilder.h @@ -0,0 +1,160 @@ +/////////////////////////////////////////////////////////////////////// +// File: networkbuilder.h +// Description: Class to parse the network description language and +// build a corresponding network. +// Author: Ray Smith +// Created: Wed Jul 16 18:35:38 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_NETWORKBUILDER_H_ +#define TESSERACT_LSTM_NETWORKBUILDER_H_ + +#include "static_shape.h" +#include "stridemap.h" + +class STRING; +class UNICHARSET; + +namespace tesseract { + +class Input; +class Network; +class Parallel; +class TRand; + +class NetworkBuilder { + public: + explicit NetworkBuilder(int num_softmax_outputs) + : num_softmax_outputs_(num_softmax_outputs) {} + + // Builds a network with a network_spec in the network description + // language, to recognize a character set of num_outputs size. + // If append_index is non-negative, then *network must be non-null and the + // given network_spec will be appended to *network AFTER append_index, with + // the top of the input *network discarded. + // Note that network_spec is call by value to allow a non-const char* pointer + // into the string for BuildFromString. + // net_flags control network behavior according to the NetworkFlags enum. + // The resulting network is returned via **network. + // Returns false if something failed. + static bool InitNetwork(int num_outputs, STRING network_spec, + int append_index, int net_flags, float weight_range, + TRand* randomizer, Network** network); + + // Parses the given string and returns a network according to the following + // language: + // ============ Syntax of description below: ============ + // represents a number. + // represents any single network element, including (recursively) a + // [...] series or (...) parallel construct. + // (s|t|r|l|m) (regex notation) represents a single required letter. + // NOTE THAT THROUGHOUT, x and y are REVERSED from conventional mathematics, + // to use the same convention as Tensor Flow. The reason TF adopts this + // convention is to eliminate the need to transpose images on input, since + // adjacent memory locations in images increase x and then y, while adjacent + // memory locations in tensors in TF, and NetworkIO in tesseract increase the + // rightmost index first, then the next-left and so-on, like C arrays. + // ============ INPUTS ============ + // ,,, A batch of b images with height h, width w, and depth d. + // b, h and/or w may be zero, to indicate variable size. Some network layer + // (summarizing LSTM) must be used to make a variable h known. + // d may be 1 for greyscale, 3 for color. + // NOTE that throughout the constructed network, the inputs/outputs are all of + // the same [batch,height,width,depth] dimensions, even if a different size. + // ============ PLUMBING ============ + // [...] Execute ... networks in series (layers). + // (...) Execute ... networks in parallel, with their output depths added. + // R Execute d replicas of net in parallel, with their output depths + // added. + // Rx Execute with x-dimension reversal. + // Ry Execute with y-dimension reversal. + // S, Rescale 2-D input by shrink factor x,y, rearranging the data by + // increasing the depth of the input by factor xy. + // Mp, Maxpool the input, reducing the size by an (x,y) rectangle. + // ============ FUNCTIONAL UNITS ============ + // C(s|t|r|l|m),, Convolves using a (x,y) window, with no shrinkage, + // random infill, producing d outputs, then applies a non-linearity: + // s: Sigmoid, t: Tanh, r: Relu, l: Linear, m: Softmax. + // F(s|t|r|l|m) Truly fully-connected with s|t|r|l|m non-linearity and d + // outputs. Connects to every x,y,depth position of the input, reducing + // height, width to 1, producing a single vector as the output. + // Input height and width must be constant. + // For a sliding-window linear or non-linear map that connects just to the + // input depth, and leaves the input image size as-is, use a 1x1 convolution + // eg. Cr1,1,64 instead of Fr64. + // L(f|r|b)(x|y)[s] LSTM cell with n states/outputs. + // The LSTM must have one of: + // f runs the LSTM forward only. + // r runs the LSTM reversed only. + // b runs the LSTM bidirectionally. + // It will operate on either the x- or y-dimension, treating the other + // dimension independently (as if part of the batch). + // s (optional) summarizes the output in the requested dimension, + // outputting only the final step, collapsing the dimension to a + // single element. + // LS Forward-only LSTM cell in the x-direction, with built-in Softmax. + // LE Forward-only LSTM cell in the x-direction, with built-in softmax, + // with binary Encoding. + // L2xy Full 2-d LSTM operating in quad-directions (bidi in x and y) and + // all the output depths added. + // ============ OUTPUTS ============ + // The network description must finish with an output specification: + // O(2|1|0)(l|s|c) output layer with n classes + // 2 (heatmap) Output is a 2-d vector map of the input (possibly at + // different scale). + // 1 (sequence) Output is a 1-d sequence of vector values. + // 0 (category) Output is a 0-d single vector value. + // l uses a logistic non-linearity on the output, allowing multiple + // hot elements in any output vector value. + // s uses a softmax non-linearity, with one-hot output in each value. + // c uses a softmax with CTC. Can only be used with s (sequence). + // NOTE1: Only O1s and O1c are currently supported. + // NOTE2: n is totally ignored, and for compatibility purposes only. The + // output number of classes is obtained automatically from the + // unicharset. + Network* BuildFromString(const StaticShape& input_shape, char** str); + + private: + // Parses an input specification and returns the result, which may include a + // series. + Network* ParseInput(char** str); + // Parses a sequential series of networks, defined by [...]. + Network* ParseSeries(const StaticShape& input_shape, Input* input_layer, + char** str); + // Parses a parallel set of networks, defined by (...). + Network* ParseParallel(const StaticShape& input_shape, char** str); + // Parses a network that begins with 'R'. + Network* ParseR(const StaticShape& input_shape, char** str); + // Parses a network that begins with 'S'. + Network* ParseS(const StaticShape& input_shape, char** str); + // Parses a network that begins with 'C'. + Network* ParseC(const StaticShape& input_shape, char** str); + // Parses a network that begins with 'M'. + Network* ParseM(const StaticShape& input_shape, char** str); + // Parses an LSTM network, either individual, bi- or quad-directional. + Network* ParseLSTM(const StaticShape& input_shape, char** str); + // Builds a set of 4 lstms with t and y reversal, running in true parallel. + static Network* BuildLSTMXYQuad(int num_inputs, int num_states); + // Parses a Fully connected network. + Network* ParseFullyConnected(const StaticShape& input_shape, char** str); + // Parses an Output spec. + Network* ParseOutput(const StaticShape& input_shape, char** str); + + private: + int num_softmax_outputs_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_NETWORKBUILDER_H_ diff --git a/lstm/networkio.cpp b/lstm/networkio.cpp new file mode 100644 index 0000000000..1f793fd0d9 --- /dev/null +++ b/lstm/networkio.cpp @@ -0,0 +1,981 @@ +/////////////////////////////////////////////////////////////////////// +// File: networkio.cpp +// Description: Network input/output data, allowing float/int implementations. +// Author: Ray Smith +// Created: Thu Jun 19 13:01:31 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "networkio.h" + +#include "allheaders.h" +#include "functions.h" +#include "statistc.h" +#include "tprintf.h" + +namespace tesseract { + +// Minimum value to output for certainty. +const float kMinCertainty = -20.0f; +// Probability corresponding to kMinCertainty. +const float kMinProb = exp(kMinCertainty); + +// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim. +void NetworkIO::Resize2d(bool int_mode, int width, int num_features) { + stride_map_ = StrideMap(); + int_mode_ = int_mode; + if (int_mode_) { + i_.ResizeNoInit(width, num_features); + } else { + f_.ResizeNoInit(width, num_features); + } +} + +// Resizes to a specific stride_map. +void NetworkIO::ResizeToMap(bool int_mode, const StrideMap& stride_map, + int num_features) { + // If this assert fails, it most likely got here through an uninitialized + // scratch element, ie call NetworkScratch::IO::Resizexxx() not + // NetworkIO::Resizexxx()!! + ASSERT_HOST(this != NULL); + stride_map_ = stride_map; + int_mode_ = int_mode; + if (int_mode_) { + i_.ResizeNoInit(stride_map.Width(), num_features); + } else { + f_.ResizeNoInit(stride_map.Width(), num_features); + } + ZeroInvalidElements(); +} + +// Shrinks image size by x_scale,y_scale, and use given number of features. +void NetworkIO::ResizeScaled(const NetworkIO& src, + int x_scale, int y_scale, int num_features) { + StrideMap stride_map = src.stride_map_; + stride_map.ScaleXY(x_scale, y_scale); + ResizeToMap(src.int_mode_, stride_map, num_features); +} + +// Resizes to just 1 x-coord, whatever the input. +void NetworkIO::ResizeXTo1(const NetworkIO& src, int num_features) { + StrideMap stride_map = src.stride_map_; + stride_map.ReduceWidthTo1(); + ResizeToMap(src.int_mode_, stride_map, num_features); +} + +// Initialize all the array to zero. +void NetworkIO::Zero() { + int width = Width(); + // Zero out the everything. Column-by-column in case it is aligned. + for (int t = 0; t < width; ++t) { + ZeroTimeStep(t); + } +} + +// Initializes to zero all elements of the array that do not correspond to +// valid image positions. (If a batch of different-sized images are packed +// together, then there will be padding pixels.) +void NetworkIO::ZeroInvalidElements() { + int num_features = NumFeatures(); + int full_width = stride_map_.Size(FD_WIDTH); + int full_height = stride_map_.Size(FD_HEIGHT); + StrideMap::Index b_index(stride_map_); + do { + int end_x = b_index.MaxIndexOfDim(FD_WIDTH) + 1; + if (end_x < full_width) { + // The width is small, so fill for every valid y. + StrideMap::Index y_index(b_index); + int fill_size = num_features * (full_width - end_x); + do { + StrideMap::Index z_index(y_index); + z_index.AddOffset(end_x, FD_WIDTH); + if (int_mode_) { + ZeroVector(fill_size, i_[z_index.t()]); + } else { + ZeroVector(fill_size, f_[z_index.t()]); + } + } while (y_index.AddOffset(1, FD_HEIGHT)); + } + int end_y = b_index.MaxIndexOfDim(FD_HEIGHT) + 1; + if (end_y < full_height) { + // The height is small, so fill in the space in one go. + StrideMap::Index y_index(b_index); + y_index.AddOffset(end_y, FD_HEIGHT); + int fill_size = num_features * full_width * (full_height - end_y); + if (int_mode_) { + ZeroVector(fill_size, i_[y_index.t()]); + } else { + ZeroVector(fill_size, f_[y_index.t()]); + } + } + } while (b_index.AddOffset(1, FD_BATCH)); +} + +// Helper computes a black point and white point to contrast-enhance an image. +// The computation is based on the assumption that the image is of a single line +// of text, so a horizontal line through the middle of the image passes through +// at least some of it, so local minima and maxima are a good proxy for black +// and white pixel samples. +static void ComputeBlackWhite(Pix* pix, float* black, float* white) { + int width = pixGetWidth(pix); + int height = pixGetHeight(pix); + STATS mins(0, 256), maxes(0, 256); + if (width >= 3) { + int y = height / 2; + const l_uint32* line = pixGetData(pix) + pixGetWpl(pix) * y; + int prev = GET_DATA_BYTE(line, 0); + int curr = GET_DATA_BYTE(line, 1); + for (int x = 1; x + 1 < width; ++x) { + int next = GET_DATA_BYTE(line, x + 1); + if ((curr < prev && curr <= next) || (curr <= prev && curr < next)) { + // Local minimum. + mins.add(curr, 1); + } + if ((curr > prev && curr >= next) || (curr >= prev && curr > next)) { + // Local maximum. + maxes.add(curr, 1); + } + prev = curr; + curr = next; + } + } + if (mins.get_total() == 0) mins.add(0, 1); + if (maxes.get_total() == 0) maxes.add(255, 1); + *black = mins.ile(0.25); + *white = maxes.ile(0.75); +} + +// Sets up the array from the given image, using the currently set int_mode_. +// If the image width doesn't match the shape, the image is truncated or padded +// with noise to match. +void NetworkIO::FromPix(const StaticShape& shape, const Pix* pix, + TRand* randomizer) { + std::vector pixes(1, pix); + FromPixes(shape, pixes, randomizer); +} + +// Sets up the array from the given set of images, using the currently set +// int_mode_. If the image width doesn't match the shape, the images are +// truncated or padded with noise to match. +void NetworkIO::FromPixes(const StaticShape& shape, + const std::vector& pixes, + TRand* randomizer) { + int target_height = shape.height(); + int target_width = shape.width(); + std::vector> h_w_pairs; + for (auto pix : pixes) { + Pix* var_pix = const_cast(pix); + int width = pixGetWidth(var_pix); + if (target_width != 0) width = target_width; + int height = pixGetHeight(var_pix); + if (target_height != 0) height = target_height; + h_w_pairs.emplace_back(height, width); + } + stride_map_.SetStride(h_w_pairs); + ResizeToMap(int_mode(), stride_map_, shape.depth()); + // Iterate over the images again to copy the data. + for (int b = 0; b < pixes.size(); ++b) { + Pix* pix = const_cast(pixes[b]); + float black = 0.0f, white = 255.0f; + if (shape.depth() != 3) ComputeBlackWhite(pix, &black, &white); + float contrast = (white - black) / 2.0f; + if (contrast <= 0.0f) contrast = 1.0f; + if (shape.height() == 1) { + Copy1DGreyImage(b, pix, black, contrast, randomizer); + } else { + Copy2DImage(b, pix, black, contrast, randomizer); + } + } +} + +// Copies the given pix to *this at the given batch index, stretching and +// clipping the pixel values so that [black, black + 2*contrast] maps to the +// dynamic range of *this, ie [-1,1] for a float and (-127,127) for int. +// This is a 2-d operation in the sense that the output depth is the number +// of input channels, the height is the height of the image, and the width +// is the width of the image, or truncated/padded with noise if the width +// is a fixed size. +void NetworkIO::Copy2DImage(int batch, Pix* pix, float black, float contrast, + TRand* randomizer) { + int width = pixGetWidth(pix); + int height = pixGetHeight(pix); + int wpl = pixGetWpl(pix); + StrideMap::Index index(stride_map_); + index.AddOffset(batch, FD_BATCH); + int t = index.t(); + int target_height = stride_map_.Size(FD_HEIGHT); + int target_width = stride_map_.Size(FD_WIDTH); + int num_features = NumFeatures(); + bool color = num_features == 3; + if (width > target_width) width = target_width; + const uinT32* line = pixGetData(pix); + for (int y = 0; y < target_height; ++y, line += wpl) { + int x = 0; + if (y < height) { + for (x = 0; x < width; ++x, ++t) { + if (color) { + int f = 0; + for (int c = COLOR_RED; c <= COLOR_BLUE; ++c) { + int pixel = GET_DATA_BYTE(line + x, c); + SetPixel(t, f++, pixel, black, contrast); + } + } else { + int pixel = GET_DATA_BYTE(line, x); + SetPixel(t, 0, pixel, black, contrast); + } + } + } + for (; x < target_width; ++x) Randomize(t++, 0, num_features, randomizer); + } +} + +// Copies the given pix to *this at the given batch index, as Copy2DImage +// above, except that the output depth is the height of the input image, the +// output height is 1, and the output width as for Copy2DImage. +// The image is thus treated as a 1-d set of vertical pixel strips. +void NetworkIO::Copy1DGreyImage(int batch, Pix* pix, float black, + float contrast, TRand* randomizer) { + int width = pixGetWidth(pix); + int height = pixGetHeight(pix); + ASSERT_HOST(height == NumFeatures()); + int wpl = pixGetWpl(pix); + StrideMap::Index index(stride_map_); + index.AddOffset(batch, FD_BATCH); + int t = index.t(); + int target_width = stride_map_.Size(FD_WIDTH); + if (width > target_width) width = target_width; + int x; + for (x = 0; x < width; ++x, ++t) { + for (int y = 0; y < height; ++y) { + const uinT32* line = pixGetData(pix) + wpl * y; + int pixel = GET_DATA_BYTE(line, x); + SetPixel(t, y, pixel, black, contrast); + } + } + for (; x < target_width; ++x) Randomize(t++, 0, height, randomizer); +} + +// Helper stores the pixel value in i_ or f_ according to int_mode_. +// t: is the index from the StrideMap corresponding to the current +// [batch,y,x] position +// f: is the index into the depth/channel +// pixel: the value of the pixel from the image (in one channel) +// black: the pixel value to map to the lowest of the range of *this +// contrast: the range of pixel values to stretch to half the range of *this. +void NetworkIO::SetPixel(int t, int f, int pixel, float black, float contrast) { + float float_pixel = (pixel - black) / contrast - 1.0f; + if (int_mode_) { + i_[t][f] = ClipToRange(IntCastRounded((MAX_INT8 + 1) * float_pixel), + -MAX_INT8, MAX_INT8); + } else { + f_[t][f] = float_pixel; + } +} + +// Converts the array to a Pix. Must be pixDestroyed after use. +Pix* NetworkIO::ToPix() const { + // Count the width of the image, and find the max multiplication factor. + int im_width = stride_map_.Size(FD_WIDTH); + int im_height = stride_map_.Size(FD_HEIGHT); + int num_features = NumFeatures(); + int feature_factor = 1; + if (num_features == 3) { + // Special hack for color. + num_features = 1; + feature_factor = 3; + } + Pix* pix = pixCreate(im_width, im_height * num_features, 32); + StrideMap::Index index(stride_map_); + do { + int im_x = index.index(FD_WIDTH); + int top_im_y = index.index(FD_HEIGHT); + int im_y = top_im_y; + int t = index.t(); + if (int_mode_) { + const inT8* features = i_[t]; + for (int y = 0; y < num_features; ++y, im_y += im_height) { + int pixel = features[y * feature_factor]; + // 1 or 2 features use greyscale. + int red = ClipToRange(pixel + 128, 0, 255); + int green = red, blue = red; + if (feature_factor == 3) { + // With 3 features assume RGB color. + green = ClipToRange(features[y * feature_factor + 1] + 128, 0, 255); + blue = ClipToRange(features[y * feature_factor + 2] + 128, 0, 255); + } else if (num_features > 3) { + // More than 3 features use false yellow/blue color, assuming a signed + // input in the range [-1,1]. + red = abs(pixel) * 2; + if (pixel >= 0) { + green = red; + blue = 0; + } else { + blue = red; + green = red = 0; + } + } + pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) | + (green << L_GREEN_SHIFT) | + (blue << L_BLUE_SHIFT)); + } + } else { + const float* features = f_[t]; + for (int y = 0; y < num_features; ++y, im_y += im_height) { + float pixel = features[y * feature_factor]; + // 1 or 2 features use greyscale. + int red = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255); + int green = red, blue = red; + if (feature_factor == 3) { + // With 3 features assume RGB color. + pixel = features[y * feature_factor + 1]; + green = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255); + pixel = features[y * feature_factor + 2]; + blue = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255); + } else if (num_features > 3) { + // More than 3 features use false yellow/blue color, assuming a signed + // input in the range [-1,1]. + red = ClipToRange(IntCastRounded(fabs(pixel) * 255), 0, 255); + if (pixel >= 0) { + green = red; + blue = 0; + } else { + blue = red; + green = red = 0; + } + } + pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) | + (green << L_GREEN_SHIFT) | + (blue << L_BLUE_SHIFT)); + } + } + } while (index.Increment()); + return pix; +} + +// Prints the first and last num timesteps of the array for each feature. +void NetworkIO::Print(int num) const { + int num_features = NumFeatures(); + for (int y = 0; y < num_features; ++y) { + for (int t = 0; t < Width(); ++t) { + if (num == 0 || t < num || t + num >= Width()) { + if (int_mode_) { + tprintf(" %g", static_cast(i_[t][y]) / MAX_INT8); + } else { + tprintf(" %g", f_[t][y]); + } + } + } + tprintf("\n"); + } +} + +// Copies a single time step from src. +void NetworkIO::CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t) { + ASSERT_HOST(int_mode_ == src.int_mode_); + if (int_mode_) { + memcpy(i_[dest_t], src.i_[src_t], i_.dim2() * sizeof(i_[0][0])); + } else { + memcpy(f_[dest_t], src.f_[src_t], f_.dim2() * sizeof(f_[0][0])); + } +} + +// Copies a part of single time step from src. +void NetworkIO::CopyTimeStepGeneral(int dest_t, int dest_offset, + int num_features, const NetworkIO& src, + int src_t, int src_offset) { + ASSERT_HOST(int_mode_ == src.int_mode_); + if (int_mode_) { + memcpy(i_[dest_t] + dest_offset, src.i_[src_t] + src_offset, + num_features * sizeof(i_[0][0])); + } else { + memcpy(f_[dest_t] + dest_offset, src.f_[src_t] + src_offset, + num_features * sizeof(f_[0][0])); + } +} + +// Zeroes a single time step. +void NetworkIO::ZeroTimeStepGeneral(int t, int offset, int num_features) { + if (int_mode_) { + ZeroVector(num_features, i_[t] + offset); + } else { + ZeroVector(num_features, f_[t] + offset); + } +} + +// Sets the given range to random values. +void NetworkIO::Randomize(int t, int offset, int num_features, + TRand* randomizer) { + if (int_mode_) { + inT8* line = i_[t] + offset; + for (int i = 0; i < num_features; ++i) + line[i] = IntCastRounded(randomizer->SignedRand(MAX_INT8)); + } else { + // float mode. + float* line = f_[t] + offset; + for (int i = 0; i < num_features; ++i) + line[i] = randomizer->SignedRand(1.0); + } +} + +// Helper returns the label and score of the best choice over a range. +int NetworkIO::BestChoiceOverRange(int t_start, int t_end, int not_this, + int null_ch, float* rating, + float* certainty) const { + if (t_end <= t_start) return -1; + int max_char = -1; + float min_score = 0.0f; + for (int c = 0; c < NumFeatures(); ++c) { + if (c == not_this || c == null_ch) continue; + ScoresOverRange(t_start, t_end, c, null_ch, rating, certainty); + if (max_char < 0 || *rating < min_score) { + min_score = *rating; + max_char = c; + } + } + ScoresOverRange(t_start, t_end, max_char, null_ch, rating, certainty); + return max_char; +} + +// Helper returns the rating and certainty of the choice over a range in output. +void NetworkIO::ScoresOverRange(int t_start, int t_end, int choice, int null_ch, + float* rating, float* certainty) const { + ASSERT_HOST(!int_mode_); + *rating = 0.0f; + *certainty = 0.0f; + if (t_end <= t_start || t_end <= 0) return; + float ratings[3] = {0.0f, 0.0f, 0.0f}; + float certs[3] = {0.0f, 0.0f, 0.0f}; + for (int t = t_start; t < t_end; ++t) { + const float* line = f_[t]; + float score = ProbToCertainty(line[choice]); + float zero = ProbToCertainty(line[null_ch]); + if (t == t_start) { + ratings[2] = MAX_FLOAT32; + ratings[1] = -score; + certs[1] = score; + } else { + for (int i = 2; i >= 1; --i) { + if (ratings[i] > ratings[i - 1]) { + ratings[i] = ratings[i - 1]; + certs[i] = certs[i - 1]; + } + } + ratings[2] -= zero; + if (zero < certs[2]) certs[2] = zero; + ratings[1] -= score; + if (score < certs[1]) certs[1] = score; + } + ratings[0] -= zero; + if (zero < certs[0]) certs[0] = zero; + } + int best_i = ratings[2] < ratings[1] ? 2 : 1; + *rating = ratings[best_i] + t_end - t_start; + *certainty = certs[best_i]; +} + +// Returns the index (label) of the best value at the given timestep, +// excluding not_this and not_that, and if not null, sets the score to the +// log of the corresponding value. +int NetworkIO::BestLabel(int t, int not_this, int not_that, + float* score) const { + ASSERT_HOST(!int_mode_); + int best_index = -1; + float best_score = -MAX_FLOAT32; + const float* line = f_[t]; + for (int i = 0; i < f_.dim2(); ++i) { + if (line[i] > best_score && i != not_this && i != not_that) { + best_score = line[i]; + best_index = i; + } + } + if (score != NULL) *score = ProbToCertainty(best_score); + return best_index; +} + +// Returns the best start position out of [start, end) (into which all labels +// must fit) to obtain the highest cumulative score for the given labels. +int NetworkIO::PositionOfBestMatch(const GenericVector& labels, int start, + int end) const { + int length = labels.size(); + int last_start = end - length; + int best_start = -1; + double best_score = 0.0; + for (int s = start; s <= last_start; ++s) { + double score = ScoreOfLabels(labels, s); + if (score > best_score || best_start < 0) { + best_score = score; + best_start = s; + } + } + return best_start; +} + +// Returns the cumulative score of the given labels starting at start, and +// using one label per time-step. +double NetworkIO::ScoreOfLabels(const GenericVector& labels, + int start) const { + int length = labels.size(); + double score = 0.0; + for (int i = 0; i < length; ++i) { + score += f_(start + i, labels[i]); + } + return score; +} + +// Helper function sets all the outputs for a single timestep, such that +// label has value ok_score, and the other labels share 1 - ok_score. +void NetworkIO::SetActivations(int t, int label, float ok_score) { + ASSERT_HOST(!int_mode_); + int num_classes = NumFeatures(); + float bad_score = (1.0f - ok_score) / (num_classes - 1); + float* targets = f_[t]; + for (int i = 0; i < num_classes; ++i) + targets[i] = bad_score; + targets[label] = ok_score; +} + +// Modifies the values, only if needed, so that the given label is +// the winner at the given time step t. +void NetworkIO::EnsureBestLabel(int t, int label) { + ASSERT_HOST(!int_mode_); + if (BestLabel(t, NULL) != label) { + // Output value needs enhancing. Third all the other elements and add the + // remainder to best_label. + int num_classes = NumFeatures(); + float* targets = f_[t]; + float enhancement = (1.0f - targets[label]) / 3.0f; + for (int c = 0; c < num_classes; ++c) { + if (c == label) { + targets[c] += (1.0 - targets[c]) * (2 / 3.0); + } else { + targets[c] /= 3.0; + } + } + } +} + +// Helper function converts prob to certainty taking the minimum into account. +/* static */ +float NetworkIO::ProbToCertainty(float prob) { + return prob > kMinProb ? log(prob) : kMinCertainty; +} + +// Returns true if there is any bad value that is suspiciously like a GT +// error. Assuming that *this is the difference(gradient) between target +// and forward output, returns true if there is a large negative value +// (correcting a very confident output) for which there is no corresponding +// positive value in an adjacent timestep for the same feature index. This +// allows the box-truthed samples to make fine adjustments to position while +// stopping other disagreements of confident output with ground truth. +bool NetworkIO::AnySuspiciousTruth(float confidence_thr) const { + int num_features = NumFeatures(); + for (int t = 0; t < Width(); ++t) { + const float* features = f_[t]; + for (int y = 0; y < num_features; ++y) { + float grad = features[y]; + if (grad < -confidence_thr) { + // Correcting strong output. Check for movement. + if ((t == 0 || f_[t - 1][y] < confidence_thr / 2) && + (t + 1 == Width() || f_[t + 1][y] < confidence_thr / 2)) { + return true; // No strong positive on either side. + } + } + } + } + return false; +} + +// Reads a single timestep to floats in the range [-1, 1]. +void NetworkIO::ReadTimeStep(int t, double* output) const { + if (int_mode_) { + const inT8* line = i_[t]; + for (int i = 0; i < i_.dim2(); ++i) { + output[i] = static_cast(line[i]) / MAX_INT8; + } + } else { + const float* line = f_[t]; + for (int i = 0; i < f_.dim2(); ++i) { + output[i] = static_cast(line[i]); + } + } +} + +// Adds a single timestep to floats. +void NetworkIO::AddTimeStep(int t, double* inout) const { + int num_features = NumFeatures(); + if (int_mode_) { + const inT8* line = i_[t]; + for (int i = 0; i < num_features; ++i) { + inout[i] += static_cast(line[i]) / MAX_INT8; + } + } else { + const float* line = f_[t]; + for (int i = 0; i < num_features; ++i) { + inout[i] += line[i]; + } + } +} + +// Adds part of a single timestep to floats. +void NetworkIO::AddTimeStepPart(int t, int offset, int num_features, + float* inout) const { + if (int_mode_) { + const inT8* line = i_[t] + offset; + for (int i = 0; i < num_features; ++i) { + inout[i] += static_cast(line[i]) / MAX_INT8; + } + } else { + const float* line = f_[t] + offset; + for (int i = 0; i < num_features; ++i) { + inout[i] += line[i]; + } + } +} + +// Writes a single timestep from floats in the range [-1, 1]. +void NetworkIO::WriteTimeStep(int t, const double* input) { + WriteTimeStepPart(t, 0, NumFeatures(), input); +} + +// Writes a single timestep from floats in the range [-1, 1] writing only +// num_features elements of input to (*this)[t], starting at offset. +void NetworkIO::WriteTimeStepPart(int t, int offset, int num_features, + const double* input) { + if (int_mode_) { + inT8* line = i_[t] + offset; + for (int i = 0; i < num_features; ++i) { + line[i] = ClipToRange(IntCastRounded(input[i] * MAX_INT8), + -MAX_INT8, MAX_INT8); + } + } else { + float* line = f_[t] + offset; + for (int i = 0; i < num_features; ++i) { + line[i] = static_cast(input[i]); + } + } +} + +// Maxpools a single time step from src. +void NetworkIO::MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t, + int* max_line) { + ASSERT_HOST(int_mode_ == src.int_mode_); + if (int_mode_) { + int dim = i_.dim2(); + inT8* dest_line = i_[dest_t]; + const inT8* src_line = src.i_[src_t]; + for (int i = 0; i < dim; ++i) { + if (dest_line[i] < src_line[i]) { + dest_line[i] = src_line[i]; + max_line[i] = src_t; + } + } + } else { + int dim = f_.dim2(); + float* dest_line = f_[dest_t]; + const float* src_line = src.f_[src_t]; + for (int i = 0; i < dim; ++i) { + if (dest_line[i] < src_line[i]) { + dest_line[i] = src_line[i]; + max_line[i] = src_t; + } + } + } +} + +// Runs maxpool backward, using maxes to index timesteps in *this. +void NetworkIO::MaxpoolBackward(const NetworkIO& fwd, + const GENERIC_2D_ARRAY& maxes) { + ASSERT_HOST(!int_mode_); + int width = fwd.Width(); + Zero(); + StrideMap::Index index(fwd.stride_map_); + do { + int t = index.t(); + const int* max_line = maxes[t]; + const float* fwd_line = fwd.f_[t]; + int num_features = fwd.f_.dim2(); + for (int i = 0; i < num_features; ++i) { + f_[max_line[i]][i] = fwd_line[i]; + } + } while (index.Increment()); +} + +// Returns the min over time of the maxes over features of the outputs. +float NetworkIO::MinOfMaxes() const { + float min_max = 0.0f; + int width = Width(); + int num_features = NumFeatures(); + for (int t = 0; t < width; ++t) { + float max_value = -MAX_FLOAT32; + if (int_mode_) { + const inT8* column = i_[t]; + for (int i = 0; i < num_features; ++i) { + if (column[i] > max_value) max_value = column[i]; + } + } else { + const float* column = f_[t]; + for (int i = 0; i < num_features; ++i) { + if (column[i] > max_value) max_value = column[i]; + } + } + if (t == 0 || max_value < min_max) min_max = max_value; + } + return min_max; +} + +// Computes combined results for a combiner that chooses between an existing +// input and itself, with an additional output to indicate the choice. +void NetworkIO::CombineOutputs(const NetworkIO& base_output, + const NetworkIO& combiner_output) { + int no = base_output.NumFeatures(); + ASSERT_HOST(combiner_output.NumFeatures() == no + 1); + Resize(base_output, no); + int width = Width(); + if (int_mode_) { + // Number of outputs from base and final result. + for (int t = 0; t < width; ++t) { + inT8* out_line = i_[t]; + const inT8* base_line = base_output.i_[t]; + const inT8* comb_line = combiner_output.i_[t]; + float base_weight = static_cast(comb_line[no]) / MAX_INT8; + float boost_weight = 1.0f - base_weight; + for (int i = 0; i < no; ++i) { + out_line[i] = IntCastRounded(base_line[i] * base_weight + + comb_line[i] * boost_weight); + } + } + } else { + for (int t = 0; t < width; ++t) { + float* out_line = f_[t]; + const float* base_line = base_output.f_[t]; + const float* comb_line = combiner_output.f_[t]; + float base_weight = comb_line[no]; + float boost_weight = 1.0f - base_weight; + for (int i = 0; i < no; ++i) { + out_line[i] = base_line[i] * base_weight + comb_line[i] * boost_weight; + } + } + } +} + +// Computes deltas for a combiner that chooses between 2 sets of inputs. +void NetworkIO::ComputeCombinerDeltas(const NetworkIO& fwd_deltas, + const NetworkIO& base_output) { + ASSERT_HOST(!int_mode_); + // Compute the deltas for the combiner. + int width = Width(); + int no = NumFeatures() - 1; + ASSERT_HOST(fwd_deltas.NumFeatures() == no); + ASSERT_HOST(base_output.NumFeatures() == no); + // Number of outputs from base and final result. + for (int t = 0; t < width; ++t) { + const float* delta_line = fwd_deltas.f_[t]; + const float* base_line = base_output.f_[t]; + float* comb_line = f_[t]; + float base_weight = comb_line[no]; + float boost_weight = 1.0f - base_weight; + float max_base_delta = 0.0; + for (int i = 0; i < no; ++i) { + // What did the combiner actually produce? + float output = base_line[i] * base_weight + comb_line[i] * boost_weight; + // Reconstruct the target from the delta. + float comb_target = delta_line[i] + output; + comb_line[i] = comb_target - comb_line[i]; + float base_delta = fabs(comb_target - base_line[i]); + if (base_delta > max_base_delta) max_base_delta = base_delta; + } + if (max_base_delta >= 0.5) { + // The base network got it wrong. The combiner should output the right + // answer and 0 for the base network. + comb_line[no] = 0.0 - base_weight; + } else { + // The base network was right. The combiner should flag that. + for (int i = 0; i < no; ++i) { + // All other targets are 0. + if (comb_line[i] > 0.0) comb_line[i] -= 1.0; + } + comb_line[no] = 1.0 - base_weight; + } + } +} + +// Copies the array checking that the types match. +void NetworkIO::CopyAll(const NetworkIO& src) { + ASSERT_HOST(src.int_mode_ == int_mode_); + f_ = src.f_; +} + +// Checks that both are floats and adds the src array to *this. +void NetworkIO::AddAllToFloat(const NetworkIO& src) { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!src.int_mode_); + f_ += src.f_; +} + +// Subtracts the array from a float array. src must also be float. +void NetworkIO::SubtractAllFromFloat(const NetworkIO& src) { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!src.int_mode_); + f_ -= src.f_; +} + +// Copies src to *this, with maxabs normalization to match scale. +void NetworkIO::CopyWithNormalization(const NetworkIO& src, + const NetworkIO& scale) { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!src.int_mode_); + ASSERT_HOST(!scale.int_mode_); + float src_max = src.f_.MaxAbs(); + ASSERT_HOST(std::isfinite(src_max)); + float scale_max = scale.f_.MaxAbs(); + ASSERT_HOST(std::isfinite(scale_max)); + if (src_max > 0.0f) { + float factor = scale_max / src_max; + for (int t = 0; t < src.Width(); ++t) { + const float* src_ptr = src.f_[t]; + float* dest_ptr = f_[t]; + for (int i = 0; i < src.f_.dim2(); ++i) dest_ptr[i] = src_ptr[i] * factor; + } + } else { + f_.Clear(); + } +} + +// Copies src to *this with independent reversal of the y dimension. +void NetworkIO::CopyWithYReversal(const NetworkIO& src) { + int num_features = src.NumFeatures(); + Resize(src, num_features); + StrideMap::Index b_index(src.stride_map_); + do { + int width = b_index.MaxIndexOfDim(FD_WIDTH) + 1; + StrideMap::Index fwd_index(b_index); + StrideMap::Index rev_index(b_index); + rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_HEIGHT), FD_HEIGHT); + do { + int fwd_t = fwd_index.t(); + int rev_t = rev_index.t(); + for (int x = 0; x < width; ++x) CopyTimeStepFrom(rev_t++, src, fwd_t++); + } while (fwd_index.AddOffset(1, FD_HEIGHT) && + rev_index.AddOffset(-1, FD_HEIGHT)); + } while (b_index.AddOffset(1, FD_BATCH)); +} + +// Copies src to *this with independent reversal of the x dimension. +void NetworkIO::CopyWithXReversal(const NetworkIO& src) { + int num_features = src.NumFeatures(); + Resize(src, num_features); + StrideMap::Index b_index(src.stride_map_); + do { + StrideMap::Index y_index(b_index); + do { + StrideMap::Index fwd_index(y_index); + StrideMap::Index rev_index(y_index); + rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_WIDTH), FD_WIDTH); + do { + CopyTimeStepFrom(rev_index.t(), src, fwd_index.t()); + } while (fwd_index.AddOffset(1, FD_WIDTH) && + rev_index.AddOffset(-1, FD_WIDTH)); + } while (y_index.AddOffset(1, FD_HEIGHT)); + } while (b_index.AddOffset(1, FD_BATCH)); +} + +// Copies src to *this with independent transpose of the x and y dimensions. +void NetworkIO::CopyWithXYTranspose(const NetworkIO& src) { + int num_features = src.NumFeatures(); + stride_map_ = src.stride_map_; + stride_map_.TransposeXY(); + ResizeToMap(src.int_mode(), stride_map_, num_features); + StrideMap::Index src_b_index(src.stride_map_); + StrideMap::Index dest_b_index(stride_map_); + do { + StrideMap::Index src_y_index(src_b_index); + StrideMap::Index dest_x_index(dest_b_index); + do { + StrideMap::Index src_x_index(src_y_index); + StrideMap::Index dest_y_index(dest_x_index); + do { + CopyTimeStepFrom(dest_y_index.t(), src, src_x_index.t()); + } while (src_x_index.AddOffset(1, FD_WIDTH) && + dest_y_index.AddOffset(1, FD_HEIGHT)); + } while (src_y_index.AddOffset(1, FD_HEIGHT) && + dest_x_index.AddOffset(1, FD_WIDTH)); + } while (src_b_index.AddOffset(1, FD_BATCH) && + dest_b_index.AddOffset(1, FD_BATCH)); +} + +// Copies src to *this, at the given feature_offset, returning the total +// feature offset after the copy. Multiple calls will stack outputs from +// multiple sources in feature space. +int NetworkIO::CopyPacking(const NetworkIO& src, int feature_offset) { + ASSERT_HOST(int_mode_ == src.int_mode_); + int width = src.Width(); + ASSERT_HOST(width <= Width()); + int num_features = src.NumFeatures(); + ASSERT_HOST(num_features + feature_offset <= NumFeatures()); + if (int_mode_) { + for (int t = 0; t < width; ++t) { + memcpy(i_[t] + feature_offset, src.i_[t], + num_features * sizeof(i_[t][0])); + } + for (int t = width; t < i_.dim1(); ++t) { + memset(i_[t], 0, num_features * sizeof(i_[t][0])); + } + } else { + for (int t = 0; t < width; ++t) { + memcpy(f_[t] + feature_offset, src.f_[t], + num_features * sizeof(f_[t][0])); + } + for (int t = width; t < f_.dim1(); ++t) { + memset(f_[t], 0, num_features * sizeof(f_[t][0])); + } + } + return num_features + feature_offset; +} + +// Opposite of CopyPacking, fills *this with a part of src, starting at +// feature_offset, and picking num_features. +void NetworkIO::CopyUnpacking(const NetworkIO& src, int feature_offset, + int num_features) { + Resize(src, num_features); + int width = src.Width(); + ASSERT_HOST(num_features + feature_offset <= src.NumFeatures()); + if (int_mode_) { + for (int t = 0; t < width; ++t) { + memcpy(i_[t], src.i_[t] + feature_offset, + num_features * sizeof(i_[t][0])); + } + } else { + for (int t = 0; t < width; ++t) { + memcpy(f_[t], src.f_[t] + feature_offset, + num_features * sizeof(f_[t][0])); + } + } +} + +// Transposes the float part of *this into dest. +void NetworkIO::Transpose(TransposedArray* dest) const { + int width = Width(); + dest->ResizeNoInit(NumFeatures(), width); + for (int t = 0; t < width; ++t) dest->WriteStrided(t, f_[t]); +} + +// Clips the content of a single time-step to +/-range. +void NetworkIO::ClipVector(int t, float range) { + ASSERT_HOST(!int_mode_); + float* v = f_[t]; + int dim = f_.dim2(); + for (int i = 0; i < dim; ++i) + v[i] = ClipToRange(v[i], -range, range); +} + +} // namespace tesseract. diff --git a/lstm/networkio.h b/lstm/networkio.h new file mode 100644 index 0000000000..5082269917 --- /dev/null +++ b/lstm/networkio.h @@ -0,0 +1,341 @@ +/////////////////////////////////////////////////////////////////////// +// File: networkio.h +// Description: Network input/output data, allowing float/int implementations. +// Author: Ray Smith +// Created: Tue Jun 17 08:43:11 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_NETWORKIO_H_ +#define TESSERACT_LSTM_NETWORKIO_H_ + +#include +#include +#include + +#include "genericvector.h" +#include "helpers.h" +#include "static_shape.h" +#include "stridemap.h" +#include "weightmatrix.h" + +struct Pix; + +namespace tesseract { + +// Class to contain all the input/output of a network, allowing for fixed or +// variable-strided 2d to 1d mapping, and float or inT8 values. Provides +// enough calculating functions to hide the detail of the implementation. +class NetworkIO { + public: + NetworkIO() : int_mode_(false) {} + // Resizes the array (and stride), avoiding realloc if possible, to the given + // size from various size specs: + // Same stride size, but given number of features. + void Resize(const NetworkIO& src, int num_features) { + ResizeToMap(src.int_mode(), src.stride_map(), num_features); + } + // Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim. + void Resize2d(bool int_mode, int width, int num_features); + // Resizes forcing a float representation with the stridemap of src and the + // given number of features. + void ResizeFloat(const NetworkIO& src, int num_features) { + ResizeToMap(false, src.stride_map(), num_features); + } + // Resizes to a specific stride_map. + void ResizeToMap(bool int_mode, const StrideMap& stride_map, + int num_features); + // Shrinks image size by x_scale,y_scale, and use given number of features. + void ResizeScaled(const NetworkIO& src, int x_scale, int y_scale, + int num_features); + // Resizes to just 1 x-coord, whatever the input. + void ResizeXTo1(const NetworkIO& src, int num_features); + // Initialize all the array to zero. + void Zero(); + // Initializes to zero all elements of the array that do not correspond to + // valid image positions. (If a batch of different-sized images are packed + // together, then there will be padding pixels.) + void ZeroInvalidElements(); + // Sets up the array from the given image, using the currently set int_mode_. + // If the image width doesn't match the shape, the image is truncated or + // padded with noise to match. + void FromPix(const StaticShape& shape, const Pix* pix, TRand* randomizer); + // Sets up the array from the given set of images, using the currently set + // int_mode_. If the image width doesn't match the shape, the images are + // truncated or padded with noise to match. + void FromPixes(const StaticShape& shape, const std::vector& pixes, + TRand* randomizer); + // Copies the given pix to *this at the given batch index, stretching and + // clipping the pixel values so that [black, black + 2*contrast] maps to the + // dynamic range of *this, ie [-1,1] for a float and (-127,127) for int. + // This is a 2-d operation in the sense that the output depth is the number + // of input channels, the height is the height of the image, and the width + // is the width of the image, or truncated/padded with noise if the width + // is a fixed size. + void Copy2DImage(int batch, Pix* pix, float black, float contrast, + TRand* randomizer); + // Copies the given pix to *this at the given batch index, as Copy2DImage + // above, except that the output depth is the height of the input image, the + // output height is 1, and the output width as for Copy2DImage. + // The image is thus treated as a 1-d set of vertical pixel strips. + void Copy1DGreyImage(int batch, Pix* pix, float black, float contrast, + TRand* randomizer); + // Helper stores the pixel value in i_ or f_ according to int_mode_. + // t: is the index from the StrideMap corresponding to the current + // [batch,y,x] position + // f: is the index into the depth/channel + // pixel: the value of the pixel from the image (in one channel) + // black: the pixel value to map to the lowest of the range of *this + // contrast: the range of pixel values to stretch to half the range of *this. + void SetPixel(int t, int f, int pixel, float black, float contrast); + // Converts the array to a Pix. Must be pixDestroyed after use. + Pix* ToPix() const; + // Prints the first and last num timesteps of the array for each feature. + void Print(int num) const; + + // Returns the timestep width. + int Width() const { + return int_mode_ ? i_.dim1() : f_.dim1(); + } + // Returns the number of features. + int NumFeatures() const { + return int_mode_ ? i_.dim2() : f_.dim2(); + } + // Accessor to a timestep of the float matrix. + float* f(int t) { + ASSERT_HOST(!int_mode_); + return f_[t]; + } + const float* f(int t) const { + ASSERT_HOST(!int_mode_); + return f_[t]; + } + const inT8* i(int t) const { + ASSERT_HOST(int_mode_); + return i_[t]; + } + bool int_mode() const { + return int_mode_; + } + void set_int_mode(bool is_quantized) { + int_mode_ = is_quantized; + } + const StrideMap& stride_map() const { + return stride_map_; + } + void set_stride_map(const StrideMap& map) { + stride_map_ = map; + } + const GENERIC_2D_ARRAY& float_array() const { return f_; } + GENERIC_2D_ARRAY* mutable_float_array() { return &f_; } + + // Copies a single time step from src. + void CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t); + // Copies a part of single time step from src. + void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, + const NetworkIO& src, int src_t, int src_offset); + // Zeroes a single time step. + void ZeroTimeStep(int t) { ZeroTimeStepGeneral(t, 0, NumFeatures()); } + void ZeroTimeStepGeneral(int t, int offset, int num_features); + // Sets the given range to random values. + void Randomize(int t, int offset, int num_features, TRand* randomizer); + + // Helper returns the label and score of the best choice over a range. + int BestChoiceOverRange(int t_start, int t_end, int not_this, int null_ch, + float* rating, float* certainty) const; + // Helper returns the rating and certainty of the choice over a range in t. + void ScoresOverRange(int t_start, int t_end, int choice, int null_ch, + float* rating, float* certainty) const; + // Returns the index (label) of the best value at the given timestep, + // and if not null, sets the score to the log of the corresponding value. + int BestLabel(int t, float* score) const { + return BestLabel(t, -1, -1, score); + } + // Returns the index (label) of the best value at the given timestep, + // excluding not_this and not_that, and if not null, sets the score to the + // log of the corresponding value. + int BestLabel(int t, int not_this, int not_that, float* score) const; + // Returns the best start position out of range (into which both start and end + // must fit) to obtain the highest cumulative score for the given labels. + int PositionOfBestMatch(const GenericVector& labels, int start, + int end) const; + // Returns the cumulative score of the given labels starting at start, and + // using one label per time-step. + double ScoreOfLabels(const GenericVector& labels, int start) const; + // Helper function sets all the outputs for a single timestep, such that + // label has value ok_score, and the other labels share 1 - ok_score. + // Assumes float mode. + void SetActivations(int t, int label, float ok_score); + // Modifies the values, only if needed, so that the given label is + // the winner at the given time step t. + // Assumes float mode. + void EnsureBestLabel(int t, int label); + // Helper function converts prob to certainty taking the minimum into account. + static float ProbToCertainty(float prob); + // Returns true if there is any bad value that is suspiciously like a GT + // error. Assuming that *this is the difference(gradient) between target + // and forward output, returns true if there is a large negative value + // (correcting a very confident output) for which there is no corresponding + // positive value in an adjacent timestep for the same feature index. This + // allows the box-truthed samples to make fine adjustments to position while + // stopping other disagreements of confident output with ground truth. + bool AnySuspiciousTruth(float confidence_thr) const; + + // Reads a single timestep to floats in the range [-1, 1]. + void ReadTimeStep(int t, double* output) const; + // Adds a single timestep to floats. + void AddTimeStep(int t, double* inout) const; + // Adds part of a single timestep to floats. + void AddTimeStepPart(int t, int offset, int num_features, float* inout) const; + // Writes a single timestep from floats in the range [-1, 1]. + void WriteTimeStep(int t, const double* input); + // Writes a single timestep from floats in the range [-1, 1] writing only + // num_features elements of input to (*this)[t], starting at offset. + void WriteTimeStepPart(int t, int offset, int num_features, + const double* input); + // Maxpools a single time step from src. + void MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t, + int* max_line); + // Runs maxpool backward, using maxes to index timesteps in *this. + void MaxpoolBackward(const NetworkIO& fwd, + const GENERIC_2D_ARRAY& maxes); + // Returns the min over time of the maxes over features of the outputs. + float MinOfMaxes() const; + // Returns the min over time. + float Max() const { return int_mode_ ? i_.Max() : f_.Max(); } + // Computes combined results for a combiner that chooses between an existing + // input and itself, with an additional output to indicate the choice. + void CombineOutputs(const NetworkIO& base_output, + const NetworkIO& combiner_output); + // Computes deltas for a combiner that chooses between 2 sets of inputs. + void ComputeCombinerDeltas(const NetworkIO& fwd_deltas, + const NetworkIO& base_output); + + // Copies the array checking that the types match. + void CopyAll(const NetworkIO& src); + // Adds the array to a float array, with scaling to [-1, 1] if the src is int. + void AddAllToFloat(const NetworkIO& src); + // Subtracts the array from a float array. src must also be float. + void SubtractAllFromFloat(const NetworkIO& src); + + // Copies src to *this, with maxabs normalization to match scale. + void CopyWithNormalization(const NetworkIO& src, const NetworkIO& scale); + // Multiplies the float data by the given factor. + void ScaleFloatBy(float factor) { f_ *= factor; } + // Copies src to *this with independent reversal of the y dimension. + void CopyWithYReversal(const NetworkIO& src); + // Copies src to *this with independent reversal of the x dimension. + void CopyWithXReversal(const NetworkIO& src); + // Copies src to *this with independent transpose of the x and y dimensions. + void CopyWithXYTranspose(const NetworkIO& src); + // Copies src to *this, at the given feature_offset, returning the total + // feature offset after the copy. Multiple calls will stack outputs from + // multiple sources in feature space. + int CopyPacking(const NetworkIO& src, int feature_offset); + // Opposite of CopyPacking, fills *this with a part of src, starting at + // feature_offset, and picking num_features. Resizes *this to match. + void CopyUnpacking(const NetworkIO& src, int feature_offset, + int num_features); + // Transposes the float part of *this into dest. + void Transpose(TransposedArray* dest) const; + + // Clips the content of a single time-step to +/-range. + void ClipVector(int t, float range); + + // Applies Func to timestep t of *this (u) and multiplies the result by v + // component-wise, putting the product in *product. + // *this and v may be int or float, but must match. The outputs are double. + template + void FuncMultiply(const NetworkIO& v_io, int t, double* product) { + Func f; + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!v_io.int_mode_); + int dim = f_.dim2(); + if (int_mode_) { + const inT8* u = i_[t]; + const inT8* v = v_io.i_[t]; + for (int i = 0; i < dim; ++i) { + product[i] = f(u[i] / static_cast(MAX_INT8)) * v[i] / + static_cast(MAX_INT8); + } + } else { + const float* u = f_[t]; + const float* v = v_io.f_[t]; + for (int i = 0; i < dim; ++i) { + product[i] = f(u[i]) * v[i]; + } + } + } + // Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w, + // component-wise, putting the product in *product. + // All NetworkIOs are assumed to be float. + template + void FuncMultiply3(int u_t, const NetworkIO& v_io, int v_t, const double* w, + double* product) const { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!v_io.int_mode_); + Func f; + const float* u = f_[u_t]; + const float* v = v_io.f_[v_t]; + int dim = f_.dim2(); + for (int i = 0; i < dim; ++i) { + product[i] = f(u[i]) * v[i] * w[i]; + } + } + // Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w, + // component-wise, adding the product to *product. + // All NetworkIOs are assumed to be float. + template + void FuncMultiply3Add(const NetworkIO& v_io, int t, const double* w, + double* product) const { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!v_io.int_mode_); + Func f; + const float* u = f_[t]; + const float* v = v_io.f_[t]; + int dim = f_.dim2(); + for (int i = 0; i < dim; ++i) { + product[i] += f(u[i]) * v[i] * w[i]; + } + } + // Applies Func1 to *this (u), Func2 to v, and multiplies the result by w, + // component-wise, putting the product in product, all at timestep t, except + // w, which is a simple array. All NetworkIOs are assumed to be float. + template + void Func2Multiply3(const NetworkIO& v_io, int t, const double* w, + double* product) const { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(!v_io.int_mode_); + Func1 f; + Func2 g; + const float* u = f_[t]; + const float* v = v_io.f_[t]; + int dim = f_.dim2(); + for (int i = 0; i < dim; ++i) { + product[i] = f(u[i]) * g(v[i]) * w[i]; + } + } + + private: + // Choice of float vs 8 bit int for data. + GENERIC_2D_ARRAY f_; + GENERIC_2D_ARRAY i_; + // Which of f_ and i_ are we actually using. + bool int_mode_; + // Stride for 2d input data. + StrideMap stride_map_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_NETWORKIO_H_ diff --git a/lstm/networkscratch.h b/lstm/networkscratch.h new file mode 100644 index 0000000000..2818550684 --- /dev/null +++ b/lstm/networkscratch.h @@ -0,0 +1,257 @@ +/////////////////////////////////////////////////////////////////////// +// File: networkscratch.h +// Description: Scratch space for Network layers that hides distinction +// between float/int implementations. +// Author: Ray Smith +// Created: Thu Jun 19 10:50:29 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_ +#define TESSERACT_LSTM_NETWORKSCRATCH_H_ + +#include "genericvector.h" +#include "matrix.h" +#include "networkio.h" +#include "svutil.h" +#include "tprintf.h" + +namespace tesseract { + +// Generic scratch space for network layers. Provides NetworkIO that can store +// a complete set (over time) of intermediates, and GenericVector +// scratch space that auto-frees after use. The aim here is to provide a set +// of temporary buffers to network layers that can be reused between layers +// and don't have to be reallocated on each call. +class NetworkScratch { + public: + NetworkScratch() : int_mode_(false) {} + ~NetworkScratch() {} + + // Sets the network representation. If the representation is integer, then + // default (integer) NetworkIOs are separated from the always-float variety. + // This saves memory by having separate int-specific and float-specific + // stacks. If the network representation is float, then all NetworkIOs go + // to the float stack. + void set_int_mode(bool int_mode) { + int_mode_ = int_mode; + } + + // Class that acts like a NetworkIO (by having an implicit cast operator), + // yet actually holds a pointer to NetworkIOs in the source NetworkScratch, + // and knows how to unstack the borrowed pointers on destruction. + class IO { + public: + // The NetworkIO should be sized after construction. + IO(const NetworkIO& src, NetworkScratch* scratch) + : int_mode_(scratch->int_mode_ && src.int_mode()), + scratch_space_(scratch) { + network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow() + : scratch_space_->float_stack_.Borrow(); + } + // Default constructor for arrays. Use one of the Resize functions + // below to initialize and size. + IO() : int_mode_(false), network_io_(NULL), scratch_space_(NULL) {} + + ~IO() { + if (scratch_space_ == NULL) { + ASSERT_HOST(network_io_ == NULL); + } else if (int_mode_) { + scratch_space_->int_stack_.Return(network_io_); + } else { + scratch_space_->float_stack_.Return(network_io_); + } + } + // Resizes the array (and stride), avoiding realloc if possible, to the + // size from various size specs: + // Same time size, given number of features. + void Resize(const NetworkIO& src, int num_features, + NetworkScratch* scratch) { + if (scratch_space_ == NULL) { + int_mode_ = scratch->int_mode_ && src.int_mode(); + scratch_space_ = scratch; + network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow() + : scratch_space_->float_stack_.Borrow(); + } + network_io_->Resize(src, num_features); + } + // Resizes to a specific size as a temp buffer. No batches, no y-dim. + void Resize2d(bool int_mode, int width, int num_features, + NetworkScratch* scratch) { + if (scratch_space_ == NULL) { + int_mode_ = scratch->int_mode_ && int_mode; + scratch_space_ = scratch; + network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow() + : scratch_space_->float_stack_.Borrow(); + } + network_io_->Resize2d(int_mode, width, num_features); + } + // Resize forcing a float representation with the width of src and the given + // number of features. + void ResizeFloat(const NetworkIO& src, int num_features, + NetworkScratch* scratch) { + if (scratch_space_ == NULL) { + int_mode_ = false; + scratch_space_ = scratch; + network_io_ = scratch_space_->float_stack_.Borrow(); + } + network_io_->ResizeFloat(src, num_features); + } + + // Returns a ref to a NetworkIO that enables *this to be treated as if + // it were just a NetworkIO*. + NetworkIO& operator*() { + return *network_io_; + } + NetworkIO* operator->() { + return network_io_; + } + operator NetworkIO*() { + return network_io_; + } + + private: + // True if this is from the always-float stack, otherwise the default stack. + bool int_mode_; + // The NetworkIO that we have borrowed from the scratch_space_. + NetworkIO* network_io_; + // The source scratch_space_. Borrowed pointer, used to free the + // NetworkIO. Don't delete! + NetworkScratch* scratch_space_; + }; // class IO. + + // Class that acts like a fixed array of float, yet actually uses space + // from a GenericVector in the source NetworkScratch, and knows how + // to unstack the borrowed vector on destruction. + class FloatVec { + public: + // The array will have size elements in it, uninitialized. + FloatVec(int size, NetworkScratch* scratch) + : vec_(NULL), scratch_space_(scratch) { + Init(size, scratch); + } + // Default constructor is for arrays. Use Init to setup. + FloatVec() : vec_(NULL), data_(NULL), scratch_space_(NULL) {} + ~FloatVec() { + if (scratch_space_ != NULL) scratch_space_->vec_stack_.Return(vec_); + } + + void Init(int size, NetworkScratch* scratch) { + if (scratch_space_ != NULL && vec_ != NULL) + scratch_space_->vec_stack_.Return(vec_); + scratch_space_ = scratch; + vec_ = scratch_space_->vec_stack_.Borrow(); + vec_->resize_no_init(size); + data_ = &(*vec_)[0]; + } + + // Use the cast operator instead of operator[] so the FloatVec can be used + // as a double* argument to a function call. + operator double*() const { return data_; } + double* get() { return data_; } + + private: + // Vector borrowed from the scratch space. Use Return to free it. + GenericVector* vec_; + // Short-cut pointer to the underlying array. + double* data_; + // The source scratch_space_. Borrowed pointer, used to free the + // vector. Don't delete! + NetworkScratch* scratch_space_; + }; // class FloatVec + + // Class that acts like a 2-D array of double, yet actually uses space + // from the source NetworkScratch, and knows how to unstack the borrowed + // array on destruction. + class GradientStore { + public: + // Default constructor is for arrays. Use Init to setup. + GradientStore() : array_(NULL), scratch_space_(NULL) {} + ~GradientStore() { + if (scratch_space_ != NULL) scratch_space_->array_stack_.Return(array_); + } + + void Init(int size1, int size2, NetworkScratch* scratch) { + if (scratch_space_ != NULL && array_ != NULL) + scratch_space_->array_stack_.Return(array_); + scratch_space_ = scratch; + array_ = scratch_space_->array_stack_.Borrow(); + array_->Resize(size1, size2, 0.0); + } + + // Accessors to get to the underlying TransposedArray. + TransposedArray* get() const { return array_; } + const TransposedArray& operator*() const { return *array_; } + + private: + // Array borrowed from the scratch space. Use Return to free it. + TransposedArray* array_; + // The source scratch_space_. Borrowed pointer, used to free the + // vector. Don't delete! + NetworkScratch* scratch_space_; + }; // class GradientStore + + // Class that does the work of holding a stack of objects, a stack pointer + // and a vector of in-use flags, so objects can be returned out of order. + // It is safe to attempt to Borrow/Return in multiple threads. + template class Stack { + public: + Stack() : stack_top_(0) { + } + + // Lends out the next free item, creating one if none available, sets + // the used flags and increments the stack top. + T* Borrow() { + SVAutoLock lock(&mutex_); + if (stack_top_ == stack_.size()) { + stack_.push_back(new T); + flags_.push_back(false); + } + flags_[stack_top_] = true; + return stack_[stack_top_++]; + } + // Takes back the given item, and marks it free. Item does not have to be + // the most recently lent out, but free slots don't get re-used until the + // blocking item is returned. The assumption is that there will only be + // small, temporary variations from true stack use. (Determined by the order + // of destructors within a local scope.) + void Return(T* item) { + SVAutoLock lock(&mutex_); + // Linear search will do. + int index = stack_top_ - 1; + while (index >= 0 && stack_[index] != item) --index; + if (index >= 0) flags_[index] = false; + while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_; + } + + private: + PointerVector stack_; + GenericVector flags_; + int stack_top_; + SVMutex mutex_; + }; // class Stack. + + private: + // If true, the network weights are inT8, if false, float. + bool int_mode_; + // Stacks of NetworkIO and GenericVector. Once allocated, they are not + // deleted until the NetworkScratch is deleted. + Stack int_stack_; + Stack float_stack_; + Stack > vec_stack_; + Stack array_stack_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_NETWORKSCRATCH_H_ diff --git a/lstm/parallel.cpp b/lstm/parallel.cpp new file mode 100644 index 0000000000..adb9d84f15 --- /dev/null +++ b/lstm/parallel.cpp @@ -0,0 +1,180 @@ +///////////////////////////////////////////////////////////////////////// +// File: parallel.cpp +// Description: Runs networks in parallel on the same input. +// Author: Ray Smith +// Created: Thu May 02 08:06:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "parallel.h" + +#include + +#include "functions.h" // For conditional undef of _OPENMP. +#include "networkscratch.h" + +namespace tesseract { + +// ni_ and no_ will be set by AddToStack. +Parallel::Parallel(const STRING& name, NetworkType type) : Plumbing(name) { + type_ = type; +} + +Parallel::~Parallel() { +} + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape Parallel::OutputShape(const StaticShape& input_shape) const { + StaticShape result = stack_[0]->OutputShape(input_shape); + int stack_size = stack_.size(); + for (int i = 1; i < stack_size; ++i) { + StaticShape shape = stack_[i]->OutputShape(input_shape); + result.set_depth(result.depth() + shape.depth()); + } + return result; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Parallel::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + bool parallel_debug = false; + // If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair, + // or a 2-d LSTM quad, do debug locally, and don't pass the flag on. + if (debug && type_ != NT_PARALLEL) { + parallel_debug = true; + debug = false; + } + int stack_size = stack_.size(); + if (type_ == NT_PAR_2D_LSTM) { + // Special case, run parallel in parallel. + GenericVector results; + results.init_to_size(stack_size, NetworkScratch::IO()); + for (int i = 0; i < stack_size; ++i) { + results[i].Resize(input, stack_[i]->NumOutputs(), scratch); + } +#ifdef _OPENMP +#pragma omp parallel for num_threads(stack_size) +#endif + for (int i = 0; i < stack_size; ++i) { + stack_[i]->Forward(debug, input, NULL, scratch, results[i]); + } + // Now pack all the results (serially) into the output. + int out_offset = 0; + output->Resize(*results[0], NumOutputs()); + for (int i = 0; i < stack_size; ++i) { + out_offset = output->CopyPacking(*results[i], out_offset); + } + } else { + // Revolving intermediate result. + NetworkScratch::IO result(input, scratch); + // Source for divided replicated. + NetworkScratch::IO source_part; + TransposedArray* src_transpose = NULL; + if (training() && type_ == NT_REPLICATED) { + // Make a transposed copy of the input. + input.Transpose(&transposed_input_); + src_transpose = &transposed_input_; + } + // Run each network, putting the outputs into result. + int input_offset = 0; + int out_offset = 0; + for (int i = 0; i < stack_size; ++i) { + stack_[i]->Forward(debug, input, src_transpose, scratch, result); + // All networks must have the same output width + if (i == 0) { + output->Resize(*result, NumOutputs()); + } else { + ASSERT_HOST(result->Width() == output->Width()); + } + out_offset = output->CopyPacking(*result, out_offset); + } + } + if (parallel_debug) { + DisplayForward(*output); + } +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Parallel::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + // If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair, + // or a 2-d LSTM quad, do debug locally, and don't pass the flag on. + if (debug && type_ != NT_PARALLEL) { + DisplayBackward(fwd_deltas); + debug = false; + } + int stack_size = stack_.size(); + if (type_ == NT_PAR_2D_LSTM) { + // Special case, run parallel in parallel. + GenericVector in_deltas, out_deltas; + in_deltas.init_to_size(stack_size, NetworkScratch::IO()); + out_deltas.init_to_size(stack_size, NetworkScratch::IO()); + // Split the forward deltas for each stack element. + int feature_offset = 0; + int out_offset = 0; + for (int i = 0; i < stack_.size(); ++i) { + int num_features = stack_[i]->NumOutputs(); + in_deltas[i].Resize(fwd_deltas, num_features, scratch); + out_deltas[i].Resize(fwd_deltas, stack_[i]->NumInputs(), scratch); + in_deltas[i]->CopyUnpacking(fwd_deltas, feature_offset, num_features); + feature_offset += num_features; + } +#ifdef _OPENMP +#pragma omp parallel for num_threads(stack_size) +#endif + for (int i = 0; i < stack_size; ++i) { + stack_[i]->Backward(debug, *in_deltas[i], scratch, + i == 0 ? back_deltas : out_deltas[i]); + } + if (needs_to_backprop_) { + for (int i = 1; i < stack_size; ++i) { + back_deltas->AddAllToFloat(*out_deltas[i]); + } + } + } else { + // Revolving partial deltas. + NetworkScratch::IO in_deltas(fwd_deltas, scratch); + // The sum of deltas from different sources, which will eventually go into + // back_deltas. + NetworkScratch::IO out_deltas; + int feature_offset = 0; + int out_offset = 0; + for (int i = 0; i < stack_.size(); ++i) { + int num_features = stack_[i]->NumOutputs(); + in_deltas->CopyUnpacking(fwd_deltas, feature_offset, num_features); + feature_offset += num_features; + if (stack_[i]->Backward(debug, *in_deltas, scratch, back_deltas)) { + if (i == 0) { + out_deltas.ResizeFloat(*back_deltas, back_deltas->NumFeatures(), + scratch); + out_deltas->CopyAll(*back_deltas); + } else if (back_deltas->NumFeatures() == out_deltas->NumFeatures()) { + // Widths are allowed to be different going back, as we may have + // input nets, so only accumulate the deltas if the widths are the + // same. + out_deltas->AddAllToFloat(*back_deltas); + } + } + } + if (needs_to_backprop_) back_deltas->CopyAll(*out_deltas); + } + if (needs_to_backprop_) back_deltas->ScaleFloatBy(1.0f / stack_size); + return needs_to_backprop_; +} + +} // namespace tesseract. diff --git a/lstm/parallel.h b/lstm/parallel.h new file mode 100644 index 0000000000..ad290a7ec1 --- /dev/null +++ b/lstm/parallel.h @@ -0,0 +1,87 @@ +/////////////////////////////////////////////////////////////////////// +// File: parallel.h +// Description: Runs networks in parallel on the same input. +// Author: Ray Smith +// Created: Thu May 02 08:02:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_PARALLEL_H_ +#define TESSERACT_LSTM_PARALLEL_H_ + +#include "plumbing.h" + +namespace tesseract { + +// Runs multiple networks in parallel, interlacing their outputs. +class Parallel : public Plumbing { + public: + // ni_ and no_ will be set by AddToStack. + Parallel(const STRING& name, NetworkType type); + virtual ~Parallel(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec; + if (type_ == NT_PAR_2D_LSTM) { + // We have 4 LSTMs operating in parallel here, so the size of each is + // the number of outputs/4. + spec.add_str_int("L2xy", no_ / 4); + } else if (type_ == NT_PAR_RL_LSTM) { + // We have 2 LSTMs operating in parallel here, so the size of each is + // the number of outputs/2. + if (stack_[0]->type() == NT_LSTM_SUMMARY) + spec.add_str_int("Lbxs", no_ / 2); + else + spec.add_str_int("Lbx", no_ / 2); + } else { + if (type_ == NT_REPLICATED) { + spec.add_str_int("R", stack_.size()); + spec += "("; + spec += stack_[0]->spec(); + } else { + spec = "("; + for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec(); + } + spec += ")"; + } + return spec; + } + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + private: + // If *this is a NT_REPLICATED, then it feeds a replicated network with + // identical inputs, and it would be extremely wasteful for them to each + // calculate and store the same transpose of the inputs, so Parallel does it + // and passes a pointer to the replicated network, allowing it to use the + // transpose on the next call to Backward. + TransposedArray transposed_input_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_PARALLEL_H_ diff --git a/lstm/plumbing.cpp b/lstm/plumbing.cpp new file mode 100644 index 0000000000..01abdb91f9 --- /dev/null +++ b/lstm/plumbing.cpp @@ -0,0 +1,233 @@ +/////////////////////////////////////////////////////////////////////// +// File: plumbing.cpp +// Description: Base class for networks that organize other networks +// eg series or parallel. +// Author: Ray Smith +// Created: Mon May 12 08:17:34 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "plumbing.h" + +namespace tesseract { + +// ni_ and no_ will be set by AddToStack. +Plumbing::Plumbing(const STRING& name) + : Network(NT_PARALLEL, name, 0, 0) { +} + +Plumbing::~Plumbing() { +} + +// Suspends/Enables training by setting the training_ flag. Serialize and +// DeSerialize only operate on the run-time data if state is false. +void Plumbing::SetEnableTraining(bool state) { + Network::SetEnableTraining(state); + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->SetEnableTraining(state); +} + +// Sets flags that control the action of the network. See NetworkFlags enum +// for bit values. +void Plumbing::SetNetworkFlags(uinT32 flags) { + Network::SetNetworkFlags(flags); + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->SetNetworkFlags(flags); +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +// Note that randomizer is a borrowed pointer that should outlive the network +// and should not be deleted by any of the networks. +// Returns the number of weights initialized. +int Plumbing::InitWeights(float range, TRand* randomizer) { + num_weights_ = 0; + for (int i = 0; i < stack_.size(); ++i) + num_weights_ += stack_[i]->InitWeights(range, randomizer); + return num_weights_; +} + +// Converts a float network to an int network. +void Plumbing::ConvertToInt() { + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->ConvertToInt(); +} + +// Provides a pointer to a TRand for any networks that care to use it. +// Note that randomizer is a borrowed pointer that should outlive the network +// and should not be deleted by any of the networks. +void Plumbing::SetRandomizer(TRand* randomizer) { + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->SetRandomizer(randomizer); +} + +// Adds the given network to the stack. +void Plumbing::AddToStack(Network* network) { + if (stack_.empty()) { + ni_ = network->NumInputs(); + no_ = network->NumOutputs(); + } else if (type_ == NT_SERIES) { + // ni is input of first, no output of last, others match output to input. + ASSERT_HOST(no_ == network->NumInputs()); + no_ = network->NumOutputs(); + } else { + // All parallel types. Output is sum of outputs, inputs all match. + ASSERT_HOST(ni_ == network->NumInputs()); + no_ += network->NumOutputs(); + } + stack_.push_back(network); +} + +// Sets needs_to_backprop_ to needs_backprop and calls on sub-network +// according to needs_backprop || any weights in this network. +bool Plumbing::SetupNeedsBackprop(bool needs_backprop) { + needs_to_backprop_ = needs_backprop; + bool retval = needs_backprop; + for (int i = 0; i < stack_.size(); ++i) { + if (stack_[i]->SetupNeedsBackprop(needs_backprop)) + retval = true; + } + return retval; +} + +// Returns an integer reduction factor that the network applies to the +// time sequence. Assumes that any 2-d is already eliminated. Used for +// scaling bounding boxes of truth data. +// WARNING: if GlobalMinimax is used to vary the scale, this will return +// the last used scale factor. Call it before any forward, and it will return +// the minimum scale factor of the paths through the GlobalMinimax. +int Plumbing::XScaleFactor() const { + return stack_[0]->XScaleFactor(); +} + +// Provides the (minimum) x scale factor to the network (of interest only to +// input units) so they can determine how to scale bounding boxes. +void Plumbing::CacheXScaleFactor(int factor) { + for (int i = 0; i < stack_.size(); ++i) { + stack_[i]->CacheXScaleFactor(factor); + } +} + +// Provides debug output on the weights. +void Plumbing::DebugWeights() { + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->DebugWeights(); +} + +// Returns a set of strings representing the layer-ids of all layers below. +void Plumbing::EnumerateLayers(const STRING* prefix, + GenericVector* layers) const { + for (int i = 0; i < stack_.size(); ++i) { + STRING layer_name; + if (prefix) layer_name = *prefix; + layer_name.add_str_int(":", i); + if (stack_[i]->IsPlumbingType()) { + Plumbing* plumbing = reinterpret_cast(stack_[i]); + plumbing->EnumerateLayers(&layer_name, layers); + } else { + layers->push_back(layer_name); + } + } +} + +// Returns a pointer to the network layer corresponding to the given id. +Network* Plumbing::GetLayer(const char* id) const { + char* next_id; + int index = strtol(id, &next_id, 10); + if (index < 0 || index >= stack_.size()) return NULL; + if (stack_[index]->IsPlumbingType()) { + Plumbing* plumbing = reinterpret_cast(stack_[index]); + ASSERT_HOST(*next_id == ':'); + return plumbing->GetLayer(next_id + 1); + } + return stack_[index]; +} + +// Returns a pointer to the learning rate for the given layer id. +float* Plumbing::LayerLearningRatePtr(const char* id) const { + char* next_id; + int index = strtol(id, &next_id, 10); + if (index < 0 || index >= stack_.size()) return NULL; + if (stack_[index]->IsPlumbingType()) { + Plumbing* plumbing = reinterpret_cast(stack_[index]); + ASSERT_HOST(*next_id == ':'); + return plumbing->LayerLearningRatePtr(next_id + 1); + } + if (index < 0 || index >= learning_rates_.size()) return NULL; + return &learning_rates_[index]; +} + +// Writes to the given file. Returns false in case of error. +bool Plumbing::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + inT32 size = stack_.size(); + // Can't use PointerVector::Serialize here as we need a special DeSerialize. + if (fp->FWrite(&size, sizeof(size), 1) != 1) return false; + for (int i = 0; i < size; ++i) + if (!stack_[i]->Serialize(fp)) return false; + if ((network_flags_ & NF_LAYER_SPECIFIC_LR) && + !learning_rates_.Serialize(fp)) { + return false; + } + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Plumbing::DeSerialize(bool swap, TFile* fp) { + stack_.truncate(0); + no_ = 0; // We will be modifying this as we AddToStack. + inT32 size; + if (fp->FRead(&size, sizeof(size), 1) != 1) return false; + for (int i = 0; i < size; ++i) { + Network* network = CreateFromFile(swap, fp); + if (network == NULL) return false; + AddToStack(network); + } + if ((network_flags_ & NF_LAYER_SPECIFIC_LR) && + !learning_rates_.DeSerialize(swap, fp)) { + return false; + } + return true; +} + +// Updates the weights using the given learning rate and momentum. +// num_samples is the quotient to be used in the adagrad computation iff +// use_ada_grad_ is true. +void Plumbing::Update(float learning_rate, float momentum, int num_samples) { + for (int i = 0; i < stack_.size(); ++i) { + if (network_flags_ & NF_LAYER_SPECIFIC_LR) { + if (i < learning_rates_.size()) + learning_rate = learning_rates_[i]; + else + learning_rates_.push_back(learning_rate); + } + if (stack_[i]->training()) + stack_[i]->Update(learning_rate, momentum, num_samples); + } +} + +// Sums the products of weight updates in *this and other, splitting into +// positive (same direction) in *same and negative (different direction) in +// *changed. +void Plumbing::CountAlternators(const Network& other, double* same, + double* changed) const { + ASSERT_HOST(other.type() == type_); + const Plumbing* plumbing = reinterpret_cast(&other); + ASSERT_HOST(plumbing->stack_.size() == stack_.size()); + for (int i = 0; i < stack_.size(); ++i) + stack_[i]->CountAlternators(*plumbing->stack_[i], same, changed); +} + +} // namespace tesseract. + diff --git a/lstm/plumbing.h b/lstm/plumbing.h new file mode 100644 index 0000000000..1a2185c333 --- /dev/null +++ b/lstm/plumbing.h @@ -0,0 +1,143 @@ +/////////////////////////////////////////////////////////////////////// +// File: plumbing.h +// Description: Base class for networks that organize other networks +// eg series or parallel. +// Author: Ray Smith +// Created: Mon May 12 08:11:36 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_PLUMBING_H_ +#define TESSERACT_LSTM_PLUMBING_H_ + +#include "genericvector.h" +#include "matrix.h" +#include "network.h" + +namespace tesseract { + +// Holds a collection of other networks and forwards calls to each of them. +class Plumbing : public Network { + public: + // ni_ and no_ will be set by AddToStack. + explicit Plumbing(const STRING& name); + virtual ~Plumbing(); + + // Returns the required shape input to the network. + virtual StaticShape InputShape() const { return stack_[0]->InputShape(); } + virtual STRING spec() const { + return "Sub-classes of Plumbing must implement spec()!"; + } + + // Returns true if the given type is derived from Plumbing, and thus contains + // multiple sub-networks that can have their own learning rate. + virtual bool IsPlumbingType() const { return true; } + + // Suspends/Enables training by setting the training_ flag. Serialize and + // DeSerialize only operate on the run-time data if state is false. + virtual void SetEnableTraining(bool state); + + // Sets flags that control the action of the network. See NetworkFlags enum + // for bit values. + virtual void SetNetworkFlags(uinT32 flags); + + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + // Note that randomizer is a borrowed pointer that should outlive the network + // and should not be deleted by any of the networks. + // Returns the number of weights initialized. + virtual int InitWeights(float range, TRand* randomizer); + + // Converts a float network to an int network. + virtual void ConvertToInt(); + + // Provides a pointer to a TRand for any networks that care to use it. + // Note that randomizer is a borrowed pointer that should outlive the network + // and should not be deleted by any of the networks. + virtual void SetRandomizer(TRand* randomizer); + + // Adds the given network to the stack. + virtual void AddToStack(Network* network); + + // Sets needs_to_backprop_ to needs_backprop and returns true if + // needs_backprop || any weights in this network so the next layer forward + // can be told to produce backprop for this layer if needed. + virtual bool SetupNeedsBackprop(bool needs_backprop); + + // Returns an integer reduction factor that the network applies to the + // time sequence. Assumes that any 2-d is already eliminated. Used for + // scaling bounding boxes of truth data. + // WARNING: if GlobalMinimax is used to vary the scale, this will return + // the last used scale factor. Call it before any forward, and it will return + // the minimum scale factor of the paths through the GlobalMinimax. + virtual int XScaleFactor() const; + + // Provides the (minimum) x scale factor to the network (of interest only to + // input units) so they can determine how to scale bounding boxes. + virtual void CacheXScaleFactor(int factor); + + // Provides debug output on the weights. + virtual void DebugWeights(); + + // Returns the current stack. + const PointerVector& stack() const { + return stack_; + } + // Returns a set of strings representing the layer-ids of all layers below. + void EnumerateLayers(const STRING* prefix, + GenericVector* layers) const; + // Returns a pointer to the network layer corresponding to the given id. + Network* GetLayer(const char* id) const; + // Returns the learning rate for a specific layer of the stack. + float LayerLearningRate(const char* id) const { + const float* lr_ptr = LayerLearningRatePtr(id); + ASSERT_HOST(lr_ptr != NULL); + return *lr_ptr; + } + // Scales the learning rate for a specific layer of the stack. + void ScaleLayerLearningRate(const char* id, double factor) { + float* lr_ptr = LayerLearningRatePtr(id); + ASSERT_HOST(lr_ptr != NULL); + *lr_ptr *= factor; + } + // Returns a pointer to the learning rate for the given layer id. + float* LayerLearningRatePtr(const char* id) const; + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Updates the weights using the given learning rate and momentum. + // num_samples is the quotient to be used in the adagrad computation iff + // use_ada_grad_ is true. + virtual void Update(float learning_rate, float momentum, int num_samples); + // Sums the products of weight updates in *this and other, splitting into + // positive (same direction) in *same and negative (different direction) in + // *changed. + virtual void CountAlternators(const Network& other, double* same, + double* changed) const; + + protected: + // The networks. + PointerVector stack_; + // Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR. + // One element for each element of stack_. + GenericVector learning_rates_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_PLUMBING_H_ + diff --git a/lstm/recodebeam.cpp b/lstm/recodebeam.cpp new file mode 100644 index 0000000000..1b28f16fff --- /dev/null +++ b/lstm/recodebeam.cpp @@ -0,0 +1,759 @@ +/////////////////////////////////////////////////////////////////////// +// File: recodebeam.cpp +// Description: Beam search to decode from the re-encoded CJK as a sequence of +// smaller numbers in place of a single large code. +// Author: Ray Smith +// Created: Fri Mar 13 09:39:01 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "recodebeam.h" +#include "networkio.h" +#include "pageres.h" +#include "unicharcompress.h" + +namespace tesseract { + +// Clipping value for certainty inside Tesseract. Reflects the minimum value +// of certainty that will be returned by ExtractBestPathAsUnicharIds. +// Supposedly on a uniform scale that can be compared across languages and +// engines. +const float RecodeBeamSearch::kMinCertainty = -20.0f; + +// The beam width at each code position. +const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = { + 5, 10, 16, 16, 16, 16, 16, 16, 16, 16, +}; + +// Borrows the pointer, which is expected to survive until *this is deleted. +RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress& recoder, + int null_char, bool simple_text, Dict* dict) + : recoder_(recoder), + dict_(dict), + space_delimited_(true), + is_simple_text_(simple_text), + null_char_(null_char) { + if (dict_ != NULL && !dict_->IsSpaceDelimitedLang()) space_delimited_ = false; +} + +// Decodes the set of network outputs, storing the lattice internally. +void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio, + double cert_offset, double worst_dict_cert, + const UNICHARSET* charset) { + beam_size_ = 0; + int width = output.Width(); + for (int t = 0; t < width; ++t) { + ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); + DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, + charset); + } +} +void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY& output, + double dict_ratio, double cert_offset, + double worst_dict_cert, + const UNICHARSET* charset) { + beam_size_ = 0; + int width = output.dim1(); + for (int t = 0; t < width; ++t) { + ComputeTopN(output[t], output.dim2(), kBeamWidths[0]); + DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset); + } +} + +// Returns the best path as labels/scores/xcoords similar to simple CTC. +void RecodeBeamSearch::ExtractBestPathAsLabels( + GenericVector* labels, GenericVector* xcoords) const { + labels->truncate(0); + xcoords->truncate(0); + GenericVector best_nodes; + ExtractBestPaths(&best_nodes, NULL); + // Now just run CTC on the best nodes. + int t = 0; + int width = best_nodes.size(); + while (t < width) { + int label = best_nodes[t]->code; + if (label != null_char_) { + labels->push_back(label); + xcoords->push_back(t); + } + while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) { + } + } + xcoords->push_back(width); +} + +// Returns the best path as unichar-ids/certs/ratings/xcoords skipping +// duplicates, nulls and intermediate parts. +void RecodeBeamSearch::ExtractBestPathAsUnicharIds( + bool debug, const UNICHARSET* unicharset, GenericVector* unichar_ids, + GenericVector* certs, GenericVector* ratings, + GenericVector* xcoords) const { + GenericVector best_nodes; + ExtractBestPaths(&best_nodes, NULL); + ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords); + if (debug) { + DebugPath(unicharset, best_nodes); + DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings, + *xcoords); + } +} + +// Returns the best path as a set of WERD_RES. +void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, + float scale_factor, bool debug, + const UNICHARSET* unicharset, + PointerVector* words) { + words->truncate(0); + GenericVector unichar_ids; + GenericVector certs; + GenericVector ratings; + GenericVector xcoords; + GenericVector best_nodes; + GenericVector second_nodes; + ExtractBestPaths(&best_nodes, &second_nodes); + if (debug) { + DebugPath(unicharset, best_nodes); + ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings, + &xcoords); + tprintf("\nSecond choice path:\n"); + DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, + xcoords); + } + ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords); + int num_ids = unichar_ids.size(); + if (debug) { + DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings, + xcoords); + } + // Convert labels to unichar-ids. + int word_end = 0; + float prev_space_cert = 0.0f; + for (int word_start = 0; word_start < num_ids; word_start = word_end) { + for (word_end = word_start + 1; word_end < num_ids; ++word_end) { + // A word is terminated when a space character or start_of_word flag is + // hit. We also want to force a separate word for every non + // space-delimited character when not in a dictionary context. + if (unichar_ids[word_end] == UNICHAR_SPACE) break; + int index = xcoords[word_end]; + if (best_nodes[index]->start_of_word) break; + if (best_nodes[index]->permuter == TOP_CHOICE_PERM && + (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) || + !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) + break; + } + float space_cert = 0.0f; + if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) + space_cert = certs[word_end]; + bool leading_space = + word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE; + // Create a WERD_RES for the output word. + WERD_RES* word_res = InitializeWord( + leading_space, line_box, word_start, word_end, + MIN(space_cert, prev_space_cert), unicharset, xcoords, scale_factor); + for (int i = word_start; i < word_end; ++i) { + BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST; + BLOB_CHOICE_IT bc_it(choices); + BLOB_CHOICE* choice = new BLOB_CHOICE( + unichar_ids[i], ratings[i], certs[i], -1, 1.0f, + static_cast(MAX_INT16), 0.0f, BCC_STATIC_CLASSIFIER); + int col = i - word_start; + choice->set_matrix_cell(col, col); + bc_it.add_after_then_move(choice); + word_res->ratings->put(col, col, choices); + } + int index = xcoords[word_end - 1]; + word_res->FakeWordFromRatings(best_nodes[index]->permuter); + words->push_back(word_res); + prev_space_cert = space_cert; + if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) + ++word_end; + } +} + +// Generates debug output of the content of the beams after a Decode. +void RecodeBeamSearch::DebugBeams(const UNICHARSET& unicharset) const { + for (int p = 0; p < beam_size_; ++p) { + // Print all the best scoring nodes for each unichar found. + tprintf("Position %d: Nondict beam\n", p); + DebugBeamPos(unicharset, beam_[p]->beams_[0]); + tprintf("Position %d: Dict beam\n", p); + DebugBeamPos(unicharset, beam_[p]->dawg_beams_[0]); + } +} + +// Generates debug output of the content of a single beam position. +void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset, + const RecodeHeap& heap) const { + GenericVector unichar_bests; + unichar_bests.init_to_size(unicharset.size(), NULL); + const RecodeNode* null_best = NULL; + int heap_size = heap.size(); + for (int i = 0; i < heap_size; ++i) { + const RecodeNode* node = &heap.get(i).data; + if (node->unichar_id == INVALID_UNICHAR_ID) { + if (null_best == NULL || null_best->score < node->score) null_best = node; + } else { + if (unichar_bests[node->unichar_id] == NULL || + unichar_bests[node->unichar_id]->score < node->score) { + unichar_bests[node->unichar_id] = node; + } + } + } + for (int u = 0; u < unichar_bests.size(); ++u) { + if (unichar_bests[u] != NULL) { + const RecodeNode& node = *unichar_bests[u]; + tprintf("label=%d, uid=%d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n", + node.code, node.unichar_id, + unicharset.debug_str(node.unichar_id).string(), node.score, + node.certainty, node.start_of_word, node.end_of_word, + node.permuter); + } + } + if (null_best != NULL) { + tprintf("null_char score=%g, c=%g, s=%d, e=%d, perm=%d\n", null_best->score, + null_best->certainty, null_best->start_of_word, + null_best->end_of_word, null_best->permuter); + } +} + +// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping +// duplicates, nulls and intermediate parts. +/* static */ +void RecodeBeamSearch::ExtractPathAsUnicharIds( + const GenericVector& best_nodes, + GenericVector* unichar_ids, GenericVector* certs, + GenericVector* ratings, GenericVector* xcoords) { + unichar_ids->truncate(0); + certs->truncate(0); + ratings->truncate(0); + xcoords->truncate(0); + // Backtrack extracting only valid, non-duplicate unichar-ids. + int t = 0; + int width = best_nodes.size(); + while (t < width) { + double certainty = 0.0; + double rating = 0.0; + while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) { + double cert = best_nodes[t++]->certainty; + if (cert < certainty) certainty = cert; + rating -= cert; + } + if (t < width) { + int unichar_id = best_nodes[t]->unichar_id; + unichar_ids->push_back(unichar_id); + xcoords->push_back(t); + do { + double cert = best_nodes[t++]->certainty; + // Special-case NO-PERM space to forget the certainty of the previous + // nulls. See long comment in ContinueContext. + if (cert < certainty || (unichar_id == UNICHAR_SPACE && + best_nodes[t - 1]->permuter == NO_PERM)) { + certainty = cert; + } + rating -= cert; + } while (t < width && best_nodes[t]->duplicate); + certs->push_back(certainty); + ratings->push_back(rating); + } else if (!certs->empty()) { + if (certainty < certs->back()) certs->back() = certainty; + ratings->back() += rating; + } + } + xcoords->push_back(width); +} + +// Sets up a word with the ratings matrix and fake blobs with boxes in the +// right places. +WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space, + const TBOX& line_box, int word_start, + int word_end, float space_certainty, + const UNICHARSET* unicharset, + const GenericVector& xcoords, + float scale_factor) { + // Make a fake blob for each non-zero label. + C_BLOB_LIST blobs; + C_BLOB_IT b_it(&blobs); + for (int i = word_start; i < word_end; ++i) { + int min_half_width = xcoords[i + 1] - xcoords[i]; + if (i > 0 && xcoords[i] - xcoords[i - 1] < min_half_width) + min_half_width = xcoords[i] - xcoords[i - 1]; + if (min_half_width < 1) min_half_width = 1; + // Make a fake blob. + TBOX box(xcoords[i] - min_half_width, 0, xcoords[i] + min_half_width, + line_box.height()); + box.scale(scale_factor); + box.move(ICOORD(line_box.left(), line_box.bottom())); + box.set_top(line_box.top()); + b_it.add_after_then_move(C_BLOB::FakeBlob(box)); + } + // Make a fake word from the blobs. + WERD* word = new WERD(&blobs, leading_space, NULL); + // Make a WERD_RES from the word. + WERD_RES* word_res = new WERD_RES(word); + word_res->uch_set = unicharset; + word_res->combination = true; // Give it ownership of the word. + word_res->space_certainty = space_certainty; + word_res->ratings = new MATRIX(word_end - word_start, 1); + return word_res; +} + +// Fills top_n_flags_ with bools that are true iff the corresponding output +// is one of the top_n. +void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs, + int top_n) { + top_n_flags_.init_to_size(num_outputs, false); + top_heap_.clear(); + for (int i = 0; i < num_outputs; ++i) { + if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) { + TopPair entry(outputs[i], i); + top_heap_.Push(&entry); + if (top_heap_.size() > top_n) top_heap_.Pop(&entry); + } + } + while (!top_heap_.empty()) { + TopPair entry; + top_heap_.Pop(&entry); + top_n_flags_[entry.data] = true; + } +} + +// Adds the computation for the current time-step to the beam. Call at each +// time-step in sequence from left to right. outputs is the activation vector +// for the current timestep. +void RecodeBeamSearch::DecodeStep(const float* outputs, int t, + double dict_ratio, double cert_offset, + double worst_dict_cert, + const UNICHARSET* charset) { + if (t == beam_.size()) beam_.push_back(new RecodeBeam); + RecodeBeam* step = beam_[t]; + beam_size_ = t + 1; + step->Clear(); + if (t == 0) { + // The first step can only use singles and initials. + ContinueContext(NULL, 0, outputs, false, true, dict_ratio, cert_offset, + worst_dict_cert, step); + if (dict_ != NULL) + ContinueContext(NULL, 0, outputs, true, true, dict_ratio, cert_offset, + worst_dict_cert, step); + } else { + RecodeBeam* prev = beam_[t - 1]; + if (charset != NULL) { + for (int i = prev->dawg_beams_[0].size() - 1; i >= 0; --i) { + GenericVector path; + ExtractPath(&prev->dawg_beams_[0].get(i).data, &path); + tprintf("Step %d: Dawg beam %d:\n", t, i); + DebugPath(charset, path); + } + } + int total_beam = 0; + // Try true and then false only if the beam is empty. This enables extending + // the context using only the top-n results first, which may have an empty + // intersection with the valid codes, so we fall back to the rest if the + // beam is empty. + for (int flag = 1; flag >= 0 && total_beam == 0; --flag) { + for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) { + // Working backwards through the heaps doesn't guarantee that we see the + // best first, but it comes before a lot of the worst, so it is slightly + // more efficient than going forwards. + for (int i = prev->dawg_beams_[length].size() - 1; i >= 0; --i) { + ContinueContext(&prev->dawg_beams_[length].get(i).data, length, + outputs, true, flag, dict_ratio, cert_offset, + worst_dict_cert, step); + } + for (int i = prev->beams_[length].size() - 1; i >= 0; --i) { + ContinueContext(&prev->beams_[length].get(i).data, length, outputs, + false, flag, dict_ratio, cert_offset, worst_dict_cert, + step); + } + } + for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) { + total_beam += step->beams_[length].size(); + total_beam += step->dawg_beams_[length].size(); + } + } + // Special case for the best initial dawg. Push it on the heap if good + // enough, but there is only one, so it doesn't blow up the beam. + RecodeHeap* dawg_heap = &step->dawg_beams_[0]; + if (step->best_initial_dawg_.code >= 0 && + (dawg_heap->size() < kBeamWidths[0] || + step->best_initial_dawg_.score > dawg_heap->PeekTop().data.score)) { + RecodePair entry(step->best_initial_dawg_.score, + step->best_initial_dawg_); + dawg_heap->Push(&entry); + if (dawg_heap->size() > kBeamWidths[0]) dawg_heap->Pop(&entry); + } + } +} + +// Adds to the appropriate beams the legal (according to recoder) +// continuations of context prev, which is of the given length, using the +// given network outputs to provide scores to the choices. Uses only those +// choices for which top_n_flags[index] == top_n_flag. +void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int length, + const float* outputs, bool use_dawgs, + bool top_n_flag, double dict_ratio, + double cert_offset, + double worst_dict_cert, + RecodeBeam* step) { + RecodedCharID prefix; + RecodedCharID full_code; + const RecodeNode* previous = prev; + for (int p = length - 1; p >= 0; --p, previous = previous->prev) { + while (previous != NULL && + (previous->duplicate || previous->code == null_char_)) { + previous = previous->prev; + } + prefix.Set(p, previous->code); + full_code.Set(p, previous->code); + } + if (prev != NULL && !is_simple_text_) { + float cert = NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset; + if ((cert >= kMinCertainty || prev->code == null_char_) && + top_n_flags_[prev->code] == top_n_flag) { + if (use_dawgs) { + if (cert > worst_dict_cert) { + PushDupIfBetter(kBeamWidths[length], cert, prev, + &step->dawg_beams_[length]); + } + } else { + PushDupIfBetter(kBeamWidths[length], cert * dict_ratio, prev, + &step->beams_[length]); + } + } + if (prev->code != null_char_ && length > 0 && + top_n_flags_[null_char_] == top_n_flag) { + // Allow nulls within multi code sequences, as the nulls within are not + // explicitly included in the code sequence. + cert = NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset; + if (cert >= kMinCertainty && (!use_dawgs || cert > worst_dict_cert)) { + if (use_dawgs) { + PushNoDawgIfBetter(kBeamWidths[length], null_char_, + INVALID_UNICHAR_ID, NO_PERM, cert, prev, + &step->dawg_beams_[length]); + } else { + PushNoDawgIfBetter(kBeamWidths[length], null_char_, + INVALID_UNICHAR_ID, TOP_CHOICE_PERM, + cert * dict_ratio, prev, &step->beams_[length]); + } + } + } + } + const GenericVector* final_codes = recoder_.GetFinalCodes(prefix); + if (final_codes != NULL) { + for (int i = 0; i < final_codes->size(); ++i) { + int code = (*final_codes)[i]; + if (top_n_flags_[code] != top_n_flag) continue; + if (prev != NULL && prev->code == code && !is_simple_text_) continue; + float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; + if (cert < kMinCertainty && code != null_char_) continue; + full_code.Set(length, code); + int unichar_id = recoder_.DecodeUnichar(full_code); + // Map the null char to INVALID. + if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID; + if (use_dawgs) { + if (cert > worst_dict_cert) { + ContinueDawg(kBeamWidths[0], code, unichar_id, cert, prev, + &step->dawg_beams_[0], step); + } + } else { + PushNoDawgIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, + cert * dict_ratio, prev, &step->beams_[0]); + if (dict_ != NULL && + ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) || + !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) { + // Any top choice position that can start a new word, ie a space or + // any non-space-delimited character, should also be considered + // by the dawg search, so push initial dawg to the dawg heap. + float dawg_cert = cert; + PermuterType permuter = TOP_CHOICE_PERM; + // Since we use the space either side of a dictionary word in the + // certainty of the word, (to properly handle weak spaces) and the + // space is coming from a non-dict word, we need special conditions + // to avoid degrading the certainty of the dict word that follows. + // With a space we don't multiply the certainty by dict_ratio, and we + // flag the space with NO_PERM to indicate that we should not use the + // predecessor nulls to generate the confidence for the space, as they + // have already been multiplied by dict_ratio, and we can't go back to + // insert more entries in any previous heaps. + if (unichar_id == UNICHAR_SPACE) + permuter = NO_PERM; + else + dawg_cert *= dict_ratio; + PushInitialDawgIfBetter(code, unichar_id, permuter, false, false, + dawg_cert, prev, &step->best_initial_dawg_); + } + } + } + } + const GenericVector* next_codes = recoder_.GetNextCodes(prefix); + if (next_codes != NULL) { + for (int i = 0; i < next_codes->size(); ++i) { + int code = (*next_codes)[i]; + if (top_n_flags_[code] != top_n_flag) continue; + if (prev != NULL && prev->code == code && !is_simple_text_) continue; + float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; + if (cert < kMinCertainty && code != null_char_) continue; + if (use_dawgs) { + if (cert > worst_dict_cert) { + ContinueDawg(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID, cert, + prev, &step->dawg_beams_[length + 1], step); + } + } else { + PushNoDawgIfBetter(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID, + TOP_CHOICE_PERM, cert * dict_ratio, prev, + &step->beams_[length + 1]); + } + } + } +} + +// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev, +// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id +// is a valid continuation of whatever is in prev. +void RecodeBeamSearch::ContinueDawg(int max_size, int code, int unichar_id, + float cert, const RecodeNode* prev, + RecodeHeap* heap, RecodeBeam* step) { + if (unichar_id == INVALID_UNICHAR_ID) { + PushNoDawgIfBetter(max_size, code, unichar_id, NO_PERM, cert, prev, heap); + return; + } + // Avoid dictionary probe if score a total loss. + float score = cert; + if (prev != NULL) score += prev->score; + if (heap->size() >= max_size && score <= heap->PeekTop().data.score) return; + const RecodeNode* uni_prev = prev; + // Prev may be a partial code, null_char, or duplicate, so scan back to the + // last valid unichar_id. + while (uni_prev != NULL && + (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) + uni_prev = uni_prev->prev; + if (unichar_id == UNICHAR_SPACE) { + if (uni_prev != NULL && uni_prev->end_of_word) { + // Space is good. Push initial state, to the dawg beam and a regular + // space to the top choice beam. + PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false, + false, cert, prev, &step->best_initial_dawg_); + PushNoDawgIfBetter(max_size, code, unichar_id, uni_prev->permuter, cert, + prev, &step->beams_[0]); + } + return; + } else if (uni_prev != NULL && uni_prev->start_of_dawg && + uni_prev->unichar_id != UNICHAR_SPACE && + dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) && + dict_->getUnicharset().IsSpaceDelimited(unichar_id)) { + return; // Can't break words between space delimited chars. + } + DawgPositionVector initial_dawgs; + DawgPositionVector* updated_dawgs = new DawgPositionVector; + DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM); + bool word_start = false; + if (uni_prev == NULL) { + // Starting from beginning of line. + dict_->default_dawgs(&initial_dawgs, false); + word_start = true; + } else if (uni_prev->dawgs != NULL) { + // Continuing a previous dict word. + dawg_args.active_dawgs = uni_prev->dawgs; + word_start = uni_prev->start_of_dawg; + } else { + return; // Can't continue if not a dict word. + } + PermuterType permuter = static_cast( + dict_->def_letter_is_okay(&dawg_args, unichar_id, false)); + if (permuter != NO_PERM) { + PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start, + dawg_args.valid_end, false, cert, prev, + dawg_args.updated_dawgs, heap); + if (dawg_args.valid_end && !space_delimited_) { + // We can start another word right away, so push initial state as well, + // to the dawg beam, and the regular character to the top choice beam, + // since non-dict words can start here too. + PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true, + cert, prev, &step->best_initial_dawg_); + PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start, + true, false, cert, prev, NULL, &step->beams_[0]); + } + } else { + delete updated_dawgs; + } +} + +// Adds a RecodeNode composed of the tuple (code, unichar_id, +// initial-dawg-state, prev, cert) to the given heap if/ there is room or if +// better than the current worst element if already full. +void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id, + PermuterType permuter, + bool start, bool end, float cert, + const RecodeNode* prev, + RecodeNode* best_initial_dawg) { + float score = cert; + if (prev != NULL) score += prev->score; + if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) { + DawgPositionVector* initial_dawgs = new DawgPositionVector; + dict_->default_dawgs(initial_dawgs, false); + RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert, + score, prev, initial_dawgs); + *best_initial_dawg = node; + } +} + +// Adds a copy of the given prev as a duplicate of and successor to prev, if +// there is room or if better than the current worst element if already full. +/* static */ +void RecodeBeamSearch::PushDupIfBetter(int max_size, float cert, + const RecodeNode* prev, + RecodeHeap* heap) { + PushHeapIfBetter(max_size, prev->code, prev->unichar_id, prev->permuter, + false, false, false, true, cert, prev, NULL, heap); +} + +// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, +// false, false, false, false, cert, prev, NULL) to heap if there is room +// or if better than the current worst element if already full. +/* static */ +void RecodeBeamSearch::PushNoDawgIfBetter(int max_size, int code, + int unichar_id, PermuterType permuter, + float cert, const RecodeNode* prev, + RecodeHeap* heap) { + float score = cert; + if (prev != NULL) score += prev->score; + if (heap->size() < max_size || score > heap->PeekTop().data.score) { + RecodeNode node(code, unichar_id, permuter, false, false, false, false, + cert, score, prev, NULL); + RecodePair entry(score, node); + heap->Push(&entry); + if (heap->size() > max_size) heap->Pop(&entry); + } +} + +// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, +// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room +// or if better than the current worst element if already full. +/* static */ +void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id, + PermuterType permuter, bool dawg_start, + bool word_start, bool end, bool dup, + float cert, const RecodeNode* prev, + DawgPositionVector* d, + RecodeHeap* heap) { + float score = cert; + if (prev != NULL) score += prev->score; + if (heap->size() < max_size || score > heap->PeekTop().data.score) { + RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end, + dup, cert, score, prev, d); + RecodePair entry(score, node); + heap->Push(&entry); + ASSERT_HOST(entry.data.dawgs == NULL); + if (heap->size() > max_size) heap->Pop(&entry); + } else { + delete d; + } +} + +// Backtracks to extract the best path through the lattice that was built +// during Decode. On return the best_nodes vector essentially contains the set +// of code, score pairs that make the optimal path with the constraint that +// the recoder can decode the code sequence back to a sequence of unichar-ids. +void RecodeBeamSearch::ExtractBestPaths( + GenericVector* best_nodes, + GenericVector* second_nodes) const { + // Scan both beams to extract the best and second best paths. + const RecodeNode* best_node = NULL; + const RecodeNode* second_best_node = NULL; + const RecodeBeam* last_beam = beam_[beam_size_ - 1]; + int heap_size = last_beam->beams_[0].size(); + for (int i = 0; i < heap_size; ++i) { + const RecodeNode* node = &last_beam->beams_[0].get(i).data; + if (best_node == NULL || node->score > best_node->score) { + second_best_node = best_node; + best_node = node; + } else if (second_best_node == NULL || + node->score > second_best_node->score) { + second_best_node = node; + } + } + // Scan the entire dawg heap for the best *valid* nodes, if any. + int dawg_size = last_beam->dawg_beams_[0].size(); + for (int i = 0; i < dawg_size; ++i) { + const RecodeNode* dawg_node = &last_beam->dawg_beams_[0].get(i).data; + // dawg_node may be a null_char, or duplicate, so scan back to the last + // valid unichar_id. + const RecodeNode* back_dawg_node = dawg_node; + while (back_dawg_node != NULL && + (back_dawg_node->unichar_id == INVALID_UNICHAR_ID || + back_dawg_node->duplicate)) + back_dawg_node = back_dawg_node->prev; + if (back_dawg_node != NULL && + (back_dawg_node->end_of_word || + back_dawg_node->unichar_id == UNICHAR_SPACE)) { + // Dawg node is valid. Use it in preference to back_dawg_node, as the + // score comparison is fair that way. + if (best_node == NULL || dawg_node->score > best_node->score) { + second_best_node = best_node; + best_node = dawg_node; + } else if (second_best_node == NULL || + dawg_node->score > second_best_node->score) { + second_best_node = dawg_node; + } + } + } + if (second_nodes != NULL) ExtractPath(second_best_node, second_nodes); + ExtractPath(best_node, best_nodes); +} + +// Helper backtracks through the lattice from the given node, storing the +// path and reversing it. +void RecodeBeamSearch::ExtractPath( + const RecodeNode* node, GenericVector* path) const { + path->truncate(0); + while (node != NULL) { + path->push_back(node); + node = node->prev; + } + path->reverse(); +} + +// Helper prints debug information on the given lattice path. +void RecodeBeamSearch::DebugPath( + const UNICHARSET* unicharset, + const GenericVector& path) const { + for (int c = 0; c < path.size(); ++c) { + const RecodeNode& node = *path[c]; + tprintf("%d %d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n", c, + node.unichar_id, unicharset->debug_str(node.unichar_id).string(), + node.score, node.certainty, node.start_of_word, node.end_of_word, + node.permuter); + } +} + +// Helper prints debug information on the given unichar path. +void RecodeBeamSearch::DebugUnicharPath( + const UNICHARSET* unicharset, const GenericVector& path, + const GenericVector& unichar_ids, const GenericVector& certs, + const GenericVector& ratings, + const GenericVector& xcoords) const { + int num_ids = unichar_ids.size(); + double total_rating = 0.0; + for (int c = 0; c < num_ids; ++c) { + int coord = xcoords[c]; + tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c], + unicharset->debug_str(unichar_ids[c]).string(), ratings[c], + certs[c], path[coord]->start_of_word, path[coord]->end_of_word, + path[coord]->permuter); + total_rating += ratings[c]; + } + tprintf("Path total rating = %g\n", total_rating); +} + +} // namespace tesseract. diff --git a/lstm/recodebeam.h b/lstm/recodebeam.h new file mode 100644 index 0000000000..df56ae0714 --- /dev/null +++ b/lstm/recodebeam.h @@ -0,0 +1,304 @@ +/////////////////////////////////////////////////////////////////////// +// File: recodebeam.h +// Description: Beam search to decode from the re-encoded CJK as a sequence of +// smaller numbers in place of a single large code. +// Author: Ray Smith +// Created: Fri Mar 13 09:12:01 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ +#define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ + +#include "dawg.h" +#include "dict.h" +#include "genericheap.h" +#include "kdpair.h" +#include "networkio.h" +#include "ratngs.h" +#include "unicharcompress.h" + +namespace tesseract { + +// Lattice element for Re-encode beam search. +struct RecodeNode { + RecodeNode() + : code(-1), + unichar_id(INVALID_UNICHAR_ID), + permuter(TOP_CHOICE_PERM), + start_of_dawg(false), + start_of_word(false), + end_of_word(false), + duplicate(false), + certainty(0.0f), + score(0.0f), + prev(NULL), + dawgs(NULL) {} + RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, + bool word_start, bool end, bool dup, float cert, float s, + const RecodeNode* p, DawgPositionVector* d) + : code(c), + unichar_id(uni_id), + permuter(perm), + start_of_dawg(dawg_start), + start_of_word(word_start), + end_of_word(end), + duplicate(dup), + certainty(cert), + score(s), + prev(p), + dawgs(d) {} + // NOTE: If we could use C++11, then this would be a move constructor. + // Instead we have copy constructor that does a move!! This is because we + // don't want to copy the whole DawgPositionVector each time, and true + // copying isn't necessary for this struct. It does get moved around a lot + // though inside the heap and during heap push, hence the move semantics. + RecodeNode(RecodeNode& src) : dawgs(NULL) { + *this = src; + ASSERT_HOST(src.dawgs == NULL); + } + RecodeNode& operator=(RecodeNode& src) { + delete dawgs; + memcpy(this, &src, sizeof(src)); + src.dawgs = NULL; + return *this; + } + ~RecodeNode() { delete dawgs; } + + // The re-encoded code here = index to network output. + int code; + // The decoded unichar_id is only valid for the final code of a sequence. + int unichar_id; + // The type of permuter active at this point. Intervals between start_of_word + // and end_of_word make valid words of type given by permuter where + // end_of_word is true. These aren't necessarily delimited by spaces. + PermuterType permuter; + // True if this is the initial dawg state. May be attached to a space or, + // in a non-space-delimited lang, the end of the previous word. + bool start_of_dawg; + // True if this is the first node in a dictionary word. + bool start_of_word; + // True if this represents a valid candidate end of word position. Does not + // necessarily mark the end of a word, since a word can be extended beyond a + // candidiate end by a continuation, eg 'the' continues to 'these'. + bool end_of_word; + // True if this is a duplicate of prev in all respects. Some training modes + // allow the network to output duplicate characters and crush them with CTC, + // but that would mess up the decoding, so we just smash them together on the + // fly using the duplicate flag. + bool duplicate; + // Certainty (log prob) of (just) this position. + float certainty; + // Total certainty of the path to this position. + float score; + // The previous node in this chain. Borrowed pointer. + const RecodeNode* prev; + // The currently active dawgs at this position. Owned pointer. + DawgPositionVector* dawgs; +}; + +typedef KDPairInc RecodePair; +typedef GenericHeap RecodeHeap; + +// Class that holds the entire beam search for recognition of a text line. +class RecodeBeamSearch { + public: + // Borrows the pointer, which is expected to survive until *this is deleted. + RecodeBeamSearch(const UnicharCompress& recoder, int null_char, + bool simple_text, Dict* dict); + + // Decodes the set of network outputs, storing the lattice internally. + // If charset is not null, it enables detailed debugging of the beam search. + void Decode(const NetworkIO& output, double dict_ratio, double cert_offset, + double worst_dict_cert, const UNICHARSET* charset); + void Decode(const GENERIC_2D_ARRAY& output, double dict_ratio, + double cert_offset, double worst_dict_cert, + const UNICHARSET* charset); + + // Returns the best path as labels/scores/xcoords similar to simple CTC. + void ExtractBestPathAsLabels(GenericVector* labels, + GenericVector* xcoords) const; + // Returns the best path as unichar-ids/certs/ratings/xcoords skipping + // duplicates, nulls and intermediate parts. + void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET* unicharset, + GenericVector* unichar_ids, + GenericVector* certs, + GenericVector* ratings, + GenericVector* xcoords) const; + + // Returns the best path as a set of WERD_RES. + void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor, + bool debug, const UNICHARSET* unicharset, + PointerVector* words); + + // Generates debug output of the content of the beams after a Decode. + void DebugBeams(const UNICHARSET& unicharset) const; + + // Clipping value for certainty inside Tesseract. Reflects the minimum value + // of certainty that will be returned by ExtractBestPathAsUnicharIds. + // Supposedly on a uniform scale that can be compared across languages and + // engines. + static const float kMinCertainty; + + private: + // Struct for the Re-encode beam search. This struct holds the data for + // a single time-step position of the output. Use a PointerVector + // to hold all the timesteps and prevent reallocation of the individual heaps. + struct RecodeBeam { + // Resets to the initial state without deleting all the memory. + void Clear() { + for (int i = 0; i <= RecodedCharID::kMaxCodeLen; ++i) { + beams_[i].clear(); + dawg_beams_[i].clear(); + } + RecodeNode empty; + best_initial_dawg_ = empty; + } + // A separate beam for each code position. Since there aren't that many + // code positions, this allows the beam to be quite narrow, and yet still + // have a low chance of losing the best path. + // Each heap is stored with the WORST result at the top, so we can quickly + // get the top-n values. + RecodeHeap beams_[RecodedCharID::kMaxCodeLen + 1]; + // Although, we can only use complete codes in the dawg, we have to separate + // partial code paths that lead back to a mid-dawg word from paths that are + // not part of a dawg word, as they have a different score. Since a dawg + // word can dead-end at any point, we need to keep the non dawg path going + // so the dawg beams_ are totally separate set with a heap for each length + // just like the non-dawg beams. + RecodeHeap dawg_beams_[RecodedCharID::kMaxCodeLen + 1]; + // While the language model is only a single word dictionary, we can use + // word starts as a choke point in the beam, and keep only a single dict + // start node at each step, so we find the best one here and push it on + // the heap, if it qualifies, after processing all of the step. + RecodeNode best_initial_dawg_; + }; + typedef KDPairInc TopPair; + + // Generates debug output of the content of a single beam position. + void DebugBeamPos(const UNICHARSET& unicharset, const RecodeHeap& heap) const; + + // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping + // duplicates, nulls and intermediate parts. + static void ExtractPathAsUnicharIds( + const GenericVector& best_nodes, + GenericVector* unichar_ids, GenericVector* certs, + GenericVector* ratings, GenericVector* xcoords); + + // Sets up a word with the ratings matrix and fake blobs with boxes in the + // right places. + WERD_RES* InitializeWord(bool leading_space, const TBOX& line_box, + int word_start, int word_end, float space_certainty, + const UNICHARSET* unicharset, + const GenericVector& xcoords, + float scale_factor); + + // Fills top_n_flags_ with bools that are true iff the corresponding output + // is one of the top_n. + void ComputeTopN(const float* outputs, int num_outputs, int top_n); + + // Adds the computation for the current time-step to the beam. Call at each + // time-step in sequence from left to right. outputs is the activation vector + // for the current timestep. + void DecodeStep(const float* outputs, int t, double dict_ratio, + double cert_offset, double worst_dict_cert, + const UNICHARSET* charset); + + // Adds to the appropriate beams the legal (according to recoder) + // continuations of context prev, which is of the given length, using the + // given network outputs to provide scores to the choices. Uses only those + // choices for which top_n_flags[index] == top_n_flag. + void ContinueContext(const RecodeNode* prev, int length, const float* outputs, + bool use_dawgs, bool top_n_flag, double dict_ratio, + double cert_offset, double worst_dict_cert, + RecodeBeam* step); + // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev, + // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id + // is a valid continuation of whatever is in prev. + void ContinueDawg(int max_size, int code, int unichar_id, float cert, + const RecodeNode* prev, RecodeHeap* heap, RecodeBeam* step); + // Adds a RecodeNode composed of the tuple (code, unichar_id, + // initial-dawg-state, prev, cert) to the given heap if/ there is room or if + // better than the current worst element if already full. + void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter, + bool start, bool end, float cert, + const RecodeNode* prev, + RecodeNode* best_initial_dawg); + // Adds a copy of the given prev as a duplicate of and successor to prev, if + // there is room or if better than the current worst element if already full. + static void PushDupIfBetter(int max_size, float cert, const RecodeNode* prev, + RecodeHeap* heap); + // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, + // false, false, false, false, cert, prev, NULL) to heap if there is room + // or if better than the current worst element if already full. + static void PushNoDawgIfBetter(int max_size, int code, int unichar_id, + PermuterType permuter, float cert, + const RecodeNode* prev, RecodeHeap* heap); + // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, + // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room + // or if better than the current worst element if already full. + static void PushHeapIfBetter(int max_size, int code, int unichar_id, + PermuterType permuter, bool dawg_start, + bool word_start, bool end, bool dup, float cert, + const RecodeNode* prev, DawgPositionVector* d, + RecodeHeap* heap); + // Backtracks to extract the best path through the lattice that was built + // during Decode. On return the best_nodes vector essentially contains the set + // of code, score pairs that make the optimal path with the constraint that + // the recoder can decode the code sequence back to a sequence of unichar-ids. + void ExtractBestPaths(GenericVector* best_nodes, + GenericVector* second_nodes) const; + // Helper backtracks through the lattice from the given node, storing the + // path and reversing it. + void ExtractPath(const RecodeNode* node, + GenericVector* path) const; + // Helper prints debug information on the given lattice path. + void DebugPath(const UNICHARSET* unicharset, + const GenericVector& path) const; + // Helper prints debug information on the given unichar path. + void DebugUnicharPath(const UNICHARSET* unicharset, + const GenericVector& path, + const GenericVector& unichar_ids, + const GenericVector& certs, + const GenericVector& ratings, + const GenericVector& xcoords) const; + + static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1]; + + // The encoder/decoder that we will be using. + const UnicharCompress& recoder_; + // The beam for each timestep in the output. + PointerVector beam_; + // The number of timesteps valid in beam_; + int beam_size_; + // A flag to indicate which outputs are the top-n choices. Current timestep + // only. + GenericVector top_n_flags_; + // Heap used to compute the top_n_flags_. + GenericHeap top_heap_; + // Borrowed pointer to the dictionary to use in the search. + Dict* dict_; + // True if the language is space-delimited, which is true for most languages + // except chi*, jpn, tha. + bool space_delimited_; + // True if the input is simple text, ie adjacent equal chars are not to be + // eliminated. + bool is_simple_text_; + // The encoded (class label) of the null/reject character. + int null_char_; +}; + +} // namespace tesseract. + +#endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ diff --git a/lstm/reconfig.cpp b/lstm/reconfig.cpp new file mode 100644 index 0000000000..aa5e01b92f --- /dev/null +++ b/lstm/reconfig.cpp @@ -0,0 +1,128 @@ +/////////////////////////////////////////////////////////////////////// +// File: reconfig.cpp +// Description: Network layer that reconfigures the scaling vs feature +// depth. +// Author: Ray Smith +// Created: Wed Feb 26 15:42:25 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#include "reconfig.h" +#include "tprintf.h" + +namespace tesseract { + +Reconfig::Reconfig(const STRING& name, int ni, int x_scale, int y_scale) + : Network(NT_RECONFIG, name, ni, ni * x_scale * y_scale), + x_scale_(x_scale), y_scale_(y_scale) { +} + +Reconfig::~Reconfig() { +} + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape Reconfig::OutputShape(const StaticShape& input_shape) const { + StaticShape result = input_shape; + result.set_height(result.height() / y_scale_); + result.set_width(result.width() / x_scale_); + if (type_ != NT_MAXPOOL) + result.set_depth(result.depth() * y_scale_ * x_scale_); + return result; +} + +// Returns an integer reduction factor that the network applies to the +// time sequence. Assumes that any 2-d is already eliminated. Used for +// scaling bounding boxes of truth data. +// WARNING: if GlobalMinimax is used to vary the scale, this will return +// the last used scale factor. Call it before any forward, and it will return +// the minimum scale factor of the paths through the GlobalMinimax. +int Reconfig::XScaleFactor() const { + return x_scale_; +} + +// Writes to the given file. Returns false in case of error. +bool Reconfig::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + if (fp->FWrite(&x_scale_, sizeof(x_scale_), 1) != 1) return false; + if (fp->FWrite(&y_scale_, sizeof(y_scale_), 1) != 1) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Reconfig::DeSerialize(bool swap, TFile* fp) { + if (fp->FRead(&x_scale_, sizeof(x_scale_), 1) != 1) return false; + if (fp->FRead(&y_scale_, sizeof(y_scale_), 1) != 1) return false; + if (swap) { + ReverseN(&x_scale_, sizeof(x_scale_)); + ReverseN(&y_scale_, sizeof(y_scale_)); + } + no_ = ni_ * x_scale_ * y_scale_; + return true; +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Reconfig::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + output->ResizeScaled(input, x_scale_, y_scale_, no_); + back_map_ = input.stride_map(); + StrideMap::Index dest_index(output->stride_map()); + do { + int out_t = dest_index.t(); + StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH), + dest_index.index(FD_HEIGHT) * y_scale_, + dest_index.index(FD_WIDTH) * x_scale_); + // Stack x_scale_ groups of y_scale_ inputs together. + for (int x = 0; x < x_scale_; ++x) { + for (int y = 0; y < y_scale_; ++y) { + StrideMap::Index src_xy(src_index); + if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) { + output->CopyTimeStepGeneral(out_t, (x * y_scale_ + y) * ni_, ni_, + input, src_xy.t(), 0); + } + } + } + } while (dest_index.Increment()); +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Reconfig::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_); + StrideMap::Index src_index(fwd_deltas.stride_map()); + do { + int in_t = src_index.t(); + StrideMap::Index dest_index(back_deltas->stride_map(), + src_index.index(FD_BATCH), + src_index.index(FD_HEIGHT) * y_scale_, + src_index.index(FD_WIDTH) * x_scale_); + // Unstack x_scale_ groups of y_scale_ inputs that are together. + for (int x = 0; x < x_scale_; ++x) { + for (int y = 0; y < y_scale_; ++y) { + StrideMap::Index dest_xy(dest_index); + if (dest_xy.AddOffset(x, FD_WIDTH) && dest_xy.AddOffset(y, FD_HEIGHT)) { + back_deltas->CopyTimeStepGeneral(dest_xy.t(), 0, ni_, fwd_deltas, + in_t, (x * y_scale_ + y) * ni_); + } + } + } + } while (src_index.Increment()); + return needs_to_backprop_; +} + + +} // namespace tesseract. diff --git a/lstm/reconfig.h b/lstm/reconfig.h new file mode 100644 index 0000000000..4409cf0a4a --- /dev/null +++ b/lstm/reconfig.h @@ -0,0 +1,86 @@ +/////////////////////////////////////////////////////////////////////// +// File: reconfig.h +// Description: Network layer that reconfigures the scaling vs feature +// depth. +// Author: Ray Smith +// Created: Wed Feb 26 15:37:42 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#ifndef TESSERACT_LSTM_RECONFIG_H_ +#define TESSERACT_LSTM_RECONFIG_H_ + + +#include "genericvector.h" +#include "matrix.h" +#include "network.h" + +namespace tesseract { + +// Reconfigures (Shrinks) the inputs by concatenating an x_scale by y_scale tile +// of inputs together, producing a single, deeper output per tile. +// Note that fractional parts are truncated for efficiency, so make sure the +// input stride is a multiple of the y_scale factor! +class Reconfig : public Network { + public: + Reconfig(const STRING& name, int ni, int x_scale, int y_scale); + virtual ~Reconfig(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec; + spec.add_str_int("S", y_scale_); + spec.add_str_int(",", x_scale_); + return spec; + } + + // Returns an integer reduction factor that the network applies to the + // time sequence. Assumes that any 2-d is already eliminated. Used for + // scaling bounding boxes of truth data. + // WARNING: if GlobalMinimax is used to vary the scale, this will return + // the last used scale factor. Call it before any forward, and it will return + // the minimum scale factor of the paths through the GlobalMinimax. + virtual int XScaleFactor() const; + + // Writes to the given file. Returns false in case of error. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + protected: + // Non-serialized data used to store parameters between forward and back. + StrideMap back_map_; + // Serialized data. + inT32 x_scale_; + inT32 y_scale_; +}; + +} // namespace tesseract. + + +#endif // TESSERACT_LSTM_SUBSAMPLE_H_ diff --git a/lstm/reversed.cpp b/lstm/reversed.cpp new file mode 100644 index 0000000000..9cdc4f96fd --- /dev/null +++ b/lstm/reversed.cpp @@ -0,0 +1,91 @@ +/////////////////////////////////////////////////////////////////////// +// File: reversed.cpp +// Description: Runs a single network on time-reversed input, reversing output. +// Author: Ray Smith +// Created: Thu May 02 08:42:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "reversed.h" + +#include + +#include "networkscratch.h" + +namespace tesseract { + +Reversed::Reversed(const STRING& name, NetworkType type) : Plumbing(name) { + type_ = type; +} +Reversed::~Reversed() { +} + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape Reversed::OutputShape(const StaticShape& input_shape) const { + if (type_ == NT_XYTRANSPOSE) { + StaticShape x_shape(input_shape); + x_shape.set_width(input_shape.height()); + x_shape.set_height(input_shape.width()); + x_shape = stack_[0]->OutputShape(x_shape); + x_shape.SetShape(x_shape.batch(), x_shape.width(), x_shape.height(), + x_shape.depth()); + return x_shape; + } + return stack_[0]->OutputShape(input_shape); +} + +// Takes ownership of the given network to make it the reversed one. +void Reversed::SetNetwork(Network* network) { + stack_.clear(); + AddToStack(network); +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Reversed::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + NetworkScratch::IO rev_input(input, scratch); + ReverseData(input, rev_input); + NetworkScratch::IO rev_output(input, scratch); + stack_[0]->Forward(debug, *rev_input, NULL, scratch, rev_output); + ReverseData(*rev_output, output); +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Reversed::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + NetworkScratch::IO rev_input(fwd_deltas, scratch); + ReverseData(fwd_deltas, rev_input); + NetworkScratch::IO rev_output(fwd_deltas, scratch); + if (stack_[0]->Backward(debug, *rev_input, scratch, rev_output)) { + ReverseData(*rev_output, back_deltas); + return true; + } + return false; +} + +// Copies src to *dest with the reversal according to type_. +void Reversed::ReverseData(const NetworkIO& src, NetworkIO* dest) const { + if (type_ == NT_XREVERSED) + dest->CopyWithXReversal(src); + else if (type_ == NT_YREVERSED) + dest->CopyWithYReversal(src); + else + dest->CopyWithXYTranspose(src); +} + +} // namespace tesseract. diff --git a/lstm/reversed.h b/lstm/reversed.h new file mode 100644 index 0000000000..97c2aebbc0 --- /dev/null +++ b/lstm/reversed.h @@ -0,0 +1,89 @@ +/////////////////////////////////////////////////////////////////////// +// File: reversed.h +// Description: Runs a single network on time-reversed input, reversing output. +// Author: Ray Smith +// Created: Thu May 02 08:38:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_REVERSED_H_ +#define TESSERACT_LSTM_REVERSED_H_ + +#include "matrix.h" +#include "plumbing.h" + +namespace tesseract { + +// C++ Implementation of the Reversed class from lstm.py. +class Reversed : public Plumbing { + public: + explicit Reversed(const STRING& name, NetworkType type); + virtual ~Reversed(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec(type_ == NT_XREVERSED ? "Rx" + : (type_ == NT_YREVERSED ? "Ry" : "Txy")); + // For most simple cases, we will output Rx or Ry where is + // the network in stack_[0], but in the special case that is an + // LSTM, we will just output the LSTM's spec modified to take the reversal + // into account. This is because when the user specified Lfy64, we actually + // generated TxyLfx64, and if the user specified Lrx64 we actually + // generated RxLfx64, and we want to display what the user asked for. + STRING net_spec = stack_[0]->spec(); + if (net_spec[0] == 'L') { + // Setup a from and to character according to the type of the reversal + // such that the LSTM spec gets modified to the spec that the user + // asked for + char from = 'f'; + char to = 'r'; + if (type_ == NT_XYTRANSPOSE) { + from = 'x'; + to = 'y'; + } + // Change the from char to the to char. + for (int i = 0; i < net_spec.length(); ++i) { + if (net_spec[i] == from) net_spec[i] = to; + } + return net_spec; + } + spec += net_spec; + return spec; + } + + // Takes ownership of the given network to make it the reversed one. + void SetNetwork(Network* network); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + private: + // Copies src to *dest with the reversal according to type_. + void ReverseData(const NetworkIO& src, NetworkIO* dest) const; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_REVERSED_H_ diff --git a/lstm/series.cpp b/lstm/series.cpp new file mode 100644 index 0000000000..df7bcbddaf --- /dev/null +++ b/lstm/series.cpp @@ -0,0 +1,188 @@ +/////////////////////////////////////////////////////////////////////// +// File: series.cpp +// Description: Runs networks in series on the same input. +// Author: Ray Smith +// Created: Thu May 02 08:26:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "series.h" + +#include "fullyconnected.h" +#include "networkscratch.h" +#include "scrollview.h" +#include "tprintf.h" + +namespace tesseract { + +// ni_ and no_ will be set by AddToStack. +Series::Series(const STRING& name) : Plumbing(name) { + type_ = NT_SERIES; +} + +Series::~Series() { +} + +// Returns the shape output from the network given an input shape (which may +// be partially unknown ie zero). +StaticShape Series::OutputShape(const StaticShape& input_shape) const { + StaticShape result(input_shape); + int stack_size = stack_.size(); + for (int i = 0; i < stack_size; ++i) { + result = stack_[i]->OutputShape(result); + } + return result; +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +// Note that series has its own implementation just for debug purposes. +int Series::InitWeights(float range, TRand* randomizer) { + num_weights_ = 0; + tprintf("Num outputs,weights in serial:\n"); + for (int i = 0; i < stack_.size(); ++i) { + int weights = stack_[i]->InitWeights(range, randomizer); + tprintf(" %s:%d, %d\n", + stack_[i]->spec().string(), stack_[i]->NumOutputs(), weights); + num_weights_ += weights; + } + tprintf("Total weights = %d\n", num_weights_); + return num_weights_; +} + +// Sets needs_to_backprop_ to needs_backprop and returns true if +// needs_backprop || any weights in this network so the next layer forward +// can be told to produce backprop for this layer if needed. +bool Series::SetupNeedsBackprop(bool needs_backprop) { + needs_to_backprop_ = needs_backprop; + for (int i = 0; i < stack_.size(); ++i) + needs_backprop = stack_[i]->SetupNeedsBackprop(needs_backprop); + return needs_backprop; +} + +// Returns an integer reduction factor that the network applies to the +// time sequence. Assumes that any 2-d is already eliminated. Used for +// scaling bounding boxes of truth data. +// WARNING: if GlobalMinimax is used to vary the scale, this will return +// the last used scale factor. Call it before any forward, and it will return +// the minimum scale factor of the paths through the GlobalMinimax. +int Series::XScaleFactor() const { + int factor = 1; + for (int i = 0; i < stack_.size(); ++i) + factor *= stack_[i]->XScaleFactor(); + return factor; +} + +// Provides the (minimum) x scale factor to the network (of interest only to +// input units) so they can determine how to scale bounding boxes. +void Series::CacheXScaleFactor(int factor) { + stack_[0]->CacheXScaleFactor(factor); +} + +// Runs forward propagation of activations on the input line. +// See NetworkCpp for a detailed discussion of the arguments. +void Series::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + int stack_size = stack_.size(); + ASSERT_HOST(stack_size > 1); + // Revolving intermediate buffers. + NetworkScratch::IO buffer1(input, scratch); + NetworkScratch::IO buffer2(input, scratch); + // Run each network in turn, giving the output of n as the input to n + 1, + // with the final network providing the real output. + stack_[0]->Forward(debug, input, input_transpose, scratch, buffer1); + for (int i = 1; i < stack_size; i += 2) { + stack_[i]->Forward(debug, *buffer1, NULL, scratch, + i + 1 < stack_size ? buffer2 : output); + if (i + 1 == stack_size) return; + stack_[i + 1]->Forward(debug, *buffer2, NULL, scratch, + i + 2 < stack_size ? buffer1 : output); + } +} + +// Runs backward propagation of errors on the deltas line. +// See NetworkCpp for a detailed discussion of the arguments. +bool Series::Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas) { + if (!training()) return false; + int stack_size = stack_.size(); + ASSERT_HOST(stack_size > 1); + // Revolving intermediate buffers. + NetworkScratch::IO buffer1(fwd_deltas, scratch); + NetworkScratch::IO buffer2(fwd_deltas, scratch); + // Run each network in reverse order, giving the back_deltas output of n as + // the fwd_deltas input to n-1, with the 0 network providing the real output. + if (!stack_.back()->training() || + !stack_.back()->Backward(debug, fwd_deltas, scratch, buffer1)) + return false; + for (int i = stack_size - 2; i >= 0; i -= 2) { + if (!stack_[i]->training() || + !stack_[i]->Backward(debug, *buffer1, scratch, + i > 0 ? buffer2 : back_deltas)) + return false; + if (i == 0) return needs_to_backprop_; + if (!stack_[i - 1]->training() || + !stack_[i - 1]->Backward(debug, *buffer2, scratch, + i > 1 ? buffer1 : back_deltas)) + return false; + } + return needs_to_backprop_; +} + +// Splits the series after the given index, returning the two parts and +// deletes itself. The first part, upto network with index last_start, goes +// into start, and the rest goes into end. +void Series::SplitAt(int last_start, Series** start, Series** end) { + *start = NULL; + *end = NULL; + if (last_start < 0 || last_start >= stack_.size()) { + tprintf("Invalid split index %d must be in range [0,%d]!\n", + last_start, stack_.size() - 1); + return; + } + Series* master_series = new Series("MasterSeries"); + Series* boosted_series = new Series("BoostedSeries"); + for (int s = 0; s <= last_start; ++s) { + if (s + 1 == stack_.size() && stack_[s]->type() == NT_SOFTMAX) { + // Change the softmax to a tanh. + FullyConnected* fc = reinterpret_cast(stack_[s]); + fc->ChangeType(NT_TANH); + } + master_series->AddToStack(stack_[s]); + stack_[s] = NULL; + } + for (int s = last_start + 1; s < stack_.size(); ++s) { + boosted_series->AddToStack(stack_[s]); + stack_[s] = NULL; + } + *start = master_series; + *end = boosted_series; + delete this; +} + +// Appends the elements of the src series to this, removing from src and +// deleting it. +void Series::AppendSeries(Network* src) { + ASSERT_HOST(src->type() == NT_SERIES); + Series* src_series = reinterpret_cast(src); + for (int s = 0; s < src_series->stack_.size(); ++s) { + AddToStack(src_series->stack_[s]); + src_series->stack_[s] = NULL; + } + delete src; +} + + +} // namespace tesseract. diff --git a/lstm/series.h b/lstm/series.h new file mode 100644 index 0000000000..cea1d82874 --- /dev/null +++ b/lstm/series.h @@ -0,0 +1,91 @@ +/////////////////////////////////////////////////////////////////////// +// File: series.h +// Description: Runs networks in series on the same input. +// Author: Ray Smith +// Created: Thu May 02 08:20:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_SERIES_H_ +#define TESSERACT_LSTM_SERIES_H_ + +#include "plumbing.h" + +namespace tesseract { + +// Runs two or more networks in series (layers) on the same input. +class Series : public Plumbing { + public: + // ni_ and no_ will be set by AddToStack. + explicit Series(const STRING& name); + virtual ~Series(); + + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const; + + virtual STRING spec() const { + STRING spec("["); + for (int i = 0; i < stack_.size(); ++i) + spec += stack_[i]->spec(); + spec += "]"; + return spec; + } + + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + // Returns the number of weights initialized. + virtual int InitWeights(float range, TRand* randomizer); + + // Sets needs_to_backprop_ to needs_backprop and returns true if + // needs_backprop || any weights in this network so the next layer forward + // can be told to produce backprop for this layer if needed. + virtual bool SetupNeedsBackprop(bool needs_backprop); + + // Returns an integer reduction factor that the network applies to the + // time sequence. Assumes that any 2-d is already eliminated. Used for + // scaling bounding boxes of truth data. + // WARNING: if GlobalMinimax is used to vary the scale, this will return + // the last used scale factor. Call it before any forward, and it will return + // the minimum scale factor of the paths through the GlobalMinimax. + virtual int XScaleFactor() const; + + // Provides the (minimum) x scale factor to the network (of interest only to + // input units) so they can determine how to scale bounding boxes. + virtual void CacheXScaleFactor(int factor); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + // Runs backward propagation of errors on the deltas line. + // See Network for a detailed discussion of the arguments. + virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, + NetworkScratch* scratch, + NetworkIO* back_deltas); + + // Splits the series after the given index, returning the two parts and + // deletes itself. The first part, upto network with index last_start, goes + // into start, and the rest goes into end. + void SplitAt(int last_start, Series** start, Series** end); + + // Appends the elements of the src series to this, removing from src and + // deleting it. + void AppendSeries(Network* src); +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_SERIES_H_ diff --git a/lstm/static_shape.h b/lstm/static_shape.h new file mode 100644 index 0000000000..25b8f03bab --- /dev/null +++ b/lstm/static_shape.h @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////// +// File: static_shape.h +// Description: Defines the size of the 4-d tensor input/output from a network. +// Author: Ray Smith +// Created: Fri Oct 14 09:07:31 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#ifndef TESSERACT_LSTM_STATIC_SHAPE_H_ +#define TESSERACT_LSTM_STATIC_SHAPE_H_ + +#include "tprintf.h" + +namespace tesseract { + +// Enum describing the loss function to apply during training and/or the +// decoding method to apply at runtime. +enum LossType { + LT_NONE, // Undefined. + LT_CTC, // Softmax with standard CTC for training/decoding. + LT_SOFTMAX, // Outputs sum to 1 in fixed positions. + LT_LOGISTIC, // Logistic outputs with independent values. +}; + +// Simple class to hold the tensor shape that is known at network build time +// and the LossType of the loss funtion. +class StaticShape { + public: + StaticShape() + : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {} + int batch() const { return batch_; } + void set_batch(int value) { batch_ = value; } + int height() const { return height_; } + void set_height(int value) { height_ = value; } + int width() const { return width_; } + void set_width(int value) { width_ = value; } + int depth() const { return depth_; } + void set_depth(int value) { depth_ = value; } + LossType loss_type() const { return loss_type_; } + void set_loss_type(LossType value) { loss_type_ = value; } + void SetShape(int batch, int height, int width, int depth) { + batch_ = batch; + height_ = height; + width_ = width; + depth_ = depth; + } + + void Print() const { + tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_, + height_, width_, depth_, loss_type_); + } + + private: + // Size of the 4-D tensor input/output to a network. A value of zero is + // allowed for all except depth_ and means to be determined at runtime, and + // regarded as variable. + // Number of elements in a batch, or number of frames in a video stream. + int batch_; + // Height of the image. + int height_; + // Width of the image. + int width_; + // Depth of the image. (Number of "nodes"). + int depth_; + // How to train/interpret the output. + LossType loss_type_; +}; + +} // namespace tesseract + +#endif // TESSERACT_LSTM_STATIC_SHAPE_H_ diff --git a/lstm/stridemap.cpp b/lstm/stridemap.cpp new file mode 100644 index 0000000000..3d95907ea0 --- /dev/null +++ b/lstm/stridemap.cpp @@ -0,0 +1,173 @@ +/////////////////////////////////////////////////////////////////////// +// File: stridemap.cpp +// Description: Indexing into a 4-d tensor held in a 2-d Array. +// Author: Ray Smith +// Created: Fri Sep 20 15:30:31 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "stridemap.h" + +namespace tesseract { + +// Returns true if *this is a valid index. +bool StrideMap::Index::IsValid() const { + // Cheap check first. + for (int d = 0; d < FD_DIMSIZE; ++d) { + if (indices_[d] < 0) return false; + } + for (int d = 0; d < FD_DIMSIZE; ++d) { + if (indices_[d] > MaxIndexOfDim(static_cast(d))) + return false; + } + return true; +} + +// Returns true if the index of the given dimension is the last. +bool StrideMap::Index::IsLast(FlexDimensions dimension) const { + return MaxIndexOfDim(dimension) == indices_[dimension]; +} + +// Given that the dimensions upto and including dim-1 are valid, returns the +// maximum index for dimension dim. +int StrideMap::Index::MaxIndexOfDim(FlexDimensions dim) const { + int max_index = stride_map_->shape_[dim] - 1; + if (dim == FD_BATCH) return max_index; + int batch = indices_[FD_BATCH]; + if (dim == FD_HEIGHT) { + if (batch >= stride_map_->heights_.size() || + stride_map_->heights_[batch] > max_index) + return max_index; + return stride_map_->heights_[batch] - 1; + } + if (batch >= stride_map_->widths_.size() || + stride_map_->widths_[batch] > max_index) + return max_index; + return stride_map_->widths_[batch] - 1; +} + +// Adds the given offset to the given dimension. Returns true if the result +// makes a valid index. +bool StrideMap::Index::AddOffset(int offset, FlexDimensions dimension) { + indices_[dimension] += offset; + SetTFromIndices(); + return IsValid(); +} + +// Increments the index in some encapsulated way that guarantees to remain +// valid until it returns false, meaning that the iteration is complete. +bool StrideMap::Index::Increment() { + for (int d = FD_DIMSIZE - 1; d >= 0; --d) { + if (!IsLast(static_cast(d))) { + t_ += stride_map_->t_increments_[d]; + ++indices_[d]; + return true; + } + t_ -= stride_map_->t_increments_[d] * indices_[d]; + indices_[d] = 0; + // Now carry to the next dimension. + } + return false; +} + +// Decrements the index in some encapsulated way that guarantees to remain +// valid until it returns false, meaning that the iteration (that started +// with InitToLast()) is complete. +bool StrideMap::Index::Decrement() { + for (int d = FD_DIMSIZE - 1; d >= 0; --d) { + if (indices_[d] > 0) { + --indices_[d]; + if (d == FD_BATCH) { + // The upper limits of the other dimensions may have changed as a result + // of a different batch index, so they have to be reset. + InitToLastOfBatch(indices_[FD_BATCH]); + } else { + t_ -= stride_map_->t_increments_[d]; + } + return true; + } + indices_[d] = MaxIndexOfDim(static_cast(d)); + t_ += stride_map_->t_increments_[d] * indices_[d]; + // Now borrow from the next dimension. + } + return false; +} + +// Initializes the indices to the last valid location in the given batch +// index. +void StrideMap::Index::InitToLastOfBatch(int batch) { + indices_[FD_BATCH] = batch; + for (int d = FD_BATCH + 1; d < FD_DIMSIZE; ++d) { + indices_[d] = MaxIndexOfDim(static_cast(d)); + } + SetTFromIndices(); +} + +// Computes and sets t_ from the current indices_. +void StrideMap::Index::SetTFromIndices() { + t_ = 0; + for (int d = 0; d < FD_DIMSIZE; ++d) { + t_ += stride_map_->t_increments_[d] * indices_[d]; + } +} + +// Sets up the stride for the given array of height, width pairs. +void StrideMap::SetStride(const std::vector>& h_w_pairs) { + int max_height = 0; + int max_width = 0; + for (const std::pair& hw : h_w_pairs) { + int height = hw.first; + int width = hw.second; + heights_.push_back(height); + widths_.push_back(width); + if (height > max_height) max_height = height; + if (width > max_width) max_width = width; + } + shape_[FD_BATCH] = heights_.size(); + shape_[FD_HEIGHT] = max_height; + shape_[FD_WIDTH] = max_width; + ComputeTIncrements(); +} + +// Scales width and height dimensions by the given factors. +void StrideMap::ScaleXY(int x_factor, int y_factor) { + for (int& height : heights_) height /= y_factor; + for (int& width : widths_) width /= x_factor; + shape_[FD_HEIGHT] /= y_factor; + shape_[FD_WIDTH] /= x_factor; + ComputeTIncrements(); +} + +// Reduces width to 1, across the batch, whatever the input size. +void StrideMap::ReduceWidthTo1() { + widths_.assign(widths_.size(), 1); + shape_[FD_WIDTH] = 1; + ComputeTIncrements(); +} + +// Transposes the width and height dimensions. +void StrideMap::TransposeXY() { + std::swap(shape_[FD_HEIGHT], shape_[FD_WIDTH]); + std::swap(heights_, widths_); + ComputeTIncrements(); +} + +// Computes t_increments_ from shape_. +void StrideMap::ComputeTIncrements() { + t_increments_[FD_DIMSIZE - 1] = 1; + for (int d = FD_DIMSIZE - 2; d >= 0; --d) { + t_increments_[d] = t_increments_[d + 1] * shape_[d + 1]; + } +} + +} // namespace tesseract diff --git a/lstm/stridemap.h b/lstm/stridemap.h new file mode 100644 index 0000000000..60dd855b4f --- /dev/null +++ b/lstm/stridemap.h @@ -0,0 +1,137 @@ +/////////////////////////////////////////////////////////////////////// +// File: stridemap.h +// Description: Indexing into a 4-d tensor held in a 2-d Array. +// Author: Ray Smith +// Created: Fri Sep 20 16:00:31 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#ifndef TESSERACT_LSTM_STRIDEMAP_H_ +#define TESSERACT_LSTM_STRIDEMAP_H_ + +#include +#include +#include "tprintf.h" + +namespace tesseract { + +// Enum describing the dimensions of the 'Tensor' in a NetworkIO. +// A NetworkIO is analogous to a TF Tensor, except that the number of dimensions +// is fixed (4), and they always have the same meaning. The underlying +// representation is a 2-D array, for which the product batch*height*width +// is always dim1 and depth is always dim2. FlexDimensions is used only for +// batch, height, width with the StrideMap, and therefore represents the runtime +// shape. The build-time shape is defined by StaticShape. +enum FlexDimensions { + FD_BATCH, // Index of multiple images. + FD_HEIGHT, // y-coordinate in image. + FD_WIDTH, // x-coordinate in image. + FD_DIMSIZE, // Number of flexible non-depth dimensions. +}; + +// Encapsulation of information relating to the mapping from [batch][y][x] to +// the first index into the 2-d array underlying a NetworkIO. +class StrideMap { + public: + // Class holding the non-depth indices. + class Index { + public: + explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) { + InitToFirst(); + } + Index(const StrideMap& stride_map, int batch, int y, int x) + : stride_map_(&stride_map) { + indices_[FD_BATCH] = batch; + indices_[FD_HEIGHT] = y; + indices_[FD_WIDTH] = x; + SetTFromIndices(); + } + // Accesses the index to the underlying array. + int t() const { return t_; } + int index(FlexDimensions dimension) const { return indices_[dimension]; } + // Initializes the indices to the first valid location. + void InitToFirst() { + memset(indices_, 0, sizeof(indices_)); + t_ = 0; + } + // Initializes the indices to the last valid location. + void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); } + // Returns true if *this is a valid index. + bool IsValid() const; + // Returns true if the index of the given dimension is the last. + bool IsLast(FlexDimensions dimension) const; + // Given that the dimensions upto and including dim-1 are valid, returns the + // maximum index for dimension dim. + int MaxIndexOfDim(FlexDimensions dim) const; + // Adds the given offset to the given dimension. Returns true if the result + // makes a valid index. + bool AddOffset(int offset, FlexDimensions dimension); + // Increments the index in some encapsulated way that guarantees to remain + // valid until it returns false, meaning that the iteration is complete. + bool Increment(); + // Decrements the index in some encapsulated way that guarantees to remain + // valid until it returns false, meaning that the iteration (that started + // with InitToLast()) is complete. + bool Decrement(); + + private: + // Initializes the indices to the last valid location in the given batch + // index. + void InitToLastOfBatch(int batch); + // Computes and sets t_ from the current indices_. + void SetTFromIndices(); + + // Map into which *this is an index. + const StrideMap* stride_map_; + // Index to the first dimension of the underlying array. + int t_; + // Indices into the individual dimensions. + int indices_[FD_DIMSIZE]; + }; + + StrideMap() { + memset(shape_, 0, sizeof(shape_)); + memset(t_increments_, 0, sizeof(t_increments_)); + } + // Default copy constructor and operator= are OK to use here! + + // Sets up the stride for the given array of height, width pairs. + void SetStride(const std::vector>& h_w_pairs); + // Scales width and height dimensions by the given factors. + void ScaleXY(int x_factor, int y_factor); + // Reduces width to 1, across the batch, whatever the input size. + void ReduceWidthTo1(); + // Transposes the width and height dimensions. + void TransposeXY(); + // Returns the size of the given dimension. + int Size(FlexDimensions dimension) const { return shape_[dimension]; } + // Returns the total width required. + int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; } + + private: + // Computes t_increments_ from shape_. + void ComputeTIncrements(); + + // The size of each non-depth dimension. + int shape_[FD_DIMSIZE]; + // Precomputed 't' increments for each dimension. This is the value of + // the given dimension in the packed 3-d array that the shape_ represents. + int t_increments_[FD_DIMSIZE]; + // Vector of size shape_[FD_BATCH] holds the height of each image in a batch. + std::vector heights_; + // Vector of size shape_[FD_BATCH] holds the width of each image in a batch. + std::vector widths_; +}; + +} // namespace tesseract + +#endif // TESSERACT_LSTM_STRIDEMAP_H_ diff --git a/lstm/tfnetwork.cpp b/lstm/tfnetwork.cpp new file mode 100644 index 0000000000..f7180a0903 --- /dev/null +++ b/lstm/tfnetwork.cpp @@ -0,0 +1,146 @@ +/////////////////////////////////////////////////////////////////////// +// File: tfnetwork.h +// Description: Encapsulation of an entire tensorflow graph as a +// Tesseract Network. +// Author: Ray Smith +// Created: Fri Feb 26 09:35:29 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#ifdef INCLUDE_TENSORFLOW + +#include "tfnetwork.h" + +#include "allheaders.h" +#include "input.h" +#include "networkscratch.h" + +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; + +namespace tesseract { + +TFNetwork::TFNetwork(const STRING& name) : Network(NT_TENSORFLOW, name, 0, 0) {} + +TFNetwork::~TFNetwork() {} + +int TFNetwork::InitFromProtoStr(const string& proto_str) { + if (!model_proto_.ParseFromString(proto_str)) return 0; + return InitFromProto(); +} + +// Writes to the given file. Returns false in case of error. +// Should be overridden by subclasses, but called by their Serialize. +bool TFNetwork::Serialize(TFile* fp) const { + if (!Network::Serialize(fp)) return false; + string proto_str; + model_proto_.SerializeToString(&proto_str); + GenericVector data; + data.init_to_size(proto_str.size(), 0); + memcpy(&data[0], proto_str.data(), proto_str.size()); + if (!data.Serialize(fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +// Should be overridden by subclasses, but NOT called by their DeSerialize. +bool TFNetwork::DeSerialize(bool swap, TFile* fp) { + GenericVector data; + if (!data.DeSerialize(swap, fp)) return false; + if (!model_proto_.ParseFromArray(&data[0], data.size())) { + return false; + } + return InitFromProto(); +} + +// Runs forward propagation of activations on the input line. +// See Network for a detailed discussion of the arguments. +void TFNetwork::Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output) { + vector> tf_inputs; + int depth = input_shape_.depth(); + ASSERT_HOST(depth == input.NumFeatures()); + // TODO(rays) Allow batching. For now batch_size = 1. + const StrideMap& stride_map = input.stride_map(); + // TF requires a tensor of shape float[batch, height, width, depth]. + TensorShape shape{1, stride_map.Size(FD_HEIGHT), stride_map.Size(FD_WIDTH), + depth}; + Tensor input_tensor(tensorflow::DT_FLOAT, shape); + // The flat() member gives a 1d array, with a data() member to get the data. + auto eigen_tensor = input_tensor.flat(); + memcpy(eigen_tensor.data(), input.f(0), + input.Width() * depth * sizeof(input.f(0)[0])); + // Add the tensor to the vector of inputs. + tf_inputs.emplace_back(model_proto_.image_input(), input_tensor); + + // Provide tensors giving the width and/or height of the image if they are + // required. Some tf ops require a separate tensor with knowledge of the + // size of the input as they cannot obtain it from the input tensor. This is + // usually true in the case of ops that process a batch of variable-sized + // objects. + if (!model_proto_.image_widths().empty()) { + TensorShape size_shape{1}; + Tensor width_tensor(tensorflow::DT_INT32, size_shape); + auto eigen_wtensor = width_tensor.flat(); + *eigen_wtensor.data() = stride_map.Size(FD_WIDTH); + tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor); + } + if (!model_proto_.image_heights().empty()) { + TensorShape size_shape{1}; + Tensor height_tensor(tensorflow::DT_INT32, size_shape); + auto eigen_htensor = height_tensor.flat(); + *eigen_htensor.data() = stride_map.Size(FD_HEIGHT); + tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor); + } + vector target_layers = {model_proto_.output_layer()}; + vector outputs; + Status s = session_->Run(tf_inputs, target_layers, {}, &outputs); + ASSERT_HOST(s.ok()); + ASSERT_HOST(outputs.size() == 1); + const Tensor& output_tensor = outputs[0]; + // Check the dimensions of the output. + ASSERT_HOST(output_tensor.shape().dims() == 2); + int output_dim0 = output_tensor.shape().dim_size(0); + int output_dim1 = output_tensor.shape().dim_size(1); + ASSERT_HOST(output_dim1 == output_shape_.depth()); + output->Resize2d(false, output_dim0, output_dim1); + auto eigen_output = output_tensor.flat(); + memcpy(output->f(0), eigen_output.data(), + output_dim0 * output_dim1 * sizeof(output->f(0)[0])); +} + +int TFNetwork::InitFromProto() { + spec_ = model_proto_.spec(); + input_shape_.SetShape( + model_proto_.batch_size(), std::max(0, model_proto_.y_size()), + std::max(0, model_proto_.x_size()), model_proto_.depth()); + output_shape_.SetShape(model_proto_.batch_size(), 1, 0, + model_proto_.num_classes()); + output_shape_.set_loss_type(model_proto_.using_ctc() ? LT_CTC : LT_SOFTMAX); + ni_ = input_shape_.height(); + no_ = output_shape_.depth(); + // Initialize the session_ with the graph. Since we can't get the graph + // back from the session_, we have to keep the proto as well + tensorflow::SessionOptions options; + session_.reset(NewSession(options)); + Status s = session_->Create(model_proto_.graph()); + if (s.ok()) return model_proto_.global_step(); + tprintf("Session_->Create returned '%s'\n", s.error_message().c_str()); + return 0; +} + +} // namespace tesseract + +#endif // ifdef INCLUDE_TENSORFLOW diff --git a/lstm/tfnetwork.h b/lstm/tfnetwork.h new file mode 100644 index 0000000000..749706cd43 --- /dev/null +++ b/lstm/tfnetwork.h @@ -0,0 +1,91 @@ +/////////////////////////////////////////////////////////////////////// +// File: tfnetwork.h +// Description: Encapsulation of an entire tensorflow graph as a +// Tesseract Network. +// Author: Ray Smith +// Created: Fri Feb 26 09:35:29 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_TFNETWORK_H_ +#define TESSERACT_LSTM_TFNETWORK_H_ + +#ifdef INCLUDE_TENSORFLOW + +#include +#include + +#include "network.h" +#include "static_shape.h" +#include "tfnetwork.proto.h" +#include "third_party/tensorflow/core/framework/graph.pb.h" +#include "third_party/tensorflow/core/public/session.h" + +namespace tesseract { + +class TFNetwork : public Network { + public: + explicit TFNetwork(const STRING& name); + virtual ~TFNetwork(); + + // Returns the required shape input to the network. + virtual StaticShape InputShape() const { return input_shape_; } + // Returns the shape output from the network given an input shape (which may + // be partially unknown ie zero). + virtual StaticShape OutputShape(const StaticShape& input_shape) const { + return output_shape_; + } + + virtual STRING spec() const { return spec_.c_str(); } + + // Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed, + // otherwise the global step of the serialized graph. + int InitFromProtoStr(const string& proto_str); + // The number of classes in this network should be equal to those in the + // recoder_ in LSTMRecognizer. + int num_classes() const { return output_shape_.depth(); } + + // Writes to the given file. Returns false in case of error. + // Should be overridden by subclasses, but called by their Serialize. + virtual bool Serialize(TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + // Should be overridden by subclasses, but NOT called by their DeSerialize. + virtual bool DeSerialize(bool swap, TFile* fp); + + // Runs forward propagation of activations on the input line. + // See Network for a detailed discussion of the arguments. + virtual void Forward(bool debug, const NetworkIO& input, + const TransposedArray* input_transpose, + NetworkScratch* scratch, NetworkIO* output); + + private: + int InitFromProto(); + + // The original network definition for reference. + string spec_; + // Input tensor parameters. + StaticShape input_shape_; + // Output tensor parameters. + StaticShape output_shape_; + // The tensor flow graph is contained in here. + std::unique_ptr session_; + // The serialized graph is also contained in here. + TFNetworkModel model_proto_; +}; + +} // namespace tesseract. + +#endif // ifdef INCLUDE_TENSORFLOW + +#endif // TESSERACT_TENSORFLOW_TFNETWORK_H_ diff --git a/lstm/tfnetwork.proto b/lstm/tfnetwork.proto new file mode 100644 index 0000000000..0942fd2724 --- /dev/null +++ b/lstm/tfnetwork.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package tesseract; + +// TODO(rays) How to make this usable both in Google and open source? +import "third_party/tensorflow/core/framework/graph.proto"; + +// This proto is the interface between a python TF graph builder/trainer and +// the C++ world. The writer of this proto must provide fields as documented +// by the comments below. +// The graph must have a placeholder for NetworkIO, Widths and Heights. The +// following python code creates the appropriate placeholders: +// +// input_layer = tf.placeholder(tf.float32, +// shape=[batch_size, xsize, ysize, depth_dim], +// name='NetworkIO') +// widths = tf.placeholder(tf.int32, shape=[batch_size], name='Widths') +// heights = tf.placeholder(tf.int32, shape=[batch_size], name='Heights') +// # Flip x and y to the TF convention. +// input_layer = tf.transpose(input_layer, [0, 2, 1, 3]) +// +// The widths and heights will be set to indicate the post-scaling size of the +// input image(s). +// For now batch_size is ignored and set to 1. +// The graph should return a 2-dimensional float32 tensor called 'softmax' of +// shape [sequence_length, num_classes], where sequence_length is allowed to +// be variable, given by the tensor itself. +// TODO(rays) determine whether it is worth providing for batch_size >1 and if +// so, how. +message TFNetworkModel { + // The TF graph definition. Required. + tensorflow.GraphDef graph = 1; + // The training index. Required to be > 0. + int64 global_step = 2; + // The original network definition for reference. Optional + string spec = 3; + // Input tensor parameters. + // Values per pixel. Required to be 1 or 3. Inputs assumed to be float32. + int32 depth = 4; + // Image size. Required. Zero implies flexible sizes, fixed if non-zero. + // If x_size > 0, images will be cropped/padded to the given size, after + // any scaling required by the y_size. + // If y_size > 0, images will be scaled isotropically to the given height. + int32 x_size = 5; + int32 y_size = 6; + // Number of images in a batch. Optional. + int32 batch_size = 8; + // Output tensor parameters. + // Number of output classes. Required to match the depth of the softmax. + int32 num_classes = 9; + // True if this network needs CTC-like decoding, dropping duplicated labels. + // The decoder always drops the null character. + bool using_ctc = 10; + // Name of input image tensor. + string image_input = 11; + // Name of image height and width tensors. + string image_widths = 12; + string image_heights = 13; + // Name of output (softmax) tensor. + string output_layer = 14; +} diff --git a/lstm/weightmatrix.cpp b/lstm/weightmatrix.cpp new file mode 100644 index 0000000000..c596fcac59 --- /dev/null +++ b/lstm/weightmatrix.cpp @@ -0,0 +1,443 @@ +/////////////////////////////////////////////////////////////////////// +// File: weightmatrix.h +// Description: Hides distinction between float/int implementations. +// Author: Ray Smith +// Created: Tue Jun 17 11:46:20 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "weightmatrix.h" + +#undef NONX86_BUILD +#if defined(ANDROID_BUILD) or defined(__PPC__) or defined(_ARCH_PPC64) +#define NONX86_BUILD 1 +#endif + +#ifndef NONX86_BUILD +#include +#endif +#include "dotproductavx.h" +#include "dotproductsse.h" +#include "statistc.h" +#include "svutil.h" +#include "tprintf.h" + +namespace tesseract { + +// Architecture detector. Add code here to detect any other architectures for +// SIMD-based faster dot product functions. Intended to be a single static +// object, but it does no real harm to have more than one. +class SIMDDetect { + public: + SIMDDetect() + : arch_tested_(false), avx_available_(false), sse_available_(false) {} + + // Returns true if AVX is available on this system. + bool IsAVXAvailable() { + if (!arch_tested_) TestArchitecture(); + return avx_available_; + } + // Returns true if SSE4.1 is available on this system. + bool IsSSEAvailable() { + if (!arch_tested_) TestArchitecture(); + return sse_available_; + } + + private: + // Tests the architecture in a system-dependent way to detect AVX, SSE and + // any other available SIMD equipment. + void TestArchitecture() { + SVAutoLock lock(&arch_mutex_); + if (arch_tested_) return; +#if defined(__linux__) && !defined(NONX86_BUILD) + if (__get_cpuid_max(0, NULL) >= 1) { + unsigned int eax, ebx, ecx, edx; + __get_cpuid(1, &eax, &ebx, &ecx, &edx); + sse_available_ = (ecx & 0x00080000) != 0; + avx_available_ = (ecx & 0x10000000) != 0; + } +#endif + if (avx_available_) tprintf("Found AVX\n"); + if (sse_available_) tprintf("Found SSE\n"); + arch_tested_ = true; + } + + private: + // Detect architecture in only a single thread. + SVMutex arch_mutex_; + // Flag set to true after TestArchitecture has been called. + bool arch_tested_; + // If true, then AVX has been detected. + bool avx_available_; + // If true, then SSe4.1 has been detected. + bool sse_available_; +}; + +static SIMDDetect detector; + +// Copies the whole input transposed, converted to double, into *this. +void TransposedArray::Transpose(const GENERIC_2D_ARRAY& input) { + int width = input.dim1(); + int num_features = input.dim2(); + ResizeNoInit(num_features, width); + for (int t = 0; t < width; ++t) WriteStrided(t, input[t]); +} + +// Sets up the network for training. Initializes weights using weights of +// scale `range` picked according to the random number generator `randomizer`. +int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, + float weight_range, TRand* randomizer) { + int_mode_ = false; + use_ada_grad_ = ada_grad; + if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0); + wf_.Resize(no, ni, 0.0); + if (randomizer != NULL) { + for (int i = 0; i < no; ++i) { + for (int j = 0; j < ni; ++j) { + wf_[i][j] = randomizer->SignedRand(weight_range); + } + } + } + InitBackward(); + return ni * no; +} + +// Converts a float network to an int network. Each set of input weights that +// corresponds to a single output weight is converted independently: +// Compute the max absolute value of the weight set. +// Scale so the max absolute value becomes MAX_INT8. +// Round to integer. +// Store a multiplicative scale factor (as a double) that will reproduce +// the original value, subject to rounding errors. +void WeightMatrix::ConvertToInt() { + wi_.ResizeNoInit(wf_.dim1(), wf_.dim2()); + scales_.init_to_size(wi_.dim1(), 0.0); + int dim2 = wi_.dim2(); + for (int t = 0; t < wi_.dim1(); ++t) { + double* f_line = wf_[t]; + inT8* i_line = wi_[t]; + double max_abs = 0.0; + for (int f = 0; f < dim2; ++f) { + double abs_val = fabs(f_line[f]); + if (abs_val > max_abs) max_abs = abs_val; + } + double scale = max_abs / MAX_INT8; + scales_[t] = scale; + if (scale == 0.0) scale = 1.0; + for (int f = 0; f < dim2; ++f) { + i_line[f] = IntCastRounded(f_line[f] / scale); + } + } + wf_.Resize(1, 1, 0.0); + int_mode_ = true; +} + +// Allocates any needed memory for running Backward, and zeroes the deltas, +// thus eliminating any existing momentum. +void WeightMatrix::InitBackward() { + int no = int_mode_ ? wi_.dim1() : wf_.dim1(); + int ni = int_mode_ ? wi_.dim2() : wf_.dim2(); + dw_.Resize(no, ni, 0.0); + updates_.Resize(no, ni, 0.0); + wf_t_.Transpose(wf_); +} + +// Flag on mode to indicate that this weightmatrix uses inT8. +const int kInt8Flag = 1; +// Flag on mode to indicate that this weightmatrix uses ada grad. +const int kAdaGradFlag = 4; +// Flag on mode to indicate that this weightmatrix uses double. Set +// independently of kInt8Flag as even in int mode the scales can +// be float or double. +const int kDoubleFlag = 128; + +// Writes to the given file. Returns false in case of error. +bool WeightMatrix::Serialize(bool training, TFile* fp) const { + // For backward compatability, add kDoubleFlag to mode to indicate the doubles + // format, without errs, so we can detect and read old format weight matrices. + uinT8 mode = (int_mode_ ? kInt8Flag : 0) | + (use_ada_grad_ ? kAdaGradFlag : 0) | kDoubleFlag; + if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false; + if (int_mode_) { + if (!wi_.Serialize(fp)) return false; + if (!scales_.Serialize(fp)) return false; + } else { + if (!wf_.Serialize(fp)) return false; + if (training && !updates_.Serialize(fp)) return false; + if (training && use_ada_grad_ && !dw_sq_sum_.Serialize(fp)) return false; + } + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool WeightMatrix::DeSerialize(bool training, bool swap, TFile* fp) { + uinT8 mode = 0; + if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false; + int_mode_ = (mode & kInt8Flag) != 0; + use_ada_grad_ = (mode & kAdaGradFlag) != 0; + if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, swap, fp); + if (int_mode_) { + if (!wi_.DeSerialize(swap, fp)) return false; + if (!scales_.DeSerialize(swap, fp)) return false; + } else { + if (!wf_.DeSerialize(swap, fp)) return false; + if (training) { + InitBackward(); + if (!updates_.DeSerialize(swap, fp)) return false; + if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(swap, fp)) return false; + } + } + return true; +} + +// As DeSerialize, but reads an old (float) format WeightMatrix for +// backward compatability. +bool WeightMatrix::DeSerializeOld(bool training, bool swap, TFile* fp) { + GENERIC_2D_ARRAY float_array; + if (int_mode_) { + if (!wi_.DeSerialize(swap, fp)) return false; + GenericVector old_scales; + if (!old_scales.DeSerialize(swap, fp)) return false; + scales_.init_to_size(old_scales.size(), 0.0); + for (int i = 0; i < old_scales.size(); ++i) scales_[i] = old_scales[i]; + } else { + if (!float_array.DeSerialize(swap, fp)) return false; + FloatToDouble(float_array, &wf_); + } + if (training) { + InitBackward(); + if (!float_array.DeSerialize(swap, fp)) return false; + FloatToDouble(float_array, &updates_); + // Errs was only used in int training, which is now dead. + if (!float_array.DeSerialize(swap, fp)) return false; + } + return true; +} + +// Computes matrix.vector v = Wu. +// u is of size W.dim2() - 1 and the output v is of size W.dim1(). +// u is imagined to have an extra element at the end with value 1, to +// implement the bias, but it doesn't actually have it. +// Asserts that the call matches what we have. +void WeightMatrix::MatrixDotVector(const double* u, double* v) const { + ASSERT_HOST(!int_mode_); + MatrixDotVectorInternal(wf_, true, false, u, v); +} + +void WeightMatrix::MatrixDotVector(const inT8* u, double* v) const { + ASSERT_HOST(int_mode_); + int num_out = wi_.dim1(); + int num_in = wi_.dim2() - 1; + for (int i = 0; i < num_out; ++i) { + const inT8* Wi = wi_[i]; + int total = 0; + if (detector.IsSSEAvailable()) { + total = IntDotProductSSE(u, Wi, num_in); + } else { + for (int j = 0; j < num_in; ++j) total += Wi[j] * u[j]; + } + // Add in the bias and correct for integer values. + v[i] = (static_cast(total) / MAX_INT8 + Wi[num_in]) * scales_[i]; + } +} + +// MatrixDotVector for peep weights, MultiplyAccumulate adds the +// component-wise products of *this[0] and v to inout. +void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) { + ASSERT_HOST(!int_mode_); + ASSERT_HOST(wf_.dim1() == 1); + int n = wf_.dim2(); + const double* u = wf_[0]; + for (int i = 0; i < n; ++i) { + inout[i] += u[i] * v[i]; + } +} + +// Computes vector.matrix v = uW. +// u is of size W.dim1() and the output v is of size W.dim2() - 1. +// The last result is discarded, as v is assumed to have an imaginary +// last value of 1, as with MatrixDotVector. +void WeightMatrix::VectorDotMatrix(const double* u, double* v) const { + ASSERT_HOST(!int_mode_); + MatrixDotVectorInternal(wf_t_, false, true, u, v); +} + +// Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements from +// u and v. In terms of the neural network, u is the gradients and v is the +// inputs. +// Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0. +// Runs parallel if requested. Note that u and v must be transposed. +void WeightMatrix::SumOuterTransposed(const TransposedArray& u, + const TransposedArray& v, + bool in_parallel) { + ASSERT_HOST(!int_mode_); + int num_outputs = dw_.dim1(); + ASSERT_HOST(u.dim1() == num_outputs); + ASSERT_HOST(u.dim2() == v.dim2()); + int num_inputs = dw_.dim2() - 1; + int num_samples = u.dim2(); + // v is missing the last element in dim1. + ASSERT_HOST(v.dim1() == num_inputs); +#ifdef _OPENMP +#pragma omp parallel for num_threads(4) if (in_parallel) +#endif + for (int i = 0; i < num_outputs; ++i) { + double* dwi = dw_[i]; + const double* ui = u[i]; + for (int j = 0; j < num_inputs; ++j) { + dwi[j] = DotProduct(ui, v[j], num_samples); + } + // The last element of v is missing, presumed 1.0f. + double total = 0.0; + for (int k = 0; k < num_samples; ++k) total += ui[k]; + dwi[num_inputs] = total; + } +} + +// Updates the weights using the given learning rate and momentum. +// num_samples is the quotient to be used in the adagrad computation iff +// use_ada_grad_ is true. +void WeightMatrix::Update(double learning_rate, double momentum, + int num_samples) { + ASSERT_HOST(!int_mode_); + if (use_ada_grad_ && num_samples > 0) { + dw_sq_sum_.SumSquares(dw_); + dw_.AdaGradScaling(dw_sq_sum_, num_samples); + } + dw_ *= learning_rate; + updates_ += dw_; + if (momentum > 0.0) wf_ += updates_; + if (momentum >= 0.0) updates_ *= momentum; + wf_t_.Transpose(wf_); +} + +// Adds the dw_ in other to the dw_ is *this. +void WeightMatrix::AddDeltas(const WeightMatrix& other) { + ASSERT_HOST(dw_.dim1() == other.dw_.dim1()); + ASSERT_HOST(dw_.dim2() == other.dw_.dim2()); + dw_ += other.dw_; +} + +// Sums the products of weight updates in *this and other, splitting into +// positive (same direction) in *same and negative (different direction) in +// *changed. +void WeightMatrix::CountAlternators(const WeightMatrix& other, double* same, + double* changed) const { + int num_outputs = updates_.dim1(); + int num_inputs = updates_.dim2(); + ASSERT_HOST(num_outputs == other.updates_.dim1()); + ASSERT_HOST(num_inputs == other.updates_.dim2()); + for (int i = 0; i < num_outputs; ++i) { + const double* this_i = updates_[i]; + const double* other_i = other.updates_[i]; + for (int j = 0; j < num_inputs; ++j) { + double product = this_i[j] * other_i[j]; + if (product < 0.0) + *changed -= product; + else + *same += product; + } + } +} + +// Helper computes an integer histogram bucket for a weight and adds it +// to the histogram. +const int kHistogramBuckets = 16; +static void HistogramWeight(double weight, STATS* histogram) { + int bucket = kHistogramBuckets - 1; + if (weight != 0.0) { + double logval = -log2(fabs(weight)); + bucket = ClipToRange(IntCastRounded(logval), 0, kHistogramBuckets - 1); + } + histogram->add(bucket, 1); +} + +void WeightMatrix::Debug2D(const char* msg) { + STATS histogram(0, kHistogramBuckets); + if (int_mode_) { + for (int i = 0; i < wi_.dim1(); ++i) { + for (int j = 0; j < wi_.dim2(); ++j) { + HistogramWeight(wi_[i][j] * scales_[i], &histogram); + } + } + } else { + for (int i = 0; i < wf_.dim1(); ++i) { + for (int j = 0; j < wf_.dim2(); ++j) { + HistogramWeight(wf_[i][j], &histogram); + } + } + } + tprintf("%s\n", msg); + histogram.print(); +} + +// Computes and returns the dot product of the two n-vectors u and v. +/* static */ +double WeightMatrix::DotProduct(const double* u, const double* v, int n) { + // Note: because the order of addition is different among the 3 DotProduct + // functions, the results can (and do) vary slightly (although they agree + // to within about 4e-15). This produces different results when running + // training, despite all random inputs being precisely equal. + // To get consistent results, use just one of these DotProduct functions. + // On a test multi-layer network, serial is 57% slower than sse, and avx + // is about 8% faster than sse. This suggests that the time is memory + // bandwidth constrained and could benefit from holding the reused vector + // in AVX registers. + if (detector.IsAVXAvailable()) return DotProductAVX(u, v, n); + if (detector.IsSSEAvailable()) return DotProductSSE(u, v, n); + double total = 0.0; + for (int k = 0; k < n; ++k) total += u[k] * v[k]; + return total; +} + +// Utility function converts an array of float to the corresponding array +// of double. +/* static */ +void WeightMatrix::FloatToDouble(const GENERIC_2D_ARRAY& wf, + GENERIC_2D_ARRAY* wd) { + int dim1 = wf.dim1(); + int dim2 = wf.dim2(); + wd->ResizeNoInit(dim1, dim2); + for (int i = 0; i < dim1; ++i) { + const float* wfi = wf[i]; + double* wdi = (*wd)[i]; + for (int j = 0; j < dim2; ++j) wdi[j] = static_cast(wfi[j]); + } +} + +// Computes matrix.vector v = Wu. +// u is of size W.dim2() - add_bias_fwd and the output v is of size +// W.dim1() - skip_bias_back. +// If add_bias_fwd, u is imagined to have an extra element at the end with value +// 1, to implement the bias, weight. +// If skip_bias_back, we are actullay performing the backwards product on a +// transposed matrix, so we need to drop the v output corresponding to the last +// element in dim1. +void WeightMatrix::MatrixDotVectorInternal(const GENERIC_2D_ARRAY& w, + bool add_bias_fwd, + bool skip_bias_back, const double* u, + double* v) { + int num_results = w.dim1() - skip_bias_back; + int extent = w.dim2() - add_bias_fwd; + for (int i = 0; i < num_results; ++i) { + const double* wi = w[i]; + double total = DotProduct(wi, u, extent); + if (add_bias_fwd) total += wi[extent]; // The bias value. + v[i] = total; + } +} + +} // namespace tesseract. + +#undef NONX86_BUILD diff --git a/lstm/weightmatrix.h b/lstm/weightmatrix.h new file mode 100644 index 0000000000..bf533ac043 --- /dev/null +++ b/lstm/weightmatrix.h @@ -0,0 +1,183 @@ +/////////////////////////////////////////////////////////////////////// +// File: weightmatrix.h +// Description: Hides distinction between float/int implementations. +// Author: Ray Smith +// Created: Tue Jun 17 09:05:39 PST 2014 +// +// (C) Copyright 2014, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_ +#define TESSERACT_LSTM_WEIGHTMATRIX_H_ + +#include "genericvector.h" +#include "matrix.h" +#include "tprintf.h" + +namespace tesseract { + +// Convenience instantiation of GENERIC_2D_ARRAY with additional +// operations to write a strided vector, so the transposed form of the input +// is memory-contiguous. +class TransposedArray : public GENERIC_2D_ARRAY { + public: + // Copies the whole input transposed, converted to double, into *this. + void Transpose(const GENERIC_2D_ARRAY& input); + // Writes a vector of data representing a timestep (gradients or sources). + // The data is assumed to be of size1 in size (the strided dimension). + void WriteStrided(int t, const float* data) { + int size1 = dim1(); + for (int i = 0; i < size1; ++i) put(i, t, data[i]); + } + void WriteStrided(int t, const double* data) { + int size1 = dim1(); + for (int i = 0; i < size1; ++i) put(i, t, data[i]); + } + // Prints the first and last num elements of the un-transposed array. + void PrintUnTransposed(int num) { + int num_features = dim1(); + int width = dim2(); + for (int y = 0; y < num_features; ++y) { + for (int t = 0; t < width; ++t) { + if (num == 0 || t < num || t + num >= width) { + tprintf(" %g", (*this)(y, t)); + } + } + tprintf("\n"); + } + } +}; // class TransposedArray + +// Generic weight matrix for network layers. Can store the matrix as either +// an array of floats or inT8. Provides functions to compute the forward and +// backward steps with the matrix and updates to the weights. +class WeightMatrix { + public: + WeightMatrix() : int_mode_(false), use_ada_grad_(false) {} + // Sets up the network for training. Initializes weights using weights of + // scale `range` picked according to the random number generator `randomizer`. + // Note the order is outputs, inputs, as this is the order of indices to + // the matrix, so the adjacent elements are multiplied by the input during + // a forward operation. + int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, + TRand* randomizer); + + // Converts a float network to an int network. Each set of input weights that + // corresponds to a single output weight is converted independently: + // Compute the max absolute value of the weight set. + // Scale so the max absolute value becomes MAX_INT8. + // Round to integer. + // Store a multiplicative scale factor (as a float) that will reproduce + // the original value, subject to rounding errors. + void ConvertToInt(); + + // Accessors. + bool is_int_mode() const { + return int_mode_; + } + int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); } + // Provides one set of weights. Only used by peep weight maxpool. + const double* GetWeights(int index) const { return wf_[index]; } + // Provides access to the deltas (dw_). + double GetDW(int i, int j) const { return dw_(i, j); } + + // Allocates any needed memory for running Backward, and zeroes the deltas, + // thus eliminating any existing momentum. + void InitBackward(); + + // Writes to the given file. Returns false in case of error. + bool Serialize(bool training, TFile* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool training, bool swap, TFile* fp); + // As DeSerialize, but reads an old (float) format WeightMatrix for + // backward compatability. + bool DeSerializeOld(bool training, bool swap, TFile* fp); + + // Computes matrix.vector v = Wu. + // u is of size W.dim2() - 1 and the output v is of size W.dim1(). + // u is imagined to have an extra element at the end with value 1, to + // implement the bias, but it doesn't actually have it. + // Asserts that the call matches what we have. + void MatrixDotVector(const double* u, double* v) const; + void MatrixDotVector(const inT8* u, double* v) const; + // MatrixDotVector for peep weights, MultiplyAccumulate adds the + // component-wise products of *this[0] and v to inout. + void MultiplyAccumulate(const double* v, double* inout); + // Computes vector.matrix v = uW. + // u is of size W.dim1() and the output v is of size W.dim2() - 1. + // The last result is discarded, as v is assumed to have an imaginary + // last value of 1, as with MatrixDotVector. + void VectorDotMatrix(const double* u, double* v) const; + // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements + // from u and v, starting with u[i][offset] and v[j][offset]. + // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0. + // Runs parallel if requested. Note that inputs must be transposed. + void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v, + bool parallel); + // Updates the weights using the given learning rate and momentum. + // num_samples is the quotient to be used in the adagrad computation iff + // use_ada_grad_ is true. + void Update(double learning_rate, double momentum, int num_samples); + // Adds the dw_ in other to the dw_ is *this. + void AddDeltas(const WeightMatrix& other); + // Sums the products of weight updates in *this and other, splitting into + // positive (same direction) in *same and negative (different direction) in + // *changed. + void CountAlternators(const WeightMatrix& other, double* same, + double* changed) const; + + void Debug2D(const char* msg); + + // Computes and returns the dot product of the two n-vectors u and v. + static double DotProduct(const double* u, const double* v, int n); + // Utility function converts an array of float to the corresponding array + // of double. + static void FloatToDouble(const GENERIC_2D_ARRAY& wf, + GENERIC_2D_ARRAY* wd); + + private: + // Computes matrix.vector v = Wu. + // u is of size starts.back()+extents.back() and the output v is of size + // starts.size(). + // The weight matrix w, is of size starts.size()xMAX(extents)+add_bias_fwd. + // If add_bias_fwd, an extra element at the end of w[i] is the bias weight + // and is added to v[i]. + static void MatrixDotVectorInternal(const GENERIC_2D_ARRAY& w, + bool add_bias_fwd, bool skip_bias_back, + const double* u, double* v); + + private: + // Choice between float and 8 bit int implementations. + GENERIC_2D_ARRAY wf_; + GENERIC_2D_ARRAY wi_; + // Transposed copy of wf_, used only for Backward, and set with each Update. + TransposedArray wf_t_; + // Which of wf_ and wi_ are we actually using. + bool int_mode_; + // True if we are running adagrad in this weight matrix. + bool use_ada_grad_; + // If we are using wi_, then scales_ is a factor to restore the row product + // with a vector to the correct range. + GenericVector scales_; + // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying + // amount to be added to wf_/wi_. + GENERIC_2D_ARRAY dw_; + GENERIC_2D_ARRAY updates_; + // Iff use_ada_grad_, the sum of squares of dw_. The number of samples is + // given to Update(). Serialized iff use_ada_grad_. + GENERIC_2D_ARRAY dw_sq_sum_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_LSTM_WEIGHTMATRIX_H_ diff --git a/textord/baselinedetect.cpp b/textord/baselinedetect.cpp index a2b0173949..9bbd999e15 100644 --- a/textord/baselinedetect.cpp +++ b/textord/baselinedetect.cpp @@ -850,7 +850,8 @@ void BaselineDetect::ComputeBaselineSplinesAndXheights(const ICOORD& page_tr, Pix* pix_spline = pix_debug_ ? pixConvertTo32(pix_debug_) : NULL; for (int i = 0; i < blocks_.size(); ++i) { BaselineBlock* bl_block = blocks_[i]; - bl_block->PrepareForSplineFitting(page_tr, remove_noise); + if (enable_splines) + bl_block->PrepareForSplineFitting(page_tr, remove_noise); bl_block->FitBaselineSplines(enable_splines, show_final_rows, textord); if (pix_spline) { bl_block->DrawPixSpline(pix_spline); diff --git a/textord/colpartition.cpp b/textord/colpartition.cpp index e8e8502e86..0d0b4ca39e 100644 --- a/textord/colpartition.cpp +++ b/textord/colpartition.cpp @@ -1632,6 +1632,10 @@ TO_BLOCK* ColPartition::MakeBlock(const ICOORD& bleft, const ICOORD& tright, ColPartition_LIST* used_parts) { if (block_parts->empty()) return NULL; // Nothing to do. + // If the block_parts are not in reading order, then it will make an invalid + // block polygon and bounding_box, so sort by bounding box now just to make + // sure. + block_parts->sort(&ColPartition::SortByBBox); ColPartition_IT it(block_parts); ColPartition* part = it.data(); PolyBlockType type = part->type(); diff --git a/textord/colpartition.h b/textord/colpartition.h index 5c941cce15..7fcbc0004e 100644 --- a/textord/colpartition.h +++ b/textord/colpartition.h @@ -704,6 +704,25 @@ class ColPartition : public ELIST2_LINK { // doing a SideSearch when you want things in the same page column. bool IsInSameColumnAs(const ColPartition& part) const; + // Sort function to sort by bounding box. + static int SortByBBox(const void* p1, const void* p2) { + const ColPartition* part1 = + *reinterpret_cast(p1); + const ColPartition* part2 = + *reinterpret_cast(p2); + int mid_y1 = part1->bounding_box_.y_middle(); + int mid_y2 = part2->bounding_box_.y_middle(); + if ((part2->bounding_box_.bottom() <= mid_y1 && + mid_y1 <= part2->bounding_box_.top()) || + (part1->bounding_box_.bottom() <= mid_y2 && + mid_y2 <= part1->bounding_box_.top())) { + // Sort by increasing x. + return part1->bounding_box_.x_middle() - part2->bounding_box_.x_middle(); + } + // Sort by decreasing y. + return mid_y2 - mid_y1; + } + // Sets the column bounds. Primarily used in testing. void set_first_column(int column) { first_column_ = column; diff --git a/textord/tordmain.cpp b/textord/tordmain.cpp index 0c433a1f27..0eaf843ec3 100644 --- a/textord/tordmain.cpp +++ b/textord/tordmain.cpp @@ -251,6 +251,7 @@ void Textord::filter_blobs(ICOORD page_tr, // top right &block->noise_blobs, &block->small_blobs, &block->large_blobs); + if (block->line_size == 0) block->line_size = 1; block->line_spacing = block->line_size * (tesseract::CCStruct::kDescenderFraction + tesseract::CCStruct::kXHeightFraction + @@ -769,6 +770,7 @@ void Textord::TransferDiacriticsToBlockGroups(BLOBNBOX_LIST* diacritic_blobs, PointerVector word_ptrs; for (int g = 0; g < groups.size(); ++g) { const BlockGroup* group = groups[g]; + if (group->bounding_box.null_box()) continue; WordGrid word_grid(group->min_xheight, group->bounding_box.botleft(), group->bounding_box.topright()); for (int b = 0; b < group->blocks.size(); ++b) { diff --git a/textord/tospace.cpp b/textord/tospace.cpp index afec42f484..ef1f037b23 100644 --- a/textord/tospace.cpp +++ b/textord/tospace.cpp @@ -1323,9 +1323,10 @@ BOOL8 Textord::make_a_word_break( we may need to set PARTICULAR spaces to fuzzy or not. The values will ONLY be used if the function returns TRUE - ie the word is to be broken. */ - blanks = (uinT8) (current_gap / row->space_size); - if (blanks < 1) - blanks = 1; + int num_blanks = current_gap; + if (row->space_size > 1.0f) + num_blanks = IntCastRounded(current_gap / row->space_size); + blanks = static_cast(ClipToRange(num_blanks, 1, MAX_UINT8)); fuzzy_sp = FALSE; fuzzy_non = FALSE; /* diff --git a/training/Makefile.am b/training/Makefile.am index fe3d85bcdc..da505be724 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -3,6 +3,7 @@ AM_CPPFLAGS += \ -DUSE_STD_NAMESPACE -DPANGO_ENABLE_ENGINE\ -I$(top_srcdir)/ccmain -I$(top_srcdir)/api \ -I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \ + -I$(top_srcdir)/lstm -I$(top_srcdir)/arch \ -I$(top_srcdir)/viewer \ -I$(top_srcdir)/textord -I$(top_srcdir)/dict \ -I$(top_srcdir)/classify -I$(top_srcdir)/display \ @@ -45,7 +46,7 @@ libtesseract_tessopt_la_SOURCES = \ tessopt.cpp bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \ - dawg2wordlist mftraining set_unicharset_properties shapeclustering \ + dawg2wordlist lstmtraining mftraining set_unicharset_properties shapeclustering \ text2image unicharset_extractor wordlist2dawg ambiguous_words_SOURCES = ambiguous_words.cpp @@ -58,6 +59,9 @@ ambiguous_words_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -82,6 +86,9 @@ classifier_tester_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -115,6 +122,9 @@ cntraining_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -136,6 +146,9 @@ if USING_MULTIPLELIBS dawg2wordlist_LDADD += \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -150,6 +163,33 @@ dawg2wordlist_LDADD += \ ../api/libtesseract.la endif +lstmtraining_SOURCES = lstmtraining.cpp +#lstmtraining_LDFLAGS = -static +lstmtraining_LDADD = \ + libtesseract_training.la \ + libtesseract_tessopt.la \ + $(libicu) +if USING_MULTIPLELIBS +lstmtraining_LDADD += \ + ../textord/libtesseract_textord.la \ + ../classify/libtesseract_classify.la \ + ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ + ../ccstruct/libtesseract_ccstruct.la \ + ../cutil/libtesseract_cutil.la \ + ../viewer/libtesseract_viewer.la \ + ../ccmain/libtesseract_main.la \ + ../cube/libtesseract_cube.la \ + ../neural_networks/runtime/libtesseract_neural.la \ + ../wordrec/libtesseract_wordrec.la \ + ../ccutil/libtesseract_ccutil.la +else +lstmtraining_LDADD += \ + ../api/libtesseract.la +endif + mftraining_SOURCES = mftraining.cpp mergenf.cpp #mftraining_LDFLAGS = -static mftraining_LDADD = \ @@ -160,6 +200,9 @@ mftraining_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -185,6 +228,9 @@ set_unicharset_properties_LDADD += \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ ../ccstruct/libtesseract_ccstruct.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ ../ccmain/libtesseract_main.la \ @@ -207,6 +253,9 @@ shapeclustering_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -230,6 +279,9 @@ text2image_LDADD += \ ../textord/libtesseract_textord.la \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ @@ -266,6 +318,9 @@ if USING_MULTIPLELIBS wordlist2dawg_LDADD += \ ../classify/libtesseract_classify.la \ ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ ../ccstruct/libtesseract_ccstruct.la \ ../cutil/libtesseract_cutil.la \ ../viewer/libtesseract_viewer.la \ diff --git a/training/degradeimage.cpp b/training/degradeimage.cpp index f9c3cfb048..333f3703dc 100644 --- a/training/degradeimage.cpp +++ b/training/degradeimage.cpp @@ -22,10 +22,36 @@ #include #include "allheaders.h" // from leptonica +#include "genericvector.h" #include "helpers.h" // For TRand. +#include "rect.h" namespace tesseract { +// A randomized perspective distortion can be applied to synthetic input. +// The perspective distortion comes from leptonica, which uses 2 sets of 4 +// corners to determine the distortion. There are random values for each of +// the x numbers x0..x3 and y0..y3, except for x2 and x3 which are instead +// defined in terms of a single shear value. This reduces the degrees of +// freedom enough to make the distortion more realistic than it would otherwise +// be if all 8 coordinates could move independently. +// One additional factor is used for the color of the pixels that don't exist +// in the source image. +// Name for each of the randomizing factors. +enum FactorNames { + FN_INCOLOR, + FN_Y0, + FN_Y1, + FN_Y2, + FN_Y3, + FN_X0, + FN_X1, + FN_SHEAR, + // x2 = x1 - shear + // x3 = x0 + shear + FN_NUM_FACTORS +}; + // Rotation is +/- kRotationRange radians. const float kRotationRange = 0.02f; // Number of grey levels to shift by for each exposure step. @@ -144,4 +170,141 @@ Pix* DegradeImage(Pix* input, int exposure, TRand* randomizer, return input; } +// Creates and returns a Pix distorted by various means according to the bool +// flags. If boxes is not NULL, the boxes are resized/positioned according to +// any spatial distortion and also by the integer reduction factor box_scale +// so they will match what the network will output. +// Returns NULL on error. The returned Pix must be pixDestroyed. +Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert, + bool white_noise, bool smooth_noise, bool blur, + int box_reduction, TRand* randomizer, + GenericVector* boxes) { + Pix* distorted = pixCopy(NULL, const_cast(pix)); + // Things to do to synthetic training data. + if (invert && randomizer->SignedRand(1.0) < 0) + pixInvert(distorted, distorted); + if ((white_noise || smooth_noise) && randomizer->SignedRand(1.0) > 0.0) { + // TODO(rays) Cook noise in a more thread-safe manner than rand(). + // Attempt to make the sequences reproducible. + srand(randomizer->IntRand()); + Pix* pixn = pixAddGaussianNoise(distorted, 8.0); + pixDestroy(&distorted); + if (smooth_noise) { + distorted = pixBlockconv(pixn, 1, 1); + pixDestroy(&pixn); + } else { + distorted = pixn; + } + } + if (blur && randomizer->SignedRand(1.0) > 0.0) { + Pix* blurred = pixBlockconv(distorted, 1, 1); + pixDestroy(&distorted); + distorted = blurred; + } + if (perspective) + GeneratePerspectiveDistortion(0, 0, randomizer, &distorted, boxes); + if (boxes != NULL) { + for (int b = 0; b < boxes->size(); ++b) { + (*boxes)[b].scale(1.0f / box_reduction); + if ((*boxes)[b].width() <= 0) + (*boxes)[b].set_right((*boxes)[b].left() + 1); + } + } + return distorted; +} + +// Distorts anything that has a non-null pointer with the same pseudo-random +// perspective distortion. Width and height only need to be set if there +// is no pix. If there is a pix, then they will be taken from there. +void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer, + Pix** pix, GenericVector* boxes) { + if (pix != NULL && *pix != NULL) { + width = pixGetWidth(*pix); + height = pixGetHeight(*pix); + } + float* im_coeffs = NULL; + float* box_coeffs = NULL; + l_int32 incolor = + ProjectiveCoeffs(width, height, randomizer, &im_coeffs, &box_coeffs); + if (pix != NULL && *pix != NULL) { + // Transform the image. + Pix* transformed = pixProjective(*pix, im_coeffs, incolor); + if (transformed == NULL) { + tprintf("Projective transformation failed!!\n"); + return; + } + pixDestroy(pix); + *pix = transformed; + } + if (boxes != NULL) { + // Transform the boxes. + for (int b = 0; b < boxes->size(); ++b) { + int x1, y1, x2, y2; + const TBOX& box = (*boxes)[b]; + projectiveXformSampledPt(box_coeffs, box.left(), height - box.top(), &x1, + &y1); + projectiveXformSampledPt(box_coeffs, box.right(), height - box.bottom(), + &x2, &y2); + TBOX new_box1(x1, height - y2, x2, height - y1); + projectiveXformSampledPt(box_coeffs, box.left(), height - box.bottom(), + &x1, &y1); + projectiveXformSampledPt(box_coeffs, box.right(), height - box.top(), &x2, + &y2); + TBOX new_box2(x1, height - y1, x2, height - y2); + (*boxes)[b] = new_box1.bounding_union(new_box2); + } + } + free(im_coeffs); + free(box_coeffs); +} + +// Computes the coefficients of a randomized projective transformation. +// The image transform requires backward transformation coefficient, and the +// box transform the forward coefficients. +// Returns the incolor arg to pixProjective. +int ProjectiveCoeffs(int width, int height, TRand* randomizer, + float** im_coeffs, float** box_coeffs) { + // Setup "from" points. + Pta* src_pts = ptaCreate(4); + ptaAddPt(src_pts, 0.0f, 0.0f); + ptaAddPt(src_pts, width, 0.0f); + ptaAddPt(src_pts, width, height); + ptaAddPt(src_pts, 0.0f, height); + // Extract factors from pseudo-random sequence. + float factors[FN_NUM_FACTORS]; + float shear = 0.0f; // Shear is signed. + for (int i = 0; i < FN_NUM_FACTORS; ++i) { + // Everything is squared to make wild values rarer. + if (i == FN_SHEAR) { + // Shear is signed. + shear = randomizer->SignedRand(0.5 / 3.0); + shear = shear >= 0.0 ? shear * shear : -shear * shear; + // Keep the sheared points within the original rectangle. + if (shear < -factors[FN_X0]) shear = -factors[FN_X0]; + if (shear > factors[FN_X1]) shear = factors[FN_X1]; + factors[i] = shear; + } else if (i != FN_INCOLOR) { + factors[i] = fabs(randomizer->SignedRand(1.0)); + if (i <= FN_Y3) + factors[i] *= 5.0 / 8.0; + else + factors[i] *= 0.5; + factors[i] *= factors[i]; + } + } + // Setup "to" points. + Pta* dest_pts = ptaCreate(4); + ptaAddPt(dest_pts, factors[FN_X0] * width, factors[FN_Y0] * height); + ptaAddPt(dest_pts, (1.0f - factors[FN_X1]) * width, factors[FN_Y1] * height); + ptaAddPt(dest_pts, (1.0f - factors[FN_X1] + shear) * width, + (1 - factors[FN_Y2]) * height); + ptaAddPt(dest_pts, (factors[FN_X0] + shear) * width, + (1 - factors[FN_Y3]) * height); + getProjectiveXformCoeffs(dest_pts, src_pts, im_coeffs); + getProjectiveXformCoeffs(src_pts, dest_pts, box_coeffs); + ptaDestroy(&src_pts); + ptaDestroy(&dest_pts); + return factors[FN_INCOLOR] > 0.5f ? L_BRING_IN_WHITE : L_BRING_IN_BLACK; +} + } // namespace tesseract diff --git a/training/degradeimage.h b/training/degradeimage.h index 2add6282f8..a7af9565ff 100644 --- a/training/degradeimage.h +++ b/training/degradeimage.h @@ -20,12 +20,13 @@ #ifndef TESSERACT_TRAINING_DEGRADEIMAGE_H_ #define TESSERACT_TRAINING_DEGRADEIMAGE_H_ -struct Pix; +#include "allheaders.h" +#include "genericvector.h" +#include "helpers.h" // For TRand. +#include "rect.h" namespace tesseract { -class TRand; - // Degrade the pix as if by a print/copy/scan cycle with exposure > 0 // corresponding to darkening on the copier and <0 lighter and 0 not copied. // If rotation is not NULL, the clockwise rotation in radians is saved there. @@ -34,6 +35,27 @@ class TRand; struct Pix* DegradeImage(struct Pix* input, int exposure, TRand* randomizer, float* rotation); +// Creates and returns a Pix distorted by various means according to the bool +// flags. If boxes is not NULL, the boxes are resized/positioned according to +// any spatial distortion and also by the integer reduction factor box_scale +// so they will match what the network will output. +// Returns NULL on error. The returned Pix must be pixDestroyed. +Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert, + bool white_noise, bool smooth_noise, bool blur, + int box_reduction, TRand* randomizer, + GenericVector* boxes); +// Distorts anything that has a non-null pointer with the same pseudo-random +// perspective distortion. Width and height only need to be set if there +// is no pix. If there is a pix, then they will be taken from there. +void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer, + Pix** pix, GenericVector* boxes); +// Computes the coefficients of a randomized projective transformation. +// The image transform requires backward transformation coefficient, and the +// box transform the forward coefficients. +// Returns the incolor arg to pixProjective. +int ProjectiveCoeffs(int width, int height, TRand* randomizer, + float** im_coeffs, float** box_coeffs); + } // namespace tesseract #endif // TESSERACT_TRAINING_DEGRADEIMAGE_H_ diff --git a/training/language-specific.sh b/training/language-specific.sh index a62f1e3cf3..15ccad47cb 100755 --- a/training/language-specific.sh +++ b/training/language-specific.sh @@ -868,6 +868,9 @@ set_lang_specific_parameters() { AMBIGS_FILTER_DENOMINATOR="100000" LEADING="32" MEAN_COUNT="40" # Default for latin script. + # Language to mix with the language for maximum accuracy. Defaults to eng. + # If no language is good, set to the base language. + MIX_LANG="eng" case ${lang} in # Latin languages. @@ -959,11 +962,13 @@ set_lang_specific_parameters() { WORD_DAWG_SIZE=1000000 test -z "$FONTS" && FONTS=( "${EARLY_LATIN_FONTS[@]}" );; - # Cyrillic script-based languages. + # Cyrillic script-based languages. It is bad to mix Latin with Cyrillic. rus ) test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" ) + MIX_LANG="rus" NUMBER_DAWG_FACTOR=0.05 WORD_DAWG_SIZE=1000000 ;; aze_cyrl | bel | bul | kaz | mkd | srp | tgk | ukr | uzb_cyrl ) + MIX_LANG="${lang}" test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" ) ;; # Special code for performing Cyrillic language-id that is trained on diff --git a/training/lstmtraining.cpp b/training/lstmtraining.cpp new file mode 100644 index 0000000000..f4d46cf9c4 --- /dev/null +++ b/training/lstmtraining.cpp @@ -0,0 +1,185 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmtraining.cpp +// Description: Training program for LSTM-based networks. +// Author: Ray Smith +// Created: Fri May 03 11:05:06 PST 2013 +// +// (C) Copyright 2013, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef USE_STD_NAMESPACE +#include "base/commandlineflags.h" +#endif +#include "commontraining.h" +#include "lstmtrainer.h" +#include "params.h" +#include "strngs.h" +#include "tprintf.h" +#include "unicharset_training_utils.h" + +INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment."); +STRING_PARAM_FLAG(net_spec, "[I1,48Lt1,100O]", "Network specification"); +INT_PARAM_FLAG(train_mode, 64, "Controls gross training behavior."); +INT_PARAM_FLAG(net_mode, 192, "Controls network behavior."); +INT_PARAM_FLAG(perfect_sample_delay, 4, + "How many imperfect samples between perfect ones."); +DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent."); +DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights."); +DOUBLE_PARAM_FLAG(learning_rate, 1.0e-4, "Weight factor for new deltas."); +DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas."); +INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images."); +STRING_PARAM_FLAG(continue_from, "", "Existing model to extend"); +STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models"); +STRING_PARAM_FLAG(script_dir, "", + "Required to set unicharset properties or" + " use unicharset compression."); +BOOL_PARAM_FLAG(stop_training, false, + "Just convert the training model to a runtime model."); +INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to" + " attach the new network defined by net_spec"); +BOOL_PARAM_FLAG(debug_network, false, + "Get info on distribution of weight values"); +INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations"); +DECLARE_STRING_PARAM_FLAG(U); + +// Number of training images to train between calls to MaintainCheckpoints. +const int kNumPagesPerBatch = 100; + +// Apart from command-line flags, input is a collection of lstmf files, that +// were previously created using tesseract with the lstm.train config file. +// The program iterates over the inputs, feeding the data to the network, +// until the error rate reaches a specified target or max_iterations is reached. +int main(int argc, char **argv) { + ParseArguments(&argc, &argv); + // Purify the model name in case it is based on the network string. + if (FLAGS_model_output.empty()) { + tprintf("Must provide a --model_output!\n"); + return 1; + } + STRING model_output = FLAGS_model_output.c_str(); + for (int i = 0; i < model_output.length(); ++i) { + if (model_output[i] == '[' || model_output[i] == ']') + model_output[i] = '-'; + if (model_output[i] == '(' || model_output[i] == ')') + model_output[i] = '_'; + } + // Setup the trainer. + STRING checkpoint_file = FLAGS_model_output.c_str(); + checkpoint_file += "_checkpoint"; + STRING checkpoint_bak = checkpoint_file + ".bak"; + tesseract::LSTMTrainer trainer( + NULL, NULL, NULL, NULL, FLAGS_model_output.c_str(), + checkpoint_file.c_str(), FLAGS_debug_interval, + static_cast(FLAGS_max_image_MB) * 1048576); + + // Reading something from an existing model doesn't require many flags, + // so do it now and exit. + if (FLAGS_stop_training || FLAGS_debug_network) { + if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) { + tprintf("Failed to read continue from: %s\n", + FLAGS_continue_from.c_str()); + return 1; + } + if (FLAGS_debug_network) { + trainer.DebugNetwork(); + } else { + if (FLAGS_train_mode & tesseract::TF_INT_MODE) + trainer.ConvertToInt(); + GenericVector recognizer_data; + trainer.SaveRecognitionDump(&recognizer_data); + if (!tesseract::SaveDataToFile(recognizer_data, + FLAGS_model_output.c_str())) { + tprintf("Failed to write recognition model : %s\n", + FLAGS_model_output.c_str()); + } + } + return 0; + } + + // Get the list of files to process. + GenericVector filenames; + for (int arg = 1; arg < argc; ++arg) { + filenames.push_back(STRING(argv[arg])); + } + + UNICHARSET unicharset; + // Checkpoints always take priority if they are available. + if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) || + trainer.TryLoadingCheckpoint(checkpoint_bak.string())) { + tprintf("Successfully restored trainer from %s\n", + checkpoint_file.string()); + } else { + if (!FLAGS_continue_from.empty()) { + // Load a past model file to improve upon. + if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) { + tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str()); + return 1; + } + tprintf("Continuing from %s\n", FLAGS_continue_from.c_str()); + } + if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) { + // We need a unicharset to start from scratch or append. + string unicharset_str; + // Character coding to be used by the classifier. + if (!unicharset.load_from_file(FLAGS_U.c_str())) { + tprintf("Error: must provide a -U unicharset!\n"); + return 1; + } + tesseract::SetupBasicProperties(true, &unicharset); + if (FLAGS_append_index >= 0) { + tprintf("Appending a new network to an old one!!"); + if (FLAGS_continue_from.empty()) { + tprintf("Must set --continue_from for appending!\n"); + return 1; + } + } + // We are initializing from scratch. + trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(), + FLAGS_train_mode); + if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, + FLAGS_net_mode, FLAGS_weight_range, + FLAGS_learning_rate, FLAGS_momentum)) { + tprintf("Failed to create network from spec: %s\n", + FLAGS_net_spec.c_str()); + return 1; + } + trainer.set_perfect_delay(FLAGS_perfect_sample_delay); + } + } + if (!trainer.LoadAllTrainingData(filenames)) { + tprintf("Load of images failed!!\n"); + return 1; + } + + bool best_dumped = true; + char* best_model_dump = NULL; + size_t best_model_size = 0; + STRING best_model_name; + do { + // Train a few. + int iteration = trainer.training_iteration(); + for (int target_iteration = iteration + kNumPagesPerBatch; + iteration < target_iteration; + iteration = trainer.training_iteration()) { + trainer.TrainOnLine(&trainer, false); + } + STRING log_str; + trainer.MaintainCheckpoints(NULL, &log_str); + tprintf("%s\n", log_str.string()); + } while (trainer.best_error_rate() > FLAGS_target_error_rate && + (trainer.training_iteration() < FLAGS_max_iterations || + FLAGS_max_iterations == 0)); + tprintf("Finished! Error rate = %g\n", trainer.best_error_rate()); + return 0; +} /* main */ + + diff --git a/training/merge_unicharsets.cpp b/training/merge_unicharsets.cpp new file mode 100644 index 0000000000..60adf198b2 --- /dev/null +++ b/training/merge_unicharsets.cpp @@ -0,0 +1,52 @@ +/////////////////////////////////////////////////////////////////////// +// File: merge_unicharsets.cpp +// Description: Simple tool to merge two or more unicharsets. +// Author: Ray Smith +// Created: Wed Sep 30 16:09:01 PDT 2015 +// +// (C) Copyright 2015, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include +#include "unicharset.h" + +int main(int argc, char** argv) { + // Print usage + if (argc < 4) { + printf("Usage: %s unicharset-in-1 ... unicharset-in-n unicharset-out\n", + argv[0]); + exit(1); + } + + UNICHARSET input_unicharset, result_unicharset; + for (int arg = 1; arg < argc - 1; ++arg) { + // Load the input unicharset + if (input_unicharset.load_from_file(argv[arg])) { + printf("Loaded unicharset of size %d from file %s\n", + input_unicharset.size(), argv[arg]); + result_unicharset.AppendOtherUnicharset(input_unicharset); + } else { + printf("Failed to load unicharset from file %s!!\n", argv[arg]); + exit(1); + } + } + + // Save the combined unicharset. + if (result_unicharset.save_to_file(argv[argc - 1])) { + printf("Wrote unicharset file %s.\n", argv[argc - 1]); + } else { + printf("Cannot save unicharset file %s.\n", argv[argc - 1]); + exit(1); + } + return 0; +} diff --git a/training/mftraining.cpp b/training/mftraining.cpp index eabcbf32db..9e2e250927 100644 --- a/training/mftraining.cpp +++ b/training/mftraining.cpp @@ -302,6 +302,9 @@ int main (int argc, char **argv) { *shape_table, float_classes, inttemp_file.string(), pffmtable_file.string()); + for (int c = 0; c < unicharset->size(); ++c) { + FreeClassFields(&float_classes[c]); + } delete [] float_classes; FreeLabeledClassList(mf_classes); delete trainer; diff --git a/training/normstrngs.cpp b/training/normstrngs.cpp index acffeee13d..e7cac21f4b 100644 --- a/training/normstrngs.cpp +++ b/training/normstrngs.cpp @@ -113,12 +113,12 @@ bool is_double_quote(const char32 ch) { return false; } -STRING NormalizeUTF8String(const char* str8) { +STRING NormalizeUTF8String(bool decompose, const char* str8) { GenericVector str32, out_str32, norm_str; UTF8ToUTF32(str8, &str32); for (int i = 0; i < str32.length(); ++i) { norm_str.clear(); - NormalizeChar32(str32[i], &norm_str); + NormalizeChar32(str32[i], decompose, &norm_str); for (int j = 0; j < norm_str.length(); ++j) { out_str32.push_back(norm_str[j]); } @@ -128,10 +128,10 @@ STRING NormalizeUTF8String(const char* str8) { return out_str8; } -void NormalizeChar32(char32 ch, GenericVector* str) { +void NormalizeChar32(char32 ch, bool decompose, GenericVector* str) { IcuErrorCode error_code; const icu::Normalizer2* nfkc = icu::Normalizer2::getInstance( - NULL, "nfkc", UNORM2_COMPOSE, error_code); + NULL, "nfkc", decompose ? UNORM2_DECOMPOSE : UNORM2_COMPOSE, error_code); error_code.assertSuccess(); error_code.reset(); diff --git a/training/normstrngs.h b/training/normstrngs.h index 71e7b8da08..6fca3193ab 100644 --- a/training/normstrngs.h +++ b/training/normstrngs.h @@ -39,11 +39,16 @@ void UTF32ToUTF8(const GenericVector& str32, STRING* utf8_str); // assumption of this function is that the input is already as fully composed // as it can be, but may require some compatibility normalizations or just // OCR evaluation related normalizations. -void NormalizeChar32(char32 ch, GenericVector* str); +void NormalizeChar32(char32 ch, bool decompose, GenericVector* str); // Normalize a UTF8 string. Same as above, but for UTF8-encoded strings, that // can contain multiple UTF32 code points. -STRING NormalizeUTF8String(const char* str8); +STRING NormalizeUTF8String(bool decompose, const char* str8); +// Default behavior is to compose, until it is proven that decomposed benefits +// at least one language. +inline STRING NormalizeUTF8String(const char* str8) { + return NormalizeUTF8String(false, str8); +} // Apply just the OCR-specific normalizations and return the normalized char. char32 OCRNormalize(char32 ch); diff --git a/training/unicharset_training_utils.cpp b/training/unicharset_training_utils.cpp index 10aaf0e6c3..efa3a22cd5 100644 --- a/training/unicharset_training_utils.cpp +++ b/training/unicharset_training_utils.cpp @@ -37,7 +37,8 @@ namespace tesseract { // Helper sets the character attribute properties and sets up the script table. // Does not set tops and bottoms. -void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset) { +void SetupBasicProperties(bool report_errors, bool decompose, + UNICHARSET* unicharset) { for (int unichar_id = 0; unichar_id < unicharset->size(); ++unichar_id) { // Convert any custom ligatures. const char* unichar_str = unicharset->id_to_unichar(unichar_id); @@ -129,7 +130,7 @@ void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset) { } // Record normalized version of this unichar. - STRING normed_str = tesseract::NormalizeUTF8String(unichar_str); + STRING normed_str = tesseract::NormalizeUTF8String(decompose, unichar_str); if (unichar_id != 0 && normed_str.length() > 0) { unicharset->set_normed(unichar_id, normed_str.c_str()); } else { @@ -158,7 +159,7 @@ void SetPropertiesForInputFile(const string& script_dir, // Set unichar properties tprintf("Setting unichar properties\n"); - SetupBasicProperties(true, &unicharset); + SetupBasicProperties(true, false, &unicharset); string xheights_str; for (int s = 0; s < unicharset.get_script_table_size(); ++s) { // Load the unicharset for the script if available. diff --git a/training/unicharset_training_utils.h b/training/unicharset_training_utils.h index ff2262875d..f03e12ace4 100644 --- a/training/unicharset_training_utils.h +++ b/training/unicharset_training_utils.h @@ -33,7 +33,13 @@ namespace tesseract { // Helper sets the character attribute properties and sets up the script table. // Does not set tops and bottoms. -void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset); +void SetupBasicProperties(bool report_errors, bool decompose, + UNICHARSET* unicharset); +// Default behavior is to compose, until it is proven that decomposed benefits +// at least one language. +inline void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset) { + SetupBasicProperties(report_errors, false, unicharset); +} // Helper to set the properties for an input unicharset file, writes to the // output file. If an appropriate script unicharset can be found in the diff --git a/viewer/svutil.h b/viewer/svutil.h index ccfce917fe..c2f97e6ca4 100644 --- a/viewer/svutil.h +++ b/viewer/svutil.h @@ -102,6 +102,17 @@ class SVMutex { #endif }; +// Auto-unlocking object that locks a mutex on construction and unlocks it +// on destruction. +class SVAutoLock { + public: + explicit SVAutoLock(SVMutex* mutex) : mutex_(mutex) { mutex->Lock(); } + ~SVAutoLock() { mutex_->Unlock(); } + + private: + SVMutex* mutex_; +}; + /// The SVNetwork class takes care of the remote connection for ScrollView /// This means setting up and maintaining a remote connection, sending and /// receiving messages and closing the connection. diff --git a/wordrec/chopper.cpp b/wordrec/chopper.cpp index 69a458bc2c..850cfcabda 100644 --- a/wordrec/chopper.cpp +++ b/wordrec/chopper.cpp @@ -426,7 +426,7 @@ void Wordrec::chop_word_main(WERD_RES *word) { if (word->best_choice == NULL) { // SegSearch found no valid paths, so just use the leading diagonal. - word->FakeWordFromRatings(); + word->FakeWordFromRatings(TOP_CHOICE_PERM); } word->RebuildBestState(); // If we finished without a hyphen at the end of the word, let the next word diff --git a/wordrec/tface.cpp b/wordrec/tface.cpp index e21fcb8829..b1fc1779fb 100644 --- a/wordrec/tface.cpp +++ b/wordrec/tface.cpp @@ -49,7 +49,11 @@ void Wordrec::program_editup(const char *textbase, if (textbase != NULL) imagefile = textbase; InitFeatureDefs(&feature_defs_); InitAdaptiveClassifier(init_classifier); - if (init_dict) getDict().Load(Dict::GlobalDawgCache()); + if (init_dict) { + getDict().SetupForLoad(Dict::GlobalDawgCache()); + getDict().Load(tessdata_manager.GetDataFileName().string(), lang); + getDict().FinishLoad(); + } pass2_ok_split = chop_ok_split; }