/src/tesseract/src/lstm/convolve.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: convolve.cpp |
3 | | // Description: Convolutional layer that stacks the inputs over its rectangle |
4 | | // and pulls in random data to fill out-of-input inputs. |
5 | | // Output is therefore same size as its input, but deeper. |
6 | | // Author: Ray Smith |
7 | | // |
8 | | // (C) Copyright 2014, Google Inc. |
9 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
10 | | // you may not use this file except in compliance with the License. |
11 | | // You may obtain a copy of the License at |
12 | | // http://www.apache.org/licenses/LICENSE-2.0 |
13 | | // Unless required by applicable law or agreed to in writing, software |
14 | | // distributed under the License is distributed on an "AS IS" BASIS, |
15 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
16 | | // See the License for the specific language governing permissions and |
17 | | // limitations under the License. |
18 | | /////////////////////////////////////////////////////////////////////// |
19 | | |
20 | | #ifdef HAVE_CONFIG_H |
21 | | # include "config_auto.h" |
22 | | #endif |
23 | | |
24 | | #include "convolve.h" |
25 | | |
26 | | #include "networkscratch.h" |
27 | | #include "serialis.h" |
28 | | |
29 | | namespace tesseract { |
30 | | |
31 | | Convolve::Convolve(const std::string &name, int ni, int half_x, int half_y) |
32 | 4 | : Network(NT_CONVOLVE, name, ni, ni * (2 * half_x + 1) * (2 * half_y + 1)) |
33 | 4 | , half_x_(half_x) |
34 | 4 | , half_y_(half_y) {} |
35 | | |
36 | | // Writes to the given file. Returns false in case of error. |
37 | 0 | bool Convolve::Serialize(TFile *fp) const { |
38 | 0 | return Network::Serialize(fp) && fp->Serialize(&half_x_) && fp->Serialize(&half_y_); |
39 | 0 | } |
40 | | |
41 | | // Reads from the given file. Returns false in case of error. |
42 | 4 | bool Convolve::DeSerialize(TFile *fp) { |
43 | 4 | if (!fp->DeSerialize(&half_x_)) { |
44 | 0 | return false; |
45 | 0 | } |
46 | 4 | if (!fp->DeSerialize(&half_y_)) { |
47 | 0 | return false; |
48 | 0 | } |
49 | 4 | no_ = ni_ * (2 * half_x_ + 1) * (2 * half_y_ + 1); |
50 | 4 | return true; |
51 | 4 | } |
52 | | |
53 | | // Runs forward propagation of activations on the input line. |
54 | | // See NetworkCpp for a detailed discussion of the arguments. |
55 | | void Convolve::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
56 | 274k | NetworkScratch *scratch, NetworkIO *output) { |
57 | 274k | output->Resize(input, no_); |
58 | 274k | int y_scale = 2 * half_y_ + 1; |
59 | 274k | StrideMap::Index dest_index(output->stride_map()); |
60 | 369M | do { |
61 | | // Stack x_scale groups of y_scale * ni_ inputs together. |
62 | 369M | int t = dest_index.t(); |
63 | 369M | int out_ix = 0; |
64 | 1.47G | for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { |
65 | 1.10G | StrideMap::Index x_index(dest_index); |
66 | 1.10G | if (!x_index.AddOffset(x, FD_WIDTH)) { |
67 | | // This x is outside the image. |
68 | 19.7M | output->Randomize(t, out_ix, y_scale * ni_, randomizer_); |
69 | 1.08G | } else { |
70 | 1.08G | int out_iy = out_ix; |
71 | 4.35G | for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { |
72 | 3.26G | StrideMap::Index y_index(x_index); |
73 | 3.26G | if (!y_index.AddOffset(y, FD_HEIGHT)) { |
74 | | // This y is outside the image. |
75 | 60.4M | output->Randomize(t, out_iy, ni_, randomizer_); |
76 | 3.20G | } else { |
77 | 3.20G | output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0); |
78 | 3.20G | } |
79 | 3.26G | } |
80 | 1.08G | } |
81 | 1.10G | } |
82 | 369M | } while (dest_index.Increment()); |
83 | | #ifndef GRAPHICS_DISABLED |
84 | | if (debug) { |
85 | | DisplayForward(*output); |
86 | | } |
87 | | #endif |
88 | 274k | } |
89 | | |
90 | | // Runs backward propagation of errors on the deltas line. |
91 | | // See NetworkCpp for a detailed discussion of the arguments. |
92 | | bool Convolve::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
93 | 0 | NetworkIO *back_deltas) { |
94 | 0 | back_deltas->Resize(fwd_deltas, ni_); |
95 | 0 | NetworkScratch::IO delta_sum; |
96 | 0 | delta_sum.ResizeFloat(fwd_deltas, ni_, scratch); |
97 | 0 | delta_sum->Zero(); |
98 | 0 | int y_scale = 2 * half_y_ + 1; |
99 | 0 | StrideMap::Index src_index(fwd_deltas.stride_map()); |
100 | 0 | do { |
101 | | // Stack x_scale groups of y_scale * ni_ inputs together. |
102 | 0 | int t = src_index.t(); |
103 | 0 | int out_ix = 0; |
104 | 0 | for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { |
105 | 0 | StrideMap::Index x_index(src_index); |
106 | 0 | if (x_index.AddOffset(x, FD_WIDTH)) { |
107 | 0 | int out_iy = out_ix; |
108 | 0 | for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { |
109 | 0 | StrideMap::Index y_index(x_index); |
110 | 0 | if (y_index.AddOffset(y, FD_HEIGHT)) { |
111 | 0 | fwd_deltas.AddTimeStepPart(t, out_iy, ni_, delta_sum->f(y_index.t())); |
112 | 0 | } |
113 | 0 | } |
114 | 0 | } |
115 | 0 | } |
116 | 0 | } while (src_index.Increment()); |
117 | 0 | back_deltas->CopyAll(*delta_sum); |
118 | 0 | return true; |
119 | 0 | } |
120 | | |
121 | | } // namespace tesseract. |