Coverage Report

Created: 2025-07-23 07:12

/src/tesseract/src/lstm/plumbing.h
Line
Count
Source (jump to first uncovered line)
1
///////////////////////////////////////////////////////////////////////
2
// File:        plumbing.h
3
// Description: Base class for networks that organize other networks
4
//              eg series or parallel.
5
// Author:      Ray Smith
6
//
7
// (C) Copyright 2014, Google Inc.
8
// Licensed under the Apache License, Version 2.0 (the "License");
9
// you may not use this file except in compliance with the License.
10
// You may obtain a copy of the License at
11
// http://www.apache.org/licenses/LICENSE-2.0
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
17
///////////////////////////////////////////////////////////////////////
18
19
#ifndef TESSERACT_LSTM_PLUMBING_H_
20
#define TESSERACT_LSTM_PLUMBING_H_
21
22
#include "matrix.h"
23
#include "network.h"
24
25
namespace tesseract {
26
27
// Holds a collection of other networks and forwards calls to each of them.
28
class TESS_API Plumbing : public Network {
29
public:
30
  // ni_ and no_ will be set by AddToStack.
31
  explicit Plumbing(const std::string &name);
32
0
  ~Plumbing() override {
33
0
    for (auto data : stack_) {
34
0
      delete data;
35
0
    }
36
0
  }
37
38
  // Returns the required shape input to the network.
39
282k
  StaticShape InputShape() const override {
40
282k
    return stack_[0]->InputShape();
41
282k
  }
42
0
  std::string spec() const override {
43
0
    return "Sub-classes of Plumbing must implement spec()!";
44
0
  }
45
46
  // Returns true if the given type is derived from Plumbing, and thus contains
47
  // multiple sub-networks that can have their own learning rate.
48
0
  bool IsPlumbingType() const override {
49
0
    return true;
50
0
  }
51
52
  // Suspends/Enables training by setting the training_ flag. Serialize and
53
  // DeSerialize only operate on the run-time data if state is false.
54
  void SetEnableTraining(TrainingState state) override;
55
56
  // Sets flags that control the action of the network. See NetworkFlags enum
57
  // for bit values.
58
  void SetNetworkFlags(uint32_t flags) override;
59
60
  // Sets up the network for training. Initializes weights using weights of
61
  // scale `range` picked according to the random number generator `randomizer`.
62
  // Note that randomizer is a borrowed pointer that should outlive the network
63
  // and should not be deleted by any of the networks.
64
  // Returns the number of weights initialized.
65
  int InitWeights(float range, TRand *randomizer) override;
66
  // Recursively searches the network for softmaxes with old_no outputs,
67
  // and remaps their outputs according to code_map. See network.h for details.
68
  int RemapOutputs(int old_no, const std::vector<int> &code_map) override;
69
70
  // Converts a float network to an int network.
71
  void ConvertToInt() override;
72
73
  // Provides a pointer to a TRand for any networks that care to use it.
74
  // Note that randomizer is a borrowed pointer that should outlive the network
75
  // and should not be deleted by any of the networks.
76
  void SetRandomizer(TRand *randomizer) override;
77
78
  // Adds the given network to the stack.
79
  virtual void AddToStack(Network *network);
80
81
  // Sets needs_to_backprop_ to needs_backprop and returns true if
82
  // needs_backprop || any weights in this network so the next layer forward
83
  // can be told to produce backprop for this layer if needed.
84
  bool SetupNeedsBackprop(bool needs_backprop) override;
85
86
  // Returns an integer reduction factor that the network applies to the
87
  // time sequence. Assumes that any 2-d is already eliminated. Used for
88
  // scaling bounding boxes of truth data.
89
  // WARNING: if GlobalMinimax is used to vary the scale, this will return
90
  // the last used scale factor. Call it before any forward, and it will return
91
  // the minimum scale factor of the paths through the GlobalMinimax.
92
  int XScaleFactor() const override;
93
94
  // Provides the (minimum) x scale factor to the network (of interest only to
95
  // input units) so they can determine how to scale bounding boxes.
96
  void CacheXScaleFactor(int factor) override;
97
98
  // Provides debug output on the weights.
99
  void DebugWeights() override;
100
101
  // Returns the current stack.
102
0
  const std::vector<Network *> &stack() const {
103
0
    return stack_;
104
0
  }
105
  // Returns a set of strings representing the layer-ids of all layers below.
106
  void EnumerateLayers(const std::string *prefix, std::vector<std::string> &layers) const;
107
  // Returns a pointer to the network layer corresponding to the given id.
108
  Network *GetLayer(const char *id) const;
109
  // Returns the learning rate for a specific layer of the stack.
110
0
  float LayerLearningRate(const char *id) {
111
0
    const float *lr_ptr = LayerLearningRatePtr(id);
112
0
    ASSERT_HOST(lr_ptr != nullptr);
113
0
    return *lr_ptr;
114
0
  }
115
  // Scales the learning rate for a specific layer of the stack.
116
0
  void ScaleLayerLearningRate(const char *id, double factor) {
117
0
    float *lr_ptr = LayerLearningRatePtr(id);
118
0
    ASSERT_HOST(lr_ptr != nullptr);
119
0
    *lr_ptr *= factor;
120
0
  }
121
122
  // Set the learning rate for a specific layer of the stack to the given value.
123
0
  void SetLayerLearningRate(const char *id, float learning_rate) {
124
0
    float *lr_ptr = LayerLearningRatePtr(id);
125
0
    ASSERT_HOST(lr_ptr != nullptr);
126
0
    *lr_ptr = learning_rate;
127
0
  }
128
129
  // Returns a pointer to the learning rate for the given layer id.
130
  float *LayerLearningRatePtr(const char *id);
131
132
  // Writes to the given file. Returns false in case of error.
133
  bool Serialize(TFile *fp) const override;
134
  // Reads from the given file. Returns false in case of error.
135
  bool DeSerialize(TFile *fp) override;
136
137
  // Updates the weights using the given learning rate, momentum and adam_beta.
138
  // num_samples is used in the adam computation iff use_adam_ is true.
139
  void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override;
140
  // Sums the products of weight updates in *this and other, splitting into
141
  // positive (same direction) in *same and negative (different direction) in
142
  // *changed.
143
  void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override;
144
145
protected:
146
  // The networks.
147
  std::vector<Network *> stack_;
148
  // Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR.
149
  // One element for each element of stack_.
150
  std::vector<float> learning_rates_;
151
};
152
153
} // namespace tesseract.
154
155
#endif // TESSERACT_LSTM_PLUMBING_H_