Coverage Report

Created: 2025-11-16 06:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/tesseract/src/lstm/weightmatrix.h
Line
Count
Source
1
///////////////////////////////////////////////////////////////////////
2
// File:        weightmatrix.h
3
// Description: Hides distinction between float/int implementations.
4
// Author:      Ray Smith
5
//
6
// (C) Copyright 2014, Google Inc.
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
// you may not use this file except in compliance with the License.
9
// You may obtain a copy of the License at
10
// http://www.apache.org/licenses/LICENSE-2.0
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
///////////////////////////////////////////////////////////////////////
17
18
#ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
19
#define TESSERACT_LSTM_WEIGHTMATRIX_H_
20
21
#include <memory>
22
#include <vector>
23
#include "intsimdmatrix.h"
24
#include "matrix.h"
25
#include "tesstypes.h"
26
#include "tprintf.h"
27
28
namespace tesseract {
29
30
// Convenience instantiation of GENERIC_2D_ARRAY<TFloat> with additional
31
// operations to write a strided vector, so the transposed form of the input
32
// is memory-contiguous.
33
class TransposedArray : public GENERIC_2D_ARRAY<TFloat> {
34
public:
35
  // Copies the whole input transposed, converted to TFloat, into *this.
36
  void Transpose(const GENERIC_2D_ARRAY<TFloat> &input);
37
  // Writes a vector of data representing a timestep (gradients or sources).
38
  // The data is assumed to be of size1 in size (the strided dimension).
39
  ~TransposedArray() override;
40
0
  void WriteStrided(int t, const float *data) {
41
0
    int size1 = dim1();
42
0
    for (int i = 0; i < size1; ++i) {
43
0
      put(i, t, data[i]);
44
0
    }
45
0
  }
46
0
  void WriteStrided(int t, const double *data) {
47
0
    int size1 = dim1();
48
0
    for (int i = 0; i < size1; ++i) {
49
0
      put(i, t, data[i]);
50
0
    }
51
0
  }
52
  // Prints the first and last num elements of the un-transposed array.
53
0
  void PrintUnTransposed(int num) {
54
0
    int num_features = dim1();
55
0
    int width = dim2();
56
0
    for (int y = 0; y < num_features; ++y) {
57
0
      for (int t = 0; t < width; ++t) {
58
0
        if (num == 0 || t < num || t + num >= width) {
59
0
          tprintf(" %g", static_cast<double>((*this)(y, t)));
60
0
        }
61
0
      }
62
0
      tprintf("\n");
63
0
    }
64
0
  }
65
}; // class TransposedArray
66
67
// Generic weight matrix for network layers. Can store the matrix as either
68
// an array of floats or int8_t. Provides functions to compute the forward and
69
// backward steps with the matrix and updates to the weights.
70
class WeightMatrix {
71
public:
72
88
  WeightMatrix() : int_mode_(false), use_adam_(false) {}
73
  // Sets up the network for training. Initializes weights using weights of
74
  // scale `range` picked according to the random number generator `randomizer`.
75
  // Note the order is outputs, inputs, as this is the order of indices to
76
  // the matrix, so the adjacent elements are multiplied by the input during
77
  // a forward operation.
78
  int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer);
79
  // Changes the number of outputs to the size of the given code_map, copying
80
  // the old weight matrix entries for each output from code_map[output] where
81
  // non-negative, and uses the mean (over all outputs) of the existing weights
82
  // for all outputs with negative code_map entries. Returns the new number of
83
  // weights.
84
  int RemapOutputs(const std::vector<int> &code_map);
85
86
  // Converts a float network to an int network. Each set of input weights that
87
  // corresponds to a single output weight is converted independently:
88
  // Compute the max absolute value of the weight set.
89
  // Scale so the max absolute value becomes INT8_MAX.
90
  // Round to integer.
91
  // Store a multiplicative scale factor (as a float) that will reproduce
92
  // the original value, subject to rounding errors.
93
  void ConvertToInt();
94
  // Returns the size rounded up to an internal factor used by the SIMD
95
  // implementation for its input.
96
1.08M
  int RoundInputs(int size) const {
97
1.08M
    if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) {
98
0
      return size;
99
0
    }
100
1.08M
    return IntSimdMatrix::intSimdMatrix->RoundInputs(size);
101
1.08M
  }
