/src/tesseract/src/wordrec/params_model.cpp
Line | Count | Source |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: params_model.cpp |
3 | | // Description: Trained language model parameters. |
4 | | // Author: David Eger |
5 | | // |
6 | | // (C) Copyright 2012, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | // |
17 | | /////////////////////////////////////////////////////////////////////// |
18 | | |
19 | | #include "params_model.h" |
20 | | |
21 | | #include <cctype> |
22 | | #include <cmath> |
23 | | #include <cstdio> |
24 | | |
25 | | #include "bitvector.h" |
26 | | #include "helpers.h" // for ClipToRange |
27 | | #include "serialis.h" // for TFile |
28 | | #include "tprintf.h" |
29 | | |
30 | | namespace tesseract { |
31 | | |
32 | | // Scale factor to apply to params model scores. |
33 | | static const float kScoreScaleFactor = 100.0f; |
34 | | // Minimum cost result to return. |
35 | | static const float kMinFinalCost = 0.001f; |
36 | | // Maximum cost result to return. |
37 | | static const float kMaxFinalCost = 100.0f; |
38 | | |
39 | 0 | void ParamsModel::Print() { |
40 | 0 | for (int p = 0; p < PTRAIN_NUM_PASSES; ++p) { |
41 | 0 | tprintf("ParamsModel for pass %d lang %s\n", p, lang_.c_str()); |
42 | 0 | for (unsigned i = 0; i < weights_vec_[p].size(); ++i) { |
43 | 0 | tprintf("%s = %g\n", kParamsTrainingFeatureTypeName[i], weights_vec_[p][i]); |
44 | 0 | } |
45 | 0 | } |
46 | 0 | } |
47 | | |
48 | 0 | void ParamsModel::Copy(const ParamsModel &other_model) { |
49 | 0 | for (int p = 0; p < PTRAIN_NUM_PASSES; ++p) { |
50 | 0 | weights_vec_[p] = other_model.weights_for_pass(static_cast<PassEnum>(p)); |
51 | 0 | } |
52 | 0 | } |
53 | | |
54 | | // Given a (modifiable) line, parse out a key / value pair. |
55 | | // Return true on success. |
56 | 0 | bool ParamsModel::ParseLine(char *line, char **key, float *val) { |
57 | 0 | if (line[0] == '#') { |
58 | 0 | return false; |
59 | 0 | } |
60 | 0 | int end_of_key = 0; |
61 | 0 | while (line[end_of_key] && !(isascii(line[end_of_key]) && isspace(line[end_of_key]))) { |
62 | 0 | end_of_key++; |
63 | 0 | } |
64 | 0 | if (!line[end_of_key]) { |
65 | 0 | tprintf("ParamsModel::Incomplete line %s\n", line); |
66 | 0 | return false; |
67 | 0 | } |
68 | 0 | line[end_of_key++] = 0; |
69 | 0 | *key = line; |
70 | 0 | if (sscanf(line + end_of_key, " %f", val) != 1) { |
71 | 0 | return false; |
72 | 0 | } |
73 | 0 | return true; |
74 | 0 | } |
75 | | |
76 | | // Applies params model weights to the given features. |
77 | | // Assumes that features is an array of size PTRAIN_NUM_FEATURE_TYPES. |
78 | | // The cost is set to a number that can be multiplied by the outline length, |
79 | | // as with the old ratings scheme. This enables words of different length |
80 | | // and combinations of words to be compared meaningfully. |
81 | 0 | float ParamsModel::ComputeCost(const float features[]) const { |
82 | 0 | float unnorm_score = 0.0; |
83 | 0 | for (int f = 0; f < PTRAIN_NUM_FEATURE_TYPES; ++f) { |
84 | 0 | unnorm_score += weights_vec_[pass_][f] * features[f]; |
85 | 0 | } |
86 | 0 | return ClipToRange(-unnorm_score / kScoreScaleFactor, kMinFinalCost, kMaxFinalCost); |
87 | 0 | } |
88 | | |
89 | 0 | bool ParamsModel::Equivalent(const ParamsModel &that) const { |
90 | 0 | float epsilon = 0.0001f; |
91 | 0 | for (int p = 0; p < PTRAIN_NUM_PASSES; ++p) { |
92 | 0 | if (weights_vec_[p].size() != that.weights_vec_[p].size()) { |
93 | 0 | return false; |
94 | 0 | } |
95 | 0 | for (unsigned i = 0; i < weights_vec_[p].size(); i++) { |
96 | 0 | if (weights_vec_[p][i] != that.weights_vec_[p][i] && |
97 | 0 | std::fabs(weights_vec_[p][i] - that.weights_vec_[p][i]) > epsilon) { |
98 | 0 | return false; |
99 | 0 | } |
100 | 0 | } |
101 | 0 | } |
102 | 0 | return true; |
103 | 0 | } |
104 | | |
105 | 0 | bool ParamsModel::LoadFromFp(const char *lang, TFile *fp) { |
106 | 0 | const int kMaxLineSize = 100; |
107 | 0 | char line[kMaxLineSize]; |
108 | 0 | BitVector present; |
109 | 0 | present.Init(PTRAIN_NUM_FEATURE_TYPES); |
110 | 0 | lang_ = lang; |
111 | | // Load weights for passes with adaption on. |
112 | 0 | std::vector<float> &weights = weights_vec_[pass_]; |
113 | 0 | weights.clear(); |
114 | 0 | weights.resize(PTRAIN_NUM_FEATURE_TYPES, 0.0f); |
115 | |
|
116 | 0 | while (fp->FGets(line, kMaxLineSize) != nullptr) { |
117 | 0 | char *key = nullptr; |
118 | 0 | float value; |
119 | 0 | if (!ParseLine(line, &key, &value)) { |
120 | 0 | continue; |
121 | 0 | } |
122 | 0 | int idx = ParamsTrainingFeatureByName(key); |
123 | 0 | if (idx < 0) { |
124 | 0 | tprintf("ParamsModel::Unknown parameter %s\n", key); |
125 | 0 | continue; |
126 | 0 | } |
127 | 0 | if (!present[idx]) { |
128 | 0 | present.SetValue(idx, true); |
129 | 0 | } |
130 | 0 | weights[idx] = value; |
131 | 0 | } |
132 | 0 | bool complete = (present.NumSetBits() == PTRAIN_NUM_FEATURE_TYPES); |
133 | 0 | if (!complete) { |
134 | 0 | for (int i = 0; i < PTRAIN_NUM_FEATURE_TYPES; i++) { |
135 | 0 | if (!present[i]) { |
136 | 0 | tprintf("Missing field %s.\n", kParamsTrainingFeatureTypeName[i]); |
137 | 0 | } |
138 | 0 | } |
139 | 0 | lang_ = ""; |
140 | 0 | weights.clear(); |
141 | 0 | } |
142 | 0 | return complete; |
143 | 0 | } |
144 | | |
145 | 0 | bool ParamsModel::SaveToFile(const char *full_path) const { |
146 | 0 | const std::vector<float> &weights = weights_vec_[pass_]; |
147 | 0 | if (weights.size() != PTRAIN_NUM_FEATURE_TYPES) { |
148 | 0 | tprintf("Refusing to save ParamsModel that has not been initialized.\n"); |
149 | 0 | return false; |
150 | 0 | } |
151 | 0 | FILE *fp = fopen(full_path, "wb"); |
152 | 0 | if (!fp) { |
153 | 0 | tprintf("Could not open %s for writing.\n", full_path); |
154 | 0 | return false; |
155 | 0 | } |
156 | 0 | bool all_good = true; |
157 | 0 | for (unsigned i = 0; i < weights.size(); i++) { |
158 | 0 | if (fprintf(fp, "%s %f\n", kParamsTrainingFeatureTypeName[i], weights[i]) < 0) { |
159 | 0 | all_good = false; |
160 | 0 | } |
161 | 0 | } |
162 | 0 | fclose(fp); |
163 | 0 | return all_good; |
164 | 0 | } |
165 | | |
166 | | } // namespace tesseract |