/src/tesseract/src/lstm/convolve.cpp
Line  | Count  | Source (jump to first uncovered line)  | 
1  |  | ///////////////////////////////////////////////////////////////////////  | 
2  |  | // File:        convolve.cpp  | 
3  |  | // Description: Convolutional layer that stacks the inputs over its rectangle  | 
4  |  | //              and pulls in random data to fill out-of-input inputs.  | 
5  |  | //              Output is therefore same size as its input, but deeper.  | 
6  |  | // Author:      Ray Smith  | 
7  |  | //  | 
8  |  | // (C) Copyright 2014, Google Inc.  | 
9  |  | // Licensed under the Apache License, Version 2.0 (the "License");  | 
10  |  | // you may not use this file except in compliance with the License.  | 
11  |  | // You may obtain a copy of the License at  | 
12  |  | // http://www.apache.org/licenses/LICENSE-2.0  | 
13  |  | // Unless required by applicable law or agreed to in writing, software  | 
14  |  | // distributed under the License is distributed on an "AS IS" BASIS,  | 
15  |  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
16  |  | // See the License for the specific language governing permissions and  | 
17  |  | // limitations under the License.  | 
18  |  | ///////////////////////////////////////////////////////////////////////  | 
19  |  |  | 
20  |  | #ifdef HAVE_CONFIG_H  | 
21  |  | #  include "config_auto.h"  | 
22  |  | #endif  | 
23  |  |  | 
24  |  | #include "convolve.h"  | 
25  |  |  | 
26  |  | #include "networkscratch.h"  | 
27  |  | #include "serialis.h"  | 
28  |  |  | 
29  |  | namespace tesseract { | 
30  |  |  | 
31  |  | Convolve::Convolve(const std::string &name, int ni, int half_x, int half_y)  | 
32  | 4  |     : Network(NT_CONVOLVE, name, ni, ni * (2 * half_x + 1) * (2 * half_y + 1))  | 
33  | 4  |     , half_x_(half_x)  | 
34  | 4  |     , half_y_(half_y) {} | 
35  |  |  | 
36  |  | // Writes to the given file. Returns false in case of error.  | 
37  | 0  | bool Convolve::Serialize(TFile *fp) const { | 
38  | 0  |   return Network::Serialize(fp) && fp->Serialize(&half_x_) && fp->Serialize(&half_y_);  | 
39  | 0  | }  | 
40  |  |  | 
41  |  | // Reads from the given file. Returns false in case of error.  | 
42  | 4  | bool Convolve::DeSerialize(TFile *fp) { | 
43  | 4  |   if (!fp->DeSerialize(&half_x_)) { | 
44  | 0  |     return false;  | 
45  | 0  |   }  | 
46  | 4  |   if (!fp->DeSerialize(&half_y_)) { | 
47  | 0  |     return false;  | 
48  | 0  |   }  | 
49  | 4  |   no_ = ni_ * (2 * half_x_ + 1) * (2 * half_y_ + 1);  | 
50  | 4  |   return true;  | 
51  | 4  | }  | 
52  |  |  | 
53  |  | // Runs forward propagation of activations on the input line.  | 
54  |  | // See NetworkCpp for a detailed discussion of the arguments.  | 
55  |  | void Convolve::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,  | 
56  | 274k  |                        NetworkScratch *scratch, NetworkIO *output) { | 
57  | 274k  |   output->Resize(input, no_);  | 
58  | 274k  |   int y_scale = 2 * half_y_ + 1;  | 
59  | 274k  |   StrideMap::Index dest_index(output->stride_map());  | 
60  | 369M  |   do { | 
61  |  |     // Stack x_scale groups of y_scale * ni_ inputs together.  | 
62  | 369M  |     int t = dest_index.t();  | 
63  | 369M  |     int out_ix = 0;  | 
64  | 1.47G  |     for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { | 
65  | 1.10G  |       StrideMap::Index x_index(dest_index);  | 
66  | 1.10G  |       if (!x_index.AddOffset(x, FD_WIDTH)) { | 
67  |  |         // This x is outside the image.  | 
68  | 19.7M  |         output->Randomize(t, out_ix, y_scale * ni_, randomizer_);  | 
69  | 1.08G  |       } else { | 
70  | 1.08G  |         int out_iy = out_ix;  | 
71  | 4.35G  |         for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { | 
72  | 3.26G  |           StrideMap::Index y_index(x_index);  | 
73  | 3.26G  |           if (!y_index.AddOffset(y, FD_HEIGHT)) { | 
74  |  |             // This y is outside the image.  | 
75  | 60.4M  |             output->Randomize(t, out_iy, ni_, randomizer_);  | 
76  | 3.20G  |           } else { | 
77  | 3.20G  |             output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0);  | 
78  | 3.20G  |           }  | 
79  | 3.26G  |         }  | 
80  | 1.08G  |       }  | 
81  | 1.10G  |     }  | 
82  | 369M  |   } while (dest_index.Increment());  | 
83  |  | #ifndef GRAPHICS_DISABLED  | 
84  |  |   if (debug) { | 
85  |  |     DisplayForward(*output);  | 
86  |  |   }  | 
87  |  | #endif  | 
88  | 274k  | }  | 
89  |  |  | 
90  |  | // Runs backward propagation of errors on the deltas line.  | 
91  |  | // See NetworkCpp for a detailed discussion of the arguments.  | 
92  |  | bool Convolve::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,  | 
93  | 0  |                         NetworkIO *back_deltas) { | 
94  | 0  |   back_deltas->Resize(fwd_deltas, ni_);  | 
95  | 0  |   NetworkScratch::IO delta_sum;  | 
96  | 0  |   delta_sum.ResizeFloat(fwd_deltas, ni_, scratch);  | 
97  | 0  |   delta_sum->Zero();  | 
98  | 0  |   int y_scale = 2 * half_y_ + 1;  | 
99  | 0  |   StrideMap::Index src_index(fwd_deltas.stride_map());  | 
100  | 0  |   do { | 
101  |  |     // Stack x_scale groups of y_scale * ni_ inputs together.  | 
102  | 0  |     int t = src_index.t();  | 
103  | 0  |     int out_ix = 0;  | 
104  | 0  |     for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) { | 
105  | 0  |       StrideMap::Index x_index(src_index);  | 
106  | 0  |       if (x_index.AddOffset(x, FD_WIDTH)) { | 
107  | 0  |         int out_iy = out_ix;  | 
108  | 0  |         for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) { | 
109  | 0  |           StrideMap::Index y_index(x_index);  | 
110  | 0  |           if (y_index.AddOffset(y, FD_HEIGHT)) { | 
111  | 0  |             fwd_deltas.AddTimeStepPart(t, out_iy, ni_, delta_sum->f(y_index.t()));  | 
112  | 0  |           }  | 
113  | 0  |         }  | 
114  | 0  |       }  | 
115  | 0  |     }  | 
116  | 0  |   } while (src_index.Increment());  | 
117  | 0  |   back_deltas->CopyAll(*delta_sum);  | 
118  | 0  |   return true;  | 
119  | 0  | }  | 
120  |  |  | 
121  |  | } // namespace tesseract.  |