Coverage Report

Created: 2025-06-13 07:02

/src/tesseract/src/lstm/fullyconnected.cpp
Line
Count
Source (jump to first uncovered line)
1
///////////////////////////////////////////////////////////////////////
2
// File:        fullyconnected.cpp
3
// Description: Simple feed-forward layer with various non-linearities.
4
// Author:      Ray Smith
5
//
6
// (C) Copyright 2014, Google Inc.
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
// you may not use this file except in compliance with the License.
9
// You may obtain a copy of the License at
10
// http://www.apache.org/licenses/LICENSE-2.0
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
///////////////////////////////////////////////////////////////////////
17
18
#ifdef HAVE_CONFIG_H
19
#  include "config_auto.h"
20
#endif
21
22
#include "fullyconnected.h"
23
24
#ifdef _OPENMP
25
#  include <omp.h>
26
#endif
27
#include <cstdio>
28
#include <cstdlib>
29
30
#include "functions.h"
31
#include "networkscratch.h"
32
33
// Number of threads to use for parallel calculation of Forward and Backward.
34
#ifdef _OPENMP
35
const int kNumThreads = 4;
36
#else
37
const int kNumThreads = 1;
38
#endif
39
40
namespace tesseract {
41
42
FullyConnected::FullyConnected(const std::string &name, int ni, int no, NetworkType type)
43
4
    : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {}
44
45
// Returns the shape output from the network given an input shape (which may
46
// be partially unknown ie zero).
47
2
StaticShape FullyConnected::OutputShape(const StaticShape &input_shape) const {
48
2
  LossType loss_type = LT_NONE;
49
2
  if (type_ == NT_SOFTMAX) {
50
1
    loss_type = LT_CTC;
51
1
  } else if (type_ == NT_SOFTMAX_NO_CTC) {
52
0
    loss_type = LT_SOFTMAX;
53
1
  } else if (type_ == NT_LOGISTIC) {
54
0
    loss_type = LT_LOGISTIC;
55
0
  }
56
2
  StaticShape result(input_shape);
57
2
  result.set_depth(no_);
58
2
  result.set_loss_type(loss_type);
59
2
  return result;
60
2
}
61
62
// Suspends/Enables training by setting the training_ flag.
63
0
void FullyConnected::SetEnableTraining(TrainingState state) {
64
0
  if (state == TS_RE_ENABLE) {
65
    // Enable only from temp disabled.
66
0
    if (training_ == TS_TEMP_DISABLE) {
67
0
      training_ = TS_ENABLED;
68
0
    }
69
0
  } else if (state == TS_TEMP_DISABLE) {
70
    // Temp disable only from enabled.
71
0
    if (training_ == TS_ENABLED) {
72
0
      training_ = state;
73
0
    }
74
0
  } else {
75
0
    if (state == TS_ENABLED && training_ != TS_ENABLED) {
76
0
      weights_.InitBackward();
77
0
    }
78
0
    training_ = state;
79
0
  }
80
0
}
81
82
// Sets up the network for training. Initializes weights using weights of
83
// scale `range` picked according to the random number generator `randomizer`.
84
0
int FullyConnected::InitWeights(float range, TRand *randomizer) {
85
0
  Network::SetRandomizer(randomizer);
86
0
  num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM), range, randomizer);
87
0
  return num_weights_;
