Coverage Report

Created: 2025-12-31 06:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/tesseract/src/lstm/network.h
Line
Count
Source
1
///////////////////////////////////////////////////////////////////////
2
// File:        network.h
3
// Description: Base class for neural network implementations.
4
// Author:      Ray Smith
5
//
6
// (C) Copyright 2013, Google Inc.
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
// you may not use this file except in compliance with the License.
9
// You may obtain a copy of the License at
10
// http://www.apache.org/licenses/LICENSE-2.0
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
///////////////////////////////////////////////////////////////////////
17
18
#ifndef TESSERACT_LSTM_NETWORK_H_
19
#define TESSERACT_LSTM_NETWORK_H_
20
21
#include "helpers.h"
22
#include "matrix.h"
23
#include "networkio.h"
24
#include "serialis.h"
25
#include "static_shape.h"
26
27
#include <cmath>
28
#include <cstdio>
29
30
struct Pix;
31
32
namespace tesseract {
33
34
class ScrollView;
35
class TBOX;
36
class ImageData;
37
class NetworkScratch;
38
39
// Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
40
enum NetworkType {
41
  NT_NONE,  // The naked base class.
42
  NT_INPUT, // Inputs from an image.
43
  // Plumbing networks combine other networks or rearrange the inputs.
44
  NT_CONVOLVE,    // Duplicates inputs in a sliding window neighborhood.
45
  NT_MAXPOOL,     // Chooses the max result from a rectangle.
46
  NT_PARALLEL,    // Runs networks in parallel.
47
  NT_REPLICATED,  // Runs identical networks in parallel.
48
  NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
49
  NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
50
  NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
51
  NT_SERIES,      // Executes a sequence of layers.
52
  NT_RECONFIG,    // Scales the time/y size but makes the output deeper.
53
  NT_XREVERSED,   // Reverses the x direction of the inputs/outputs.
54
  NT_YREVERSED,   // Reverses the y-direction of the inputs/outputs.
55
  NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
56
  // Functional networks actually calculate stuff.
57
  NT_LSTM,           // Long-Short-Term-Memory block.
58
  NT_LSTM_SUMMARY,   // LSTM that only keeps its last output.
59
  NT_LOGISTIC,       // Fully connected logistic nonlinearity.
60
  NT_POSCLIP,        // Fully connected rect lin version of logistic.
61
  NT_SYMCLIP,        // Fully connected rect lin version of tanh.
62
  NT_TANH,           // Fully connected with tanh nonlinearity.
63
  NT_RELU,           // Fully connected with rectifier nonlinearity.
64
  NT_LINEAR,         // Fully connected with no nonlinearity.
65
  NT_SOFTMAX,        // Softmax uses exponential normalization, with CTC.
66
  NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
67
  // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
68
  // the outputs fed back to the input of the LSTM at the next timestep.
69
  // The ENCODED version binary encodes the softmax outputs, providing log2 of
70
  // the number of outputs as additional inputs, and the other version just
71
  // provides all the softmax outputs as additional inputs.
72
  NT_LSTM_SOFTMAX,         // 1-d LSTM with built-in fully connected softmax.
73
  NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
74
  // A TensorFlow graph encapsulated as a Tesseract network.
75
  NT_TENSORFLOW,
76
77
  NT_COUNT // Array size.
78
};
79
80
// Enum of Network behavior flags. Can in theory be set for each individual
81
// network element.
82
enum NetworkFlags {
83
  // Network forward/backprop behavior.
84
  NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
85
  NF_ADAM = 128,             // Weight-specific learning rate.
86
};
87
88
// State of training and desired state used in SetEnableTraining.
89
enum TrainingState {
90
  // Valid states of training_.
91
  TS_DISABLED,     // Disabled permanently.
92
  TS_ENABLED,      // Enabled for backprop and to write a training dump.
93
                   // Re-enable from ANY disabled state.
94
  TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
95
  // Valid only for SetEnableTraining.
96
  TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
97
};
98
99
// Base class for network types. Not quite an abstract base class, but almost.
100
// Most of the time no isolated Network exists, except prior to
101
// deserialization.
102
class TESS_API Network {
103
public:
104
  Network();
105
  Network(NetworkType type, const std::string &name, int ni, int no);
106
0
  virtual ~Network() = default;
107
108
  // Accessors.
109
0
  NetworkType type() const {
110
0
    return type_;
111
0
  }
112
499M
  bool IsTraining() const {
113
499M
    return training_ == TS_ENABLED;
114
499M
  }
115
0
  bool needs_to_backprop() const {
116
0
    return needs_to_backprop_;
117
0
  }
118
0
  int num_weights() const {
119
0
    return num_weights_;
120
0
  }
121
141k
  int NumInputs() const {
122
141k
    return ni_;
123
141k
  }
124
48
  int NumOutputs() const {
125
48
    return no_;
126
48
  }
127
  // Returns the required shape input to the network.
128
0
  virtual StaticShape InputShape() const {
129
0
    StaticShape result;
130
0
    return result;
131
0
  }
132
  // Returns the shape output from the network given an input shape (which may
133
  // be partially unknown ie zero).
134
2
  virtual StaticShape OutputShape(const StaticShape &input_shape) const {
135
2
    StaticShape result(input_shape);
136
2
    result.set_depth(no_);
137
2
    return result;
138
2
  }
139
0
  const std::string &name() const {
140
0
    return name_;
141
0
  }
142
  virtual std::string spec() const = 0;
143
0
  bool TestFlag(NetworkFlags flag) const {
144
0
    return (network_flags_ & flag) != 0;
145
0
  }
146
147
  // Initialization and administrative functions that are mostly provided
148
  // by Plumbing.
149
  // Returns true if the given type is derived from Plumbing, and thus contains
150
  // multiple sub-networks that can have their own learning rate.
151
0
  virtual bool IsPlumbingType() const {
152
0
    return false;
153
0
  }
154
155
  // Suspends/Enables/Permanently disables training by setting the training_
156
  // flag. Serialize and DeSerialize only operate on the run-time data if state
157
  // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
158
  // temporarily disable layers in state TS_ENABLED, allowing a trainer to
159
  // serialize as if it were a recognizer.
160
  // TS_RE_ENABLE will re-enable layers that were previously in any disabled
161
  // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
162
  // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
163
  // recognizer can be converted back to a trainer.
164
  virtual void SetEnableTraining(TrainingState state);
165
166
  // Sets flags that control the action of the network. See NetworkFlags enum
167
  // for bit values.
168
  virtual void SetNetworkFlags(uint32_t flags);
169
170
  // Sets up the network for training. Initializes weights using weights of
171
  // scale `range` picked according to the random number generator `randomizer`.
172
  // Note that randomizer is a borrowed pointer that should outlive the network
173
  // and should not be deleted by any of the networks.
174
  // Returns the number of weights initialized.
175
  virtual int InitWeights(float range, TRand *randomizer);
176
  // Changes the number of outputs to the outside world to the size of the given
177
  // code_map. Recursively searches the entire network for Softmax layers that
178
  // have exactly old_no outputs, and operates only on those, leaving all others
179
  // unchanged. This enables networks with multiple output layers to get all
180
  // their softmaxes updated, but if an internal layer, uses one of those
181
  // softmaxes for input, then the inputs will effectively be scrambled.
182
  // TODO(rays) Fix this before any such network is implemented.
183
  // The softmaxes are resized by copying the old weight matrix entries for each
184
  // output from code_map[output] where non-negative, and uses the mean (over
185
  // all outputs) of the existing weights for all outputs with negative code_map
186
  // entries. Returns the new number of weights.
187
  virtual int RemapOutputs([[maybe_unused]] int old_no,
188
0
                           [[maybe_unused]] const std::vector<int> &code_map) {
189
0
    return 0;
190
0
  }
191
192
  // Converts a float network to an int network.
193
0
  virtual void ConvertToInt() {}
194
195
  // Provides a pointer to a TRand for any networks that care to use it.
196
  // Note that randomizer is a borrowed pointer that should outlive the network
197
  // and should not be deleted by any of the networks.
198
  virtual void SetRandomizer(TRand *randomizer);
199
200
  // Sets needs_to_backprop_ to needs_backprop and returns true if
201
  // needs_backprop || any weights in this network so the next layer forward
202
  // can be told to produce backprop for this layer if needed.
203
  virtual bool SetupNeedsBackprop(bool needs_backprop);
204
205
  // Returns the most recent reduction factor that the network applied to the
206
  // time sequence. Assumes that any 2-d is already eliminated. Used for
207
  // scaling bounding boxes of truth data and calculating result bounding boxes.
208
  // WARNING: if GlobalMinimax is used to vary the scale, this will return
209
  // the last used scale factor. Call it before any forward, and it will return
210
  // the minimum scale factor of the paths through the GlobalMinimax.
211
991k
  virtual int XScaleFactor() const {
212
991k
    return 1;
213
991k
  }
214
215
  // Provides the (minimum) x scale factor to the network (of interest only to
216
  // input units) so they can determine how to scale bounding boxes.
217
0
  virtual void CacheXScaleFactor([[maybe_unused]] int factor) {}
218
219
  // Provides debug output on the weights.
220
  virtual void DebugWeights() = 0;
221
222
  // Writes to the given file. Returns false in case of error.
223
  // Should be overridden by subclasses, but called by their Serialize.
224
  virtual bool Serialize(TFile *fp) const;
225
  // Reads from the given file. Returns false in case of error.
226
  // Should be overridden by subclasses, but NOT called by their DeSerialize.
227
  virtual bool DeSerialize(TFile *fp) = 0;
228
229
public:
230
  // Updates the weights using the given learning rate, momentum and adam_beta.
231
  // num_samples is used in the adam computation iff use_adam_ is true.
232
  virtual void Update([[maybe_unused]] float learning_rate,
233
                      [[maybe_unused]] float momentum,
234
                      [[maybe_unused]] float adam_beta,
235
0
                      [[maybe_unused]] int num_samples) {}
236
  // Sums the products of weight updates in *this and other, splitting into
237
  // positive (same direction) in *same and negative (different direction) in
238
  // *changed.
239
  virtual void CountAlternators([[maybe_unused]] const Network &other,
240
                                [[maybe_unused]] TFloat *same,
241
0
                                [[maybe_unused]] TFloat *changed) const {}
242
243
  // Reads from the given file. Returns nullptr in case of error.
244
  // Determines the type of the serialized class and calls its DeSerialize
245
  // on a new object of the appropriate type, which is returned.
246
  static Network *CreateFromFile(TFile *fp);
247
248
  // Runs forward propagation of activations on the input line.
249
  // Note that input and output are both 2-d arrays.
250
  // The 1st index is the time element. In a 1-d network, it might be the pixel
251
  // position on the textline. In a 2-d network, the linearization is defined
252
  // by the stride_map. (See networkio.h).
253
  // The 2nd index of input is the network inputs/outputs, and the dimension
254
  // of the input must match NumInputs() of this network.
255
  // The output array will be resized as needed so that its 1st dimension is
256
  // always equal to the number of output values, and its second dimension is
257
  // always NumOutputs(). Note that all this detail is encapsulated away inside
258
  // NetworkIO, as are the internals of the scratch memory space used by the
259
  // network. See networkscratch.h for that.
260
  // If input_transpose is not nullptr, then it contains the transpose of input,
261
  // and the caller guarantees that it will still be valid on the next call to
262
  // backward. The callee is therefore at liberty to save the pointer and
263
  // reference it on a call to backward. This is a bit ugly, but it makes it
264
  // possible for a replicating parallel to calculate the input transpose once
265
  // instead of all the replicated networks having to do it.
266
  virtual void Forward(bool debug, const NetworkIO &input,
267
                       const TransposedArray *input_transpose,
268
                       NetworkScratch *scratch, NetworkIO *output) = 0;
269
270
  // Runs backward propagation of errors on fwdX_deltas.
271
  // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
272
  // Returns false if back_deltas was not set, due to there being no point in
273
  // propagating further backwards. Thus most complete networks will always
274
  // return false from Backward!
275
  virtual bool Backward(bool debug, const NetworkIO &fwd_deltas,
276
                        NetworkScratch *scratch, NetworkIO *back_deltas) = 0;
277
278
  // === Debug image display methods. ===
279
  // Displays the image of the matrix to the forward window.
280
  void DisplayForward(const NetworkIO &matrix);
281
  // Displays the image of the matrix to the backward window.
282
  void DisplayBackward(const NetworkIO &matrix);
283
284
  // Creates the window if needed, otherwise clears it.
285
  static void ClearWindow(bool tess_coords, const char *window_name, int width,
286
                          int height, ScrollView **window);
287
288
  // Displays the pix in the given window. and returns the height of the pix.
289
  // The pix is pixDestroyed.
290
  static int DisplayImage(Image pix, ScrollView *window);
291
292
protected:
293
  // Returns a random number in [-range, range].
294
  TFloat Random(TFloat range);
295
296
protected:
297
  NetworkType type_;       // Type of the derived network class.
298
  TrainingState training_; // Are we currently training?
299
  bool needs_to_backprop_; // This network needs to output back_deltas.
300
  int32_t network_flags_;  // Behavior control flags in NetworkFlags.
301
  int32_t ni_;             // Number of input values.
302
  int32_t no_;             // Number of output values.
303
  int32_t num_weights_;    // Number of weights in this and sub-network.
304
  std::string name_;       // A unique name for this layer.
305
306
  // NOT-serialized debug data.
307
  ScrollView *forward_win_;  // Recognition debug display window.
308
  ScrollView *backward_win_; // Training debug display window.
309
  TRand *randomizer_;        // Random number generator.
310
};
311
312
} // namespace tesseract.
313
314
#endif // TESSERACT_LSTM_NETWORK_H_