/src/tesseract/src/lstm/reconfig.h
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: reconfig.h |
3 | | // Description: Network layer that reconfigures the scaling vs feature |
4 | | // depth. |
5 | | // Author: Ray Smith |
6 | | // |
7 | | // (C) Copyright 2014, Google Inc. |
8 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
9 | | // you may not use this file except in compliance with the License. |
10 | | // You may obtain a copy of the License at |
11 | | // http://www.apache.org/licenses/LICENSE-2.0 |
12 | | // Unless required by applicable law or agreed to in writing, software |
13 | | // distributed under the License is distributed on an "AS IS" BASIS, |
14 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | | // See the License for the specific language governing permissions and |
16 | | // limitations under the License. |
17 | | /////////////////////////////////////////////////////////////////////// |
18 | | |
19 | | #ifndef TESSERACT_LSTM_RECONFIG_H_ |
20 | | #define TESSERACT_LSTM_RECONFIG_H_ |
21 | | |
22 | | #include "matrix.h" |
23 | | #include "network.h" |
24 | | |
25 | | namespace tesseract { |
26 | | |
27 | | // Reconfigures (Shrinks) the inputs by concatenating an x_scale by y_scale tile |
28 | | // of inputs together, producing a single, deeper output per tile. |
29 | | // Note that fractional parts are truncated for efficiency, so make sure the |
30 | | // input stride is a multiple of the y_scale factor! |
31 | | class Reconfig : public Network { |
32 | | public: |
33 | | TESS_API |
34 | | Reconfig(const std::string &name, int ni, int x_scale, int y_scale); |
35 | 0 | ~Reconfig() override = default; |
36 | | |
37 | | // Returns the shape output from the network given an input shape (which may |
38 | | // be partially unknown ie zero). |
39 | | StaticShape OutputShape(const StaticShape &input_shape) const override; |
40 | | |
41 | 0 | std::string spec() const override { |
42 | 0 | return "S" + std::to_string(y_scale_) + "," + std::to_string(x_scale_); |
43 | 0 | } |
44 | | |
45 | | // Returns an integer reduction factor that the network applies to the |
46 | | // time sequence. Assumes that any 2-d is already eliminated. Used for |
47 | | // scaling bounding boxes of truth data. |
48 | | // WARNING: if GlobalMinimax is used to vary the scale, this will return |
49 | | // the last used scale factor. Call it before any forward, and it will return |
50 | | // the minimum scale factor of the paths through the GlobalMinimax. |
51 | | int XScaleFactor() const override; |
52 | | |
53 | | // Writes to the given file. Returns false in case of error. |
54 | | bool Serialize(TFile *fp) const override; |
55 | | // Reads from the given file. Returns false in case of error. |
56 | | bool DeSerialize(TFile *fp) override; |
57 | | |
58 | | // Runs forward propagation of activations on the input line. |
59 | | // See Network for a detailed discussion of the arguments. |
60 | | void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
61 | | NetworkScratch *scratch, NetworkIO *output) override; |
62 | | |
63 | | // Runs backward propagation of errors on the deltas line. |
64 | | // See Network for a detailed discussion of the arguments. |
65 | | bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
66 | | NetworkIO *back_deltas) override; |
67 | | |
68 | | private: |
69 | 0 | void DebugWeights() override { |
70 | 0 | tprintf("Must override Network::DebugWeights for type %d\n", type_); |
71 | 0 | } |
72 | | |
73 | | protected: |
74 | | // Non-serialized data used to store parameters between forward and back. |
75 | | StrideMap back_map_; |
76 | | // Serialized data. |
77 | | int32_t x_scale_; |
78 | | int32_t y_scale_; |
79 | | }; |
80 | | |
81 | | } // namespace tesseract. |
82 | | |
83 | | #endif // TESSERACT_LSTM_SUBSAMPLE_H_ |