/src/tesseract/src/lstm/fullyconnected.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: fullyconnected.cpp |
3 | | // Description: Simple feed-forward layer with various non-linearities. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2014, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifdef HAVE_CONFIG_H |
19 | | # include "config_auto.h" |
20 | | #endif |
21 | | |
22 | | #include "fullyconnected.h" |
23 | | |
24 | | #ifdef _OPENMP |
25 | | # include <omp.h> |
26 | | #endif |
27 | | #include <cstdio> |
28 | | #include <cstdlib> |
29 | | |
30 | | #include "functions.h" |
31 | | #include "networkscratch.h" |
32 | | |
33 | | // Number of threads to use for parallel calculation of Forward and Backward. |
34 | | #ifdef _OPENMP |
35 | | const int kNumThreads = 4; |
36 | | #else |
37 | | const int kNumThreads = 1; |
38 | | #endif |
39 | | |
40 | | namespace tesseract { |
41 | | |
42 | | FullyConnected::FullyConnected(const std::string &name, int ni, int no, NetworkType type) |
43 | 4 | : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {} |
44 | | |
45 | | // Returns the shape output from the network given an input shape (which may |
46 | | // be partially unknown ie zero). |
47 | 2 | StaticShape FullyConnected::OutputShape(const StaticShape &input_shape) const { |
48 | 2 | LossType loss_type = LT_NONE; |
49 | 2 | if (type_ == NT_SOFTMAX) { |
50 | 1 | loss_type = LT_CTC; |
51 | 1 | } else if (type_ == NT_SOFTMAX_NO_CTC) { |
52 | 0 | loss_type = LT_SOFTMAX; |
53 | 1 | } else if (type_ == NT_LOGISTIC) { |
54 | 0 | loss_type = LT_LOGISTIC; |
55 | 0 | } |
56 | 2 | StaticShape result(input_shape); |
57 | 2 | result.set_depth(no_); |
58 | 2 | result.set_loss_type(loss_type); |
59 | 2 | return result; |
60 | 2 | } |
61 | | |
62 | | // Suspends/Enables training by setting the training_ flag. |
63 | 0 | void FullyConnected::SetEnableTraining(TrainingState state) { |
64 | 0 | if (state == TS_RE_ENABLE) { |
65 | | // Enable only from temp disabled. |
66 | 0 | if (training_ == TS_TEMP_DISABLE) { |
67 | 0 | training_ = TS_ENABLED; |
68 | 0 | } |
69 | 0 | } else if (state == TS_TEMP_DISABLE) { |
70 | | // Temp disable only from enabled. |
71 | 0 | if (training_ == TS_ENABLED) { |
72 | 0 | training_ = state; |
73 | 0 | } |
74 | 0 | } else { |
75 | 0 | if (state == TS_ENABLED && training_ != TS_ENABLED) { |
76 | 0 | weights_.InitBackward(); |
77 | 0 | } |
78 | 0 | training_ = state; |
79 | 0 | } |
80 | 0 | } |
81 | | |
82 | | // Sets up the network for training. Initializes weights using weights of |
83 | | // scale `range` picked according to the random number generator `randomizer`. |
84 | 0 | int FullyConnected::InitWeights(float range, TRand *randomizer) { |
85 | 0 | Network::SetRandomizer(randomizer); |
86 | 0 | num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM), range, randomizer); |
87 | 0 | return num_weights_; |
88 | 0 | } |
89 | | |
90 | | // Recursively searches the network for softmaxes with old_no outputs, |
91 | | // and remaps their outputs according to code_map. See network.h for details. |
92 | | |
93 | 0 | int FullyConnected::RemapOutputs(int old_no, const std::vector<int> &code_map) { |
94 | 0 | if (type_ == NT_SOFTMAX && no_ == old_no) { |
95 | 0 | num_weights_ = weights_.RemapOutputs(code_map); |
96 | 0 | no_ = code_map.size(); |
97 | 0 | } |
98 | 0 | return num_weights_; |
99 | 0 | } |
100 | | |
101 | | // Converts a float network to an int network. |
102 | 0 | void FullyConnected::ConvertToInt() { |
103 | 0 | weights_.ConvertToInt(); |
104 | 0 | } |
105 | | |
106 | | // Provides debug output on the weights. |
107 | 0 | void FullyConnected::DebugWeights() { |
108 | 0 | weights_.Debug2D(name_.c_str()); |
109 | 0 | } |
110 | | |
111 | | // Writes to the given file. Returns false in case of error. |
112 | 0 | bool FullyConnected::Serialize(TFile *fp) const { |
113 | 0 | if (!Network::Serialize(fp)) { |
114 | 0 | return false; |
115 | 0 | } |
116 | 0 | if (!weights_.Serialize(IsTraining(), fp)) { |
117 | 0 | return false; |
118 | 0 | } |
119 | 0 | return true; |
120 | 0 | } |
121 | | |
122 | | // Reads from the given file. Returns false in case of error. |
123 | 4 | bool FullyConnected::DeSerialize(TFile *fp) { |
124 | 4 | return weights_.DeSerialize(IsTraining(), fp); |
125 | 4 | } |
126 | | |
127 | | // Runs forward propagation of activations on the input line. |
128 | | // See NetworkCpp for a detailed discussion of the arguments. |
129 | | void FullyConnected::Forward(bool debug, const NetworkIO &input, |
130 | | const TransposedArray *input_transpose, NetworkScratch *scratch, |
131 | 172k | NetworkIO *output) { |
132 | 172k | int width = input.Width(); |
133 | 172k | if (type_ == NT_SOFTMAX) { |
134 | 86.2k | output->ResizeFloat(input, no_); |
135 | 86.2k | } else { |
136 | 86.2k | output->Resize(input, no_); |
137 | 86.2k | } |
138 | 172k | SetupForward(input, input_transpose); |
139 | 172k | std::vector<NetworkScratch::FloatVec> temp_lines(kNumThreads); |
140 | 172k | std::vector<NetworkScratch::FloatVec> curr_input(kNumThreads); |
141 | 172k | int ro = no_; |
142 | 172k | if (IntSimdMatrix::intSimdMatrix) { |
143 | 172k | ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro); |
144 | 172k | } |
145 | 345k | for (int i = 0; i < kNumThreads; ++i) { |
146 | 172k | temp_lines[i].Init(ro, scratch); |
147 | 172k | curr_input[i].Init(ni_, scratch); |
148 | 172k | } |
149 | | #ifdef _OPENMP |
150 | | # pragma omp parallel for num_threads(kNumThreads) |
151 | | for (int t = 0; t < width; ++t) { |
152 | | // Thread-local pointer to temporary storage. |
153 | | int thread_id = omp_get_thread_num(); |
154 | | #else |
155 | 85.5M | for (int t = 0; t < width; ++t) { |
156 | | // Thread-local pointer to temporary storage. |
157 | 85.3M | int thread_id = 0; |
158 | 85.3M | #endif |
159 | 85.3M | TFloat *temp_line = temp_lines[thread_id]; |
160 | 85.3M | if (input.int_mode()) { |
161 | 85.3M | ForwardTimeStep(input.i(t), t, temp_line); |
162 | 85.3M | } else { |
163 | 0 | input.ReadTimeStep(t, curr_input[thread_id]); |
164 | 0 | ForwardTimeStep(curr_input[thread_id], t, temp_line); |
165 | 0 | } |
166 | 85.3M | output->WriteTimeStep(t, temp_line); |
167 | 85.3M | if (IsTraining() && type_ != NT_SOFTMAX) { |
168 | 0 | acts_.CopyTimeStepFrom(t, *output, t); |
169 | 0 | } |
170 | 85.3M | } |
171 | | // Zero all the elements that are in the padding around images that allows |
172 | | // multiple different-sized images to exist in a single array. |
173 | | // acts_ is only used if this is not a softmax op. |
174 | 172k | if (IsTraining() && type_ != NT_SOFTMAX) { |
175 | 0 | acts_.ZeroInvalidElements(); |
176 | 0 | } |
177 | 172k | output->ZeroInvalidElements(); |
178 | | #if DEBUG_DETAIL > 0 |
179 | | tprintf("F Output:%s\n", name_.c_str()); |
180 | | output->Print(10); |
181 | | #endif |
182 | | #ifndef GRAPHICS_DISABLED |
183 | | if (debug) { |
184 | | DisplayForward(*output); |
185 | | } |
186 | | #endif |
187 | 172k | } |
188 | | |
189 | | // Components of Forward so FullyConnected can be reused inside LSTM. |
190 | 172k | void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray *input_transpose) { |
191 | | // Softmax output is always float, so save the input type. |
192 | 172k | int_mode_ = input.int_mode(); |
193 | 172k | if (IsTraining()) { |
194 | 0 | acts_.Resize(input, no_); |
195 | | // Source_ is a transposed copy of input. It isn't needed if provided. |
196 | 0 | external_source_ = input_transpose; |
197 | 0 | if (external_source_ == nullptr) { |
198 | 0 | source_t_.ResizeNoInit(ni_, input.Width()); |
199 | 0 | } |
200 | 0 | } |
201 | 172k | } |
202 | | |
203 | 85.3M | void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) { |
204 | 85.3M | if (type_ == NT_TANH) { |
205 | 84.5M | FuncInplace<GFunc>(no_, output_line); |
206 | 84.5M | } else if (type_ == NT_LOGISTIC) { |
207 | 0 | FuncInplace<FFunc>(no_, output_line); |
208 | 755k | } else if (type_ == NT_POSCLIP) { |
209 | 0 | FuncInplace<ClipFFunc>(no_, output_line); |
210 | 755k | } else if (type_ == NT_SYMCLIP) { |
211 | 0 | FuncInplace<ClipGFunc>(no_, output_line); |
212 | 755k | } else if (type_ == NT_RELU) { |
213 | 0 | FuncInplace<Relu>(no_, output_line); |
214 | 755k | } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) { |
215 | 755k | SoftmaxInPlace(no_, output_line); |
216 | 755k | } else if (type_ != NT_LINEAR) { |
217 | 0 | ASSERT_HOST("Invalid fully-connected type!" == nullptr); |
218 | 0 | } |
219 | 85.3M | } |
220 | | |
221 | 0 | void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) { |
222 | | // input is copied to source_ line-by-line for cache coherency. |
223 | 0 | if (IsTraining() && external_source_ == nullptr) { |
224 | 0 | source_t_.WriteStrided(t, d_input); |
225 | 0 | } |
226 | 0 | weights_.MatrixDotVector(d_input, output_line); |
227 | 0 | ForwardTimeStep(t, output_line); |
228 | 0 | } |
229 | | |
230 | 85.3M | void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) { |
231 | | // input is copied to source_ line-by-line for cache coherency. |
232 | 85.3M | weights_.MatrixDotVector(i_input, output_line); |
233 | 85.3M | ForwardTimeStep(t, output_line); |
234 | 85.3M | } |
235 | | |
236 | | // Runs backward propagation of errors on the deltas line. |
237 | | // See NetworkCpp for a detailed discussion of the arguments. |
238 | | bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
239 | 0 | NetworkIO *back_deltas) { |
240 | | #ifndef GRAPHICS_DISABLED |
241 | | if (debug) { |
242 | | DisplayBackward(fwd_deltas); |
243 | | } |
244 | | #endif |
245 | 0 | back_deltas->Resize(fwd_deltas, ni_); |
246 | 0 | std::vector<NetworkScratch::FloatVec> errors(kNumThreads); |
247 | 0 | for (int i = 0; i < kNumThreads; ++i) { |
248 | 0 | errors[i].Init(no_, scratch); |
249 | 0 | } |
250 | 0 | std::vector<NetworkScratch::FloatVec> temp_backprops; |
251 | 0 | if (needs_to_backprop_) { |
252 | 0 | temp_backprops.resize(kNumThreads); |
253 | 0 | for (int i = 0; i < kNumThreads; ++i) { |
254 | 0 | temp_backprops[i].Init(ni_, scratch); |
255 | 0 | } |
256 | 0 | } |
257 | 0 | int width = fwd_deltas.Width(); |
258 | 0 | NetworkScratch::GradientStore errors_t; |
259 | 0 | errors_t.Init(no_, width, scratch); |
260 | | #ifdef _OPENMP |
261 | | # pragma omp parallel for num_threads(kNumThreads) |
262 | | for (int t = 0; t < width; ++t) { |
263 | | int thread_id = omp_get_thread_num(); |
264 | | #else |
265 | 0 | for (int t = 0; t < width; ++t) { |
266 | 0 | int thread_id = 0; |
267 | 0 | #endif |
268 | 0 | TFloat *backprop = nullptr; |
269 | 0 | if (needs_to_backprop_) { |
270 | 0 | backprop = temp_backprops[thread_id]; |
271 | 0 | } |
272 | 0 | TFloat *curr_errors = errors[thread_id]; |
273 | 0 | BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop); |
274 | 0 | if (backprop != nullptr) { |
275 | 0 | back_deltas->WriteTimeStep(t, backprop); |
276 | 0 | } |
277 | 0 | } |
278 | 0 | FinishBackward(*errors_t.get()); |
279 | 0 | if (needs_to_backprop_) { |
280 | 0 | back_deltas->ZeroInvalidElements(); |
281 | | #if DEBUG_DETAIL > 0 |
282 | | tprintf("F Backprop:%s\n", name_.c_str()); |
283 | | back_deltas->Print(10); |
284 | | #endif |
285 | 0 | return true; |
286 | 0 | } |
287 | 0 | return false; // No point going further back. |
288 | 0 | } |
289 | | |
290 | | void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, |
291 | 0 | TransposedArray *errors_t, TFloat *backprop) { |
292 | 0 | if (type_ == NT_TANH) { |
293 | 0 | acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors); |
294 | 0 | } else if (type_ == NT_LOGISTIC) { |
295 | 0 | acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors); |
296 | 0 | } else if (type_ == NT_POSCLIP) { |
297 | 0 | acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors); |
298 | 0 | } else if (type_ == NT_SYMCLIP) { |
299 | 0 | acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors); |
300 | 0 | } else if (type_ == NT_RELU) { |
301 | 0 | acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors); |
302 | 0 | } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || type_ == NT_LINEAR) { |
303 | 0 | fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors. |
304 | 0 | } else { |
305 | 0 | ASSERT_HOST("Invalid fully-connected type!" == nullptr); |
306 | 0 | } |
307 | | // Generate backprop only if needed by the lower layer. |
308 | 0 | if (backprop != nullptr) { |
309 | 0 | weights_.VectorDotMatrix(curr_errors, backprop); |
310 | 0 | } |
311 | 0 | errors_t->WriteStrided(t, curr_errors); |
312 | 0 | } |
313 | | |
314 | 0 | void FullyConnected::FinishBackward(const TransposedArray &errors_t) { |
315 | 0 | if (external_source_ == nullptr) { |
316 | 0 | weights_.SumOuterTransposed(errors_t, source_t_, true); |
317 | 0 | } else { |
318 | 0 | weights_.SumOuterTransposed(errors_t, *external_source_, true); |
319 | 0 | } |
320 | 0 | } |
321 | | |
322 | | // Updates the weights using the given learning rate, momentum and adam_beta. |
323 | | // num_samples is used in the adam computation iff use_adam_ is true. |
324 | 0 | void FullyConnected::Update(float learning_rate, float momentum, float adam_beta, int num_samples) { |
325 | 0 | weights_.Update(learning_rate, momentum, adam_beta, num_samples); |
326 | 0 | } |
327 | | |
328 | | // Sums the products of weight updates in *this and other, splitting into |
329 | | // positive (same direction) in *same and negative (different direction) in |
330 | | // *changed. |
331 | 0 | void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { |
332 | 0 | ASSERT_HOST(other.type() == type_); |
333 | 0 | const auto *fc = static_cast<const FullyConnected *>(&other); |
334 | 0 | weights_.CountAlternators(fc->weights_, same, changed); |
335 | 0 | } |
336 | | |
337 | | } // namespace tesseract. |