/src/tesseract/src/lstm/static_shape.h
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: static_shape.h |
3 | | // Description: Defines the size of the 4-d tensor input/output from a network. |
4 | | // Author: Ray Smith |
5 | | // Created: Fri Oct 14 09:07:31 PST 2016 |
6 | | // |
7 | | // (C) Copyright 2016, Google Inc. |
8 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
9 | | // you may not use this file except in compliance with the License. |
10 | | // You may obtain a copy of the License at |
11 | | // http://www.apache.org/licenses/LICENSE-2.0 |
12 | | // Unless required by applicable law or agreed to in writing, software |
13 | | // distributed under the License is distributed on an "AS IS" BASIS, |
14 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | | // See the License for the specific language governing permissions and |
16 | | // limitations under the License. |
17 | | /////////////////////////////////////////////////////////////////////// |
18 | | |
19 | | #ifndef TESSERACT_LSTM_STATIC_SHAPE_H_ |
20 | | #define TESSERACT_LSTM_STATIC_SHAPE_H_ |
21 | | |
22 | | #include "serialis.h" // for TFile |
23 | | #include "tprintf.h" // for tprintf |
24 | | |
25 | | namespace tesseract { |
26 | | |
27 | | // Enum describing the loss function to apply during training and/or the |
28 | | // decoding method to apply at runtime. |
29 | | enum LossType { |
30 | | LT_NONE, // Undefined. |
31 | | LT_CTC, // Softmax with standard CTC for training/decoding. |
32 | | LT_SOFTMAX, // Outputs sum to 1 in fixed positions. |
33 | | LT_LOGISTIC, // Logistic outputs with independent values. |
34 | | }; |
35 | | |
36 | | // Simple class to hold the tensor shape that is known at network build time |
37 | | // and the LossType of the loss function. |
38 | | class StaticShape { |
39 | | public: |
40 | 6 | StaticShape() : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {} |
41 | 2 | int batch() const { |
42 | 2 | return batch_; |
43 | 2 | } |
44 | 0 | void set_batch(int value) { |
45 | 0 | batch_ = value; |
46 | 0 | } |
47 | 769k | int height() const { |
48 | 769k | return height_; |
49 | 769k | } |
50 | 4 | void set_height(int value) { |
51 | 4 | height_ = value; |
52 | 4 | } |
53 | 256k | int width() const { |
54 | 256k | return width_; |
55 | 256k | } |
56 | 6 | void set_width(int value) { |
57 | 6 | width_ = value; |
58 | 6 | } |
59 | 769k | int depth() const { |
60 | 769k | return depth_; |
61 | 769k | } |
62 | 14 | void set_depth(int value) { |
63 | 14 | depth_ = value; |
64 | 14 | } |
65 | 2 | LossType loss_type() const { |
66 | 2 | return loss_type_; |
67 | 2 | } |
68 | 4 | void set_loss_type(LossType value) { |
69 | 4 | loss_type_ = value; |
70 | 4 | } |
71 | 2 | void SetShape(int batch, int height, int width, int depth) { |
72 | 2 | batch_ = batch; |
73 | 2 | height_ = height; |
74 | 2 | width_ = width; |
75 | 2 | depth_ = depth; |
76 | 2 | } |
77 | | |
78 | 0 | void Print() const { |
79 | 0 | tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_, height_, width_, depth_, |
80 | 0 | loss_type_); |
81 | 0 | } |
82 | | |
83 | 4 | bool DeSerialize(TFile *fp) { |
84 | 4 | int32_t tmp = LT_NONE; |
85 | 4 | bool result = fp->DeSerialize(&batch_) && fp->DeSerialize(&height_) && |
86 | 4 | fp->DeSerialize(&width_) && fp->DeSerialize(&depth_) && fp->DeSerialize(&tmp); |
87 | 4 | loss_type_ = static_cast<LossType>(tmp); |
88 | 4 | return result; |
89 | 4 | } |
90 | | |
91 | 0 | bool Serialize(TFile *fp) const { |
92 | 0 | int32_t tmp = loss_type_; |
93 | 0 | return fp->Serialize(&batch_) && fp->Serialize(&height_) && fp->Serialize(&width_) && |
94 | 0 | fp->Serialize(&depth_) && fp->Serialize(&tmp); |
95 | 0 | } |
96 | | |
97 | | private: |
98 | | // Size of the 4-D tensor input/output to a network. A value of zero is |
99 | | // allowed for all except depth_ and means to be determined at runtime, and |
100 | | // regarded as variable. |
101 | | // Number of elements in a batch, or number of frames in a video stream. |
102 | | int32_t batch_; |
103 | | // Height of the image. |
104 | | int32_t height_; |
105 | | // Width of the image. |
106 | | int32_t width_; |
107 | | // Depth of the image. (Number of "nodes"). |
108 | | int32_t depth_; |
109 | | // How to train/interpret the output. |
110 | | LossType loss_type_; |
111 | | }; |
112 | | |
113 | | } // namespace tesseract |
114 | | |
115 | | #endif // TESSERACT_LSTM_STATIC_SHAPE_H_ |