/src/tesseract/src/wordrec/language_model.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: language_model.cpp |
3 | | // Description: Functions that utilize the knowledge about the properties, |
4 | | // structure and statistics of the language to help recognition. |
5 | | // Author: Daria Antonova |
6 | | // |
7 | | // (C) Copyright 2009, Google Inc. |
8 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
9 | | // you may not use this file except in compliance with the License. |
10 | | // You may obtain a copy of the License at |
11 | | // http://www.apache.org/licenses/LICENSE-2.0 |
12 | | // Unless required by applicable law or agreed to in writing, software |
13 | | // distributed under the License is distributed on an "AS IS" BASIS, |
14 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | | // See the License for the specific language governing permissions and |
16 | | // limitations under the License. |
17 | | // |
18 | | /////////////////////////////////////////////////////////////////////// |
19 | | |
20 | | #include "language_model.h" |
21 | | #include <tesseract/unichar.h> // for UNICHAR_ID, INVALID_UNICHAR_ID |
22 | | #include <cassert> // for assert |
23 | | #include <cmath> // for log2, pow |
24 | | #include "blamer.h" // for BlamerBundle |
25 | | #include "ccutil.h" // for CCUtil |
26 | | #include "dawg.h" // for NO_EDGE, Dawg, Dawg::kPatternUn... |
27 | | #include "errcode.h" // for ASSERT_HOST |
28 | | #include "lm_state.h" // for ViterbiStateEntry, ViterbiState... |
29 | | #include "matrix.h" // for MATRIX_COORD |
30 | | #include "pageres.h" // for WERD_RES |
31 | | #include "params.h" // for IntParam, BoolParam, DoubleParam |
32 | | #include "params_training_featdef.h" // for ParamsTrainingHypothesis, PTRAI... |
33 | | #include "tprintf.h" // for tprintf |
34 | | #include "unicharset.h" // for UNICHARSET |
35 | | #include "unicity_table.h" // for UnicityTable |
36 | | |
37 | | template <typename T> |
38 | | class UnicityTable; |
39 | | |
40 | | namespace tesseract { |
41 | | |
42 | | class LMPainPoints; |
43 | | struct FontInfo; |
44 | | |
45 | | #if defined(ANDROID) |
46 | | static inline double log2(double n) { |
47 | | return log(n) / log(2.0); |
48 | | } |
49 | | #endif // ANDROID |
50 | | |
51 | | const float LanguageModel::kMaxAvgNgramCost = 25.0f; |
52 | | |
53 | | LanguageModel::LanguageModel(const UnicityTable<FontInfo> *fontinfo_table, Dict *dict) |
54 | 2 | : INT_MEMBER(language_model_debug_level, 0, "Language model debug level", |
55 | | dict->getCCUtil()->params()) |
56 | 2 | , BOOL_INIT_MEMBER(language_model_ngram_on, false, |
57 | | "Turn on/off the use of character ngram model", dict->getCCUtil()->params()) |
58 | 2 | , INT_MEMBER(language_model_ngram_order, 8, "Maximum order of the character ngram model", |
59 | | dict->getCCUtil()->params()) |
60 | 2 | , INT_MEMBER(language_model_viterbi_list_max_num_prunable, 10, |
61 | | "Maximum number of prunable (those for which" |
62 | | " PrunablePath() is true) entries in each viterbi list" |
63 | | " recorded in BLOB_CHOICEs", |
64 | | dict->getCCUtil()->params()) |
65 | 2 | , INT_MEMBER(language_model_viterbi_list_max_size, 500, |
66 | | "Maximum size of viterbi lists recorded in BLOB_CHOICEs", |
67 | | dict->getCCUtil()->params()) |
68 | 2 | , double_MEMBER(language_model_ngram_small_prob, 0.000001, |
69 | | "To avoid overly small denominators use this as the " |
70 | | "floor of the probability returned by the ngram model.", |
71 | | dict->getCCUtil()->params()) |
72 | 2 | , double_MEMBER(language_model_ngram_nonmatch_score, -40.0, |
73 | | "Average classifier score of a non-matching unichar.", |
74 | | dict->getCCUtil()->params()) |
75 | 2 | , BOOL_MEMBER(language_model_ngram_use_only_first_uft8_step, false, |
76 | | "Use only the first UTF8 step of the given string" |
77 | | " when computing log probabilities.", |
78 | | dict->getCCUtil()->params()) |
79 | 2 | , double_MEMBER(language_model_ngram_scale_factor, 0.03, |
80 | | "Strength of the character ngram model relative to the" |
81 | | " character classifier ", |
82 | | dict->getCCUtil()->params()) |
83 | 2 | , double_MEMBER(language_model_ngram_rating_factor, 16.0, |
84 | | "Factor to bring log-probs into the same range as ratings" |
85 | | " when multiplied by outline length ", |
86 | | dict->getCCUtil()->params()) |
87 | 2 | , BOOL_MEMBER(language_model_ngram_space_delimited_language, true, |
88 | | "Words are delimited by space", dict->getCCUtil()->params()) |
89 | 2 | , INT_MEMBER(language_model_min_compound_length, 3, "Minimum length of compound words", |
90 | | dict->getCCUtil()->params()) |
91 | 2 | , double_MEMBER(language_model_penalty_non_freq_dict_word, 0.1, |
92 | | "Penalty for words not in the frequent word dictionary", |
93 | | dict->getCCUtil()->params()) |
94 | 2 | , double_MEMBER(language_model_penalty_non_dict_word, 0.15, "Penalty for non-dictionary words", |
95 | | dict->getCCUtil()->params()) |
96 | 2 | , double_MEMBER(language_model_penalty_punc, 0.2, "Penalty for inconsistent punctuation", |
97 | | dict->getCCUtil()->params()) |
98 | 2 | , double_MEMBER(language_model_penalty_case, 0.1, "Penalty for inconsistent case", |
99 | | dict->getCCUtil()->params()) |
100 | 2 | , double_MEMBER(language_model_penalty_script, 0.5, "Penalty for inconsistent script", |
101 | | dict->getCCUtil()->params()) |
102 | 2 | , double_MEMBER(language_model_penalty_chartype, 0.3, "Penalty for inconsistent character type", |
103 | | dict->getCCUtil()->params()) |
104 | | , |
105 | | // TODO(daria, rays): enable font consistency checking |
106 | | // after improving font analysis. |
107 | 2 | double_MEMBER(language_model_penalty_font, 0.00, "Penalty for inconsistent font", |
108 | | dict->getCCUtil()->params()) |
109 | 2 | , double_MEMBER(language_model_penalty_spacing, 0.05, "Penalty for inconsistent spacing", |
110 | | dict->getCCUtil()->params()) |
111 | 2 | , double_MEMBER(language_model_penalty_increment, 0.01, "Penalty increment", |
112 | | dict->getCCUtil()->params()) |
113 | 2 | , INT_MEMBER(wordrec_display_segmentations, 0, "Display Segmentations (ScrollView)", |
114 | | dict->getCCUtil()->params()) |
115 | 2 | , BOOL_INIT_MEMBER(language_model_use_sigmoidal_certainty, false, |
116 | | "Use sigmoidal score for certainty", dict->getCCUtil()->params()) |
117 | 2 | , dawg_args_(nullptr, new DawgPositionVector(), NO_PERM) |
118 | 2 | , fontinfo_table_(fontinfo_table) |
119 | 2 | , dict_(dict) { |
120 | 2 | ASSERT_HOST(dict_ != nullptr); |
121 | 2 | } |
122 | | |
123 | 0 | LanguageModel::~LanguageModel() { |
124 | 0 | delete dawg_args_.updated_dawgs; |
125 | 0 | } |
126 | | |
127 | | void LanguageModel::InitForWord(const WERD_CHOICE *prev_word, bool fixed_pitch, |
128 | 98.6k | float max_char_wh_ratio, float rating_cert_scale) { |
129 | 98.6k | fixed_pitch_ = fixed_pitch; |
130 | 98.6k | max_char_wh_ratio_ = max_char_wh_ratio; |
131 | 98.6k | rating_cert_scale_ = rating_cert_scale; |
132 | 98.6k | acceptable_choice_found_ = false; |
133 | 98.6k | correct_segmentation_explored_ = false; |
134 | | |
135 | | // Initialize vectors with beginning DawgInfos. |
136 | 98.6k | very_beginning_active_dawgs_.clear(); |
137 | 98.6k | dict_->init_active_dawgs(&very_beginning_active_dawgs_, false); |
138 | 98.6k | beginning_active_dawgs_.clear(); |
139 | 98.6k | dict_->default_dawgs(&beginning_active_dawgs_, false); |
140 | | |
141 | | // Fill prev_word_str_ with the last language_model_ngram_order |
142 | | // unichars from prev_word. |
143 | 98.6k | if (language_model_ngram_on) { |
144 | 0 | if (prev_word != nullptr && !prev_word->unichar_string().empty()) { |
145 | 0 | prev_word_str_ = prev_word->unichar_string(); |
146 | 0 | if (language_model_ngram_space_delimited_language) { |
147 | 0 | prev_word_str_ += ' '; |
148 | 0 | } |
149 | 0 | } else { |
150 | 0 | prev_word_str_ = " "; |
151 | 0 | } |
152 | 0 | const char *str_ptr = prev_word_str_.c_str(); |
153 | 0 | const char *str_end = str_ptr + prev_word_str_.length(); |
154 | 0 | int step; |
155 | 0 | prev_word_unichar_step_len_ = 0; |
156 | 0 | while (str_ptr != str_end && (step = UNICHAR::utf8_step(str_ptr))) { |
157 | 0 | str_ptr += step; |
158 | 0 | ++prev_word_unichar_step_len_; |
159 | 0 | } |
160 | 0 | ASSERT_HOST(str_ptr == str_end); |
161 | 0 | } |
162 | 98.6k | } |
163 | | |
164 | | /** |
165 | | * Helper scans the collection of predecessors for competing siblings that |
166 | | * have the same letter with the opposite case, setting competing_vse. |
167 | | */ |
168 | 2.49M | static void ScanParentsForCaseMix(const UNICHARSET &unicharset, LanguageModelState *parent_node) { |
169 | 2.49M | if (parent_node == nullptr) { |
170 | 238k | return; |
171 | 238k | } |
172 | 2.25M | ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); |
173 | 15.0M | for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { |
174 | 12.7M | ViterbiStateEntry *vse = vit.data(); |
175 | 12.7M | vse->competing_vse = nullptr; |
176 | 12.7M | UNICHAR_ID unichar_id = vse->curr_b->unichar_id(); |
177 | 12.7M | if (unicharset.get_isupper(unichar_id) || unicharset.get_islower(unichar_id)) { |
178 | 7.98M | UNICHAR_ID other_case = unicharset.get_other_case(unichar_id); |
179 | 7.98M | if (other_case == unichar_id) { |
180 | 385k | continue; // Not in unicharset. |
181 | 385k | } |
182 | | // Find other case in same list. There could be multiple entries with |
183 | | // the same unichar_id, but in theory, they should all point to the |
184 | | // same BLOB_CHOICE, and that is what we will be using to decide |
185 | | // which to keep. |
186 | 7.60M | ViterbiStateEntry_IT vit2(&parent_node->viterbi_state_entries); |
187 | 7.60M | for (vit2.mark_cycle_pt(); |
188 | 82.0M | !vit2.cycled_list() && vit2.data()->curr_b->unichar_id() != other_case; vit2.forward()) { |
189 | 74.4M | } |
190 | 7.60M | if (!vit2.cycled_list()) { |
191 | 1.10M | vse->competing_vse = vit2.data(); |
192 | 1.10M | } |
193 | 7.60M | } |
194 | 12.7M | } |
195 | 2.25M | } |
196 | | |
197 | | /** |
198 | | * Helper returns true if the given choice has a better case variant before |
199 | | * it in the choice_list that is not distinguishable by size. |
200 | | */ |
201 | | static bool HasBetterCaseVariant(const UNICHARSET &unicharset, const BLOB_CHOICE *choice, |
202 | 7.60M | BLOB_CHOICE_LIST *choices) { |
203 | 7.60M | UNICHAR_ID choice_id = choice->unichar_id(); |
204 | 7.60M | UNICHAR_ID other_case = unicharset.get_other_case(choice_id); |
205 | 7.60M | if (other_case == choice_id || other_case == INVALID_UNICHAR_ID) { |
206 | 4.16M | return false; // Not upper or lower or not in unicharset. |
207 | 4.16M | } |
208 | 3.44M | if (unicharset.SizesDistinct(choice_id, other_case)) { |
209 | 1.91M | return false; // Can be separated by size. |
210 | 1.91M | } |
211 | 1.53M | BLOB_CHOICE_IT bc_it(choices); |
212 | 4.23M | for (bc_it.mark_cycle_pt(); !bc_it.cycled_list(); bc_it.forward()) { |
213 | 4.23M | BLOB_CHOICE *better_choice = bc_it.data(); |
214 | 4.23M | if (better_choice->unichar_id() == other_case) { |
215 | 112k | return true; // Found an earlier instance of other_case. |
216 | 4.12M | } else if (better_choice == choice) { |
217 | 1.41M | return false; // Reached the original choice. |
218 | 1.41M | } |
219 | 4.23M | } |
220 | 0 | return false; // Should never happen, but just in case. |
221 | 1.53M | } |
222 | | |
223 | | /** |
224 | | * UpdateState has the job of combining the ViterbiStateEntry lists on each |
225 | | * of the choices on parent_list with each of the blob choices in curr_list, |
226 | | * making a new ViterbiStateEntry for each sensible path. |
227 | | * |
228 | | * This could be a huge set of combinations, creating a lot of work only to |
229 | | * be truncated by some beam limit, but only certain kinds of paths will |
230 | | * continue at the next step: |
231 | | * - paths that are liked by the language model: either a DAWG or the n-gram |
232 | | * model, where active. |
233 | | * - paths that represent some kind of top choice. The old permuter permuted |
234 | | * the top raw classifier score, the top upper case word and the top lower- |
235 | | * case word. UpdateState now concentrates its top-choice paths on top |
236 | | * lower-case, top upper-case (or caseless alpha), and top digit sequence, |
237 | | * with allowance for continuation of these paths through blobs where such |
238 | | * a character does not appear in the choices list. |
239 | | * |
240 | | * GetNextParentVSE enforces some of these models to minimize the number of |
241 | | * calls to AddViterbiStateEntry, even prior to looking at the language model. |
242 | | * Thus an n-blob sequence of [l1I] will produce 3n calls to |
243 | | * AddViterbiStateEntry instead of 3^n. |
244 | | * |
245 | | * Of course it isn't quite that simple as Title Case is handled by allowing |
246 | | * lower case to continue an upper case initial, but it has to be detected |
247 | | * in the combiner so it knows which upper case letters are initial alphas. |
248 | | */ |
249 | | bool LanguageModel::UpdateState(bool just_classified, int curr_col, int curr_row, |
250 | | BLOB_CHOICE_LIST *curr_list, LanguageModelState *parent_node, |
251 | | LMPainPoints *pain_points, WERD_RES *word_res, |
252 | 2.49M | BestChoiceBundle *best_choice_bundle, BlamerBundle *blamer_bundle) { |
253 | 2.49M | if (language_model_debug_level > 0) { |
254 | 0 | tprintf("\nUpdateState: col=%d row=%d %s", curr_col, curr_row, |
255 | 0 | just_classified ? "just_classified" : ""); |
256 | 0 | if (language_model_debug_level > 5) { |
257 | 0 | tprintf("(parent=%p)\n", static_cast<void *>(parent_node)); |
258 | 0 | } else { |
259 | 0 | tprintf("\n"); |
260 | 0 | } |
261 | 0 | } |
262 | | // Initialize helper variables. |
263 | 2.49M | bool word_end = (curr_row + 1 >= word_res->ratings->dimension()); |
264 | 2.49M | bool new_changed = false; |
265 | 2.49M | float denom = (language_model_ngram_on) ? ComputeDenom(curr_list) : 1.0f; |
266 | 2.49M | const UNICHARSET &unicharset = dict_->getUnicharset(); |
267 | 2.49M | BLOB_CHOICE *first_lower = nullptr; |
268 | 2.49M | BLOB_CHOICE *first_upper = nullptr; |
269 | 2.49M | BLOB_CHOICE *first_digit = nullptr; |
270 | 2.49M | bool has_alnum_mix = false; |
271 | 2.49M | if (parent_node != nullptr) { |
272 | 2.25M | int result = SetTopParentLowerUpperDigit(parent_node); |
273 | 2.25M | if (result < 0) { |
274 | 0 | if (language_model_debug_level > 0) { |
275 | 0 | tprintf("No parents found to process\n"); |
276 | 0 | } |
277 | 0 | return false; |
278 | 0 | } |
279 | 2.25M | if (result > 0) { |
280 | 262k | has_alnum_mix = true; |
281 | 262k | } |
282 | 2.25M | } |
283 | 2.49M | if (!GetTopLowerUpperDigit(curr_list, &first_lower, &first_upper, &first_digit)) { |
284 | 2.24M | has_alnum_mix = false; |
285 | 2.24M | }; |
286 | 2.49M | ScanParentsForCaseMix(unicharset, parent_node); |
287 | 2.49M | if (language_model_debug_level > 3 && parent_node != nullptr) { |
288 | 0 | parent_node->Print("Parent viterbi list"); |
289 | 0 | } |
290 | 2.49M | LanguageModelState *curr_state = best_choice_bundle->beam[curr_row]; |
291 | | |
292 | | // Call AddViterbiStateEntry() for each parent+child ViterbiStateEntry. |
293 | 2.49M | ViterbiStateEntry_IT vit; |
294 | 2.49M | BLOB_CHOICE_IT c_it(curr_list); |
295 | 11.7M | for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { |
296 | 9.23M | BLOB_CHOICE *choice = c_it.data(); |
297 | | // TODO(antonova): make sure commenting this out if ok for ngram |
298 | | // model scoring (I think this was introduced to fix ngram model quirks). |
299 | | // Skip nullptr unichars unless it is the only choice. |
300 | | // if (!curr_list->singleton() && c_it.data()->unichar_id() == 0) continue; |
301 | 9.23M | UNICHAR_ID unichar_id = choice->unichar_id(); |
302 | 9.23M | if (unicharset.get_fragment(unichar_id)) { |
303 | 0 | continue; // Skip fragments. |
304 | 0 | } |
305 | | // Set top choice flags. |
306 | 9.23M | LanguageModelFlagsType blob_choice_flags = kXhtConsistentFlag; |
307 | 9.23M | if (c_it.at_first() || !new_changed) { |
308 | 5.71M | blob_choice_flags |= kSmallestRatingFlag; |
309 | 5.71M | } |
310 | 9.23M | if (first_lower == choice) { |
311 | 2.49M | blob_choice_flags |= kLowerCaseFlag; |
312 | 2.49M | } |
313 | 9.23M | if (first_upper == choice) { |
314 | 2.49M | blob_choice_flags |= kUpperCaseFlag; |
315 | 2.49M | } |
316 | 9.23M | if (first_digit == choice) { |
317 | 2.49M | blob_choice_flags |= kDigitFlag; |
318 | 2.49M | } |
319 | | |
320 | 9.23M | if (parent_node == nullptr) { |
321 | | // Process the beginning of a word. |
322 | | // If there is a better case variant that is not distinguished by size, |
323 | | // skip this blob choice, as we have no choice but to accept the result |
324 | | // of the character classifier to distinguish between them, even if |
325 | | // followed by an upper case. |
326 | | // With words like iPoc, and other CamelBackWords, the lower-upper |
327 | | // transition can only be achieved if the classifier has the correct case |
328 | | // as the top choice, and leaving an initial I lower down the list |
329 | | // increases the chances of choosing IPoc simply because it doesn't |
330 | | // include such a transition. iPoc will beat iPOC and ipoc because |
331 | | // the other words are baseline/x-height inconsistent. |
332 | 977k | if (HasBetterCaseVariant(unicharset, choice, curr_list)) { |
333 | 23.5k | continue; |
334 | 23.5k | } |
335 | | // Upper counts as lower at the beginning of a word. |
336 | 953k | if (blob_choice_flags & kUpperCaseFlag) { |
337 | 226k | blob_choice_flags |= kLowerCaseFlag; |
338 | 226k | } |
339 | 953k | new_changed |= AddViterbiStateEntry(blob_choice_flags, denom, word_end, curr_col, curr_row, |
340 | 953k | choice, curr_state, nullptr, pain_points, word_res, |
341 | 953k | best_choice_bundle, blamer_bundle); |
342 | 8.25M | } else { |
343 | | // Get viterbi entries from each parent ViterbiStateEntry. |
344 | 8.25M | vit.set_to_list(&parent_node->viterbi_state_entries); |
345 | 8.25M | int vit_counter = 0; |
346 | 8.25M | vit.mark_cycle_pt(); |
347 | 8.25M | ViterbiStateEntry *parent_vse = nullptr; |
348 | 8.25M | LanguageModelFlagsType top_choice_flags; |
349 | 26.6M | while ((parent_vse = |
350 | 26.6M | GetNextParentVSE(just_classified, has_alnum_mix, c_it.data(), blob_choice_flags, |
351 | 26.6M | unicharset, word_res, &vit, &top_choice_flags)) != nullptr) { |
352 | | // Skip pruned entries and do not look at prunable entries if already |
353 | | // examined language_model_viterbi_list_max_num_prunable of those. |
354 | 18.3M | if (PrunablePath(*parent_vse) && |
355 | 18.3M | (++vit_counter > language_model_viterbi_list_max_num_prunable || |
356 | 3.77M | (language_model_ngram_on && parent_vse->ngram_info->pruned))) { |
357 | 256k | continue; |
358 | 256k | } |
359 | | // If the parent has no alnum choice, (ie choice is the first in a |
360 | | // string of alnum), and there is a better case variant that is not |
361 | | // distinguished by size, skip this blob choice/parent, as with the |
362 | | // initial blob treatment above. |
363 | 18.1M | if (!parent_vse->HasAlnumChoice(unicharset) && |
364 | 18.1M | HasBetterCaseVariant(unicharset, choice, curr_list)) { |
365 | 88.6k | continue; |
366 | 88.6k | } |
367 | | // Create a new ViterbiStateEntry if BLOB_CHOICE in c_it.data() |
368 | | // looks good according to the Dawgs or character ngram model. |
369 | 18.0M | new_changed |= AddViterbiStateEntry(top_choice_flags, denom, word_end, curr_col, curr_row, |
370 | 18.0M | c_it.data(), curr_state, parent_vse, pain_points, |
371 | 18.0M | word_res, best_choice_bundle, blamer_bundle); |
372 | 18.0M | } |
373 | 8.25M | } |
374 | 9.23M | } |
375 | 2.49M | return new_changed; |
376 | 2.49M | } |
377 | | |
378 | | /** |
379 | | * Finds the first lower and upper case letter and first digit in curr_list. |
380 | | * For non-upper/lower languages, alpha counts as upper. |
381 | | * Uses the first character in the list in place of empty results. |
382 | | * Returns true if both alpha and digits are found. |
383 | | */ |
384 | | bool LanguageModel::GetTopLowerUpperDigit(BLOB_CHOICE_LIST *curr_list, BLOB_CHOICE **first_lower, |
385 | | BLOB_CHOICE **first_upper, |
386 | 2.49M | BLOB_CHOICE **first_digit) const { |
387 | 2.49M | BLOB_CHOICE_IT c_it(curr_list); |
388 | 2.49M | const UNICHARSET &unicharset = dict_->getUnicharset(); |
389 | 2.49M | BLOB_CHOICE *first_unichar = nullptr; |
390 | 11.7M | for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { |
391 | 9.23M | UNICHAR_ID unichar_id = c_it.data()->unichar_id(); |
392 | 9.23M | if (unicharset.get_fragment(unichar_id)) { |
393 | 0 | continue; // skip fragments |
394 | 0 | } |
395 | 9.23M | if (first_unichar == nullptr) { |
396 | 2.49M | first_unichar = c_it.data(); |
397 | 2.49M | } |
398 | 9.23M | if (*first_lower == nullptr && unicharset.get_islower(unichar_id)) { |
399 | 2.04M | *first_lower = c_it.data(); |
400 | 2.04M | } |
401 | 9.23M | if (*first_upper == nullptr && unicharset.get_isalpha(unichar_id) && |
402 | 9.23M | !unicharset.get_islower(unichar_id)) { |
403 | 804k | *first_upper = c_it.data(); |
404 | 804k | } |
405 | 9.23M | if (*first_digit == nullptr && unicharset.get_isdigit(unichar_id)) { |
406 | 256k | *first_digit = c_it.data(); |
407 | 256k | } |
408 | 9.23M | } |
409 | 2.49M | ASSERT_HOST(first_unichar != nullptr); |
410 | 2.49M | bool mixed = (*first_lower != nullptr || *first_upper != nullptr) && *first_digit != nullptr; |
411 | 2.49M | if (*first_lower == nullptr) { |
412 | 444k | *first_lower = first_unichar; |
413 | 444k | } |
414 | 2.49M | if (*first_upper == nullptr) { |
415 | 1.69M | *first_upper = first_unichar; |
416 | 1.69M | } |
417 | 2.49M | if (*first_digit == nullptr) { |
418 | 2.23M | *first_digit = first_unichar; |
419 | 2.23M | } |
420 | 2.49M | return mixed; |
421 | 2.49M | } |
422 | | |
423 | | /** |
424 | | * Forces there to be at least one entry in the overall set of the |
425 | | * viterbi_state_entries of each element of parent_node that has the |
426 | | * top_choice_flag set for lower, upper and digit using the same rules as |
427 | | * GetTopLowerUpperDigit, setting the flag on the first found suitable |
428 | | * candidate, whether or not the flag is set on some other parent. |
429 | | * Returns 1 if both alpha and digits are found among the parents, -1 if no |
430 | | * parents are found at all (a legitimate case), and 0 otherwise. |
431 | | */ |
432 | 2.25M | int LanguageModel::SetTopParentLowerUpperDigit(LanguageModelState *parent_node) const { |
433 | 2.25M | if (parent_node == nullptr) { |
434 | 0 | return -1; |
435 | 0 | } |
436 | 2.25M | UNICHAR_ID top_id = INVALID_UNICHAR_ID; |
437 | 2.25M | ViterbiStateEntry *top_lower = nullptr; |
438 | 2.25M | ViterbiStateEntry *top_upper = nullptr; |
439 | 2.25M | ViterbiStateEntry *top_digit = nullptr; |
440 | 2.25M | ViterbiStateEntry *top_choice = nullptr; |
441 | 2.25M | float lower_rating = 0.0f; |
442 | 2.25M | float upper_rating = 0.0f; |
443 | 2.25M | float digit_rating = 0.0f; |
444 | 2.25M | float top_rating = 0.0f; |
445 | 2.25M | const UNICHARSET &unicharset = dict_->getUnicharset(); |
446 | 2.25M | ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); |
447 | 15.0M | for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { |
448 | 12.7M | ViterbiStateEntry *vse = vit.data(); |
449 | | // INVALID_UNICHAR_ID should be treated like a zero-width joiner, so scan |
450 | | // back to the real character if needed. |
451 | 12.7M | ViterbiStateEntry *unichar_vse = vse; |
452 | 12.7M | UNICHAR_ID unichar_id = unichar_vse->curr_b->unichar_id(); |
453 | 12.7M | float rating = unichar_vse->curr_b->rating(); |
454 | 12.7M | while (unichar_id == INVALID_UNICHAR_ID && unichar_vse->parent_vse != nullptr) { |
455 | 0 | unichar_vse = unichar_vse->parent_vse; |
456 | 0 | unichar_id = unichar_vse->curr_b->unichar_id(); |
457 | 0 | rating = unichar_vse->curr_b->rating(); |
458 | 0 | } |
459 | 12.7M | if (unichar_id != INVALID_UNICHAR_ID) { |
460 | 12.7M | if (unicharset.get_islower(unichar_id)) { |
461 | 6.55M | if (top_lower == nullptr || lower_rating > rating) { |
462 | 2.07M | top_lower = vse; |
463 | 2.07M | lower_rating = rating; |
464 | 2.07M | } |
465 | 6.55M | } else if (unicharset.get_isalpha(unichar_id)) { |
466 | 1.43M | if (top_upper == nullptr || upper_rating > rating) { |
467 | 650k | top_upper = vse; |
468 | 650k | upper_rating = rating; |
469 | 650k | } |
470 | 4.76M | } else if (unicharset.get_isdigit(unichar_id)) { |
471 | 609k | if (top_digit == nullptr || digit_rating > rating) { |
472 | 278k | top_digit = vse; |
473 | 278k | digit_rating = rating; |
474 | 278k | } |
475 | 609k | } |
476 | 12.7M | } |
477 | 12.7M | if (top_choice == nullptr || top_rating > rating) { |
478 | 2.99M | top_choice = vse; |
479 | 2.99M | top_rating = rating; |
480 | 2.99M | top_id = unichar_id; |
481 | 2.99M | } |
482 | 12.7M | } |
483 | 2.25M | if (top_choice == nullptr) { |
484 | 0 | return -1; |
485 | 0 | } |
486 | 2.25M | bool mixed = (top_lower != nullptr || top_upper != nullptr) && top_digit != nullptr; |
487 | 2.25M | if (top_lower == nullptr) { |
488 | 488k | top_lower = top_choice; |
489 | 488k | } |
490 | 2.25M | top_lower->top_choice_flags |= kLowerCaseFlag; |
491 | 2.25M | if (top_upper == nullptr) { |
492 | 1.64M | top_upper = top_choice; |
493 | 1.64M | } |
494 | 2.25M | top_upper->top_choice_flags |= kUpperCaseFlag; |
495 | 2.25M | if (top_digit == nullptr) { |
496 | 1.98M | top_digit = top_choice; |
497 | 1.98M | } |
498 | 2.25M | top_digit->top_choice_flags |= kDigitFlag; |
499 | 2.25M | top_choice->top_choice_flags |= kSmallestRatingFlag; |
500 | 2.25M | if (top_id != INVALID_UNICHAR_ID && dict_->compound_marker(top_id) && |
501 | 2.25M | (top_choice->top_choice_flags & (kLowerCaseFlag | kUpperCaseFlag | kDigitFlag))) { |
502 | | // If the compound marker top choice carries any of the top alnum flags, |
503 | | // then give it all of them, allowing words like I-295 to be chosen. |
504 | 0 | top_choice->top_choice_flags |= kLowerCaseFlag | kUpperCaseFlag | kDigitFlag; |
505 | 0 | } |
506 | 2.25M | return mixed ? 1 : 0; |
507 | 2.25M | } |
508 | | |
509 | | /** |
510 | | * Finds the next ViterbiStateEntry with which the given unichar_id can |
511 | | * combine sensibly, taking into account any mixed alnum/mixed case |
512 | | * situation, and whether this combination has been inspected before. |
513 | | */ |
514 | | ViterbiStateEntry *LanguageModel::GetNextParentVSE(bool just_classified, bool mixed_alnum, |
515 | | const BLOB_CHOICE *bc, |
516 | | LanguageModelFlagsType blob_choice_flags, |
517 | | const UNICHARSET &unicharset, WERD_RES *word_res, |
518 | | ViterbiStateEntry_IT *vse_it, |
519 | 26.6M | LanguageModelFlagsType *top_choice_flags) const { |
520 | 57.8M | for (; !vse_it->cycled_list(); vse_it->forward()) { |
521 | 49.6M | ViterbiStateEntry *parent_vse = vse_it->data(); |
522 | | // Only consider the parent if it has been updated or |
523 | | // if the current ratings cell has just been classified. |
524 | 49.6M | if (!just_classified && !parent_vse->updated) { |
525 | 30.6M | continue; |
526 | 30.6M | } |
527 | 19.0M | if (language_model_debug_level > 2) { |
528 | 0 | parent_vse->Print("Considering"); |
529 | 0 | } |
530 | | // If the parent is non-alnum, then upper counts as lower. |
531 | 19.0M | *top_choice_flags = blob_choice_flags; |
532 | 19.0M | if ((blob_choice_flags & kUpperCaseFlag) && !parent_vse->HasAlnumChoice(unicharset)) { |
533 | 1.57M | *top_choice_flags |= kLowerCaseFlag; |
534 | 1.57M | } |
535 | 19.0M | *top_choice_flags &= parent_vse->top_choice_flags; |
536 | 19.0M | UNICHAR_ID unichar_id = bc->unichar_id(); |
537 | 19.0M | const BLOB_CHOICE *parent_b = parent_vse->curr_b; |
538 | 19.0M | UNICHAR_ID parent_id = parent_b->unichar_id(); |
539 | | // Digits do not bind to alphas if there is a mix in both parent and current |
540 | | // or if the alpha is not the top choice. |
541 | 19.0M | if (unicharset.get_isdigit(unichar_id) && unicharset.get_isalpha(parent_id) && |
542 | 19.0M | (mixed_alnum || *top_choice_flags == 0)) { |
543 | 197k | continue; // Digits don't bind to alphas. |
544 | 197k | } |
545 | | // Likewise alphas do not bind to digits if there is a mix in both or if |
546 | | // the digit is not the top choice. |
547 | 18.8M | if (unicharset.get_isalpha(unichar_id) && unicharset.get_isdigit(parent_id) && |
548 | 18.8M | (mixed_alnum || *top_choice_flags == 0)) { |
549 | 391k | continue; // Alphas don't bind to digits. |
550 | 391k | } |
551 | | // If there is a case mix of the same alpha in the parent list, then |
552 | | // competing_vse is non-null and will be used to determine whether |
553 | | // or not to bind the current blob choice. |
554 | 18.4M | if (parent_vse->competing_vse != nullptr) { |
555 | 1.25M | const BLOB_CHOICE *competing_b = parent_vse->competing_vse->curr_b; |
556 | 1.25M | UNICHAR_ID other_id = competing_b->unichar_id(); |
557 | 1.25M | if (language_model_debug_level >= 5) { |
558 | 0 | tprintf("Parent %s has competition %s\n", unicharset.id_to_unichar(parent_id), |
559 | 0 | unicharset.id_to_unichar(other_id)); |
560 | 0 | } |
561 | 1.25M | if (unicharset.SizesDistinct(parent_id, other_id)) { |
562 | | // If other_id matches bc wrt position and size, and parent_id, doesn't, |
563 | | // don't bind to the current parent. |
564 | 525k | if (bc->PosAndSizeAgree(*competing_b, word_res->x_height, |
565 | 525k | language_model_debug_level >= 5) && |
566 | 525k | !bc->PosAndSizeAgree(*parent_b, word_res->x_height, language_model_debug_level >= 5)) { |
567 | 78.7k | continue; // Competing blobchoice has a better vertical match. |
568 | 78.7k | } |
569 | 525k | } |
570 | 1.25M | } |
571 | 18.3M | vse_it->forward(); |
572 | 18.3M | return parent_vse; // This one is good! |
573 | 18.4M | } |
574 | 8.25M | return nullptr; // Ran out of possibilities. |
575 | 26.6M | } |
576 | | |
577 | | bool LanguageModel::AddViterbiStateEntry(LanguageModelFlagsType top_choice_flags, float denom, |
578 | | bool word_end, int curr_col, int curr_row, BLOB_CHOICE *b, |
579 | | LanguageModelState *curr_state, |
580 | | ViterbiStateEntry *parent_vse, LMPainPoints *pain_points, |
581 | | WERD_RES *word_res, BestChoiceBundle *best_choice_bundle, |
582 | 18.9M | BlamerBundle *blamer_bundle) { |
583 | 18.9M | ViterbiStateEntry_IT vit; |
584 | 18.9M | if (language_model_debug_level > 1) { |
585 | 0 | tprintf( |
586 | 0 | "AddViterbiStateEntry for unichar %s rating=%.4f" |
587 | 0 | " certainty=%.4f top_choice_flags=0x%x", |
588 | 0 | dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->rating(), b->certainty(), |
589 | 0 | top_choice_flags); |
590 | 0 | if (language_model_debug_level > 5) { |
591 | 0 | tprintf(" parent_vse=%p\n", static_cast<void *>(parent_vse)); |
592 | 0 | } else { |
593 | 0 | tprintf("\n"); |
594 | 0 | } |
595 | 0 | } |
596 | 18.9M | ASSERT_HOST(curr_state != nullptr); |
597 | | // Check whether the list is full. |
598 | 18.9M | if (curr_state->viterbi_state_entries_length >= language_model_viterbi_list_max_size) { |
599 | 0 | if (language_model_debug_level > 1) { |
600 | 0 | tprintf("AddViterbiStateEntry: viterbi list is full!\n"); |
601 | 0 | } |
602 | 0 | return false; |
603 | 0 | } |
604 | | |
605 | | // Invoke Dawg language model component. |
606 | 18.9M | LanguageModelDawgInfo *dawg_info = GenerateDawgInfo(word_end, curr_col, curr_row, *b, parent_vse); |
607 | | |
608 | 18.9M | float outline_length = AssociateUtils::ComputeOutlineLength(rating_cert_scale_, *b); |
609 | | // Invoke Ngram language model component. |
610 | 18.9M | LanguageModelNgramInfo *ngram_info = nullptr; |
611 | 18.9M | if (language_model_ngram_on) { |
612 | 0 | ngram_info = |
613 | 0 | GenerateNgramInfo(dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->certainty(), |
614 | 0 | denom, curr_col, curr_row, outline_length, parent_vse); |
615 | 0 | ASSERT_HOST(ngram_info != nullptr); |
616 | 0 | } |
617 | 18.9M | bool liked_by_language_model = |
618 | 18.9M | dawg_info != nullptr || (ngram_info != nullptr && !ngram_info->pruned); |
619 | | // Quick escape if not liked by the language model, can't be consistent |
620 | | // xheight, and not top choice. |
621 | 18.9M | if (!liked_by_language_model && top_choice_flags == 0) { |
622 | 8.75M | if (language_model_debug_level > 1) { |
623 | 0 | tprintf("Language model components very early pruned this entry\n"); |
624 | 0 | } |
625 | 8.75M | delete ngram_info; |
626 | 8.75M | delete dawg_info; |
627 | 8.75M | return false; |
628 | 8.75M | } |
629 | | |
630 | | // Check consistency of the path and set the relevant consistency_info. |
631 | 10.2M | LMConsistencyInfo consistency_info(parent_vse != nullptr ? &parent_vse->consistency_info |
632 | 10.2M | : nullptr); |
633 | | // Start with just the x-height consistency, as it provides significant |
634 | | // pruning opportunity. |
635 | 10.2M | consistency_info.ComputeXheightConsistency( |
636 | 10.2M | b, dict_->getUnicharset().get_ispunctuation(b->unichar_id())); |
637 | | // Turn off xheight consistent flag if not consistent. |
638 | 10.2M | if (consistency_info.InconsistentXHeight()) { |
639 | 6.01M | top_choice_flags &= ~kXhtConsistentFlag; |
640 | 6.01M | } |
641 | | |
642 | | // Quick escape if not liked by the language model, not consistent xheight, |
643 | | // and not top choice. |
644 | 10.2M | if (!liked_by_language_model && top_choice_flags == 0) { |
645 | 1.14M | if (language_model_debug_level > 1) { |
646 | 0 | tprintf("Language model components early pruned this entry\n"); |
647 | 0 | } |
648 | 1.14M | delete ngram_info; |
649 | 1.14M | delete dawg_info; |
650 | 1.14M | return false; |
651 | 1.14M | } |
652 | | |
653 | | // Compute the rest of the consistency info. |
654 | 9.07M | FillConsistencyInfo(curr_col, word_end, b, parent_vse, word_res, &consistency_info); |
655 | 9.07M | if (dawg_info != nullptr && consistency_info.invalid_punc) { |
656 | 0 | consistency_info.invalid_punc = false; // do not penalize dict words |
657 | 0 | } |
658 | | |
659 | | // Compute cost of associating the blobs that represent the current unichar. |
660 | 9.07M | AssociateStats associate_stats; |
661 | 9.07M | ComputeAssociateStats(curr_col, curr_row, max_char_wh_ratio_, parent_vse, word_res, |
662 | 9.07M | &associate_stats); |
663 | 9.07M | if (parent_vse != nullptr) { |
664 | 8.12M | associate_stats.shape_cost += parent_vse->associate_stats.shape_cost; |
665 | 8.12M | associate_stats.bad_shape |= parent_vse->associate_stats.bad_shape; |
666 | 8.12M | } |
667 | | |
668 | | // Create the new ViterbiStateEntry compute the adjusted cost of the path. |
669 | 9.07M | auto *new_vse = new ViterbiStateEntry(parent_vse, b, 0.0, outline_length, consistency_info, |
670 | 9.07M | associate_stats, top_choice_flags, dawg_info, ngram_info, |
671 | 9.07M | (language_model_debug_level > 0) |
672 | 9.07M | ? dict_->getUnicharset().id_to_unichar(b->unichar_id()) |
673 | 9.07M | : nullptr); |
674 | 9.07M | new_vse->cost = ComputeAdjustedPathCost(new_vse); |
675 | 9.07M | if (language_model_debug_level >= 3) { |
676 | 0 | tprintf("Adjusted cost = %g\n", new_vse->cost); |
677 | 0 | } |
678 | | |
679 | | // Invoke Top Choice language model component to make the final adjustments |
680 | | // to new_vse->top_choice_flags. |
681 | 9.07M | if (!curr_state->viterbi_state_entries.empty() && new_vse->top_choice_flags) { |
682 | 8.58M | GenerateTopChoiceInfo(new_vse, parent_vse, curr_state); |
683 | 8.58M | } |
684 | | |
685 | | // If language model components did not like this unichar - return. |
686 | 9.07M | bool keep = new_vse->top_choice_flags || liked_by_language_model; |
687 | 9.07M | if (!(top_choice_flags & kSmallestRatingFlag) && // no non-top choice paths |
688 | 9.07M | consistency_info.inconsistent_script) { // with inconsistent script |
689 | 20.8k | keep = false; |
690 | 20.8k | } |
691 | 9.07M | if (!keep) { |
692 | 6.72M | if (language_model_debug_level > 1) { |
693 | 0 | tprintf("Language model components did not like this entry\n"); |
694 | 0 | } |
695 | 6.72M | delete new_vse; |
696 | 6.72M | return false; |
697 | 6.72M | } |
698 | | |
699 | | // Discard this entry if it represents a prunable path and |
700 | | // language_model_viterbi_list_max_num_prunable such entries with a lower |
701 | | // cost have already been recorded. |
702 | 2.35M | if (PrunablePath(*new_vse) && |
703 | 2.35M | (curr_state->viterbi_state_entries_prunable_length >= |
704 | 0 | language_model_viterbi_list_max_num_prunable) && |
705 | 2.35M | new_vse->cost >= curr_state->viterbi_state_entries_prunable_max_cost) { |
706 | 0 | if (language_model_debug_level > 1) { |
707 | 0 | tprintf("Discarded ViterbiEntry with high cost %g max cost %g\n", new_vse->cost, |
708 | 0 | curr_state->viterbi_state_entries_prunable_max_cost); |
709 | 0 | } |
710 | 0 | delete new_vse; |
711 | 0 | return false; |
712 | 0 | } |
713 | | |
714 | | // Update best choice if needed. |
715 | 2.35M | if (word_end) { |
716 | 399k | UpdateBestChoice(new_vse, pain_points, word_res, best_choice_bundle, blamer_bundle); |
717 | | // Discard the entry if UpdateBestChoice() found flaws in it. |
718 | 399k | if (new_vse->cost >= WERD_CHOICE::kBadRating && new_vse != best_choice_bundle->best_vse) { |
719 | 0 | if (language_model_debug_level > 1) { |
720 | 0 | tprintf("Discarded ViterbiEntry with high cost %g\n", new_vse->cost); |
721 | 0 | } |
722 | 0 | delete new_vse; |
723 | 0 | return false; |
724 | 0 | } |
725 | 399k | } |
726 | | |
727 | | // Add the new ViterbiStateEntry and to curr_state->viterbi_state_entries. |
728 | 2.35M | curr_state->viterbi_state_entries.add_sorted(ViterbiStateEntry::Compare, false, new_vse); |
729 | 2.35M | curr_state->viterbi_state_entries_length++; |
730 | 2.35M | if (PrunablePath(*new_vse)) { |
731 | 0 | curr_state->viterbi_state_entries_prunable_length++; |
732 | 0 | } |
733 | | |
734 | | // Update lms->viterbi_state_entries_prunable_max_cost and clear |
735 | | // top_choice_flags of entries with ratings_sum than new_vse->ratings_sum. |
736 | 2.35M | if ((curr_state->viterbi_state_entries_prunable_length >= |
737 | 2.35M | language_model_viterbi_list_max_num_prunable) || |
738 | 2.35M | new_vse->top_choice_flags) { |
739 | 2.35M | ASSERT_HOST(!curr_state->viterbi_state_entries.empty()); |
740 | 2.35M | int prunable_counter = language_model_viterbi_list_max_num_prunable; |
741 | 2.35M | vit.set_to_list(&(curr_state->viterbi_state_entries)); |
742 | 16.5M | for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { |
743 | 14.1M | ViterbiStateEntry *curr_vse = vit.data(); |
744 | | // Clear the appropriate top choice flags of the entries in the |
745 | | // list that have cost higher thank new_entry->cost |
746 | | // (since they will not be top choices any more). |
747 | 14.1M | if (curr_vse->top_choice_flags && curr_vse != new_vse && curr_vse->cost > new_vse->cost) { |
748 | 3.27M | curr_vse->top_choice_flags &= ~(new_vse->top_choice_flags); |
749 | 3.27M | } |
750 | 14.1M | if (prunable_counter > 0 && PrunablePath(*curr_vse)) { |
751 | 6.69M | --prunable_counter; |
752 | 6.69M | } |
753 | | // Update curr_state->viterbi_state_entries_prunable_max_cost. |
754 | 14.1M | if (prunable_counter == 0) { |
755 | 272k | curr_state->viterbi_state_entries_prunable_max_cost = vit.data()->cost; |
756 | 272k | if (language_model_debug_level > 1) { |
757 | 0 | tprintf("Set viterbi_state_entries_prunable_max_cost to %g\n", |
758 | 0 | curr_state->viterbi_state_entries_prunable_max_cost); |
759 | 0 | } |
760 | 272k | prunable_counter = -1; // stop counting |
761 | 272k | } |
762 | 14.1M | } |
763 | 2.35M | } |
764 | | |
765 | | // Print the newly created ViterbiStateEntry. |
766 | 2.35M | if (language_model_debug_level > 2) { |
767 | 0 | new_vse->Print("New"); |
768 | 0 | if (language_model_debug_level > 5) { |
769 | 0 | curr_state->Print("Updated viterbi list"); |
770 | 0 | } |
771 | 0 | } |
772 | | |
773 | 2.35M | return true; |
774 | 2.35M | } |
775 | | |
776 | | void LanguageModel::GenerateTopChoiceInfo(ViterbiStateEntry *new_vse, |
777 | | const ViterbiStateEntry *parent_vse, |
778 | 8.58M | LanguageModelState *lms) { |
779 | 8.58M | ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries)); |
780 | 8.58M | for (vit.mark_cycle_pt(); |
781 | 24.3M | !vit.cycled_list() && new_vse->top_choice_flags && new_vse->cost >= vit.data()->cost; |
782 | 15.7M | vit.forward()) { |
783 | | // Clear the appropriate flags if the list already contains |
784 | | // a top choice entry with a lower cost. |
785 | 15.7M | new_vse->top_choice_flags &= ~(vit.data()->top_choice_flags); |
786 | 15.7M | } |
787 | 8.58M | if (language_model_debug_level > 2) { |
788 | 0 | tprintf("GenerateTopChoiceInfo: top_choice_flags=0x%x\n", new_vse->top_choice_flags); |
789 | 0 | } |
790 | 8.58M | } |
791 | | |
792 | | LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo(bool word_end, int curr_col, int curr_row, |
793 | | const BLOB_CHOICE &b, |
794 | 18.9M | const ViterbiStateEntry *parent_vse) { |
795 | | // Initialize active_dawgs from parent_vse if it is not nullptr. |
796 | | // Otherwise use very_beginning_active_dawgs_. |
797 | 18.9M | if (parent_vse == nullptr) { |
798 | 953k | dawg_args_.active_dawgs = &very_beginning_active_dawgs_; |
799 | 953k | dawg_args_.permuter = NO_PERM; |
800 | 18.0M | } else { |
801 | 18.0M | if (parent_vse->dawg_info == nullptr) { |
802 | 18.0M | return nullptr; // not a dict word path |
803 | 18.0M | } |
804 | 0 | dawg_args_.active_dawgs = &parent_vse->dawg_info->active_dawgs; |
805 | 0 | dawg_args_.permuter = parent_vse->dawg_info->permuter; |
806 | 0 | } |
807 | | |
808 | | // Deal with hyphenated words. |
809 | 953k | if (word_end && dict_->has_hyphen_end(&dict_->getUnicharset(), b.unichar_id(), curr_col == 0)) { |
810 | 0 | if (language_model_debug_level > 0) { |
811 | 0 | tprintf("Hyphenated word found\n"); |
812 | 0 | } |
813 | 0 | return new LanguageModelDawgInfo(dawg_args_.active_dawgs, COMPOUND_PERM); |
814 | 0 | } |
815 | | |
816 | | // Deal with compound words. |
817 | 953k | if (dict_->compound_marker(b.unichar_id()) && |
818 | 953k | (parent_vse == nullptr || parent_vse->dawg_info->permuter != NUMBER_PERM)) { |
819 | 0 | if (language_model_debug_level > 0) { |
820 | 0 | tprintf("Found compound marker\n"); |
821 | 0 | } |
822 | | // Do not allow compound operators at the beginning and end of the word. |
823 | | // Do not allow more than one compound operator per word. |
824 | | // Do not allow compounding of words with lengths shorter than |
825 | | // language_model_min_compound_length |
826 | 0 | if (parent_vse == nullptr || word_end || dawg_args_.permuter == COMPOUND_PERM || |
827 | 0 | parent_vse->length < language_model_min_compound_length) { |
828 | 0 | return nullptr; |
829 | 0 | } |
830 | | |
831 | | // Check that the path terminated before the current character is a word. |
832 | 0 | bool has_word_ending = false; |
833 | 0 | for (unsigned i = 0; i < parent_vse->dawg_info->active_dawgs.size(); ++i) { |
834 | 0 | const DawgPosition &pos = parent_vse->dawg_info->active_dawgs[i]; |
835 | 0 | const Dawg *pdawg = pos.dawg_index < 0 ? nullptr : dict_->GetDawg(pos.dawg_index); |
836 | 0 | if (pdawg == nullptr || pos.back_to_punc) { |
837 | 0 | continue; |
838 | 0 | }; |
839 | 0 | if (pdawg->type() == DAWG_TYPE_WORD && pos.dawg_ref != NO_EDGE && |
840 | 0 | pdawg->end_of_word(pos.dawg_ref)) { |
841 | 0 | has_word_ending = true; |
842 | 0 | break; |
843 | 0 | } |
844 | 0 | } |
845 | 0 | if (!has_word_ending) { |
846 | 0 | return nullptr; |
847 | 0 | } |
848 | | |
849 | 0 | if (language_model_debug_level > 0) { |
850 | 0 | tprintf("Compound word found\n"); |
851 | 0 | } |
852 | 0 | return new LanguageModelDawgInfo(&beginning_active_dawgs_, COMPOUND_PERM); |
853 | 0 | } // done dealing with compound words |
854 | | |
855 | 953k | LanguageModelDawgInfo *dawg_info = nullptr; |
856 | | |
857 | | // Call LetterIsOkay(). |
858 | | // Use the normalized IDs so that all shapes of ' can be allowed in words |
859 | | // like don't. |
860 | 953k | const auto &normed_ids = dict_->getUnicharset().normed_ids(b.unichar_id()); |
861 | 953k | DawgPositionVector tmp_active_dawgs; |
862 | 953k | for (unsigned i = 0; i < normed_ids.size(); ++i) { |
863 | 953k | if (language_model_debug_level > 2) { |
864 | 0 | tprintf("Test Letter OK for unichar %d, normed %d\n", b.unichar_id(), normed_ids[i]); |
865 | 0 | } |
866 | 953k | dict_->LetterIsOkay(&dawg_args_, dict_->getUnicharset(), normed_ids[i], |
867 | 953k | word_end && i == normed_ids.size() - 1); |
868 | 953k | if (dawg_args_.permuter == NO_PERM) { |
869 | 953k | break; |
870 | 953k | } else if (i < normed_ids.size() - 1) { |
871 | 0 | tmp_active_dawgs = *dawg_args_.updated_dawgs; |
872 | 0 | dawg_args_.active_dawgs = &tmp_active_dawgs; |
873 | 0 | } |
874 | 0 | if (language_model_debug_level > 2) { |
875 | 0 | tprintf("Letter was OK for unichar %d, normed %d\n", b.unichar_id(), normed_ids[i]); |
876 | 0 | } |
877 | 0 | } |
878 | 953k | dawg_args_.active_dawgs = nullptr; |
879 | 953k | if (dawg_args_.permuter != NO_PERM) { |
880 | 0 | dawg_info = new LanguageModelDawgInfo(dawg_args_.updated_dawgs, dawg_args_.permuter); |
881 | 953k | } else if (language_model_debug_level > 3) { |
882 | 0 | tprintf("Letter %s not OK!\n", dict_->getUnicharset().id_to_unichar(b.unichar_id())); |
883 | 0 | } |
884 | | |
885 | 953k | return dawg_info; |
886 | 953k | } |
887 | | |
888 | | LanguageModelNgramInfo *LanguageModel::GenerateNgramInfo(const char *unichar, float certainty, |
889 | | float denom, int curr_col, int curr_row, |
890 | | float outline_length, |
891 | 0 | const ViterbiStateEntry *parent_vse) { |
892 | | // Initialize parent context. |
893 | 0 | const char *pcontext_ptr = ""; |
894 | 0 | int pcontext_unichar_step_len = 0; |
895 | 0 | if (parent_vse == nullptr) { |
896 | 0 | pcontext_ptr = prev_word_str_.c_str(); |
897 | 0 | pcontext_unichar_step_len = prev_word_unichar_step_len_; |
898 | 0 | } else { |
899 | 0 | pcontext_ptr = parent_vse->ngram_info->context.c_str(); |
900 | 0 | pcontext_unichar_step_len = parent_vse->ngram_info->context_unichar_step_len; |
901 | 0 | } |
902 | | // Compute p(unichar | parent context). |
903 | 0 | int unichar_step_len = 0; |
904 | 0 | bool pruned = false; |
905 | 0 | float ngram_cost; |
906 | 0 | float ngram_and_classifier_cost = ComputeNgramCost(unichar, certainty, denom, pcontext_ptr, |
907 | 0 | &unichar_step_len, &pruned, &ngram_cost); |
908 | | // Normalize just the ngram_and_classifier_cost by outline_length. |
909 | | // The ngram_cost is used by the params_model, so it needs to be left as-is, |
910 | | // and the params model cost will be normalized by outline_length. |
911 | 0 | ngram_and_classifier_cost *= outline_length / language_model_ngram_rating_factor; |
912 | | // Add the ngram_cost of the parent. |
913 | 0 | if (parent_vse != nullptr) { |
914 | 0 | ngram_and_classifier_cost += parent_vse->ngram_info->ngram_and_classifier_cost; |
915 | 0 | ngram_cost += parent_vse->ngram_info->ngram_cost; |
916 | 0 | } |
917 | | |
918 | | // Shorten parent context string by unichar_step_len unichars. |
919 | 0 | int num_remove = (unichar_step_len + pcontext_unichar_step_len - language_model_ngram_order); |
920 | 0 | if (num_remove > 0) { |
921 | 0 | pcontext_unichar_step_len -= num_remove; |
922 | 0 | } |
923 | 0 | while (num_remove > 0 && *pcontext_ptr != '\0') { |
924 | 0 | pcontext_ptr += UNICHAR::utf8_step(pcontext_ptr); |
925 | 0 | --num_remove; |
926 | 0 | } |
927 | | |
928 | | // Decide whether to prune this ngram path and update changed accordingly. |
929 | 0 | if (parent_vse != nullptr && parent_vse->ngram_info->pruned) { |
930 | 0 | pruned = true; |
931 | 0 | } |
932 | | |
933 | | // Construct and return the new LanguageModelNgramInfo. |
934 | 0 | auto *ngram_info = new LanguageModelNgramInfo(pcontext_ptr, pcontext_unichar_step_len, pruned, |
935 | 0 | ngram_cost, ngram_and_classifier_cost); |
936 | 0 | ngram_info->context += unichar; |
937 | 0 | ngram_info->context_unichar_step_len += unichar_step_len; |
938 | 0 | assert(ngram_info->context_unichar_step_len <= language_model_ngram_order); |
939 | 0 | return ngram_info; |
940 | 0 | } |
941 | | |
942 | | float LanguageModel::ComputeNgramCost(const char *unichar, float certainty, float denom, |
943 | | const char *context, int *unichar_step_len, |
944 | 0 | bool *found_small_prob, float *ngram_cost) { |
945 | 0 | const char *context_ptr = context; |
946 | 0 | char *modified_context = nullptr; |
947 | 0 | char *modified_context_end = nullptr; |
948 | 0 | const char *unichar_ptr = unichar; |
949 | 0 | const char *unichar_end = unichar_ptr + strlen(unichar_ptr); |
950 | 0 | float prob = 0.0f; |
951 | 0 | int step = 0; |
952 | 0 | while (unichar_ptr < unichar_end && (step = UNICHAR::utf8_step(unichar_ptr)) > 0) { |
953 | 0 | if (language_model_debug_level > 1) { |
954 | 0 | tprintf("prob(%s | %s)=%g\n", unichar_ptr, context_ptr, |
955 | 0 | dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step)); |
956 | 0 | } |
957 | 0 | prob += dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step); |
958 | 0 | ++(*unichar_step_len); |
959 | 0 | if (language_model_ngram_use_only_first_uft8_step) { |
960 | 0 | break; |
961 | 0 | } |
962 | 0 | unichar_ptr += step; |
963 | | // If there are multiple UTF8 characters present in unichar, context is |
964 | | // updated to include the previously examined characters from str, |
965 | | // unless use_only_first_uft8_step is true. |
966 | 0 | if (unichar_ptr < unichar_end) { |
967 | 0 | if (modified_context == nullptr) { |
968 | 0 | size_t context_len = strlen(context); |
969 | 0 | modified_context = new char[context_len + strlen(unichar_ptr) + step + 1]; |
970 | 0 | memcpy(modified_context, context, context_len); |
971 | 0 | modified_context_end = modified_context + context_len; |
972 | 0 | context_ptr = modified_context; |
973 | 0 | } |
974 | 0 | strncpy(modified_context_end, unichar_ptr - step, step); |
975 | 0 | modified_context_end += step; |
976 | 0 | *modified_context_end = '\0'; |
977 | 0 | } |
978 | 0 | } |
979 | 0 | prob /= static_cast<float>(*unichar_step_len); // normalize |
980 | 0 | if (prob < language_model_ngram_small_prob) { |
981 | 0 | if (language_model_debug_level > 0) { |
982 | 0 | tprintf("Found small prob %g\n", prob); |
983 | 0 | } |
984 | 0 | *found_small_prob = true; |
985 | 0 | prob = language_model_ngram_small_prob; |
986 | 0 | } |
987 | 0 | *ngram_cost = -1 * std::log2(prob); |
988 | 0 | float ngram_and_classifier_cost = -1 * std::log2(CertaintyScore(certainty) / denom) + |
989 | 0 | *ngram_cost * language_model_ngram_scale_factor; |
990 | 0 | if (language_model_debug_level > 1) { |
991 | 0 | tprintf("-log [ p(%s) * p(%s | %s) ] = -log2(%g*%g) = %g\n", unichar, unichar, context_ptr, |
992 | 0 | CertaintyScore(certainty) / denom, prob, ngram_and_classifier_cost); |
993 | 0 | } |
994 | 0 | delete[] modified_context; |
995 | 0 | return ngram_and_classifier_cost; |
996 | 0 | } |
997 | | |
998 | 0 | float LanguageModel::ComputeDenom(BLOB_CHOICE_LIST *curr_list) { |
999 | 0 | if (curr_list->empty()) { |
1000 | 0 | return 1.0f; |
1001 | 0 | } |
1002 | 0 | float denom = 0.0f; |
1003 | 0 | int len = 0; |
1004 | 0 | BLOB_CHOICE_IT c_it(curr_list); |
1005 | 0 | for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { |
1006 | 0 | ASSERT_HOST(c_it.data() != nullptr); |
1007 | 0 | ++len; |
1008 | 0 | denom += CertaintyScore(c_it.data()->certainty()); |
1009 | 0 | } |
1010 | 0 | assert(len != 0); |
1011 | | // The ideal situation would be to have the classifier scores for |
1012 | | // classifying each position as each of the characters in the unicharset. |
1013 | | // Since we cannot do this because of speed, we add a very crude estimate |
1014 | | // of what these scores for the "missing" classifications would sum up to. |
1015 | 0 | denom += |
1016 | 0 | (dict_->getUnicharset().size() - len) * CertaintyScore(language_model_ngram_nonmatch_score); |
1017 | |
|
1018 | 0 | return denom; |
1019 | 0 | } |
1020 | | |
1021 | | void LanguageModel::FillConsistencyInfo(int curr_col, bool word_end, BLOB_CHOICE *b, |
1022 | | ViterbiStateEntry *parent_vse, WERD_RES *word_res, |
1023 | 9.07M | LMConsistencyInfo *consistency_info) { |
1024 | 9.07M | const UNICHARSET &unicharset = dict_->getUnicharset(); |
1025 | 9.07M | UNICHAR_ID unichar_id = b->unichar_id(); |
1026 | 9.07M | BLOB_CHOICE *parent_b = parent_vse != nullptr ? parent_vse->curr_b : nullptr; |
1027 | | |
1028 | | // Check punctuation validity. |
1029 | 9.07M | if (unicharset.get_ispunctuation(unichar_id)) { |
1030 | 2.17M | consistency_info->num_punc++; |
1031 | 2.17M | } |
1032 | 9.07M | if (dict_->GetPuncDawg() != nullptr && !consistency_info->invalid_punc) { |
1033 | 0 | if (dict_->compound_marker(unichar_id) && parent_b != nullptr && |
1034 | 0 | (unicharset.get_isalpha(parent_b->unichar_id()) || |
1035 | 0 | unicharset.get_isdigit(parent_b->unichar_id()))) { |
1036 | | // reset punc_ref for compound words |
1037 | 0 | consistency_info->punc_ref = NO_EDGE; |
1038 | 0 | } else { |
1039 | 0 | bool is_apos = dict_->is_apostrophe(unichar_id); |
1040 | 0 | bool prev_is_numalpha = |
1041 | 0 | (parent_b != nullptr && (unicharset.get_isalpha(parent_b->unichar_id()) || |
1042 | 0 | unicharset.get_isdigit(parent_b->unichar_id()))); |
1043 | 0 | UNICHAR_ID pattern_unichar_id = |
1044 | 0 | (unicharset.get_isalpha(unichar_id) || unicharset.get_isdigit(unichar_id) || |
1045 | 0 | (is_apos && prev_is_numalpha)) |
1046 | 0 | ? Dawg::kPatternUnicharID |
1047 | 0 | : unichar_id; |
1048 | 0 | if (consistency_info->punc_ref == NO_EDGE || pattern_unichar_id != Dawg::kPatternUnicharID || |
1049 | 0 | dict_->GetPuncDawg()->edge_letter(consistency_info->punc_ref) != |
1050 | 0 | Dawg::kPatternUnicharID) { |
1051 | 0 | NODE_REF node = Dict::GetStartingNode(dict_->GetPuncDawg(), consistency_info->punc_ref); |
1052 | 0 | consistency_info->punc_ref = (node != NO_EDGE) ? dict_->GetPuncDawg()->edge_char_of( |
1053 | 0 | node, pattern_unichar_id, word_end) |
1054 | 0 | : NO_EDGE; |
1055 | 0 | if (consistency_info->punc_ref == NO_EDGE) { |
1056 | 0 | consistency_info->invalid_punc = true; |
1057 | 0 | } |
1058 | 0 | } |
1059 | 0 | } |
1060 | 0 | } |
1061 | | |
1062 | | // Update case related counters. |
1063 | 9.07M | if (parent_vse != nullptr && !word_end && dict_->compound_marker(unichar_id)) { |
1064 | | // Reset counters if we are dealing with a compound word. |
1065 | 0 | consistency_info->num_lower = 0; |
1066 | 0 | consistency_info->num_non_first_upper = 0; |
1067 | 9.07M | } else if (unicharset.get_islower(unichar_id)) { |
1068 | 4.22M | consistency_info->num_lower++; |
1069 | 4.85M | } else if ((parent_b != nullptr) && unicharset.get_isupper(unichar_id)) { |
1070 | 1.35M | if (unicharset.get_isupper(parent_b->unichar_id()) || consistency_info->num_lower > 0 || |
1071 | 1.35M | consistency_info->num_non_first_upper > 0) { |
1072 | 1.13M | consistency_info->num_non_first_upper++; |
1073 | 1.13M | } |
1074 | 1.35M | } |
1075 | | |
1076 | | // Initialize consistency_info->script_id (use script of unichar_id |
1077 | | // if it is not Common, use script id recorded by the parent otherwise). |
1078 | | // Set inconsistent_script to true if the script of the current unichar |
1079 | | // is not consistent with that of the parent. |
1080 | 9.07M | consistency_info->script_id = unicharset.get_script(unichar_id); |
1081 | | // Hiragana and Katakana can mix with Han. |
1082 | 9.07M | if (dict_->getUnicharset().han_sid() != dict_->getUnicharset().null_sid()) { |
1083 | 0 | if ((unicharset.hiragana_sid() != unicharset.null_sid() && |
1084 | 0 | consistency_info->script_id == unicharset.hiragana_sid()) || |
1085 | 0 | (unicharset.katakana_sid() != unicharset.null_sid() && |
1086 | 0 | consistency_info->script_id == unicharset.katakana_sid())) { |
1087 | 0 | consistency_info->script_id = dict_->getUnicharset().han_sid(); |
1088 | 0 | } |
1089 | 0 | } |
1090 | | |
1091 | 9.07M | if (parent_vse != nullptr && |
1092 | 9.07M | (parent_vse->consistency_info.script_id != dict_->getUnicharset().common_sid())) { |
1093 | 7.00M | int parent_script_id = parent_vse->consistency_info.script_id; |
1094 | | // If script_id is Common, use script id of the parent instead. |
1095 | 7.00M | if (consistency_info->script_id == dict_->getUnicharset().common_sid()) { |
1096 | 2.41M | consistency_info->script_id = parent_script_id; |
1097 | 2.41M | } |
1098 | 7.00M | if (consistency_info->script_id != parent_script_id) { |
1099 | 33.1k | consistency_info->inconsistent_script = true; |
1100 | 33.1k | } |
1101 | 7.00M | } |
1102 | | |
1103 | | // Update chartype related counters. |
1104 | 9.07M | if (unicharset.get_isalpha(unichar_id)) { |
1105 | 5.75M | consistency_info->num_alphas++; |
1106 | 5.75M | } else if (unicharset.get_isdigit(unichar_id)) { |
1107 | 283k | consistency_info->num_digits++; |
1108 | 3.03M | } else if (!unicharset.get_ispunctuation(unichar_id)) { |
1109 | 854k | consistency_info->num_other++; |
1110 | 854k | } |
1111 | | |
1112 | | // Check font and spacing consistency. |
1113 | 9.07M | if (fontinfo_table_->size() > 0 && parent_b != nullptr) { |
1114 | 8.12M | int fontinfo_id = -1; |
1115 | 8.12M | if (parent_b->fontinfo_id() == b->fontinfo_id() || |
1116 | 8.12M | parent_b->fontinfo_id2() == b->fontinfo_id()) { |
1117 | 1.90M | fontinfo_id = b->fontinfo_id(); |
1118 | 6.21M | } else if (parent_b->fontinfo_id() == b->fontinfo_id2() || |
1119 | 6.21M | parent_b->fontinfo_id2() == b->fontinfo_id2()) { |
1120 | 498k | fontinfo_id = b->fontinfo_id2(); |
1121 | 498k | } |
1122 | 8.12M | if (language_model_debug_level > 1) { |
1123 | 0 | tprintf( |
1124 | 0 | "pfont %s pfont %s font %s font2 %s common %s(%d)\n", |
1125 | 0 | (parent_b->fontinfo_id() >= 0) ? fontinfo_table_->at(parent_b->fontinfo_id()).name : "", |
1126 | 0 | (parent_b->fontinfo_id2() >= 0) ? fontinfo_table_->at(parent_b->fontinfo_id2()).name |
1127 | 0 | : "", |
1128 | 0 | (b->fontinfo_id() >= 0) ? fontinfo_table_->at(b->fontinfo_id()).name : "", |
1129 | 0 | (fontinfo_id >= 0) ? fontinfo_table_->at(fontinfo_id).name : "", |
1130 | 0 | (fontinfo_id >= 0) ? fontinfo_table_->at(fontinfo_id).name : "", fontinfo_id); |
1131 | 0 | } |
1132 | 8.12M | if (!word_res->blob_widths.empty()) { // if we have widths/gaps info |
1133 | 8.12M | bool expected_gap_found = false; |
1134 | 8.12M | float expected_gap = 0.0f; |
1135 | 8.12M | int temp_gap; |
1136 | 8.12M | if (fontinfo_id >= 0) { // found a common font |
1137 | 2.40M | ASSERT_HOST(fontinfo_id < fontinfo_table_->size()); |
1138 | 2.40M | if (fontinfo_table_->at(fontinfo_id) |
1139 | 2.40M | .get_spacing(parent_b->unichar_id(), unichar_id, &temp_gap)) { |
1140 | 1.66M | expected_gap = temp_gap; |
1141 | 1.66M | expected_gap_found = true; |
1142 | 1.66M | } |
1143 | 5.71M | } else { |
1144 | 5.71M | consistency_info->inconsistent_font = true; |
1145 | | // Get an average of the expected gaps in each font |
1146 | 5.71M | int num_addends = 0; |
1147 | 5.71M | int temp_fid; |
1148 | 28.5M | for (int i = 0; i < 4; ++i) { |
1149 | 22.8M | if (i == 0) { |
1150 | 5.71M | temp_fid = parent_b->fontinfo_id(); |
1151 | 17.1M | } else if (i == 1) { |
1152 | 5.71M | temp_fid = parent_b->fontinfo_id2(); |
1153 | 11.4M | } else if (i == 2) { |
1154 | 5.71M | temp_fid = b->fontinfo_id(); |
1155 | 5.71M | } else { |
1156 | 5.71M | temp_fid = b->fontinfo_id2(); |
1157 | 5.71M | } |
1158 | 22.8M | ASSERT_HOST(temp_fid < 0 || fontinfo_table_->size()); |
1159 | 22.8M | if (temp_fid >= 0 && fontinfo_table_->at(temp_fid).get_spacing(parent_b->unichar_id(), |
1160 | 22.6M | unichar_id, &temp_gap)) { |
1161 | 13.0M | expected_gap += temp_gap; |
1162 | 13.0M | num_addends++; |
1163 | 13.0M | } |
1164 | 22.8M | } |
1165 | 5.71M | if (num_addends > 0) { |
1166 | 3.30M | expected_gap /= static_cast<float>(num_addends); |
1167 | 3.30M | expected_gap_found = true; |
1168 | 3.30M | } |
1169 | 5.71M | } |
1170 | 8.12M | if (expected_gap_found) { |
1171 | 4.96M | int actual_gap = word_res->GetBlobsGap(curr_col - 1); |
1172 | 4.96M | if (actual_gap == 0) { |
1173 | 913k | consistency_info->num_inconsistent_spaces++; |
1174 | 4.05M | } else { |
1175 | 4.05M | float gap_ratio = expected_gap / actual_gap; |
1176 | | // TODO(rays) The gaps seem to be way off most of the time, saved by |
1177 | | // the error here that the ratio was compared to 1/2, when it should |
1178 | | // have been 0.5f. Find the source of the gaps discrepancy and put |
1179 | | // the 0.5f here in place of 0.0f. |
1180 | | // Test on 2476595.sj, pages 0 to 6. (In French.) |
1181 | 4.05M | if (gap_ratio < 0.0f || gap_ratio > 2.0f) { |
1182 | 1.83M | consistency_info->num_inconsistent_spaces++; |
1183 | 1.83M | } |
1184 | 4.05M | } |
1185 | 4.96M | if (language_model_debug_level > 1) { |
1186 | 0 | tprintf("spacing for %s(%d) %s(%d) col %d: expected %g actual %d\n", |
1187 | 0 | unicharset.id_to_unichar(parent_b->unichar_id()), parent_b->unichar_id(), |
1188 | 0 | unicharset.id_to_unichar(unichar_id), unichar_id, curr_col, expected_gap, |
1189 | 0 | actual_gap); |
1190 | 0 | } |
1191 | 4.96M | } |
1192 | 8.12M | } |
1193 | 8.12M | } |
1194 | 9.07M | } |
1195 | | |
1196 | 9.07M | float LanguageModel::ComputeAdjustedPathCost(ViterbiStateEntry *vse) { |
1197 | 9.07M | ASSERT_HOST(vse != nullptr); |
1198 | 9.07M | if (params_model_.Initialized()) { |
1199 | 0 | float features[PTRAIN_NUM_FEATURE_TYPES]; |
1200 | 0 | ExtractFeaturesFromPath(*vse, features); |
1201 | 0 | float cost = params_model_.ComputeCost(features); |
1202 | 0 | if (language_model_debug_level > 3) { |
1203 | 0 | tprintf("ComputeAdjustedPathCost %g ParamsModel features:\n", cost); |
1204 | 0 | if (language_model_debug_level >= 5) { |
1205 | 0 | for (int f = 0; f < PTRAIN_NUM_FEATURE_TYPES; ++f) { |
1206 | 0 | tprintf("%s=%g\n", kParamsTrainingFeatureTypeName[f], features[f]); |
1207 | 0 | } |
1208 | 0 | } |
1209 | 0 | } |
1210 | 0 | return cost * vse->outline_length; |
1211 | 9.07M | } else { |
1212 | 9.07M | float adjustment = 1.0f; |
1213 | 9.07M | if (vse->dawg_info == nullptr || vse->dawg_info->permuter != FREQ_DAWG_PERM) { |
1214 | 9.07M | adjustment += language_model_penalty_non_freq_dict_word; |
1215 | 9.07M | } |
1216 | 9.07M | if (vse->dawg_info == nullptr) { |
1217 | 9.07M | adjustment += language_model_penalty_non_dict_word; |
1218 | 9.07M | if (vse->length > language_model_min_compound_length) { |
1219 | 5.99M | adjustment += |
1220 | 5.99M | ((vse->length - language_model_min_compound_length) * language_model_penalty_increment); |
1221 | 5.99M | } |
1222 | 9.07M | } |
1223 | 9.07M | if (vse->associate_stats.shape_cost > 0) { |
1224 | 0 | adjustment += vse->associate_stats.shape_cost / static_cast<float>(vse->length); |
1225 | 0 | } |
1226 | 9.07M | if (language_model_ngram_on) { |
1227 | 0 | ASSERT_HOST(vse->ngram_info != nullptr); |
1228 | 0 | return vse->ngram_info->ngram_and_classifier_cost * adjustment; |
1229 | 9.07M | } else { |
1230 | 9.07M | adjustment += ComputeConsistencyAdjustment(vse->dawg_info, vse->consistency_info); |
1231 | 9.07M | return vse->ratings_sum * adjustment; |
1232 | 9.07M | } |
1233 | 9.07M | } |
1234 | 9.07M | } |
1235 | | |
1236 | | void LanguageModel::UpdateBestChoice(ViterbiStateEntry *vse, LMPainPoints *pain_points, |
1237 | | WERD_RES *word_res, BestChoiceBundle *best_choice_bundle, |
1238 | 399k | BlamerBundle *blamer_bundle) { |
1239 | 399k | bool truth_path; |
1240 | 399k | WERD_CHOICE *word = |
1241 | 399k | ConstructWord(vse, word_res, &best_choice_bundle->fixpt, blamer_bundle, &truth_path); |
1242 | 399k | ASSERT_HOST(word != nullptr); |
1243 | 399k | if (dict_->stopper_debug_level >= 1) { |
1244 | 0 | std::string word_str; |
1245 | 0 | word->string_and_lengths(&word_str, nullptr); |
1246 | 0 | vse->Print(word_str.c_str()); |
1247 | 0 | } |
1248 | 399k | if (language_model_debug_level > 0) { |
1249 | 0 | word->print("UpdateBestChoice() constructed word"); |
1250 | 0 | } |
1251 | | // Record features from the current path if necessary. |
1252 | 399k | ParamsTrainingHypothesis curr_hyp; |
1253 | 399k | if (blamer_bundle != nullptr) { |
1254 | 0 | if (vse->dawg_info != nullptr) { |
1255 | 0 | vse->dawg_info->permuter = static_cast<PermuterType>(word->permuter()); |
1256 | 0 | } |
1257 | 0 | ExtractFeaturesFromPath(*vse, curr_hyp.features); |
1258 | 0 | word->string_and_lengths(&(curr_hyp.str), nullptr); |
1259 | 0 | curr_hyp.cost = vse->cost; // record cost for error rate computations |
1260 | 0 | if (language_model_debug_level > 0) { |
1261 | 0 | tprintf("Raw features extracted from %s (cost=%g) [ ", curr_hyp.str.c_str(), curr_hyp.cost); |
1262 | 0 | for (float feature : curr_hyp.features) { |
1263 | 0 | tprintf("%g ", feature); |
1264 | 0 | } |
1265 | 0 | tprintf("]\n"); |
1266 | 0 | } |
1267 | | // Record the current hypothesis in params_training_bundle. |
1268 | 0 | blamer_bundle->AddHypothesis(curr_hyp); |
1269 | 0 | if (truth_path) { |
1270 | 0 | blamer_bundle->UpdateBestRating(word->rating()); |
1271 | 0 | } |
1272 | 0 | } |
1273 | 399k | if (blamer_bundle != nullptr && blamer_bundle->GuidedSegsearchStillGoing()) { |
1274 | | // The word was constructed solely for blamer_bundle->AddHypothesis, so |
1275 | | // we no longer need it. |
1276 | 0 | delete word; |
1277 | 0 | return; |
1278 | 0 | } |
1279 | 399k | if (word_res->chopped_word != nullptr && !word_res->chopped_word->blobs.empty()) { |
1280 | 399k | word->SetScriptPositions(false, word_res->chopped_word, language_model_debug_level); |
1281 | 399k | } |
1282 | | // Update and log new raw_choice if needed. |
1283 | 399k | if (word_res->raw_choice == nullptr || word->rating() < word_res->raw_choice->rating()) { |
1284 | 200k | if (word_res->LogNewRawChoice(word) && language_model_debug_level > 0) { |
1285 | 0 | tprintf("Updated raw choice\n"); |
1286 | 0 | } |
1287 | 200k | } |
1288 | | // Set the modified rating for best choice to vse->cost and log best choice. |
1289 | 399k | word->set_rating(vse->cost); |
1290 | | // Call LogNewChoice() for best choice from Dict::adjust_word() since it |
1291 | | // computes adjust_factor that is used by the adaption code (e.g. by |
1292 | | // ClassifyAdaptableWord() to compute adaption acceptance thresholds). |
1293 | | // Note: the rating of the word is not adjusted. |
1294 | 399k | dict_->adjust_word(word, vse->dawg_info == nullptr, vse->consistency_info.xht_decision, 0.0, |
1295 | 399k | false, language_model_debug_level > 0); |
1296 | | // Hand ownership of the word over to the word_res. |
1297 | 399k | if (!word_res->LogNewCookedChoice(dict_->tessedit_truncate_wordchoice_log, |
1298 | 399k | dict_->stopper_debug_level >= 1, word)) { |
1299 | | // The word was so bad that it was deleted. |
1300 | 46.9k | return; |
1301 | 46.9k | } |
1302 | 352k | if (word_res->best_choice == word) { |
1303 | | // Word was the new best. |
1304 | 216k | if (dict_->AcceptableChoice(*word, vse->consistency_info.xht_decision) && |
1305 | 216k | AcceptablePath(*vse)) { |
1306 | 0 | acceptable_choice_found_ = true; |
1307 | 0 | } |
1308 | | // Update best_choice_bundle. |
1309 | 216k | best_choice_bundle->updated = true; |
1310 | 216k | best_choice_bundle->best_vse = vse; |
1311 | 216k | if (language_model_debug_level > 0) { |
1312 | 0 | tprintf("Updated best choice\n"); |
1313 | 0 | word->print_state("New state "); |
1314 | 0 | } |
1315 | | // Update hyphen state if we are dealing with a dictionary word. |
1316 | 216k | if (vse->dawg_info != nullptr) { |
1317 | 0 | if (dict_->has_hyphen_end(*word)) { |
1318 | 0 | dict_->set_hyphen_word(*word, *(dawg_args_.active_dawgs)); |
1319 | 0 | } else { |
1320 | 0 | dict_->reset_hyphen_vars(true); |
1321 | 0 | } |
1322 | 0 | } |
1323 | | |
1324 | 216k | if (blamer_bundle != nullptr) { |
1325 | 0 | blamer_bundle->set_best_choice_is_dict_and_top_choice(vse->dawg_info != nullptr && |
1326 | 0 | vse->top_choice_flags); |
1327 | 0 | } |
1328 | 216k | } |
1329 | | #ifndef GRAPHICS_DISABLED |
1330 | | if (wordrec_display_segmentations && word_res->chopped_word != nullptr) { |
1331 | | word->DisplaySegmentation(word_res->chopped_word); |
1332 | | } |
1333 | | #endif |
1334 | 352k | } |
1335 | | |
1336 | 0 | void LanguageModel::ExtractFeaturesFromPath(const ViterbiStateEntry &vse, float features[]) { |
1337 | 0 | memset(features, 0, sizeof(float) * PTRAIN_NUM_FEATURE_TYPES); |
1338 | | // Record dictionary match info. |
1339 | 0 | int len = vse.length <= kMaxSmallWordUnichars ? 0 : vse.length <= kMaxMediumWordUnichars ? 1 : 2; |
1340 | 0 | if (vse.dawg_info != nullptr) { |
1341 | 0 | int permuter = vse.dawg_info->permuter; |
1342 | 0 | if (permuter == NUMBER_PERM || permuter == USER_PATTERN_PERM) { |
1343 | 0 | if (vse.consistency_info.num_digits == vse.length) { |
1344 | 0 | features[PTRAIN_DIGITS_SHORT + len] = 1.0f; |
1345 | 0 | } else { |
1346 | 0 | features[PTRAIN_NUM_SHORT + len] = 1.0f; |
1347 | 0 | } |
1348 | 0 | } else if (permuter == DOC_DAWG_PERM) { |
1349 | 0 | features[PTRAIN_DOC_SHORT + len] = 1.0f; |
1350 | 0 | } else if (permuter == SYSTEM_DAWG_PERM || permuter == USER_DAWG_PERM || |
1351 | 0 | permuter == COMPOUND_PERM) { |
1352 | 0 | features[PTRAIN_DICT_SHORT + len] = 1.0f; |
1353 | 0 | } else if (permuter == FREQ_DAWG_PERM) { |
1354 | 0 | features[PTRAIN_FREQ_SHORT + len] = 1.0f; |
1355 | 0 | } |
1356 | 0 | } |
1357 | | // Record shape cost feature (normalized by path length). |
1358 | 0 | features[PTRAIN_SHAPE_COST_PER_CHAR] = |
1359 | 0 | vse.associate_stats.shape_cost / static_cast<float>(vse.length); |
1360 | | // Record ngram cost. (normalized by the path length). |
1361 | 0 | features[PTRAIN_NGRAM_COST_PER_CHAR] = 0.0f; |
1362 | 0 | if (vse.ngram_info != nullptr) { |
1363 | 0 | features[PTRAIN_NGRAM_COST_PER_CHAR] = |
1364 | 0 | vse.ngram_info->ngram_cost / static_cast<float>(vse.length); |
1365 | 0 | } |
1366 | | // Record consistency-related features. |
1367 | | // Disabled this feature for due to its poor performance. |
1368 | | // features[PTRAIN_NUM_BAD_PUNC] = vse.consistency_info.NumInconsistentPunc(); |
1369 | 0 | features[PTRAIN_NUM_BAD_CASE] = vse.consistency_info.NumInconsistentCase(); |
1370 | 0 | features[PTRAIN_XHEIGHT_CONSISTENCY] = vse.consistency_info.xht_decision; |
1371 | 0 | features[PTRAIN_NUM_BAD_CHAR_TYPE] = |
1372 | 0 | vse.dawg_info == nullptr ? vse.consistency_info.NumInconsistentChartype() : 0.0f; |
1373 | 0 | features[PTRAIN_NUM_BAD_SPACING] = vse.consistency_info.NumInconsistentSpaces(); |
1374 | | // Disabled this feature for now due to its poor performance. |
1375 | | // features[PTRAIN_NUM_BAD_FONT] = vse.consistency_info.inconsistent_font; |
1376 | | |
1377 | | // Classifier-related features. |
1378 | 0 | if (vse.outline_length > 0.0f) { |
1379 | 0 | features[PTRAIN_RATING_PER_CHAR] = vse.ratings_sum / vse.outline_length; |
1380 | 0 | } else { |
1381 | | // Avoid FP division by 0. |
1382 | 0 | features[PTRAIN_RATING_PER_CHAR] = 0.0f; |
1383 | 0 | } |
1384 | 0 | } |
1385 | | |
1386 | | WERD_CHOICE *LanguageModel::ConstructWord(ViterbiStateEntry *vse, WERD_RES *word_res, |
1387 | | DANGERR *fixpt, BlamerBundle *blamer_bundle, |
1388 | 399k | bool *truth_path) { |
1389 | 399k | if (truth_path != nullptr) { |
1390 | 399k | *truth_path = |
1391 | 399k | (blamer_bundle != nullptr && vse->length == blamer_bundle->correct_segmentation_length()); |
1392 | 399k | } |
1393 | 399k | BLOB_CHOICE *curr_b = vse->curr_b; |
1394 | 399k | ViterbiStateEntry *curr_vse = vse; |
1395 | | |
1396 | 399k | int i; |
1397 | 399k | bool compound = dict_->hyphenated(); // treat hyphenated words as compound |
1398 | | |
1399 | | // Re-compute the variance of the width-to-height ratios (since we now |
1400 | | // can compute the mean over the whole word). |
1401 | 399k | float full_wh_ratio_mean = 0.0f; |
1402 | 399k | if (vse->associate_stats.full_wh_ratio_var != 0.0f) { |
1403 | 0 | vse->associate_stats.shape_cost -= vse->associate_stats.full_wh_ratio_var; |
1404 | 0 | full_wh_ratio_mean = |
1405 | 0 | (vse->associate_stats.full_wh_ratio_total / static_cast<float>(vse->length)); |
1406 | 0 | vse->associate_stats.full_wh_ratio_var = 0.0f; |
1407 | 0 | } |
1408 | | |
1409 | | // Construct a WERD_CHOICE by tracing parent pointers. |
1410 | 399k | auto *word = new WERD_CHOICE(word_res->uch_set, vse->length); |
1411 | 399k | word->set_length(vse->length); |
1412 | 399k | int total_blobs = 0; |
1413 | 2.88M | for (i = (vse->length - 1); i >= 0; --i) { |
1414 | 2.88M | if (blamer_bundle != nullptr && truth_path != nullptr && *truth_path && |
1415 | 2.88M | !blamer_bundle->MatrixPositionCorrect(i, curr_b->matrix_cell())) { |
1416 | 0 | *truth_path = false; |
1417 | 0 | } |
1418 | | // The number of blobs used for this choice is row - col + 1. |
1419 | 2.88M | int num_blobs = curr_b->matrix_cell().row - curr_b->matrix_cell().col + 1; |
1420 | 2.88M | total_blobs += num_blobs; |
1421 | 2.88M | word->set_blob_choice(i, num_blobs, curr_b); |
1422 | | // Update the width-to-height ratio variance. Useful non-space delimited |
1423 | | // languages to ensure that the blobs are of uniform width. |
1424 | | // Skip leading and trailing punctuation when computing the variance. |
1425 | 2.88M | if ((full_wh_ratio_mean != 0.0f && |
1426 | 2.88M | ((curr_vse != vse && curr_vse->parent_vse != nullptr) || |
1427 | 0 | !dict_->getUnicharset().get_ispunctuation(curr_b->unichar_id())))) { |
1428 | 0 | vse->associate_stats.full_wh_ratio_var += |
1429 | 0 | pow(full_wh_ratio_mean - curr_vse->associate_stats.full_wh_ratio, 2); |
1430 | 0 | if (language_model_debug_level > 2) { |
1431 | 0 | tprintf("full_wh_ratio_var += (%g-%g)^2\n", full_wh_ratio_mean, |
1432 | 0 | curr_vse->associate_stats.full_wh_ratio); |
1433 | 0 | } |
1434 | 0 | } |
1435 | | |
1436 | | // Mark the word as compound if compound permuter was set for any of |
1437 | | // the unichars on the path (usually this will happen for unichars |
1438 | | // that are compounding operators, like "-" and "/"). |
1439 | 2.88M | if (!compound && curr_vse->dawg_info && curr_vse->dawg_info->permuter == COMPOUND_PERM) { |
1440 | 0 | compound = true; |
1441 | 0 | } |
1442 | | |
1443 | | // Update curr_* pointers. |
1444 | 2.88M | curr_vse = curr_vse->parent_vse; |
1445 | 2.88M | if (curr_vse == nullptr) { |
1446 | 399k | break; |
1447 | 399k | } |
1448 | 2.48M | curr_b = curr_vse->curr_b; |
1449 | 2.48M | } |
1450 | 399k | ASSERT_HOST(i == 0); // check that we recorded all the unichar ids. |
1451 | 399k | ASSERT_HOST(total_blobs == word_res->ratings->dimension()); |
1452 | | // Re-adjust shape cost to include the updated width-to-height variance. |
1453 | 399k | if (full_wh_ratio_mean != 0.0f) { |
1454 | 0 | vse->associate_stats.shape_cost += vse->associate_stats.full_wh_ratio_var; |
1455 | 0 | } |
1456 | | |
1457 | 399k | word->set_rating(vse->ratings_sum); |
1458 | 399k | word->set_certainty(vse->min_certainty); |
1459 | 399k | word->set_x_heights(vse->consistency_info.BodyMinXHeight(), |
1460 | 399k | vse->consistency_info.BodyMaxXHeight()); |
1461 | 399k | if (vse->dawg_info != nullptr) { |
1462 | 0 | word->set_permuter(compound ? COMPOUND_PERM : vse->dawg_info->permuter); |
1463 | 399k | } else if (language_model_ngram_on && !vse->ngram_info->pruned) { |
1464 | 0 | word->set_permuter(NGRAM_PERM); |
1465 | 399k | } else if (vse->top_choice_flags) { |
1466 | 399k | word->set_permuter(TOP_CHOICE_PERM); |
1467 | 399k | } else { |
1468 | 0 | word->set_permuter(NO_PERM); |
1469 | 0 | } |
1470 | 399k | word->set_dangerous_ambig_found_(!dict_->NoDangerousAmbig(word, fixpt, true, word_res->ratings)); |
1471 | 399k | return word; |
1472 | 399k | } |
1473 | | |
1474 | | } // namespace tesseract |