forked from tesseract-ocr/tesseract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput.cpp
152 lines (137 loc) · 5.48 KB
/
input.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
///////////////////////////////////////////////////////////////////////
// File: input.cpp
// Description: Input layer class for neural network implementations.
// Author: Ray Smith
// Created: Thu Mar 13 09:10:34 PDT 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "input.h"
#include "allheaders.h"
#include "imagedata.h"
#include "pageres.h"
#include "scrollview.h"
namespace tesseract {
// Max height for variable height inputs before scaling anyway.
const int kMaxInputHeight = 48;
Input::Input(const STRING& name, int ni, int no)
: Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
Input::Input(const STRING& name, const StaticShape& shape)
: Network(NT_INPUT, name, shape.height(), shape.depth()),
shape_(shape),
cached_x_scale_(1) {
if (shape.height() == 1) ni_ = shape.depth();
}
Input::~Input() {
}
// Writes to the given file. Returns false in case of error.
bool Input::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
bool Input::DeSerialize(TFile* fp) {
return fp->FReadEndian(&shape_, sizeof(shape_), 1) == 1;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
int Input::XScaleFactor() const {
return 1;
}
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
void Input::CacheXScaleFactor(int factor) {
cached_x_scale_ = factor;
}
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
void Input::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
*output = input;
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Input::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
tprintf("Input::Backward should not be called!!\n");
return false;
}
// Creates and returns a Pix of appropriate size for the network from the
// image_data. If non-null, *image_scale returns the image scale factor used.
// Returns nullptr on error.
/* static */
Pix* Input::PrepareLSTMInputs(const ImageData& image_data,
const Network* network, int min_width,
TRand* randomizer, float* image_scale) {
// Note that NumInputs() is defined as input image height.
int target_height = network->NumInputs();
int width, height;
Pix* pix = image_data.PreScale(target_height, kMaxInputHeight, image_scale,
&width, &height, nullptr);
if (pix == nullptr) {
tprintf("Bad pix from ImageData!\n");
return nullptr;
}
if (width <= min_width || height < min_width) {
tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
height, min_width);
pixDestroy(&pix);
return nullptr;
}
return pix;
}
// Converts the given pix to a NetworkIO of height and depth appropriate to the
// given StaticShape:
// If depth == 3, convert to 24 bit color, otherwise normalized grey.
// Scale to target height, if the shape's height is > 1, or its depth if the
// height == 1. If height == 0 then no scaling.
// NOTE: It isn't safe for multiple threads to call this on the same pix.
/* static */
void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
TRand* randomizer, NetworkIO* input) {
bool color = shape.depth() == 3;
Pix* var_pix = const_cast<Pix*>(pix);
int depth = pixGetDepth(var_pix);
Pix* normed_pix = nullptr;
// On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
// colormap, so we just have to deal with depth conversion here.
if (color) {
// Force RGB.
if (depth == 32)
normed_pix = pixClone(var_pix);
else
normed_pix = pixConvertTo32(var_pix);
} else {
// Convert non-8-bit images to 8 bit.
if (depth == 8)
normed_pix = pixClone(var_pix);
else
normed_pix = pixConvertTo8(var_pix, false);
}
int height = pixGetHeight(normed_pix);
int target_height = shape.height();
if (target_height == 1) target_height = shape.depth();
if (target_height != 0 && target_height != height) {
// Get the scaled image.
float im_factor = static_cast<float>(target_height) / height;
Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
pixDestroy(&normed_pix);
normed_pix = scaled_pix;
}
input->FromPix(shape, normed_pix, randomizer);
pixDestroy(&normed_pix);
}
} // namespace tesseract.