Coverage Report

Created: 2025-06-13 07:15

/src/tesseract/src/lstm/network.cpp
Line
Count
Source (jump to first uncovered line)
1
///////////////////////////////////////////////////////////////////////
2
// File:        network.cpp
3
// Description: Base class for neural network implementations.
4
// Author:      Ray Smith
5
//
6
// (C) Copyright 2013, Google Inc.
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
// you may not use this file except in compliance with the License.
9
// You may obtain a copy of the License at
10
// http://www.apache.org/licenses/LICENSE-2.0
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
///////////////////////////////////////////////////////////////////////
17
18
// Include automatically generated configuration file if running autoconf.
19
#ifdef HAVE_CONFIG_H
20
#  include "config_auto.h"
21
#endif
22
23
#include "network.h"
24
25
#include <cstdlib>
26
27
// This base class needs to know about all its sub-classes because of the
28
// factory deserializing method: CreateFromFile.
29
#include <allheaders.h>
30
#include "convolve.h"
31
#include "fullyconnected.h"
32
#include "input.h"
33
#include "lstm.h"
34
#include "maxpool.h"
35
#include "parallel.h"
36
#include "reconfig.h"
37
#include "reversed.h"
38
#include "scrollview.h"
39
#include "series.h"
40
#include "statistc.h"
41
#include "tprintf.h"
42
43
namespace tesseract {
44
45
#ifndef GRAPHICS_DISABLED
46
47
// Min and max window sizes.
48
const int kMinWinSize = 500;
49
const int kMaxWinSize = 2000;
50
// Window frame sizes need adding on to make the content fit.
51
const int kXWinFrameSize = 30;
52
const int kYWinFrameSize = 80;
53
54
#endif // !GRAPHICS_DISABLED
55
56
// String names corresponding to the NetworkType enum.
57
// Keep in sync with NetworkType.
58
// Names used in Serialization to allow re-ordering/addition/deletion of
59
// layer types in NetworkType without invalidating existing network files.
60
static char const *const kTypeNames[NT_COUNT] = {
61
    "Invalid",     "Input",
62
    "Convolve",    "Maxpool",
63
    "Parallel",    "Replicated",
64
    "ParBidiLSTM", "DepParUDLSTM",
65
    "Par2dLSTM",   "Series",
66
    "Reconfig",    "RTLReversed",
67
    "TTBReversed", "XYTranspose",
68
    "LSTM",        "SummLSTM",
69
    "Logistic",    "LinLogistic",
70
    "LinTanh",     "Tanh",
71
    "Relu",        "Linear",
72
    "Softmax",     "SoftmaxNoCTC",
73
    "LSTMSoftmax", "LSTMBinarySoftmax",
74
    "TensorFlow",
75
};
76
77
Network::Network()
78
0
    : type_(NT_NONE)
79
0
    , training_(TS_ENABLED)
80
0
    , needs_to_backprop_(true)
81
0
    , network_flags_(0)
82
0
    , ni_(0)
83
0
    , no_(0)
84
0
    , num_weights_(0)
85
0
    , forward_win_(nullptr)
86
0
    , backward_win_(nullptr)
87
0
    , randomizer_(nullptr) {}
88
Network::Network(NetworkType type, const std::string &name, int ni, int no)
89
26
    : type_(type)
90
26
    , training_(TS_ENABLED)
91
26
    , needs_to_backprop_(true)
92
26
    , network_flags_(0)
93
26
    , ni_(ni)
94
26
    , no_(no)
95
26
    , num_weights_(0)
96
26
    , name_(name)
97
26
    , forward_win_(nullptr)
98
26
    , backward_win_(nullptr)
99
26
    , randomizer_(nullptr) {}
100
101
// Suspends/Enables/Permanently disables training by setting the training_
102
// flag. Serialize and DeSerialize only operate on the run-time data if state
103
// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
104
// temporarily disable layers in state TS_ENABLED, allowing a trainer to
105
// serialize as if it were a recognizer.
106
// TS_RE_ENABLE will re-enable layers that were previously in any disabled
107
// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
108
// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
109
// recognizer can be converted back to a trainer.
110
0
void Network::SetEnableTraining(TrainingState state) {
111
0
  if (state == TS_RE_ENABLE) {
112
    // Enable only from temp disabled.
113
0
    if (training_ == TS_TEMP_DISABLE) {
114
0
      training_ = TS_ENABLED;
115
0
    }
116
0
  } else if (state == TS_TEMP_DISABLE) {
117
    // Temp disable only from enabled.
118
0
    if (training_ == TS_ENABLED) {
119
0
      training_ = state;
120
0
    }
121
0
  } else {
122
0
    training_ = state;
123
0
  }
124
0
}
125
126
// Sets flags that control the action of the network. See NetworkFlags enum
127
// for bit values.
128
0
void Network::SetNetworkFlags(uint32_t flags) {
129
0
  network_flags_ = flags;
130
0
}
131
132
// Sets up the network for training. Initializes weights using weights of
133
// scale `range` picked according to the random number generator `randomizer`.
134
0
int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) {
135
0
  randomizer_ = randomizer;
136
0
  return 0;
137
0
}
138
139
// Provides a pointer to a TRand for any networks that care to use it.
140
// Note that randomizer is a borrowed pointer that should outlive the network
141
// and should not be deleted by any of the networks.
142
18
void Network::SetRandomizer(TRand *randomizer) {
143
18
  randomizer_ = randomizer;
144
18
}
145
146
// Sets needs_to_backprop_ to needs_backprop and returns true if
147
// needs_backprop || any weights in this network so the next layer forward
148
// can be told to produce backprop for this layer if needed.
149
0
bool Network::SetupNeedsBackprop(bool needs_backprop) {
150
0
  needs_to_backprop_ = needs_backprop;
151
0
  return needs_backprop || num_weights_ > 0;
152
0
}
153
154
// Writes to the given file. Returns false in case of error.
155
0
bool Network::Serialize(TFile *fp) const {
156
0
  int8_t data = NT_NONE;
157
0
  if (!fp->Serialize(&data)) {
158
0
    return false;
159
0
  }
160
0
  std::string type_name = kTypeNames[type_];
161
0
  if (!fp->Serialize(type_name)) {
162
0
    return false;
163
0
  }
164
0
  data = training_;
165
0
  if (!fp->Serialize(&data)) {
166
0
    return false;
167
0
  }
168
0
  data = needs_to_backprop_;
169
0
  if (!fp->Serialize(&data)) {
170
0
    return false;
171
0
  }
172
0
  if (!fp->Serialize(&network_flags_)) {
173
0
    return false;
174
0
  }
175
0
  if (!fp->Serialize(&ni_)) {
176
0
    return false;
177
0
  }
178
0
  if (!fp->Serialize(&no_)) {
179
0
    return false;
180
0
  }
181
0
  if (!fp->Serialize(&num_weights_)) {
182
0
    return false;
183
0
  }
184
0
  uint32_t length = name_.length();
185
0
  if (!fp->Serialize(&length)) {
186
0
    return false;
187
0
  }
188
0
  return fp->Serialize(name_.c_str(), length);
189
0
}
190
191
26
static NetworkType getNetworkType(TFile *fp) {
192
26
  int8_t data;
193
26
  if (!fp->DeSerialize(&data)) {
194
0
    return NT_NONE;
195
0
  }
196
26
  if (data == NT_NONE) {
197
26
    std::string type_name;
198
26
    if (!fp->DeSerialize(type_name)) {
199
0
      return NT_NONE;
200
0
    }
201
318
    for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
202
292
    }
203
26
    if (data == NT_COUNT) {
204
0
      tprintf("Invalid network layer type:%s\n", type_name.c_str());
205
0
      return NT_NONE;
206
0
    }
207
26
  }