102
103
  // Accessors.
104
0
  bool is_int_mode() const {
105
0
    return int_mode_;
106
0
  }
107
16
  int NumOutputs() const {
108
16
    return int_mode_ ? wi_.dim1() : wf_.dim1();
109
16
  }
110
  // Provides one set of weights. Only used by peep weight maxpool.
111
0
  const TFloat *GetWeights(int index) const {
112
0
    return wf_[index];
113
0
  }
114
  // Provides access to the deltas (dw_).
115
0
  TFloat GetDW(int i, int j) const {
116
0
    return dw_(i, j);
117
0
  }
118
119
  // Allocates any needed memory for running Backward, and zeroes the deltas,
120
  // thus eliminating any existing momentum.
121
  void InitBackward();
122
123
  // Writes to the given file. Returns false in case of error.
124
  bool Serialize(bool training, TFile *fp) const;
125
  // Reads from the given file. Returns false in case of error.
126
  bool DeSerialize(bool training, TFile *fp);
127
  // As DeSerialize, but reads an old (float) format WeightMatrix for
128
  // backward compatibility.
129
  bool DeSerializeOld(bool training, TFile *fp);
130
131
  // Computes matrix.vector v = Wu.
132
  // u is of size W.dim2() - 1 and the output v is of size W.dim1().
133
  // u is imagined to have an extra element at the end with value 1, to
134
  // implement the bias, but it doesn't actually have it.
135
  // Asserts that the call matches what we have.
136
  void MatrixDotVector(const TFloat *u, TFloat *v) const;
137
  void MatrixDotVector(const int8_t *u, TFloat *v) const;
138
  // MatrixDotVector for peep weights, MultiplyAccumulate adds the
139
  // component-wise products of *this[0] and v to inout.
140
  void MultiplyAccumulate(const TFloat *v, TFloat *inout);
141
  // Computes vector.matrix v = uW.
142
  // u is of size W.dim1() and the output v is of size W.dim2() - 1.
143
  // The last result is discarded, as v is assumed to have an imaginary
144
  // last value of 1, as with MatrixDotVector.
145
  void VectorDotMatrix(const TFloat *u, TFloat *v) const;
146
  // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
147
  // from u and v, starting with u[i][offset] and v[j][offset].
148
  // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
149
  // Runs parallel if requested. Note that inputs must be transposed.
150
  void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel);
151
  // Updates the weights using the given learning rate, momentum and adam_beta.
152
  // num_samples is used in the Adam correction factor.
153
  void Update(float learning_rate, float momentum, float adam_beta, int num_samples);
154
  // Adds the dw_ in other to the dw_ is *this.
155
  void AddDeltas(const WeightMatrix &other);
156
  // Sums the products of weight updates in *this and other, splitting into
157
  // positive (same direction) in *same and negative (different direction) in
158
  // *changed.
159
  void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const;
160
161
  void Debug2D(const char *msg);
162
163
private:
164
  // Choice between float and 8 bit int implementations.
165
  GENERIC_2D_ARRAY<TFloat> wf_;
166
  GENERIC_2D_ARRAY<int8_t> wi_;
167
  // Transposed copy of wf_, used only for Backward, and set with each Update.
168
  TransposedArray wf_t_;
169
  // Which of wf_ and wi_ are we actually using.
170
  bool int_mode_;
171
  // True if we are running adam in this weight matrix.
172
  bool use_adam_;
173
  // If we are using wi_, then scales_ is a factor to restore the row product
174
  // with a vector to the correct range.
175
  std::vector<TFloat> scales_;
176
  // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
177
  // amount to be added to wf_/wi_.
178
  GENERIC_2D_ARRAY<TFloat> dw_;
179
  GENERIC_2D_ARRAY<TFloat> updates_;
180
  // Iff use_adam_, the sum of squares of dw_. The number of samples is
181
  // given to Update(). Serialized iff use_adam_.
182
  GENERIC_2D_ARRAY<TFloat> dw_sq_sum_;
183
  // The weights matrix reorganized in whatever way suits this instance.
184
  std::vector<int8_t> shaped_w_;
185
};
186
187
} // namespace tesseract.
188
189
#endif // TESSERACT_LSTM_WEIGHTMATRIX_H_