/src/tesseract/src/lstm/network.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: network.h |
3 | | // Description: Base class for neural network implementations. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2013, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifndef TESSERACT_LSTM_NETWORK_H_ |
19 | | #define TESSERACT_LSTM_NETWORK_H_ |
20 | | |
21 | | #include "helpers.h" |
22 | | #include "matrix.h" |
23 | | #include "networkio.h" |
24 | | #include "serialis.h" |
25 | | #include "static_shape.h" |
26 | | |
27 | | #include <cmath> |
28 | | #include <cstdio> |
29 | | |
30 | | struct Pix; |
31 | | |
32 | | namespace tesseract { |
33 | | |
34 | | class ScrollView; |
35 | | class TBOX; |
36 | | class ImageData; |
37 | | class NetworkScratch; |
38 | | |
39 | | // Enum to store the run-time type of a Network. Keep in sync with kTypeNames. |
40 | | enum NetworkType { |
41 | | NT_NONE, // The naked base class. |
42 | | NT_INPUT, // Inputs from an image. |
43 | | // Plumbing networks combine other networks or rearrange the inputs. |
44 | | NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood. |
45 | | NT_MAXPOOL, // Chooses the max result from a rectangle. |
46 | | NT_PARALLEL, // Runs networks in parallel. |
47 | | NT_REPLICATED, // Runs identical networks in parallel. |
48 | | NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel. |
49 | | NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel. |
50 | | NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel. |
51 | | NT_SERIES, // Executes a sequence of layers. |
52 | | NT_RECONFIG, // Scales the time/y size but makes the output deeper. |
53 | | NT_XREVERSED, // Reverses the x direction of the inputs/outputs. |
54 | | NT_YREVERSED, // Reverses the y-direction of the inputs/outputs. |
55 | | NT_XYTRANSPOSE, // Transposes x and y (for just a single op). |
56 | | // Functional networks actually calculate stuff. |
57 | | NT_LSTM, // Long-Short-Term-Memory block. |
58 | | NT_LSTM_SUMMARY, // LSTM that only keeps its last output. |
59 | | NT_LOGISTIC, // Fully connected logistic nonlinearity. |
60 | | NT_POSCLIP, // Fully connected rect lin version of logistic. |
61 | | NT_SYMCLIP, // Fully connected rect lin version of tanh. |
62 | | NT_TANH, // Fully connected with tanh nonlinearity. |
63 | | NT_RELU, // Fully connected with rectifier nonlinearity. |
64 | | NT_LINEAR, // Fully connected with no nonlinearity. |
65 | | NT_SOFTMAX, // Softmax uses exponential normalization, with CTC. |
66 | | NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC. |
67 | | // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with |
68 | | // the outputs fed back to the input of the LSTM at the next timestep. |
69 | | // The ENCODED version binary encodes the softmax outputs, providing log2 of |
70 | | // the number of outputs as additional inputs, and the other version just |
71 | | // provides all the softmax outputs as additional inputs. |
72 | | NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax. |
73 | | NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax. |
74 | | // A TensorFlow graph encapsulated as a Tesseract network. |
75 | | NT_TENSORFLOW, |
76 | | |
77 | | NT_COUNT // Array size. |
78 | | }; |
79 | | |
80 | | // Enum of Network behavior flags. Can in theory be set for each individual |
81 | | // network element. |
82 | | enum NetworkFlags { |
83 | | // Network forward/backprop behavior. |
84 | | NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer. |
85 | | NF_ADAM = 128, // Weight-specific learning rate. |
86 | | }; |
87 | | |
88 | | // State of training and desired state used in SetEnableTraining. |
89 | | enum TrainingState { |
90 | | // Valid states of training_. |
91 | | TS_DISABLED, // Disabled permanently. |
92 | | TS_ENABLED, // Enabled for backprop and to write a training dump. |
93 | | // Re-enable from ANY disabled state. |
94 | | TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump. |
95 | | // Valid only for SetEnableTraining. |
96 | | TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED. |
97 | | }; |
98 | | |
99 | | // Base class for network types. Not quite an abstract base class, but almost. |
100 | | // Most of the time no isolated Network exists, except prior to |
101 | | // deserialization. |
102 | | class TESS_API Network { |
103 | | public: |
104 | | Network(); |
105 | | Network(NetworkType type, const std::string &name, int ni, int no); |
106 | 0 | virtual ~Network() = default; |
107 | | |
108 | | // Accessors. |
109 | 0 | NetworkType type() const { |
110 | 0 | return type_; |
111 | 0 | } |
112 | 499M | bool IsTraining() const { |
113 | 499M | return training_ == TS_ENABLED; |
114 | 499M | } |
115 | 0 | bool needs_to_backprop() const { |
116 | 0 | return needs_to_backprop_; |
117 | 0 | } |
118 | 0 | int num_weights() const { |
119 | 0 | return num_weights_; |
120 | 0 | } |
121 | 141k | int NumInputs() const { |
122 | 141k | return ni_; |
123 | 141k | } |
124 | 48 | int NumOutputs() const { |
125 | 48 | return no_; |
126 | 48 | } |
127 | | // Returns the required shape input to the network. |
128 | 0 | virtual StaticShape InputShape() const { |
129 | 0 | StaticShape result; |
130 | 0 | return result; |
131 | 0 | } |
132 | | // Returns the shape output from the network given an input shape (which may |
133 | | // be partially unknown ie zero). |
134 | 2 | virtual StaticShape OutputShape(const StaticShape &input_shape) const { |
135 | 2 | StaticShape result(input_shape); |
136 | 2 | result.set_depth(no_); |
137 | 2 | return result; |
138 | 2 | } |
139 | 0 | const std::string &name() const { |
140 | 0 | return name_; |
141 | 0 | } |
142 | | virtual std::string spec() const = 0; |
143 | 0 | bool TestFlag(NetworkFlags flag) const { |
144 | 0 | return (network_flags_ & flag) != 0; |
145 | 0 | } |
146 | | |
147 | | // Initialization and administrative functions that are mostly provided |
148 | | // by Plumbing. |
149 | | // Returns true if the given type is derived from Plumbing, and thus contains |
150 | | // multiple sub-networks that can have their own learning rate. |
151 | 0 | virtual bool IsPlumbingType() const { |
152 | 0 | return false; |
153 | 0 | } |
154 | | |
155 | | // Suspends/Enables/Permanently disables training by setting the training_ |
156 | | // flag. Serialize and DeSerialize only operate on the run-time data if state |
157 | | // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will |
158 | | // temporarily disable layers in state TS_ENABLED, allowing a trainer to |
159 | | // serialize as if it were a recognizer. |
160 | | // TS_RE_ENABLE will re-enable layers that were previously in any disabled |
161 | | // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in |
162 | | // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a |
163 | | // recognizer can be converted back to a trainer. |
164 | | virtual void SetEnableTraining(TrainingState state); |
165 | | |
166 | | // Sets flags that control the action of the network. See NetworkFlags enum |
167 | | // for bit values. |
168 | | virtual void SetNetworkFlags(uint32_t flags); |
169 | | |
170 | | // Sets up the network for training. Initializes weights using weights of |
171 | | // scale `range` picked according to the random number generator `randomizer`. |
172 | | // Note that randomizer is a borrowed pointer that should outlive the network |
173 | | // and should not be deleted by any of the networks. |
174 | | // Returns the number of weights initialized. |
175 | | virtual int InitWeights(float range, TRand *randomizer); |
176 | | // Changes the number of outputs to the outside world to the size of the given |
177 | | // code_map. Recursively searches the entire network for Softmax layers that |
178 | | // have exactly old_no outputs, and operates only on those, leaving all others |
179 | | // unchanged. This enables networks with multiple output layers to get all |
180 | | // their softmaxes updated, but if an internal layer, uses one of those |
181 | | // softmaxes for input, then the inputs will effectively be scrambled. |
182 | | // TODO(rays) Fix this before any such network is implemented. |
183 | | // The softmaxes are resized by copying the old weight matrix entries for each |
184 | | // output from code_map[output] where non-negative, and uses the mean (over |
185 | | // all outputs) of the existing weights for all outputs with negative code_map |
186 | | // entries. Returns the new number of weights. |
187 | | virtual int RemapOutputs([[maybe_unused]] int old_no, |
188 | 0 | [[maybe_unused]] const std::vector<int> &code_map) { |
189 | 0 | return 0; |
190 | 0 | } |
191 | | |
192 | | // Converts a float network to an int network. |
193 | 0 | virtual void ConvertToInt() {} |
194 | | |
195 | | // Provides a pointer to a TRand for any networks that care to use it. |
196 | | // Note that randomizer is a borrowed pointer that should outlive the network |
197 | | // and should not be deleted by any of the networks. |
198 | | virtual void SetRandomizer(TRand *randomizer); |
199 | | |
200 | | // Sets needs_to_backprop_ to needs_backprop and returns true if |
201 | | // needs_backprop || any weights in this network so the next layer forward |
202 | | // can be told to produce backprop for this layer if needed. |
203 | | virtual bool SetupNeedsBackprop(bool needs_backprop); |
204 | | |
205 | | // Returns the most recent reduction factor that the network applied to the |
206 | | // time sequence. Assumes that any 2-d is already eliminated. Used for |
207 | | // scaling bounding boxes of truth data and calculating result bounding boxes. |
208 | | // WARNING: if GlobalMinimax is used to vary the scale, this will return |
209 | | // the last used scale factor. Call it before any forward, and it will return |
210 | | // the minimum scale factor of the paths through the GlobalMinimax. |
211 | 991k | virtual int XScaleFactor() const { |
212 | 991k | return 1; |
213 | 991k | } |
214 | | |
215 | | // Provides the (minimum) x scale factor to the network (of interest only to |
216 | | // input units) so they can determine how to scale bounding boxes. |
217 | 0 | virtual void CacheXScaleFactor([[maybe_unused]] int factor) {} |
218 | | |
219 | | // Provides debug output on the weights. |
220 | | virtual void DebugWeights() = 0; |
221 | | |
222 | | // Writes to the given file. Returns false in case of error. |
223 | | // Should be overridden by subclasses, but called by their Serialize. |
224 | | virtual bool Serialize(TFile *fp) const; |
225 | | // Reads from the given file. Returns false in case of error. |
226 | | // Should be overridden by subclasses, but NOT called by their DeSerialize. |
227 | | virtual bool DeSerialize(TFile *fp) = 0; |
228 | | |
229 | | public: |
230 | | // Updates the weights using the given learning rate, momentum and adam_beta. |
231 | | // num_samples is used in the adam computation iff use_adam_ is true. |
232 | | virtual void Update([[maybe_unused]] float learning_rate, |
233 | | [[maybe_unused]] float momentum, |
234 | | [[maybe_unused]] float adam_beta, |
235 | 0 | [[maybe_unused]] int num_samples) {} |
236 | | // Sums the products of weight updates in *this and other, splitting into |
237 | | // positive (same direction) in *same and negative (different direction) in |
238 | | // *changed. |
239 | | virtual void CountAlternators([[maybe_unused]] const Network &other, |
240 | | [[maybe_unused]] TFloat *same, |
241 | 0 | [[maybe_unused]] TFloat *changed) const {} |
242 | | |
243 | | // Reads from the given file. Returns nullptr in case of error. |
244 | | // Determines the type of the serialized class and calls its DeSerialize |
245 | | // on a new object of the appropriate type, which is returned. |
246 | | static Network *CreateFromFile(TFile *fp); |
247 | | |
248 | | // Runs forward propagation of activations on the input line. |
249 | | // Note that input and output are both 2-d arrays. |
250 | | // The 1st index is the time element. In a 1-d network, it might be the pixel |
251 | | // position on the textline. In a 2-d network, the linearization is defined |
252 | | // by the stride_map. (See networkio.h). |
253 | | // The 2nd index of input is the network inputs/outputs, and the dimension |
254 | | // of the input must match NumInputs() of this network. |
255 | | // The output array will be resized as needed so that its 1st dimension is |
256 | | // always equal to the number of output values, and its second dimension is |
257 | | // always NumOutputs(). Note that all this detail is encapsulated away inside |
258 | | // NetworkIO, as are the internals of the scratch memory space used by the |
259 | | // network. See networkscratch.h for that. |
260 | | // If input_transpose is not nullptr, then it contains the transpose of input, |
261 | | // and the caller guarantees that it will still be valid on the next call to |
262 | | // backward. The callee is therefore at liberty to save the pointer and |
263 | | // reference it on a call to backward. This is a bit ugly, but it makes it |
264 | | // possible for a replicating parallel to calculate the input transpose once |
265 | | // instead of all the replicated networks having to do it. |
266 | | virtual void Forward(bool debug, const NetworkIO &input, |
267 | | const TransposedArray *input_transpose, |
268 | | NetworkScratch *scratch, NetworkIO *output) = 0; |
269 | | |
270 | | // Runs backward propagation of errors on fwdX_deltas. |
271 | | // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward. |
272 | | // Returns false if back_deltas was not set, due to there being no point in |
273 | | // propagating further backwards. Thus most complete networks will always |
274 | | // return false from Backward! |
275 | | virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, |
276 | | NetworkScratch *scratch, NetworkIO *back_deltas) = 0; |
277 | | |
278 | | // === Debug image display methods. === |
279 | | // Displays the image of the matrix to the forward window. |
280 | | void DisplayForward(const NetworkIO &matrix); |
281 | | // Displays the image of the matrix to the backward window. |
282 | | void DisplayBackward(const NetworkIO &matrix); |
283 | | |
284 | | // Creates the window if needed, otherwise clears it. |
285 | | static void ClearWindow(bool tess_coords, const char *window_name, int width, |
286 | | int height, ScrollView **window); |
287 | | |
288 | | // Displays the pix in the given window. and returns the height of the pix. |
289 | | // The pix is pixDestroyed. |
290 | | static int DisplayImage(Image pix, ScrollView *window); |
291 | | |
292 | | protected: |
293 | | // Returns a random number in [-range, range]. |
294 | | TFloat Random(TFloat range); |
295 | | |
296 | | protected: |
297 | | NetworkType type_; // Type of the derived network class. |
298 | | TrainingState training_; // Are we currently training? |
299 | | bool needs_to_backprop_; // This network needs to output back_deltas. |
300 | | int32_t network_flags_; // Behavior control flags in NetworkFlags. |
301 | | int32_t ni_; // Number of input values. |
302 | | int32_t no_; // Number of output values. |
303 | | int32_t num_weights_; // Number of weights in this and sub-network. |
304 | | std::string name_; // A unique name for this layer. |
305 | | |
306 | | // NOT-serialized debug data. |
307 | | ScrollView *forward_win_; // Recognition debug display window. |
308 | | ScrollView *backward_win_; // Training debug display window. |
309 | | TRand *randomizer_; // Random number generator. |
310 | | }; |
311 | | |
312 | | } // namespace tesseract. |
313 | | |
314 | | #endif // TESSERACT_LSTM_NETWORK_H_ |