88
0
}
89
90
// Recursively searches the network for softmaxes with old_no outputs,
91
// and remaps their outputs according to code_map. See network.h for details.
92
93
0
int FullyConnected::RemapOutputs(int old_no, const std::vector<int> &code_map) {
94
0
  if (type_ == NT_SOFTMAX && no_ == old_no) {
95
0
    num_weights_ = weights_.RemapOutputs(code_map);
96
0
    no_ = code_map.size();
97
0
  }
98
0
  return num_weights_;
99
0
}
100
101
// Converts a float network to an int network.
102
0
void FullyConnected::ConvertToInt() {
103
0
  weights_.ConvertToInt();
104
0
}
105
106
// Provides debug output on the weights.
107
0
void FullyConnected::DebugWeights() {
108
0
  weights_.Debug2D(name_.c_str());
109
0
}
110
111
// Writes to the given file. Returns false in case of error.
112
0
bool FullyConnected::Serialize(TFile *fp) const {
113
0
  if (!Network::Serialize(fp)) {
114
0
    return false;
115
0
  }
116
0
  if (!weights_.Serialize(IsTraining(), fp)) {
117
0
    return false;
118
0
  }
119
0
  return true;
120
0
}
121
122
// Reads from the given file. Returns false in case of error.
123
4
bool FullyConnected::DeSerialize(TFile *fp) {
124
4
  return weights_.DeSerialize(IsTraining(), fp);
125
4
}
126
127
// Runs forward propagation of activations on the input line.
128
// See NetworkCpp for a detailed discussion of the arguments.
129
void FullyConnected::Forward(bool debug, const NetworkIO &input,
130
                             const TransposedArray *input_transpose, NetworkScratch *scratch,
131
172k
                             NetworkIO *output) {
132
172k
  int width = input.Width();
133
172k
  if (type_ == NT_SOFTMAX) {
134
86.2k
    output->ResizeFloat(input, no_);
135
86.2k
  } else {
136
86.2k
    output->Resize(input, no_);
137
86.2k
  }
138
172k
  SetupForward(input, input_transpose);
139
172k
  std::vector<NetworkScratch::FloatVec> temp_lines(kNumThreads);
140
172k
  std::vector<NetworkScratch::FloatVec> curr_input(kNumThreads);
141
172k
  int ro = no_;
142
172k
  if (IntSimdMatrix::intSimdMatrix) {
143
172k
    ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
144
172k
  }
145
345k
  for (int i = 0; i < kNumThreads; ++i) {
146
172k
    temp_lines[i].Init(ro, scratch);
147
172k
    curr_input[i].Init(ni_, scratch);
148
172k
  }
149
#ifdef _OPENMP
150
#  pragma omp parallel for num_threads(kNumThreads)
151
  for (int t = 0; t < width; ++t) {
152
    // Thread-local pointer to temporary storage.
153
    int thread_id = omp_get_thread_num();
154
#else
155
85.5M
  for (int t = 0; t < width; ++t) {
156
    // Thread-local pointer to temporary storage.
157
85.3M
    int thread_id = 0;
158
85.3M
#endif
159
85.3M
    TFloat *temp_line = temp_lines[thread_id];
160
85.3M
    if (input.int_mode()) {
161
85.3M
      ForwardTimeStep(input.i(t), t, temp_line);
162
85.3M
    } else {
163
0
      input.ReadTimeStep(t, curr_input[thread_id]);
164
0
      ForwardTimeStep(curr_input[thread_id], t, temp_line);
165
0
    }
166
85.3M
    output->WriteTimeStep(t, temp_line);
167
85.3M
    if (IsTraining() && type_ != NT_SOFTMAX) {
168
0
      acts_.CopyTimeStepFrom(t, *output, t);
169
0
    }
170
85.3M
  }
171
  // Zero all the elements that are in the padding around images that allows
172
  // multiple different-sized images to exist in a single array.
173
  // acts_ is only used if this is not a softmax op.
174
172k
  if (IsTraining() && type_ != NT_SOFTMAX) {
175
0
    acts_.ZeroInvalidElements();
176
0
  }
177
172k
  output->ZeroInvalidElements();
178
#if DEBUG_DETAIL > 0
179
  tprintf("F Output:%s\n", name_.c_str());
180
  output->Print(10);
181
#endif
182
#ifndef GRAPHICS_DISABLED
183
  if (debug) {
184
    DisplayForward(*output);
185
  }
186
#endif
187
172k
}
188
189
// Components of Forward so FullyConnected can be reused inside LSTM.
190
172k
void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray *input_transpose) {
191
  // Softmax output is always float, so save the input type.
192
172k
  int_mode_ = input.int_mode();
193
172k
  if (IsTraining()) {
194
0
    acts_.Resize(input, no_);
195
    // Source_ is a transposed copy of input. It isn't needed if provided.
196
0
    external_source_ = input_transpose;
197
0
    if (external_source_ == nullptr) {
198
0
      source_t_.ResizeNoInit(ni_, input.Width());
199
0
    }
200
0
  }
201
172k
}
202
203
85.3M
void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) {
204
85.3M
  if (type_ == NT_TANH) {
205
84.5M
    FuncInplace<GFunc>(no_, output_line);
206
84.5M
  } else if (type_ == NT_LOGISTIC) {
207
0
    FuncInplace<FFunc>(no_, output_line);
208
755k
  } else if (type_ == NT_POSCLIP) {
209
0
    FuncInplace<ClipFFunc>(no_, output_line);
210
755k
  } else if (type_ == NT_SYMCLIP) {
211
0
    FuncInplace<ClipGFunc>(no_, output_line);
212
755k
  } else if (type_ == NT_RELU) {
213
0
    FuncInplace<Relu>(no_, output_line);
214
755k
  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
215
755k
    SoftmaxInPlace(no_, output_line);
216
755k
  } else if (type_ != NT_LINEAR) {
217
0
    ASSERT_HOST("Invalid fully-connected type!" == nullptr);
218
0
  }
219
85.3M
}
220
221
0
void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) {
222
  // input is copied to source_ line-by-line for cache coherency.
223
0
  if (IsTraining() && external_source_ == nullptr) {
224
0
    source_t_.WriteStrided(t, d_input);
225
0
  }
226
0
  weights_.MatrixDotVector(d_input, output_line);
227
0
  ForwardTimeStep(t, output_line);
228
0
}
229
230
85.3M
void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) {
231
  // input is copied to source_ line-by-line for cache coherency.
232
85.3M
  weights_.MatrixDotVector(i_input, output_line);
233
85.3M
  ForwardTimeStep(t, output_line);
234
85.3M
}
235
236
// Runs backward propagation of errors on the deltas line.
237
// See NetworkCpp for a detailed discussion of the arguments.
238
bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
239
0
                              NetworkIO *back_deltas) {
240
#ifndef GRAPHICS_DISABLED
241
  if (debug) {
242
    DisplayBackward(fwd_deltas);
243
  }
244
#endif
245
0
  back_deltas->Resize(fwd_deltas, ni_);
246
0
  std::vector<NetworkScratch::FloatVec> errors(kNumThreads);
247
0
  for (int i = 0; i < kNumThreads; ++i) {
248
0
    errors[i].Init(no_, scratch);
249
0
  }
250
0
  std::vector<NetworkScratch::FloatVec> temp_backprops;
251
0
  if (needs_to_backprop_) {
252
0
    temp_backprops.resize(kNumThreads);
253
0
    for (int i = 0; i < kNumThreads; ++i) {
254
0
      temp_backprops[i].Init(ni_, scratch);
255
0
    }
256
0
  }
257
0
  int width = fwd_deltas.Width();
258
0
  NetworkScratch::GradientStore errors_t;
259
0
  errors_t.Init(no_, width, scratch);
260
#ifdef _OPENMP
261
#  pragma omp parallel for num_threads(kNumThreads)
262
  for (int t = 0; t < width; ++t) {
263
    int thread_id = omp_get_thread_num();
264
#else
265
0
  for (int t = 0; t < width; ++t) {
266
0
    int thread_id = 0;
267
0
#endif
268
0
    TFloat *backprop = nullptr;
269
0
    if (needs_to_backprop_) {
270
0
      backprop = temp_backprops[thread_id];
271
0
    }
272
0
    TFloat *curr_errors = errors[thread_id];
273
0
    BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
274
0
    if (backprop != nullptr) {
275
0
      back_deltas->WriteTimeStep(t, backprop);
276
0
    }
277
0
  }
278
0
  FinishBackward(*errors_t.get());
279
0
  if (needs_to_backprop_) {
280
0
    back_deltas->ZeroInvalidElements();
281
#if DEBUG_DETAIL > 0
282
    tprintf("F Backprop:%s\n", name_.c_str());
283
    back_deltas->Print(10);
284
#endif
285
0
    return true;
286
0
  }
287
0
  return false; // No point going further back.
288
0
}
289
290
void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors,
291
0
                                      TransposedArray *errors_t, TFloat *backprop) {
292
0
  if (type_ == NT_TANH) {
293
0
    acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
294
0
  } else if (type_ == NT_LOGISTIC) {
295
0
    acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
296
0
  } else if (type_ == NT_POSCLIP) {
297
0
    acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
298
0
  } else if (type_ == NT_SYMCLIP) {
299
0
    acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
300
0
  } else if (type_ == NT_RELU) {
301
0
    acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
302
0
  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || type_ == NT_LINEAR) {
303
0
    fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
304
0
  } else {
305
0
    ASSERT_HOST("Invalid fully-connected type!" == nullptr);
306
0
  }
307
  // Generate backprop only if needed by the lower layer.
308
0
  if (backprop != nullptr) {
309
0
    weights_.VectorDotMatrix(curr_errors, backprop);
310
0
  }
311
0
  errors_t->WriteStrided(t, curr_errors);
312
0
}
313
314
0
void FullyConnected::FinishBackward(const TransposedArray &errors_t) {
315
0
  if (external_source_ == nullptr) {
316
0
    weights_.SumOuterTransposed(errors_t, source_t_, true);
317
0
  } else {
318
0
    weights_.SumOuterTransposed(errors_t, *external_source_, true);
319
0
  }
320
0
}
321
322
// Updates the weights using the given learning rate, momentum and adam_beta.
323
// num_samples is used in the adam computation iff use_adam_ is true.
324
0
void FullyConnected::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
325
0
  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
326
0
}
327
328
// Sums the products of weight updates in *this and other, splitting into
329
// positive (same direction) in *same and negative (different direction) in
330
// *changed.
331
0
void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
332
0
  ASSERT_HOST(other.type() == type_);
333
0
  const auto *fc = static_cast<const FullyConnected *>(&other);
334
0
  weights_.CountAlternators(fc->weights_, same, changed);
335
0
}
336
337
} // namespace tesseract.