/src/tesseract/src/lstm/weightmatrix.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: weightmatrix.h |
3 | | // Description: Hides distinction between float/int implementations. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2014, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_ |
19 | | #define TESSERACT_LSTM_WEIGHTMATRIX_H_ |
20 | | |
21 | | #include <memory> |
22 | | #include <vector> |
23 | | #include "intsimdmatrix.h" |
24 | | #include "matrix.h" |
25 | | #include "tesstypes.h" |
26 | | #include "tprintf.h" |
27 | | |
28 | | namespace tesseract { |
29 | | |
30 | | // Convenience instantiation of GENERIC_2D_ARRAY<TFloat> with additional |
31 | | // operations to write a strided vector, so the transposed form of the input |
32 | | // is memory-contiguous. |
33 | | class TransposedArray : public GENERIC_2D_ARRAY<TFloat> { |
34 | | public: |
35 | | // Copies the whole input transposed, converted to TFloat, into *this. |
36 | | void Transpose(const GENERIC_2D_ARRAY<TFloat> &input); |
37 | | // Writes a vector of data representing a timestep (gradients or sources). |
38 | | // The data is assumed to be of size1 in size (the strided dimension). |
39 | | ~TransposedArray() override; |
40 | 0 | void WriteStrided(int t, const float *data) { |
41 | 0 | int size1 = dim1(); |
42 | 0 | for (int i = 0; i < size1; ++i) { |
43 | 0 | put(i, t, data[i]); |
44 | 0 | } |
45 | 0 | } |
46 | 0 | void WriteStrided(int t, const double *data) { |
47 | 0 | int size1 = dim1(); |
48 | 0 | for (int i = 0; i < size1; ++i) { |
49 | 0 | put(i, t, data[i]); |
50 | 0 | } |
51 | 0 | } |
52 | | // Prints the first and last num elements of the un-transposed array. |
53 | 0 | void PrintUnTransposed(int num) { |
54 | 0 | int num_features = dim1(); |
55 | 0 | int width = dim2(); |
56 | 0 | for (int y = 0; y < num_features; ++y) { |
57 | 0 | for (int t = 0; t < width; ++t) { |
58 | 0 | if (num == 0 || t < num || t + num >= width) { |
59 | 0 | tprintf(" %g", static_cast<double>((*this)(y, t))); |
60 | 0 | } |
61 | 0 | } |
62 | 0 | tprintf("\n"); |
63 | 0 | } |
64 | 0 | } |
65 | | }; // class TransposedArray |
66 | | |
67 | | // Generic weight matrix for network layers. Can store the matrix as either |
68 | | // an array of floats or int8_t. Provides functions to compute the forward and |
69 | | // backward steps with the matrix and updates to the weights. |
70 | | class WeightMatrix { |
71 | | public: |
72 | 88 | WeightMatrix() : int_mode_(false), use_adam_(false) {} |
73 | | // Sets up the network for training. Initializes weights using weights of |
74 | | // scale `range` picked according to the random number generator `randomizer`. |
75 | | // Note the order is outputs, inputs, as this is the order of indices to |
76 | | // the matrix, so the adjacent elements are multiplied by the input during |
77 | | // a forward operation. |
78 | | int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer); |
79 | | // Changes the number of outputs to the size of the given code_map, copying |
80 | | // the old weight matrix entries for each output from code_map[output] where |
81 | | // non-negative, and uses the mean (over all outputs) of the existing weights |
82 | | // for all outputs with negative code_map entries. Returns the new number of |
83 | | // weights. |
84 | | int RemapOutputs(const std::vector<int> &code_map); |
85 | | |
86 | | // Converts a float network to an int network. Each set of input weights that |
87 | | // corresponds to a single output weight is converted independently: |
88 | | // Compute the max absolute value of the weight set. |
89 | | // Scale so the max absolute value becomes INT8_MAX. |
90 | | // Round to integer. |
91 | | // Store a multiplicative scale factor (as a float) that will reproduce |
92 | | // the original value, subject to rounding errors. |
93 | | void ConvertToInt(); |
94 | | // Returns the size rounded up to an internal factor used by the SIMD |
95 | | // implementation for its input. |
96 | 1.08M | int RoundInputs(int size) const { |
97 | 1.08M | if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) { |
98 | 0 | return size; |
99 | 0 | } |
100 | 1.08M | return IntSimdMatrix::intSimdMatrix->RoundInputs(size); |
101 | 1.08M | } |
102 | | |
103 | | // Accessors. |
104 | 0 | bool is_int_mode() const { |
105 | 0 | return int_mode_; |
106 | 0 | } |
107 | 16 | int NumOutputs() const { |
108 | 16 | return int_mode_ ? wi_.dim1() : wf_.dim1(); |
109 | 16 | } |
110 | | // Provides one set of weights. Only used by peep weight maxpool. |
111 | 0 | const TFloat *GetWeights(int index) const { |
112 | 0 | return wf_[index]; |
113 | 0 | } |
114 | | // Provides access to the deltas (dw_). |
115 | 0 | TFloat GetDW(int i, int j) const { |
116 | 0 | return dw_(i, j); |
117 | 0 | } |
118 | | |
119 | | // Allocates any needed memory for running Backward, and zeroes the deltas, |
120 | | // thus eliminating any existing momentum. |
121 | | void InitBackward(); |
122 | | |
123 | | // Writes to the given file. Returns false in case of error. |
124 | | bool Serialize(bool training, TFile *fp) const; |
125 | | // Reads from the given file. Returns false in case of error. |
126 | | bool DeSerialize(bool training, TFile *fp); |
127 | | // As DeSerialize, but reads an old (float) format WeightMatrix for |
128 | | // backward compatibility. |
129 | | bool DeSerializeOld(bool training, TFile *fp); |
130 | | |
131 | | // Computes matrix.vector v = Wu. |
132 | | // u is of size W.dim2() - 1 and the output v is of size W.dim1(). |
133 | | // u is imagined to have an extra element at the end with value 1, to |
134 | | // implement the bias, but it doesn't actually have it. |
135 | | // Asserts that the call matches what we have. |
136 | | void MatrixDotVector(const TFloat *u, TFloat *v) const; |
137 | | void MatrixDotVector(const int8_t *u, TFloat *v) const; |
138 | | // MatrixDotVector for peep weights, MultiplyAccumulate adds the |
139 | | // component-wise products of *this[0] and v to inout. |
140 | | void MultiplyAccumulate(const TFloat *v, TFloat *inout); |
141 | | // Computes vector.matrix v = uW. |
142 | | // u is of size W.dim1() and the output v is of size W.dim2() - 1. |
143 | | // The last result is discarded, as v is assumed to have an imaginary |
144 | | // last value of 1, as with MatrixDotVector. |
145 | | void VectorDotMatrix(const TFloat *u, TFloat *v) const; |
146 | | // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements |
147 | | // from u and v, starting with u[i][offset] and v[j][offset]. |
148 | | // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0. |
149 | | // Runs parallel if requested. Note that inputs must be transposed. |
150 | | void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel); |
151 | | // Updates the weights using the given learning rate, momentum and adam_beta. |
152 | | // num_samples is used in the Adam correction factor. |
153 | | void Update(float learning_rate, float momentum, float adam_beta, int num_samples); |
154 | | // Adds the dw_ in other to the dw_ is *this. |
155 | | void AddDeltas(const WeightMatrix &other); |
156 | | // Sums the products of weight updates in *this and other, splitting into |
157 | | // positive (same direction) in *same and negative (different direction) in |
158 | | // *changed. |
159 | | void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const; |
160 | | |
161 | | void Debug2D(const char *msg); |
162 | | |
163 | | private: |
164 | | // Choice between float and 8 bit int implementations. |
165 | | GENERIC_2D_ARRAY<TFloat> wf_; |
166 | | GENERIC_2D_ARRAY<int8_t> wi_; |
167 | | // Transposed copy of wf_, used only for Backward, and set with each Update. |
168 | | TransposedArray wf_t_; |
169 | | // Which of wf_ and wi_ are we actually using. |
170 | | bool int_mode_; |
171 | | // True if we are running adam in this weight matrix. |
172 | | bool use_adam_; |
173 | | // If we are using wi_, then scales_ is a factor to restore the row product |
174 | | // with a vector to the correct range. |
175 | | std::vector<TFloat> scales_; |
176 | | // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying |
177 | | // amount to be added to wf_/wi_. |
178 | | GENERIC_2D_ARRAY<TFloat> dw_; |
179 | | GENERIC_2D_ARRAY<TFloat> updates_; |
180 | | // Iff use_adam_, the sum of squares of dw_. The number of samples is |
181 | | // given to Update(). Serialized iff use_adam_. |
182 | | GENERIC_2D_ARRAY<TFloat> dw_sq_sum_; |
183 | | // The weights matrix reorganized in whatever way suits this instance. |
184 | | std::vector<int8_t> shaped_w_; |
185 | | }; |
186 | | |
187 | | } // namespace tesseract. |
188 | | |
189 | | #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_ |