/src/tesseract/src/lstm/lstm.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: lstm.h |
3 | | // Description: Long-term-short-term-memory Recurrent neural network. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2013, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifndef TESSERACT_LSTM_LSTM_H_ |
19 | | #define TESSERACT_LSTM_LSTM_H_ |
20 | | |
21 | | #include "fullyconnected.h" |
22 | | #include "network.h" |
23 | | |
24 | | namespace tesseract { |
25 | | |
26 | | // C++ Implementation of the LSTM class from lstm.py. |
27 | | class LSTM : public Network { |
28 | | public: |
29 | | // Enum for the different weights in LSTM, to reduce some of the I/O and |
30 | | // setup code to loops. The elements of the enum correspond to elements of an |
31 | | // array of WeightMatrix or a corresponding array of NetworkIO. |
32 | | enum WeightType { |
33 | | CI, // Cell Inputs. |
34 | | GI, // Gate at the input. |
35 | | GF1, // Forget gate at the memory (1-d or looking back 1 timestep). |
36 | | GO, // Gate at the output. |
37 | | GFS, // Forget gate at the memory, looking back in the other dimension. |
38 | | |
39 | | WT_COUNT // Number of WeightTypes. |
40 | | }; |
41 | | |
42 | | // Constructor for NT_LSTM (regular 1 or 2-d LSTM), NT_LSTM_SOFTMAX (LSTM with |
43 | | // additional softmax layer included and fed back into the input at the next |
44 | | // timestep), or NT_LSTM_SOFTMAX_ENCODED (as LSTM_SOFTMAX, but the feedback |
45 | | // is binary encoded instead of categorical) only. |
46 | | // 2-d and bidi softmax LSTMs are not rejected, but are impossible to build |
47 | | // in the conventional way because the output feedback both forwards and |
48 | | // backwards in time does become impossible. |
49 | | TESS_API |
50 | | LSTM(const std::string &name, int num_inputs, int num_states, int num_outputs, |
51 | | bool two_dimensional, NetworkType type); |
52 | | ~LSTM() override; |
53 | | |
54 | | // Returns the shape output from the network given an input shape (which may |
55 | | // be partially unknown ie zero). |
56 | | StaticShape OutputShape(const StaticShape &input_shape) const override; |
57 | | |
58 | 0 | std::string spec() const override { |
59 | 0 | std::string spec; |
60 | 0 | if (type_ == NT_LSTM) { |
61 | 0 | spec += "Lfx" + std::to_string(ns_); |
62 | 0 | } else if (type_ == NT_LSTM_SUMMARY) { |
63 | 0 | spec += "Lfxs" + std::to_string(ns_); |
64 | 0 | } else if (type_ == NT_LSTM_SOFTMAX) { |
65 | 0 | spec += "LS" + std::to_string(ns_); |
66 | 0 | } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) { |
67 | 0 | spec += "LE" + std::to_string(ns_); |
68 | 0 | } |
69 | 0 | if (softmax_ != nullptr) { |
70 | 0 | spec += softmax_->spec(); |
71 | 0 | } |
72 | 0 | return spec; |
73 | 0 | } |
74 | | |
75 | | // Suspends/Enables training by setting the training_ flag. Serialize and |
76 | | // DeSerialize only operate on the run-time data if state is false. |
77 | | void SetEnableTraining(TrainingState state) override; |
78 | | |
79 | | // Sets up the network for training. Initializes weights using weights of |
80 | | // scale `range` picked according to the random number generator `randomizer`. |
81 | | int InitWeights(float range, TRand *randomizer) override; |
82 | | // Recursively searches the network for softmaxes with old_no outputs, |
83 | | // and remaps their outputs according to code_map. See network.h for details. |
84 | | int RemapOutputs(int old_no, const std::vector<int> &code_map) override; |
85 | | |
86 | | // Converts a float network to an int network. |
87 | | void ConvertToInt() override; |
88 | | |
89 | | // Provides debug output on the weights. |
90 | | void DebugWeights() override; |
91 | | |
92 | | // Writes to the given file. Returns false in case of error. |
93 | | bool Serialize(TFile *fp) const override; |
94 | | // Reads from the given file. Returns false in case of error. |
95 | | bool DeSerialize(TFile *fp) override; |
96 | | |
97 | | // Runs forward propagation of activations on the input line. |
98 | | // See Network for a detailed discussion of the arguments. |
99 | | void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
100 | | NetworkScratch *scratch, NetworkIO *output) override; |
101 | | |
102 | | // Runs backward propagation of errors on the deltas line. |
103 | | // See Network for a detailed discussion of the arguments. |
104 | | bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
105 | | NetworkIO *back_deltas) override; |
106 | | // Updates the weights using the given learning rate, momentum and adam_beta. |
107 | | // num_samples is used in the adam computation iff use_adam_ is true. |
108 | | void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override; |
109 | | // Sums the products of weight updates in *this and other, splitting into |
110 | | // positive (same direction) in *same and negative (different direction) in |
111 | | // *changed. |
112 | | void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override; |
113 | | // Prints the weights for debug purposes. |
114 | | void PrintW(); |
115 | | // Prints the weight deltas for debug purposes. |
116 | | void PrintDW(); |
117 | | |
118 | | // Returns true of this is a 2-d lstm. |
119 | 265M | bool Is2D() const { |
120 | 265M | return is_2d_; |
121 | 265M | } |
122 | | |
123 | | private: |
124 | | // Resizes forward data to cope with an input image of the given width. |
125 | | void ResizeForward(const NetworkIO &input); |
126 | | |
127 | | private: |
128 | | // Size of padded input to weight matrices = ni_ + no_ for 1-D operation |
129 | | // and ni_ + 2 * no_ for 2-D operation. Note that there is a phantom 1 input |
130 | | // for the bias that makes the weight matrices of size [na + 1][no]. |
131 | | int32_t na_; |
132 | | // Number of internal states. Equal to no_ except for a softmax LSTM. |
133 | | // ns_ is NOT serialized, but is calculated from gate_weights_. |
134 | | int32_t ns_; |
135 | | // Number of additional feedback states. The softmax types feed back |
136 | | // additional output information on top of the ns_ internal states. |
137 | | // In the case of a binary-coded (EMBEDDED) softmax, nf_ < no_. |
138 | | int32_t nf_; |
139 | | // Flag indicating 2-D operation. |
140 | | bool is_2d_; |
141 | | |
142 | | // Gate weight arrays of size [na + 1, no]. |
143 | | WeightMatrix gate_weights_[WT_COUNT]; |
144 | | // Used only if this is a softmax LSTM. |
145 | | FullyConnected *softmax_; |
146 | | // Input padded with previous output of size [width, na]. |
147 | | NetworkIO source_; |
148 | | // Internal state used during forward operation, of size [width, ns]. |
149 | | NetworkIO state_; |
150 | | // State of the 2-d maxpool, generated during forward, used during backward. |
151 | | GENERIC_2D_ARRAY<int8_t> which_fg_; |
152 | | // Internal state saved from forward, but used only during backward. |
153 | | NetworkIO node_values_[WT_COUNT]; |
154 | | // Preserved input stride_map used for Backward when NT_LSTM_SQUASHED. |
155 | | StrideMap input_map_; |
156 | | int input_width_; |
157 | | }; |
158 | | |
159 | | } // namespace tesseract. |
160 | | |
161 | | #endif // TESSERACT_LSTM_LSTM_H_ |