/src/tesseract/src/lstm/reversed.cpp
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: reversed.cpp |
3 | | // Description: Runs a single network on time-reversed input, reversing output. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2013, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #include "reversed.h" |
19 | | |
20 | | #include <cstdio> |
21 | | |
22 | | #include "networkscratch.h" |
23 | | |
24 | | namespace tesseract { |
25 | | |
26 | 8 | Reversed::Reversed(const std::string &name, NetworkType type) : Plumbing(name) { |
27 | 8 | type_ = type; |
28 | 8 | } |
29 | | |
30 | | // Returns the shape output from the network given an input shape (which may |
31 | | // be partially unknown ie zero). |
32 | 4 | StaticShape Reversed::OutputShape(const StaticShape &input_shape) const { |
33 | 4 | if (type_ == NT_XYTRANSPOSE) { |
34 | 2 | StaticShape x_shape(input_shape); |
35 | 2 | x_shape.set_width(input_shape.height()); |
36 | 2 | x_shape.set_height(input_shape.width()); |
37 | 2 | x_shape = stack_[0]->OutputShape(x_shape); |
38 | 2 | x_shape.SetShape(x_shape.batch(), x_shape.width(), x_shape.height(), x_shape.depth()); |
39 | 2 | return x_shape; |
40 | 2 | } |
41 | 2 | return stack_[0]->OutputShape(input_shape); |
42 | 4 | } |
43 | | |
44 | | // Takes ownership of the given network to make it the reversed one. |
45 | 0 | void Reversed::SetNetwork(Network *network) { |
46 | 0 | stack_.clear(); |
47 | 0 | AddToStack(network); |
48 | 0 | } |
49 | | |
50 | | // Runs forward propagation of activations on the input line. |
51 | | // See NetworkCpp for a detailed discussion of the arguments. |
52 | | void Reversed::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, |
53 | 543k | NetworkScratch *scratch, NetworkIO *output) { |
54 | 543k | NetworkScratch::IO rev_input(input, scratch); |
55 | 543k | ReverseData(input, rev_input); |
56 | 543k | NetworkScratch::IO rev_output(input, scratch); |
57 | 543k | stack_[0]->Forward(debug, *rev_input, nullptr, scratch, rev_output); |
58 | 543k | ReverseData(*rev_output, output); |
59 | 543k | } |
60 | | |
61 | | // Runs backward propagation of errors on the deltas line. |
62 | | // See NetworkCpp for a detailed discussion of the arguments. |
63 | | bool Reversed::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, |
64 | 0 | NetworkIO *back_deltas) { |
65 | 0 | NetworkScratch::IO rev_input(fwd_deltas, scratch); |
66 | 0 | ReverseData(fwd_deltas, rev_input); |
67 | 0 | NetworkScratch::IO rev_output(fwd_deltas, scratch); |
68 | 0 | if (stack_[0]->Backward(debug, *rev_input, scratch, rev_output)) { |
69 | 0 | ReverseData(*rev_output, back_deltas); |
70 | 0 | return true; |
71 | 0 | } |
72 | 0 | return false; |
73 | 0 | } |
74 | | |
75 | | // Copies src to *dest with the reversal according to type_. |
76 | 1.08M | void Reversed::ReverseData(const NetworkIO &src, NetworkIO *dest) const { |
77 | 1.08M | if (type_ == NT_XREVERSED) { |
78 | 543k | dest->CopyWithXReversal(src); |
79 | 543k | } else if (type_ == NT_YREVERSED) { |
80 | 0 | dest->CopyWithYReversal(src); |
81 | 543k | } else { |
82 | 543k | dest->CopyWithXYTranspose(src); |
83 | 543k | } |
84 | 1.08M | } |
85 | | |
86 | | } // namespace tesseract. |