208
26
  return static_cast<NetworkType>(data);
209
26
}
210
211
// Reads from the given file. Returns nullptr in case of error.
212
// Determines the type of the serialized class and calls its DeSerialize
213
// on a new object of the appropriate type, which is returned.
214
26
Network *Network::CreateFromFile(TFile *fp) {
215
26
  NetworkType type;       // Type of the derived network class.
216
26
  TrainingState training; // Are we currently training?
217
26
  bool needs_to_backprop; // This network needs to output back_deltas.
218
26
  int32_t network_flags;  // Behavior control flags in NetworkFlags.
219
26
  int32_t ni;             // Number of input values.
220
26
  int32_t no;             // Number of output values.
221
26
  int32_t num_weights;    // Number of weights in this and sub-network.
222
26
  std::string name;       // A unique name for this layer.
223
26
  int8_t data;
224
26
  Network *network = nullptr;
225
26
  type = getNetworkType(fp);
226
26
  if (!fp->DeSerialize(&data)) {
227
0
    return nullptr;
228
0
  }
229
26
  training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
230
26
  if (!fp->DeSerialize(&data)) {
231
0
    return nullptr;
232
0
  }
233
26
  needs_to_backprop = data != 0;
234
26
  if (!fp->DeSerialize(&network_flags)) {
235
0
    return nullptr;
236
0
  }
237
26
  if (!fp->DeSerialize(&ni)) {
238
0
    return nullptr;
239
0
  }
240
26
  if (!fp->DeSerialize(&no)) {
241
0
    return nullptr;
242
0
  }
243
26
  if (!fp->DeSerialize(&num_weights)) {
244
0
    return nullptr;
245
0
  }
246
26
  if (!fp->DeSerialize(name)) {
247
0
    return nullptr;
248
0
  }
249
250
26
  switch (type) {
251
2
    case NT_CONVOLVE:
252
2
      network = new Convolve(name, ni, 0, 0);
253
2
      break;
254
2
    case NT_INPUT:
255
2
      network = new Input(name, ni, no);
256
2
      break;
257
6
    case NT_LSTM:
258
6
    case NT_LSTM_SOFTMAX:
259
6
    case NT_LSTM_SOFTMAX_ENCODED:
260
8
    case NT_LSTM_SUMMARY:
261
8
      network = new LSTM(name, ni, no, no, false, type);
262
8
      break;
263
2
    case NT_MAXPOOL:
264
2
      network = new Maxpool(name, ni, 0, 0);
265
2
      break;
266
    // All variants of Parallel.
267
0
    case NT_PARALLEL:
268
0
    case NT_REPLICATED:
269
0
    case NT_PAR_RL_LSTM:
270
0
    case NT_PAR_UD_LSTM:
271
0
    case NT_PAR_2D_LSTM:
272
0
      network = new Parallel(name, type);
273
0
      break;
274
0
    case NT_RECONFIG:
275
0
      network = new Reconfig(name, ni, 0, 0);
276
0
      break;
277
    // All variants of reversed.
278
2
    case NT_XREVERSED:
279
2
    case NT_YREVERSED:
280
4
    case NT_XYTRANSPOSE:
281
4
      network = new Reversed(name, type);
282
4
      break;
283
4
    case NT_SERIES:
284
4
      network = new Series(name);
285
4
      break;
286
0
    case NT_TENSORFLOW:
287
0
      tprintf("Unsupported TensorFlow model\n");
288
0
      break;
289
    // All variants of FullyConnected.
290
2
    case NT_SOFTMAX:
291
2
    case NT_SOFTMAX_NO_CTC:
292
2
    case NT_RELU:
293
4
    case NT_TANH:
294
4
    case NT_LINEAR:
295
4
    case NT_LOGISTIC:
296
4
    case NT_POSCLIP:
297
4
    case NT_SYMCLIP:
298
4
      network = new FullyConnected(name, ni, no, type);
299
4
      break;
300
0
    default:
301
0
      break;
302
26
  }
303
26
  if (network) {
304
26
    network->training_ = training;
305
26
    network->needs_to_backprop_ = needs_to_backprop;
306
26
    network->network_flags_ = network_flags;
307
26
    network->num_weights_ = num_weights;
308
26
    if (!network->DeSerialize(fp)) {
309
0
      delete network;
310
0
      network = nullptr;
311
0
    }
312
26
  }
313
26
  return network;
314
26
}
315
316
// Returns a random number in [-range, range].
317
0
TFloat Network::Random(TFloat range) {
318
0
  ASSERT_HOST(randomizer_ != nullptr);
319
0
  return randomizer_->SignedRand(range);
320
0
}
321
322
#ifndef GRAPHICS_DISABLED
323
324
// === Debug image display methods. ===
325
// Displays the image of the matrix to the forward window.
326
void Network::DisplayForward(const NetworkIO &matrix) {
327
  Image image = matrix.ToPix();
328
  ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_);
