/src/tesseract/src/lstm/reconfig.cpp
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: reconfig.cpp |
3 | | // Description: Network layer that reconfigures the scaling vs feature |
4 | | // depth. |
5 | | // Author: Ray Smith |
6 | | // |
7 | | // (C) Copyright 2014, Google Inc. |
8 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
9 | | // you may not use this file except in compliance with the License. |
10 | | // You may obtain a copy of the License at |
11 | | // http://www.apache.org/licenses/LICENSE-2.0 |
12 | | // Unless required by applicable law or agreed to in writing, software |
13 | | // distributed under the License is distributed on an "AS IS" BASIS, |
14 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | | // See the License for the specific language governing permissions and |
16 | | // limitations under the License. |
17 | | /////////////////////////////////////////////////////////////////////// |
18 | | |
19 | | #include "reconfig.h" |
20 | | |
21 | | namespace tesseract { |
22 | | |
23 | | Reconfig::Reconfig(const std::string &name, int ni, int x_scale, int y_scale) |
24 | 4 | : Network(NT_RECONFIG, name, ni, ni * x_scale * y_scale) |
25 | 4 | , x_scale_(x_scale) |
26 | 4 | , y_scale_(y_scale) {} |
27 | | |
28 | | // Returns the shape output from the network given an input shape (which may |
29 | | // be partially unknown ie zero). |
30 | 2 | StaticShape Reconfig::OutputShape(const StaticShape &input_shape) const { |
31 | 2 | StaticShape result = input_shape; |
32 | 2 | result.set_height(result.height() / y_scale_); |
33 | 2 | result.set_width(result.width() / x_scale_); |
34 | 2 | if (type_ != NT_MAXPOOL) { |
35 | 0 | result.set_depth(result.depth() * y_scale_ * x_scale_); |
36 | 0 | } |
37 | 2 | return result; |
38 | 2 | } |
39 | | |
40 | | // Returns an integer reduction factor that the network applies to the |
41 | | // time sequence. Assumes that any 2-d is already eliminated. Used for |
42 | | // scaling bounding boxes of truth data. |
43 | | // WARNING: if GlobalMinimax is used to vary the scale, this will return |
44 | | // the last used scale factor. Call it before any forward, and it will return |
45 | | // the minimum scale factor of the paths through the GlobalMinimax. |
46 | 140k | int Reconfig::XScaleFactor() const { |
47 | 140k | return x_scale_; |
48 | 140k | } |
49 | | |
50 | | // Writes to the given file. Returns false in case of error. |
51 | 0 | bool Reconfig::Serialize(TFile *fp) const { |
52 | 0 | return Network::Serialize(fp) && fp->Serialize(&x_scale_) && fp->Serialize(&y_scale_); |
53 | 0 | } |
54 | | |
55 | | // Reads from the given file. Returns false in case of error. |
56 | 4 | bool Reconfig::DeSerialize(TFile *fp) { |
57 | 4 | if (!fp->DeSerialize(&x_scale_)) { |
58 | 0 | return false; |
59 | 0 | } |
60 | 4 | if (!fp->DeSerialize(&y_scale_)) { |
61 | 0 | return false; |
62 | 0 | } |
63 | 4 | no_ = ni_ * x_scale_ * y_scale_; |
64 | 4 | return true; |
65 | 4 | } |
66 | | |
67 | | // Runs forward propagation of activations on the input line. |
68 | | // See NetworkCpp for a detailed discussion of the arguments. |
69 | | void Reconfig::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
70 | 0 | NetworkScratch *scratch, NetworkIO *output) { |
71 | 0 | output->ResizeScaled(input, x_scale_, y_scale_, no_); |
72 | 0 | back_map_ = input.stride_map(); |
73 | 0 | StrideMap::Index dest_index(output->stride_map()); |
74 | 0 | do { |
75 | 0 | int out_t = dest_index.t(); |
76 | 0 | StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH), |
77 | 0 | dest_index.index(FD_HEIGHT) * y_scale_, |
78 | 0 | dest_index.index(FD_WIDTH) * x_scale_); |
79 | | // Stack x_scale_ groups of y_scale_ inputs together. |
80 | 0 | for (int x = 0; x < x_scale_; ++x) { |
81 | 0 | for (int y = 0; y < y_scale_; ++y) { |
82 | 0 | StrideMap::Index src_xy(src_index); |
83 | 0 | if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) { |
84 | 0 | output->CopyTimeStepGeneral(out_t, (x * y_scale_ + y) * ni_, ni_, input, src_xy.t(), 0); |
85 | 0 | } |
86 | 0 | } |
87 | 0 | } |
88 | 0 | } while (dest_index.Increment()); |
89 | 0 | } |
90 | | |
91 | | // Runs backward propagation of errors on the deltas line. |
92 | | // See NetworkCpp for a detailed discussion of the arguments. |
93 | | bool Reconfig::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
94 | 0 | NetworkIO *back_deltas) { |
95 | 0 | back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_); |
96 | 0 | StrideMap::Index src_index(fwd_deltas.stride_map()); |
97 | 0 | do { |
98 | 0 | int in_t = src_index.t(); |
99 | 0 | StrideMap::Index dest_index(back_deltas->stride_map(), src_index.index(FD_BATCH), |
100 | 0 | src_index.index(FD_HEIGHT) * y_scale_, |
101 | 0 | src_index.index(FD_WIDTH) * x_scale_); |
102 | | // Unstack x_scale_ groups of y_scale_ inputs that are together. |
103 | 0 | for (int x = 0; x < x_scale_; ++x) { |
104 | 0 | for (int y = 0; y < y_scale_; ++y) { |
105 | 0 | StrideMap::Index dest_xy(dest_index); |
106 | 0 | if (dest_xy.AddOffset(x, FD_WIDTH) && dest_xy.AddOffset(y, FD_HEIGHT)) { |
107 | 0 | back_deltas->CopyTimeStepGeneral(dest_xy.t(), 0, ni_, fwd_deltas, in_t, |
108 | 0 | (x * y_scale_ + y) * ni_); |
109 | 0 | } |
110 | 0 | } |
111 | 0 | } |
112 | 0 | } while (src_index.Increment()); |
113 | 0 | return needs_to_backprop_; |
114 | 0 | } |
115 | | |
116 | | } // namespace tesseract. |