/src/tesseract/src/lstm/network.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: network.cpp |
3 | | // Description: Base class for neural network implementations. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2013, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | // Include automatically generated configuration file if running autoconf. |
19 | | #ifdef HAVE_CONFIG_H |
20 | | # include "config_auto.h" |
21 | | #endif |
22 | | |
23 | | #include "network.h" |
24 | | |
25 | | #include <cstdlib> |
26 | | |
27 | | // This base class needs to know about all its sub-classes because of the |
28 | | // factory deserializing method: CreateFromFile. |
29 | | #include <allheaders.h> |
30 | | #include "convolve.h" |
31 | | #include "fullyconnected.h" |
32 | | #include "input.h" |
33 | | #include "lstm.h" |
34 | | #include "maxpool.h" |
35 | | #include "parallel.h" |
36 | | #include "reconfig.h" |
37 | | #include "reversed.h" |
38 | | #include "scrollview.h" |
39 | | #include "series.h" |
40 | | #include "statistc.h" |
41 | | #include "tprintf.h" |
42 | | |
43 | | namespace tesseract { |
44 | | |
45 | | #ifndef GRAPHICS_DISABLED |
46 | | |
47 | | // Min and max window sizes. |
48 | | const int kMinWinSize = 500; |
49 | | const int kMaxWinSize = 2000; |
50 | | // Window frame sizes need adding on to make the content fit. |
51 | | const int kXWinFrameSize = 30; |
52 | | const int kYWinFrameSize = 80; |
53 | | |
54 | | #endif // !GRAPHICS_DISABLED |
55 | | |
56 | | // String names corresponding to the NetworkType enum. |
57 | | // Keep in sync with NetworkType. |
58 | | // Names used in Serialization to allow re-ordering/addition/deletion of |
59 | | // layer types in NetworkType without invalidating existing network files. |
60 | | static char const *const kTypeNames[NT_COUNT] = { |
61 | | "Invalid", "Input", |
62 | | "Convolve", "Maxpool", |
63 | | "Parallel", "Replicated", |
64 | | "ParBidiLSTM", "DepParUDLSTM", |
65 | | "Par2dLSTM", "Series", |
66 | | "Reconfig", "RTLReversed", |
67 | | "TTBReversed", "XYTranspose", |
68 | | "LSTM", "SummLSTM", |
69 | | "Logistic", "LinLogistic", |
70 | | "LinTanh", "Tanh", |
71 | | "Relu", "Linear", |
72 | | "Softmax", "SoftmaxNoCTC", |
73 | | "LSTMSoftmax", "LSTMBinarySoftmax", |
74 | | "TensorFlow", |
75 | | }; |
76 | | |
77 | | Network::Network() |
78 | 0 | : type_(NT_NONE) |
79 | 0 | , training_(TS_ENABLED) |
80 | 0 | , needs_to_backprop_(true) |
81 | 0 | , network_flags_(0) |
82 | 0 | , ni_(0) |
83 | 0 | , no_(0) |
84 | 0 | , num_weights_(0) |
85 | 0 | , forward_win_(nullptr) |
86 | 0 | , backward_win_(nullptr) |
87 | 0 | , randomizer_(nullptr) {} |
88 | | Network::Network(NetworkType type, const std::string &name, int ni, int no) |
89 | 26 | : type_(type) |
90 | 26 | , training_(TS_ENABLED) |
91 | 26 | , needs_to_backprop_(true) |
92 | 26 | , network_flags_(0) |
93 | 26 | , ni_(ni) |
94 | 26 | , no_(no) |
95 | 26 | , num_weights_(0) |
96 | 26 | , name_(name) |
97 | 26 | , forward_win_(nullptr) |
98 | 26 | , backward_win_(nullptr) |
99 | 26 | , randomizer_(nullptr) {} |
100 | | |
101 | | // Suspends/Enables/Permanently disables training by setting the training_ |
102 | | // flag. Serialize and DeSerialize only operate on the run-time data if state |
103 | | // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will |
104 | | // temporarily disable layers in state TS_ENABLED, allowing a trainer to |
105 | | // serialize as if it were a recognizer. |
106 | | // TS_RE_ENABLE will re-enable layers that were previously in any disabled |
107 | | // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in |
108 | | // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a |
109 | | // recognizer can be converted back to a trainer. |
110 | 0 | void Network::SetEnableTraining(TrainingState state) { |
111 | 0 | if (state == TS_RE_ENABLE) { |
112 | | // Enable only from temp disabled. |
113 | 0 | if (training_ == TS_TEMP_DISABLE) { |
114 | 0 | training_ = TS_ENABLED; |
115 | 0 | } |
116 | 0 | } else if (state == TS_TEMP_DISABLE) { |
117 | | // Temp disable only from enabled. |
118 | 0 | if (training_ == TS_ENABLED) { |
119 | 0 | training_ = state; |
120 | 0 | } |
121 | 0 | } else { |
122 | 0 | training_ = state; |
123 | 0 | } |
124 | 0 | } |
125 | | |
126 | | // Sets flags that control the action of the network. See NetworkFlags enum |
127 | | // for bit values. |
128 | 0 | void Network::SetNetworkFlags(uint32_t flags) { |
129 | 0 | network_flags_ = flags; |
130 | 0 | } |
131 | | |
132 | | // Sets up the network for training. Initializes weights using weights of |
133 | | // scale `range` picked according to the random number generator `randomizer`. |
134 | 0 | int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) { |
135 | 0 | randomizer_ = randomizer; |
136 | 0 | return 0; |
137 | 0 | } |
138 | | |
139 | | // Provides a pointer to a TRand for any networks that care to use it. |
140 | | // Note that randomizer is a borrowed pointer that should outlive the network |
141 | | // and should not be deleted by any of the networks. |
142 | 18 | void Network::SetRandomizer(TRand *randomizer) { |
143 | 18 | randomizer_ = randomizer; |
144 | 18 | } |
145 | | |
146 | | // Sets needs_to_backprop_ to needs_backprop and returns true if |
147 | | // needs_backprop || any weights in this network so the next layer forward |
148 | | // can be told to produce backprop for this layer if needed. |
149 | 0 | bool Network::SetupNeedsBackprop(bool needs_backprop) { |
150 | 0 | needs_to_backprop_ = needs_backprop; |
151 | 0 | return needs_backprop || num_weights_ > 0; |
152 | 0 | } |
153 | | |
154 | | // Writes to the given file. Returns false in case of error. |
155 | 0 | bool Network::Serialize(TFile *fp) const { |
156 | 0 | int8_t data = NT_NONE; |
157 | 0 | if (!fp->Serialize(&data)) { |
158 | 0 | return false; |
159 | 0 | } |
160 | 0 | std::string type_name = kTypeNames[type_]; |
161 | 0 | if (!fp->Serialize(type_name)) { |
162 | 0 | return false; |
163 | 0 | } |
164 | 0 | data = training_; |
165 | 0 | if (!fp->Serialize(&data)) { |
166 | 0 | return false; |
167 | 0 | } |
168 | 0 | data = needs_to_backprop_; |
169 | 0 | if (!fp->Serialize(&data)) { |
170 | 0 | return false; |
171 | 0 | } |
172 | 0 | if (!fp->Serialize(&network_flags_)) { |
173 | 0 | return false; |
174 | 0 | } |
175 | 0 | if (!fp->Serialize(&ni_)) { |
176 | 0 | return false; |
177 | 0 | } |
178 | 0 | if (!fp->Serialize(&no_)) { |
179 | 0 | return false; |
180 | 0 | } |
181 | 0 | if (!fp->Serialize(&num_weights_)) { |
182 | 0 | return false; |
183 | 0 | } |
184 | 0 | uint32_t length = name_.length(); |
185 | 0 | if (!fp->Serialize(&length)) { |
186 | 0 | return false; |
187 | 0 | } |
188 | 0 | return fp->Serialize(name_.c_str(), length); |
189 | 0 | } |
190 | | |
191 | 26 | static NetworkType getNetworkType(TFile *fp) { |
192 | 26 | int8_t data; |
193 | 26 | if (!fp->DeSerialize(&data)) { |
194 | 0 | return NT_NONE; |
195 | 0 | } |
196 | 26 | if (data == NT_NONE) { |
197 | 26 | std::string type_name; |
198 | 26 | if (!fp->DeSerialize(type_name)) { |
199 | 0 | return NT_NONE; |
200 | 0 | } |
201 | 318 | for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) { |
202 | 292 | } |
203 | 26 | if (data == NT_COUNT) { |
204 | 0 | tprintf("Invalid network layer type:%s\n", type_name.c_str()); |
205 | 0 | return NT_NONE; |
206 | 0 | } |
207 | 26 | } |
208 | 26 | return static_cast<NetworkType>(data); |
209 | 26 | } |
210 | | |
211 | | // Reads from the given file. Returns nullptr in case of error. |
212 | | // Determines the type of the serialized class and calls its DeSerialize |
213 | | // on a new object of the appropriate type, which is returned. |
214 | 26 | Network *Network::CreateFromFile(TFile *fp) { |
215 | 26 | NetworkType type; // Type of the derived network class. |
216 | 26 | TrainingState training; // Are we currently training? |
217 | 26 | bool needs_to_backprop; // This network needs to output back_deltas. |
218 | 26 | int32_t network_flags; // Behavior control flags in NetworkFlags. |
219 | 26 | int32_t ni; // Number of input values. |
220 | 26 | int32_t no; // Number of output values. |
221 | 26 | int32_t num_weights; // Number of weights in this and sub-network. |
222 | 26 | std::string name; // A unique name for this layer. |
223 | 26 | int8_t data; |
224 | 26 | Network *network = nullptr; |
225 | 26 | type = getNetworkType(fp); |
226 | 26 | if (!fp->DeSerialize(&data)) { |
227 | 0 | return nullptr; |
228 | 0 | } |
229 | 26 | training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED; |
230 | 26 | if (!fp->DeSerialize(&data)) { |
231 | 0 | return nullptr; |
232 | 0 | } |
233 | 26 | needs_to_backprop = data != 0; |
234 | 26 | if (!fp->DeSerialize(&network_flags)) { |
235 | 0 | return nullptr; |
236 | 0 | } |
237 | 26 | if (!fp->DeSerialize(&ni)) { |
238 | 0 | return nullptr; |
239 | 0 | } |
240 | 26 | if (!fp->DeSerialize(&no)) { |
241 | 0 | return nullptr; |
242 | 0 | } |
243 | 26 | if (!fp->DeSerialize(&num_weights)) { |
244 | 0 | return nullptr; |
245 | 0 | } |
246 | 26 | if (!fp->DeSerialize(name)) { |
247 | 0 | return nullptr; |
248 | 0 | } |
249 | | |
250 | 26 | switch (type) { |
251 | 2 | case NT_CONVOLVE: |
252 | 2 | network = new Convolve(name, ni, 0, 0); |
253 | 2 | break; |
254 | 2 | case NT_INPUT: |
255 | 2 | network = new Input(name, ni, no); |
256 | 2 | break; |
257 | 6 | case NT_LSTM: |
258 | 6 | case NT_LSTM_SOFTMAX: |
259 | 6 | case NT_LSTM_SOFTMAX_ENCODED: |
260 | 8 | case NT_LSTM_SUMMARY: |
261 | 8 | network = new LSTM(name, ni, no, no, false, type); |
262 | 8 | break; |
263 | 2 | case NT_MAXPOOL: |
264 | 2 | network = new Maxpool(name, ni, 0, 0); |
265 | 2 | break; |
266 | | // All variants of Parallel. |
267 | 0 | case NT_PARALLEL: |
268 | 0 | case NT_REPLICATED: |
269 | 0 | case NT_PAR_RL_LSTM: |
270 | 0 | case NT_PAR_UD_LSTM: |
271 | 0 | case NT_PAR_2D_LSTM: |
272 | 0 | network = new Parallel(name, type); |
273 | 0 | break; |
274 | 0 | case NT_RECONFIG: |
275 | 0 | network = new Reconfig(name, ni, 0, 0); |
276 | 0 | break; |
277 | | // All variants of reversed. |
278 | 2 | case NT_XREVERSED: |
279 | 2 | case NT_YREVERSED: |
280 | 4 | case NT_XYTRANSPOSE: |
281 | 4 | network = new Reversed(name, type); |
282 | 4 | break; |
283 | 4 | case NT_SERIES: |
284 | 4 | network = new Series(name); |
285 | 4 | break; |
286 | 0 | case NT_TENSORFLOW: |
287 | 0 | tprintf("Unsupported TensorFlow model\n"); |
288 | 0 | break; |
289 | | // All variants of FullyConnected. |
290 | 2 | case NT_SOFTMAX: |
291 | 2 | case NT_SOFTMAX_NO_CTC: |
292 | 2 | case NT_RELU: |
293 | 4 | case NT_TANH: |
294 | 4 | case NT_LINEAR: |
295 | 4 | case NT_LOGISTIC: |
296 | 4 | case NT_POSCLIP: |
297 | 4 | case NT_SYMCLIP: |
298 | 4 | network = new FullyConnected(name, ni, no, type); |
299 | 4 | break; |
300 | 0 | default: |
301 | 0 | break; |
302 | 26 | } |
303 | 26 | if (network) { |
304 | 26 | network->training_ = training; |
305 | 26 | network->needs_to_backprop_ = needs_to_backprop; |
306 | 26 | network->network_flags_ = network_flags; |
307 | 26 | network->num_weights_ = num_weights; |
308 | 26 | if (!network->DeSerialize(fp)) { |
309 | 0 | delete network; |
310 | 0 | network = nullptr; |
311 | 0 | } |
312 | 26 | } |
313 | 26 | return network; |
314 | 26 | } |
315 | | |
316 | | // Returns a random number in [-range, range]. |
317 | 0 | TFloat Network::Random(TFloat range) { |
318 | 0 | ASSERT_HOST(randomizer_ != nullptr); |
319 | 0 | return randomizer_->SignedRand(range); |
320 | 0 | } |
321 | | |
322 | | #ifndef GRAPHICS_DISABLED |
323 | | |
324 | | // === Debug image display methods. === |
325 | | // Displays the image of the matrix to the forward window. |
326 | | void Network::DisplayForward(const NetworkIO &matrix) { |
327 | | Image image = matrix.ToPix(); |
328 | | ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_); |
329 | | DisplayImage(image, forward_win_); |
330 | | forward_win_->Update(); |
331 | | } |
332 | | |
333 | | // Displays the image of the matrix to the backward window. |
334 | | void Network::DisplayBackward(const NetworkIO &matrix) { |
335 | | Image image = matrix.ToPix(); |
336 | | std::string window_name = name_ + "-back"; |
337 | | ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_); |
338 | | DisplayImage(image, backward_win_); |
339 | | backward_win_->Update(); |
340 | | } |
341 | | |
342 | | // Creates the window if needed, otherwise clears it. |
343 | | void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height, |
344 | | ScrollView **window) { |
345 | | if (*window == nullptr) { |
346 | | int min_size = std::min(width, height); |
347 | | if (min_size < kMinWinSize) { |
348 | | if (min_size < 1) { |
349 | | min_size = 1; |
350 | | } |
351 | | width = width * kMinWinSize / min_size; |
352 | | height = height * kMinWinSize / min_size; |
353 | | } |
354 | | width += kXWinFrameSize; |
355 | | height += kYWinFrameSize; |
356 | | if (width > kMaxWinSize) { |
357 | | width = kMaxWinSize; |
358 | | } |
359 | | if (height > kMaxWinSize) { |
360 | | height = kMaxWinSize; |
361 | | } |
362 | | *window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords); |
363 | | tprintf("Created window %s of size %d, %d\n", window_name, width, height); |
364 | | } else { |
365 | | (*window)->Clear(); |
366 | | } |
367 | | } |
368 | | |
369 | | // Displays the pix in the given window. and returns the height of the pix. |
370 | | // The pix is pixDestroyed. |
371 | | int Network::DisplayImage(Image pix, ScrollView *window) { |
372 | | int height = pixGetHeight(pix); |
373 | | window->Draw(pix, 0, 0); |
374 | | pix.destroy(); |
375 | | return height; |
376 | | } |
377 | | #endif // !GRAPHICS_DISABLED |
378 | | |
379 | | } // namespace tesseract. |