/src/tesseract/src/lstm/input.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: input.h |
3 | | // Description: Input layer class for neural network implementations. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2014, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifndef TESSERACT_LSTM_INPUT_H_ |
19 | | #define TESSERACT_LSTM_INPUT_H_ |
20 | | |
21 | | #include "network.h" |
22 | | |
23 | | namespace tesseract { |
24 | | |
25 | | class ScrollView; |
26 | | |
27 | | class Input : public Network { |
28 | | public: |
29 | | TESS_API |
30 | | Input(const std::string &name, int ni, int no); |
31 | | TESS_API |
32 | | Input(const std::string &name, const StaticShape &shape); |
33 | | ~Input() override = default; |
34 | | |
35 | 0 | std::string spec() const override { |
36 | 0 | return std::to_string(shape_.batch()) + "," + |
37 | 0 | std::to_string(shape_.height()) + "," + |
38 | 0 | std::to_string(shape_.width()) + "," + |
39 | 0 | std::to_string(shape_.depth()); |
40 | 0 | } |
41 | | |
42 | | // Returns the required shape input to the network. |
43 | 270k | StaticShape InputShape() const override { |
44 | 270k | return shape_; |
45 | 270k | } |
46 | | // Returns the shape output from the network given an input shape (which may |
47 | | // be partially unknown ie zero). |
48 | | StaticShape OutputShape( |
49 | 2 | [[maybe_unused]] const StaticShape &input_shape) const override { |
50 | 2 | return shape_; |
51 | 2 | } |
52 | | // Writes to the given file. Returns false in case of error. |
53 | | // Should be overridden by subclasses, but called by their Serialize. |
54 | | bool Serialize(TFile *fp) const override; |
55 | | // Reads from the given file. Returns false in case of error. |
56 | | bool DeSerialize(TFile *fp) override; |
57 | | |
58 | | // Returns an integer reduction factor that the network applies to the |
59 | | // time sequence. Assumes that any 2-d is already eliminated. Used for |
60 | | // scaling bounding boxes of truth data. |
61 | | // WARNING: if GlobalMinimax is used to vary the scale, this will return |
62 | | // the last used scale factor. Call it before any forward, and it will return |
63 | | // the minimum scale factor of the paths through the GlobalMinimax. |
64 | | int XScaleFactor() const override; |
65 | | |
66 | | // Provides the (minimum) x scale factor to the network (of interest only to |
67 | | // input units) so they can determine how to scale bounding boxes. |
68 | | void CacheXScaleFactor(int factor) override; |
69 | | |
70 | | // Runs forward propagation of activations on the input line. |
71 | | // See Network for a detailed discussion of the arguments. |
72 | | void Forward(bool debug, const NetworkIO &input, |
73 | | const TransposedArray *input_transpose, NetworkScratch *scratch, |
74 | | NetworkIO *output) override; |
75 | | |
76 | | // Runs backward propagation of errors on the deltas line. |
77 | | // See Network for a detailed discussion of the arguments. |
78 | | bool Backward(bool debug, const NetworkIO &fwd_deltas, |
79 | | NetworkScratch *scratch, NetworkIO *back_deltas) override; |
80 | | // Creates and returns a Pix of appropriate size for the network from the |
81 | | // image_data. If non-null, *image_scale returns the image scale factor used. |
82 | | // Returns nullptr on error. |
83 | | /* static */ |
84 | | static Image PrepareLSTMInputs(const ImageData &image_data, |
85 | | const Network *network, int min_width, |
86 | | TRand *randomizer, float *image_scale); |
87 | | // Converts the given pix to a NetworkIO of height and depth appropriate to |
88 | | // the given StaticShape: |
89 | | // If depth == 3, convert to 24 bit color, otherwise normalized grey. |
90 | | // Scale to target height, if the shape's height is > 1, or its depth if the |
91 | | // height == 1. If height == 0 then no scaling. |
92 | | // NOTE: It isn't safe for multiple threads to call this on the same pix. |
93 | | static void PreparePixInput(const StaticShape &shape, const Image pix, |
94 | | TRand *randomizer, NetworkIO *input); |
95 | | |
96 | | private: |
97 | 0 | void DebugWeights() override { |
98 | 0 | tprintf("Must override Network::DebugWeights for type %d\n", type_); |
99 | 0 | } |
100 | | |
101 | | // Input shape determines how images are dealt with. |
102 | | StaticShape shape_; |
103 | | // Cached total network x scale factor for scaling bounding boxes. |
104 | | int cached_x_scale_; |
105 | | }; |
106 | | |
107 | | } // namespace tesseract. |
108 | | |
109 | | #endif // TESSERACT_LSTM_INPUT_H_ |