/src/tesseract/src/lstm/maxpool.cpp
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: maxpool.cpp |
3 | | // Description: Standard Max-Pooling layer. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2014, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #include "maxpool.h" |
19 | | |
20 | | namespace tesseract { |
21 | | |
22 | | Maxpool::Maxpool(const std::string &name, int ni, int x_scale, int y_scale) |
23 | 4 | : Reconfig(name, ni, x_scale, y_scale) { |
24 | 4 | type_ = NT_MAXPOOL; |
25 | 4 | no_ = ni; |
26 | 4 | } |
27 | | |
28 | | // Reads from the given file. Returns false in case of error. |
29 | 4 | bool Maxpool::DeSerialize(TFile *fp) { |
30 | 4 | bool result = Reconfig::DeSerialize(fp); |
31 | 4 | no_ = ni_; |
32 | 4 | return result; |
33 | 4 | } |
34 | | |
35 | | // Runs forward propagation of activations on the input line. |
36 | | // See NetworkCpp for a detailed discussion of the arguments. |
37 | | void Maxpool::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
38 | 256k | NetworkScratch *scratch, NetworkIO *output) { |
39 | 256k | output->ResizeScaled(input, x_scale_, y_scale_, no_); |
40 | 256k | maxes_.ResizeNoInit(output->Width(), ni_); |
41 | 256k | back_map_ = input.stride_map(); |
42 | | |
43 | 256k | StrideMap::Index dest_index(output->stride_map()); |
44 | 33.8M | do { |
45 | 33.8M | int out_t = dest_index.t(); |
46 | 33.8M | StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH), |
47 | 33.8M | dest_index.index(FD_HEIGHT) * y_scale_, |
48 | 33.8M | dest_index.index(FD_WIDTH) * x_scale_); |
49 | | // Find the max input out of x_scale_ groups of y_scale_ inputs. |
50 | | // Do it independently for each input dimension. |
51 | 33.8M | int *max_line = maxes_[out_t]; |
52 | 33.8M | int in_t = src_index.t(); |
53 | 33.8M | output->CopyTimeStepFrom(out_t, input, in_t); |
54 | 575M | for (int i = 0; i < ni_; ++i) { |
55 | 542M | max_line[i] = in_t; |
56 | 542M | } |
57 | 135M | for (int x = 0; x < x_scale_; ++x) { |
58 | 406M | for (int y = 0; y < y_scale_; ++y) { |
59 | 304M | StrideMap::Index src_xy(src_index); |
60 | 304M | if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) { |
61 | 304M | output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line); |
62 | 304M | } |
63 | 304M | } |
64 | 101M | } |
65 | 33.8M | } while (dest_index.Increment()); |
66 | 256k | } |
67 | | |
68 | | // Runs backward propagation of errors on the deltas line. |
69 | | // See NetworkCpp for a detailed discussion of the arguments. |
70 | | bool Maxpool::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
71 | 0 | NetworkIO *back_deltas) { |
72 | 0 | back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_); |
73 | 0 | back_deltas->MaxpoolBackward(fwd_deltas, maxes_); |
74 | 0 | return true; |
75 | 0 | } |
76 | | |
77 | | } // namespace tesseract. |