329
  DisplayImage(image, forward_win_);
330
  forward_win_->Update();
331
}
332
333
// Displays the image of the matrix to the backward window.
334
void Network::DisplayBackward(const NetworkIO &matrix) {
335
  Image image = matrix.ToPix();
336
  std::string window_name = name_ + "-back";
337
  ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_);
338
  DisplayImage(image, backward_win_);
339
  backward_win_->Update();
340
}
341
342
// Creates the window if needed, otherwise clears it.
343
void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height,
344
                          ScrollView **window) {
345
  if (*window == nullptr) {
346
    int min_size = std::min(width, height);
347
    if (min_size < kMinWinSize) {
348
      if (min_size < 1) {
349
        min_size = 1;
350
      }
351
      width = width * kMinWinSize / min_size;
352
      height = height * kMinWinSize / min_size;
353
    }
354
    width += kXWinFrameSize;
355
    height += kYWinFrameSize;
356
    if (width > kMaxWinSize) {
357
      width = kMaxWinSize;
358
    }
359
    if (height > kMaxWinSize) {
360
      height = kMaxWinSize;
361
    }
362
    *window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords);
363
    tprintf("Created window %s of size %d, %d\n", window_name, width, height);
364
  } else {
365
    (*window)->Clear();
366
  }
367
}
368
369
// Displays the pix in the given window. and returns the height of the pix.
370
// The pix is pixDestroyed.
371
int Network::DisplayImage(Image pix, ScrollView *window) {
372
  int height = pixGetHeight(pix);
373
  window->Draw(pix, 0, 0);
374
  pix.destroy();
375
  return height;
376
}
377
#endif // !GRAPHICS_DISABLED
378
379
} // namespace tesseract.