/src/tesseract/src/lstm/reversed.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: reversed.h |
3 | | // Description: Runs a single network on time-reversed input, reversing output. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2013, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #ifndef TESSERACT_LSTM_REVERSED_H_ |
19 | | #define TESSERACT_LSTM_REVERSED_H_ |
20 | | |
21 | | #include "matrix.h" |
22 | | #include "plumbing.h" |
23 | | |
24 | | namespace tesseract { |
25 | | |
26 | | // C++ Implementation of the Reversed class from lstm.py. |
27 | | class Reversed : public Plumbing { |
28 | | public: |
29 | | TESS_API |
30 | | explicit Reversed(const std::string &name, NetworkType type); |
31 | | ~Reversed() override = default; |
32 | | |
33 | | // Returns the shape output from the network given an input shape (which may |
34 | | // be partially unknown ie zero). |
35 | | StaticShape OutputShape(const StaticShape &input_shape) const override; |
36 | | |
37 | 0 | std::string spec() const override { |
38 | 0 | std::string spec(type_ == NT_XREVERSED ? "Rx" : (type_ == NT_YREVERSED ? "Ry" : "Txy")); |
39 | | // For most simple cases, we will output Rx<net> or Ry<net> where <net> is |
40 | | // the network in stack_[0], but in the special case that <net> is an |
41 | | // LSTM, we will just output the LSTM's spec modified to take the reversal |
42 | | // into account. This is because when the user specified Lfy64, we actually |
43 | | // generated TxyLfx64, and if the user specified Lrx64 we actually |
44 | | // generated RxLfx64, and we want to display what the user asked for. |
45 | 0 | std::string net_spec(stack_[0]->spec()); |
46 | 0 | if (net_spec[0] == 'L') { |
47 | | // Setup a from and to character according to the type of the reversal |
48 | | // such that the LSTM spec gets modified to the spec that the user |
49 | | // asked for |
50 | 0 | char from = 'f'; |
51 | 0 | char to = 'r'; |
52 | 0 | if (type_ == NT_XYTRANSPOSE) { |
53 | 0 | from = 'x'; |
54 | 0 | to = 'y'; |
55 | 0 | } |
56 | | // Change the from char to the to char. |
57 | 0 | for (auto &it : net_spec) { |
58 | 0 | if (it == from) { |
59 | 0 | it = to; |
60 | 0 | } |
61 | 0 | } |
62 | 0 | spec += net_spec; |
63 | 0 | return spec; |
64 | 0 | } |
65 | 0 | spec += net_spec; |
66 | 0 | return spec; |
67 | 0 | } |
68 | | |
69 | | // Takes ownership of the given network to make it the reversed one. |
70 | | TESS_API |
71 | | void SetNetwork(Network *network); |
72 | | |
73 | | // Runs forward propagation of activations on the input line. |
74 | | // See Network for a detailed discussion of the arguments. |
75 | | void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
76 | | NetworkScratch *scratch, NetworkIO *output) override; |
77 | | |
78 | | // Runs backward propagation of errors on the deltas line. |
79 | | // See Network for a detailed discussion of the arguments. |
80 | | bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
81 | | NetworkIO *back_deltas) override; |
82 | | |
83 | | private: |
84 | | // Copies src to *dest with the reversal according to type_. |
85 | | void ReverseData(const NetworkIO &src, NetworkIO *dest) const; |
86 | | }; |
87 | | |
88 | | } // namespace tesseract. |
89 | | |
90 | | #endif // TESSERACT_LSTM_REVERSED_H_ |