/src/tesseract/src/lstm/recodebeam.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: recodebeam.cpp |
3 | | // Description: Beam search to decode from the re-encoded CJK as a sequence of |
4 | | // smaller numbers in place of a single large code. |
5 | | // Author: Ray Smith |
6 | | // |
7 | | // (C) Copyright 2015, Google Inc. |
8 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
9 | | // you may not use this file except in compliance with the License. |
10 | | // You may obtain a copy of the License at |
11 | | // http://www.apache.org/licenses/LICENSE-2.0 |
12 | | // Unless required by applicable law or agreed to in writing, software |
13 | | // distributed under the License is distributed on an "AS IS" BASIS, |
14 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | | // See the License for the specific language governing permissions and |
16 | | // limitations under the License. |
17 | | // |
18 | | /////////////////////////////////////////////////////////////////////// |
19 | | |
20 | | #include "recodebeam.h" |
21 | | |
22 | | #include "networkio.h" |
23 | | #include "pageres.h" |
24 | | #include "unicharcompress.h" |
25 | | |
26 | | #include <algorithm> // for std::reverse |
27 | | |
28 | | namespace tesseract { |
29 | | |
30 | | // The beam width at each code position. |
31 | | const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = { |
32 | | 5, 10, 16, 16, 16, 16, 16, 16, 16, 16, |
33 | | }; |
34 | | |
35 | | static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"}; |
36 | | |
37 | | // Prints debug details of the node. |
38 | | void RecodeNode::Print(int null_char, const UNICHARSET &unicharset, |
39 | 0 | int depth) const { |
40 | 0 | if (code == null_char) { |
41 | 0 | tprintf("null_char"); |
42 | 0 | } else { |
43 | 0 | tprintf("label=%d, uid=%d=%s", code, unichar_id, |
44 | 0 | unicharset.debug_str(unichar_id).c_str()); |
45 | 0 | } |
46 | 0 | tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty, |
47 | 0 | start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "", |
48 | 0 | end_of_word ? " End" : "", permuter, code_hash); |
49 | 0 | if (depth > 0 && prev != nullptr) { |
50 | 0 | tprintf(" prev:"); |
51 | 0 | prev->Print(null_char, unicharset, depth - 1); |
52 | 0 | } else { |
53 | 0 | tprintf("\n"); |
54 | 0 | } |
55 | 0 | } |
56 | | |
57 | | // Borrows the pointer, which is expected to survive until *this is deleted. |
58 | | RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder, |
59 | | int null_char, bool simple_text, Dict *dict) |
60 | 1 | : recoder_(recoder), |
61 | 1 | beam_size_(0), |
62 | 1 | top_code_(-1), |
63 | 1 | second_code_(-1), |
64 | 1 | dict_(dict), |
65 | 1 | space_delimited_(true), |
66 | 1 | is_simple_text_(simple_text), |
67 | 1 | null_char_(null_char) { |
68 | 1 | if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) { |
69 | 0 | space_delimited_ = false; |
70 | 0 | } |
71 | 1 | } |
72 | | |
73 | 0 | RecodeBeamSearch::~RecodeBeamSearch() { |
74 | 0 | for (auto data : beam_) { |
75 | 0 | delete data; |
76 | 0 | } |
77 | 0 | for (auto data : secondary_beam_) { |
78 | 0 | delete data; |
79 | 0 | } |
80 | 0 | } |
81 | | |
82 | | // Decodes the set of network outputs, storing the lattice internally. |
83 | | void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio, |
84 | | double cert_offset, double worst_dict_cert, |
85 | 96.1k | const UNICHARSET *charset, int lstm_choice_mode) { |
86 | 96.1k | beam_size_ = 0; |
87 | 96.1k | int width = output.Width(); |
88 | 96.1k | if (lstm_choice_mode) { |
89 | 0 | timesteps.clear(); |
90 | 0 | } |
91 | 1.43M | for (int t = 0; t < width; ++t) { |
92 | 1.33M | ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); |
93 | 1.33M | DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, |
94 | 1.33M | charset); |
95 | 1.33M | if (lstm_choice_mode) { |
96 | 0 | SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t); |
97 | 0 | } |
98 | 1.33M | } |
99 | 96.1k | } |
100 | | void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float> &output, |
101 | | double dict_ratio, double cert_offset, |
102 | | double worst_dict_cert, |
103 | 0 | const UNICHARSET *charset) { |
104 | 0 | beam_size_ = 0; |
105 | 0 | int width = output.dim1(); |
106 | 0 | for (int t = 0; t < width; ++t) { |
107 | 0 | ComputeTopN(output[t], output.dim2(), kBeamWidths[0]); |
108 | 0 | DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset); |
109 | 0 | } |
110 | 0 | } |
111 | | |
112 | | void RecodeBeamSearch::DecodeSecondaryBeams( |
113 | | const NetworkIO &output, double dict_ratio, double cert_offset, |
114 | 0 | double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) { |
115 | 0 | for (auto data : secondary_beam_) { |
116 | 0 | delete data; |
117 | 0 | } |
118 | 0 | secondary_beam_.clear(); |
119 | 0 | if (character_boundaries_.size() < 2) { |
120 | 0 | return; |
121 | 0 | } |
122 | 0 | int width = output.Width(); |
123 | 0 | unsigned bucketNumber = 0; |
124 | 0 | for (int t = 0; t < width; ++t) { |
125 | 0 | while ((bucketNumber + 1) < character_boundaries_.size() && |
126 | 0 | t >= character_boundaries_[bucketNumber + 1]) { |
127 | 0 | ++bucketNumber; |
128 | 0 | } |
129 | 0 | ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t), |
130 | 0 | output.NumFeatures(), kBeamWidths[0]); |
131 | 0 | DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset, |
132 | 0 | worst_dict_cert, charset); |
133 | 0 | } |
134 | 0 | } |
135 | | |
136 | | void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs, |
137 | | int num_outputs, |
138 | | const UNICHARSET *charset, |
139 | 0 | int xCoord) { |
140 | 0 | std::vector<std::pair<const char *, float>> choices; |
141 | 0 | for (int i = 0; i < num_outputs; ++i) { |
142 | 0 | if (outputs[i] >= 0.01f) { |
143 | 0 | const char *character; |
144 | 0 | if (i + 2 >= num_outputs) { |
145 | 0 | character = ""; |
146 | 0 | } else if (i > 0) { |
147 | 0 | character = charset->id_to_unichar_ext(i + 2); |
148 | 0 | } else { |
149 | 0 | character = charset->id_to_unichar_ext(i); |
150 | 0 | } |
151 | 0 | size_t pos = 0; |
152 | | // order the possible choices within one timestep |
153 | | // beginning with the most likely |
154 | 0 | while (choices.size() > pos && choices[pos].second > outputs[i]) { |
155 | 0 | pos++; |
156 | 0 | } |
157 | 0 | choices.insert(choices.begin() + pos, |
158 | 0 | std::pair<const char *, float>(character, outputs[i])); |
159 | 0 | } |
160 | 0 | } |
161 | 0 | timesteps.push_back(choices); |
162 | 0 | } |
163 | | |
164 | 0 | void RecodeBeamSearch::segmentTimestepsByCharacters() { |
165 | 0 | for (unsigned i = 1; i < character_boundaries_.size(); ++i) { |
166 | 0 | std::vector<std::vector<std::pair<const char *, float>>> segment; |
167 | 0 | for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i]; |
168 | 0 | ++j) { |
169 | 0 | segment.push_back(timesteps[j]); |
170 | 0 | } |
171 | 0 | segmentedTimesteps.push_back(segment); |
172 | 0 | } |
173 | 0 | } |
174 | | std::vector<std::vector<std::pair<const char *, float>>> |
175 | | RecodeBeamSearch::combineSegmentedTimesteps( |
176 | | std::vector<std::vector<std::vector<std::pair<const char *, float>>>> |
177 | 0 | *segmentedTimesteps) { |
178 | 0 | std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps; |
179 | 0 | for (auto &segmentedTimestep : *segmentedTimesteps) { |
180 | 0 | for (auto &j : segmentedTimestep) { |
181 | 0 | combined_timesteps.push_back(j); |
182 | 0 | } |
183 | 0 | } |
184 | 0 | return combined_timesteps; |
185 | 0 | } |
186 | | |
187 | | void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts, |
188 | | std::vector<int> *ends, |
189 | | std::vector<int> *char_bounds_, |
190 | 96.1k | int maxWidth) { |
191 | 96.1k | char_bounds_->push_back(0); |
192 | 238k | for (unsigned i = 0; i < ends->size(); ++i) { |
193 | 141k | int middle = ((*starts)[i + 1] - (*ends)[i]) / 2; |
194 | 141k | char_bounds_->push_back((*ends)[i] + middle); |
195 | 141k | } |
196 | 96.1k | char_bounds_->pop_back(); |
197 | 96.1k | char_bounds_->push_back(maxWidth); |
198 | 96.1k | } |
199 | | |
200 | | // Returns the best path as labels/scores/xcoords similar to simple CTC. |
201 | | void RecodeBeamSearch::ExtractBestPathAsLabels( |
202 | 0 | std::vector<int> *labels, std::vector<int> *xcoords) const { |
203 | 0 | labels->clear(); |
204 | 0 | xcoords->clear(); |
205 | 0 | std::vector<const RecodeNode *> best_nodes; |
206 | 0 | ExtractBestPaths(&best_nodes, nullptr); |
207 | | // Now just run CTC on the best nodes. |
208 | 0 | int t = 0; |
209 | 0 | int width = best_nodes.size(); |
210 | 0 | while (t < width) { |
211 | 0 | int label = best_nodes[t]->code; |
212 | 0 | if (label != null_char_) { |
213 | 0 | labels->push_back(label); |
214 | 0 | xcoords->push_back(t); |
215 | 0 | } |
216 | 0 | while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) { |
217 | 0 | } |
218 | 0 | } |
219 | 0 | xcoords->push_back(width); |
220 | 0 | } |
221 | | |
222 | | // Returns the best path as unichar-ids/certs/ratings/xcoords skipping |
223 | | // duplicates, nulls and intermediate parts. |
224 | | void RecodeBeamSearch::ExtractBestPathAsUnicharIds( |
225 | | bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids, |
226 | | std::vector<float> *certs, std::vector<float> *ratings, |
227 | 0 | std::vector<int> *xcoords) const { |
228 | 0 | std::vector<const RecodeNode *> best_nodes; |
229 | 0 | ExtractBestPaths(&best_nodes, nullptr); |
230 | 0 | ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords); |
231 | 0 | if (debug) { |
232 | 0 | DebugPath(unicharset, best_nodes); |
233 | 0 | DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings, |
234 | 0 | *xcoords); |
235 | 0 | } |
236 | 0 | } |
237 | | |
238 | | // Returns the best path as a set of WERD_RES. |
239 | | void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box, |
240 | | float scale_factor, bool debug, |
241 | | const UNICHARSET *unicharset, |
242 | | PointerVector<WERD_RES> *words, |
243 | 96.1k | int lstm_choice_mode) { |
244 | 96.1k | words->truncate(0); |
245 | 96.1k | std::vector<int> unichar_ids; |
246 | 96.1k | std::vector<float> certs; |
247 | 96.1k | std::vector<float> ratings; |
248 | 96.1k | std::vector<int> xcoords; |
249 | 96.1k | std::vector<const RecodeNode *> best_nodes; |
250 | 96.1k | std::vector<const RecodeNode *> second_nodes; |
251 | 96.1k | character_boundaries_.clear(); |
252 | 96.1k | ExtractBestPaths(&best_nodes, &second_nodes); |
253 | 96.1k | if (debug) { |
254 | 0 | DebugPath(unicharset, best_nodes); |
255 | 0 | ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings, |
256 | 0 | &xcoords); |
257 | 0 | tprintf("\nSecond choice path:\n"); |
258 | 0 | DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, |
259 | 0 | xcoords); |
260 | 0 | } |
261 | | // If lstm choice mode is required in granularity level 2, it stores the x |
262 | | // Coordinates of every chosen character, to match the alternative choices to |
263 | | // it. |
264 | 96.1k | ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords, |
265 | 96.1k | &character_boundaries_); |
266 | 96.1k | int num_ids = unichar_ids.size(); |
267 | 96.1k | if (debug) { |
268 | 0 | DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings, |
269 | 0 | xcoords); |
270 | 0 | } |
271 | | // Convert labels to unichar-ids. |
272 | 96.1k | int word_end = 0; |
273 | 96.1k | float prev_space_cert = 0.0f; |
274 | 186k | for (int word_start = 0; word_start < num_ids; word_start = word_end) { |
275 | 127k | for (word_end = word_start + 1; word_end < num_ids; ++word_end) { |
276 | | // A word is terminated when a space character or start_of_word flag is |
277 | | // hit. We also want to force a separate word for every non |
278 | | // space-delimited character when not in a dictionary context. |
279 | 51.1k | if (unichar_ids[word_end] == UNICHAR_SPACE) { |
280 | 14.5k | break; |
281 | 14.5k | } |
282 | 36.5k | int index = xcoords[word_end]; |
283 | 36.5k | if (best_nodes[index]->start_of_word) { |
284 | 44 | break; |
285 | 44 | } |
286 | 36.5k | if (best_nodes[index]->permuter == TOP_CHOICE_PERM && |
287 | 36.5k | (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) || |
288 | 21.5k | !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) { |
289 | 0 | break; |
290 | 0 | } |
291 | 36.5k | } |
292 | 90.7k | float space_cert = 0.0f; |
293 | 90.7k | if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) { |
294 | 14.5k | space_cert = certs[word_end]; |
295 | 14.5k | } |
296 | 90.7k | bool leading_space = |
297 | 90.7k | word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE; |
298 | | // Create a WERD_RES for the output word. |
299 | 90.7k | WERD_RES *word_res = |
300 | 90.7k | InitializeWord(leading_space, line_box, word_start, word_end, |
301 | 90.7k | std::min(space_cert, prev_space_cert), unicharset, |
302 | 90.7k | xcoords, scale_factor); |
303 | 218k | for (int i = word_start; i < word_end; ++i) { |
304 | 127k | auto *choices = new BLOB_CHOICE_LIST; |
305 | 127k | BLOB_CHOICE_IT bc_it(choices); |
306 | 127k | auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1, |
307 | 127k | 1.0f, static_cast<float>(INT16_MAX), 0.0f, |
308 | 127k | BCC_STATIC_CLASSIFIER); |
309 | 127k | int col = i - word_start; |
310 | 127k | choice->set_matrix_cell(col, col); |
311 | 127k | bc_it.add_after_then_move(choice); |
312 | 127k | word_res->ratings->put(col, col, choices); |
313 | 127k | } |
314 | 90.7k | int index = xcoords[word_end - 1]; |
315 | 90.7k | word_res->FakeWordFromRatings(best_nodes[index]->permuter); |
316 | 90.7k | words->push_back(word_res); |
317 | 90.7k | prev_space_cert = space_cert; |
318 | 90.7k | if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) { |
319 | 14.5k | ++word_end; |
320 | 14.5k | } |
321 | 90.7k | } |
322 | 96.1k | } |
323 | | |
324 | | struct greater_than { |
325 | 0 | inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const { |
326 | 0 | return (node1->score > node2->score); |
327 | 0 | } |
328 | | }; |
329 | | |
330 | | void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs, |
331 | | const UNICHARSET *charset, |
332 | 0 | bool secondary) const { |
333 | 0 | std::vector<std::vector<const RecodeNode *>> topology; |
334 | 0 | std::unordered_set<const RecodeNode *> visited; |
335 | 0 | const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_; |
336 | | // create the topology |
337 | 0 | for (int step = beam.size() - 1; step >= 0; --step) { |
338 | 0 | std::vector<const RecodeNode *> layer; |
339 | 0 | topology.push_back(layer); |
340 | 0 | } |
341 | | // fill the topology with depths first |
342 | 0 | for (int step = beam.size() - 1; step >= 0; --step) { |
343 | 0 | std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap(); |
344 | 0 | for (auto &&node : heaps) { |
345 | 0 | int backtracker = 0; |
346 | 0 | const RecodeNode *curr = &node.data(); |
347 | 0 | while (curr != nullptr && !visited.count(curr)) { |
348 | 0 | visited.insert(curr); |
349 | 0 | topology[step - backtracker].push_back(curr); |
350 | 0 | curr = curr->prev; |
351 | 0 | ++backtracker; |
352 | 0 | } |
353 | 0 | } |
354 | 0 | } |
355 | 0 | int ct = 0; |
356 | 0 | unsigned cb = 1; |
357 | 0 | for (const std::vector<const RecodeNode *> &layer : topology) { |
358 | 0 | if (cb >= character_boundaries_.size()) { |
359 | 0 | break; |
360 | 0 | } |
361 | 0 | if (ct == character_boundaries_[cb]) { |
362 | 0 | tprintf("***\n"); |
363 | 0 | ++cb; |
364 | 0 | } |
365 | 0 | for (const RecodeNode *node : layer) { |
366 | 0 | const char *code; |
367 | 0 | int intCode; |
368 | 0 | if (node->unichar_id != INVALID_UNICHAR_ID) { |
369 | 0 | code = charset->id_to_unichar(node->unichar_id); |
370 | 0 | intCode = node->unichar_id; |
371 | 0 | } else if (node->code == null_char_) { |
372 | 0 | intCode = 0; |
373 | 0 | code = " "; |
374 | 0 | } else { |
375 | 0 | intCode = 666; |
376 | 0 | code = "*"; |
377 | 0 | } |
378 | 0 | int intPrevCode = 0; |
379 | 0 | const char *prevCode; |
380 | 0 | float prevScore = 0; |
381 | 0 | if (node->prev != nullptr) { |
382 | 0 | prevScore = node->prev->score; |
383 | 0 | if (node->prev->unichar_id != INVALID_UNICHAR_ID) { |
384 | 0 | prevCode = charset->id_to_unichar(node->prev->unichar_id); |
385 | 0 | intPrevCode = node->prev->unichar_id; |
386 | 0 | } else if (node->code == null_char_) { |
387 | 0 | intPrevCode = 0; |
388 | 0 | prevCode = " "; |
389 | 0 | } else { |
390 | 0 | prevCode = "*"; |
391 | 0 | intPrevCode = 666; |
392 | 0 | } |
393 | 0 | } else { |
394 | 0 | prevCode = " "; |
395 | 0 | } |
396 | 0 | if (uids) { |
397 | 0 | tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode, |
398 | 0 | node->score); |
399 | 0 | } else { |
400 | 0 | tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score); |
401 | 0 | } |
402 | 0 | } |
403 | 0 | tprintf("-\n"); |
404 | 0 | ++ct; |
405 | 0 | } |
406 | 0 | tprintf("***\n"); |
407 | 0 | } |
408 | | |
409 | 0 | void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET *unicharset) { |
410 | 0 | if (character_boundaries_.size() < 2) { |
411 | 0 | return; |
412 | 0 | } |
413 | | // For the first iteration the original beam is analyzed. After that a |
414 | | // new beam is calculated based on the results from the original beam. |
415 | 0 | std::vector<RecodeBeam *> ¤tBeam = |
416 | 0 | secondary_beam_.empty() ? beam_ : secondary_beam_; |
417 | 0 | character_boundaries_[0] = 0; |
418 | 0 | for (unsigned j = 1; j < character_boundaries_.size(); ++j) { |
419 | 0 | std::vector<int> unichar_ids; |
420 | 0 | std::vector<float> certs; |
421 | 0 | std::vector<float> ratings; |
422 | 0 | std::vector<int> xcoords; |
423 | 0 | int backpath = character_boundaries_[j] - character_boundaries_[j - 1]; |
424 | 0 | std::vector<tesseract::RecodePair> &heaps = |
425 | 0 | currentBeam.at(character_boundaries_[j] - 1)->beams_->heap(); |
426 | 0 | std::vector<const RecodeNode *> best_nodes; |
427 | 0 | std::vector<const RecodeNode *> best; |
428 | | // Scan the segmented node chain for valid unichar ids. |
429 | 0 | for (auto &&entry : heaps) { |
430 | 0 | bool validChar = false; |
431 | 0 | int backcounter = 0; |
432 | 0 | const RecodeNode *node = &entry.data(); |
433 | 0 | while (node != nullptr && backcounter < backpath) { |
434 | 0 | if (node->code != null_char_ && |
435 | 0 | node->unichar_id != INVALID_UNICHAR_ID) { |
436 | 0 | validChar = true; |
437 | 0 | break; |
438 | 0 | } |
439 | 0 | node = node->prev; |
440 | 0 | ++backcounter; |
441 | 0 | } |
442 | 0 | if (validChar) { |
443 | 0 | best.push_back(&entry.data()); |
444 | 0 | } |
445 | 0 | } |
446 | | // find the best rated segmented node chain and extract the unichar id. |
447 | 0 | if (!best.empty()) { |
448 | 0 | std::sort(best.begin(), best.end(), greater_than()); |
449 | 0 | ExtractPath(best[0], &best_nodes, backpath); |
450 | 0 | ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, |
451 | 0 | &xcoords); |
452 | 0 | } |
453 | 0 | if (!unichar_ids.empty()) { |
454 | 0 | int bestPos = 0; |
455 | 0 | for (unsigned i = 1; i < unichar_ids.size(); ++i) { |
456 | 0 | if (ratings[i] < ratings[bestPos]) { |
457 | 0 | bestPos = i; |
458 | 0 | } |
459 | 0 | } |
460 | | #if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60). |
461 | | int bestCode = -10; |
462 | | for (auto &node : best_nodes) { |
463 | | if (node->unichar_id == unichar_ids[bestPos]) { |
464 | | bestCode = node->code; |
465 | | } |
466 | | } |
467 | | #endif |
468 | | // Exclude the best choice for the followup decoding. |
469 | 0 | std::unordered_set<int> excludeCodeList; |
470 | 0 | for (auto &best_node : best_nodes) { |
471 | 0 | if (best_node->code != null_char_) { |
472 | 0 | excludeCodeList.insert(best_node->code); |
473 | 0 | } |
474 | 0 | } |
475 | 0 | if (j - 1 < excludedUnichars.size()) { |
476 | 0 | for (auto elem : excludeCodeList) { |
477 | 0 | excludedUnichars[j - 1].insert(elem); |
478 | 0 | } |
479 | 0 | } else { |
480 | 0 | excludedUnichars.push_back(excludeCodeList); |
481 | 0 | } |
482 | | // Save the best choice for the choice iterator. |
483 | 0 | if (j - 1 < ctc_choices.size()) { |
484 | 0 | int id = unichar_ids[bestPos]; |
485 | 0 | const char *result = unicharset->id_to_unichar_ext(id); |
486 | 0 | float rating = ratings[bestPos]; |
487 | 0 | ctc_choices[j - 1].push_back( |
488 | 0 | std::pair<const char *, float>(result, rating)); |
489 | 0 | } else { |
490 | 0 | std::vector<std::pair<const char *, float>> choice; |
491 | 0 | int id = unichar_ids[bestPos]; |
492 | 0 | const char *result = unicharset->id_to_unichar_ext(id); |
493 | 0 | float rating = ratings[bestPos]; |
494 | 0 | choice.emplace_back(result, rating); |
495 | 0 | ctc_choices.push_back(choice); |
496 | 0 | } |
497 | | // fill the blank spot with an empty array |
498 | 0 | } else { |
499 | 0 | if (j - 1 >= excludedUnichars.size()) { |
500 | 0 | std::unordered_set<int> excludeCodeList; |
501 | 0 | excludedUnichars.push_back(excludeCodeList); |
502 | 0 | } |
503 | 0 | if (j - 1 >= ctc_choices.size()) { |
504 | 0 | std::vector<std::pair<const char *, float>> choice; |
505 | 0 | ctc_choices.push_back(choice); |
506 | 0 | } |
507 | 0 | } |
508 | 0 | } |
509 | 0 | for (auto data : secondary_beam_) { |
510 | 0 | delete data; |
511 | 0 | } |
512 | 0 | secondary_beam_.clear(); |
513 | 0 | } |
514 | | |
515 | | // Generates debug output of the content of the beams after a Decode. |
516 | 0 | void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const { |
517 | 0 | for (int p = 0; p < beam_size_; ++p) { |
518 | 0 | for (int d = 0; d < 2; ++d) { |
519 | 0 | for (int c = 0; c < NC_COUNT; ++c) { |
520 | 0 | auto cont = static_cast<NodeContinuation>(c); |
521 | 0 | int index = BeamIndex(d, cont, 0); |
522 | 0 | if (beam_[p]->beams_[index].empty()) { |
523 | 0 | continue; |
524 | 0 | } |
525 | | // Print all the best scoring nodes for each unichar found. |
526 | 0 | tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict", |
527 | 0 | kNodeContNames[c]); |
528 | 0 | DebugBeamPos(unicharset, beam_[p]->beams_[index]); |
529 | 0 | } |
530 | 0 | } |
531 | 0 | } |
532 | 0 | } |
533 | | |
534 | | // Generates debug output of the content of a single beam position. |
535 | | void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset, |
536 | 0 | const RecodeHeap &heap) const { |
537 | 0 | std::vector<const RecodeNode *> unichar_bests(unicharset.size()); |
538 | 0 | const RecodeNode *null_best = nullptr; |
539 | 0 | int heap_size = heap.size(); |
540 | 0 | for (int i = 0; i < heap_size; ++i) { |
541 | 0 | const RecodeNode *node = &heap.get(i).data(); |
542 | 0 | if (node->unichar_id == INVALID_UNICHAR_ID) { |
543 | 0 | if (null_best == nullptr || null_best->score < node->score) { |
544 | 0 | null_best = node; |
545 | 0 | } |
546 | 0 | } else { |
547 | 0 | if (unichar_bests[node->unichar_id] == nullptr || |
548 | 0 | unichar_bests[node->unichar_id]->score < node->score) { |
549 | 0 | unichar_bests[node->unichar_id] = node; |
550 | 0 | } |
551 | 0 | } |
552 | 0 | } |
553 | 0 | for (auto &unichar_best : unichar_bests) { |
554 | 0 | if (unichar_best != nullptr) { |
555 | 0 | const RecodeNode &node = *unichar_best; |
556 | 0 | node.Print(null_char_, unicharset, 1); |
557 | 0 | } |
558 | 0 | } |
559 | 0 | if (null_best != nullptr) { |
560 | 0 | null_best->Print(null_char_, unicharset, 1); |
561 | 0 | } |
562 | 0 | } |
563 | | |
564 | | // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping |
565 | | // duplicates, nulls and intermediate parts. |
566 | | /* static */ |
567 | | void RecodeBeamSearch::ExtractPathAsUnicharIds( |
568 | | const std::vector<const RecodeNode *> &best_nodes, |
569 | | std::vector<int> *unichar_ids, std::vector<float> *certs, |
570 | | std::vector<float> *ratings, std::vector<int> *xcoords, |
571 | 96.1k | std::vector<int> *character_boundaries) { |
572 | 96.1k | unichar_ids->clear(); |
573 | 96.1k | certs->clear(); |
574 | 96.1k | ratings->clear(); |
575 | 96.1k | xcoords->clear(); |
576 | 96.1k | std::vector<int> starts; |
577 | 96.1k | std::vector<int> ends; |
578 | | // Backtrack extracting only valid, non-duplicate unichar-ids. |
579 | 96.1k | int t = 0; |
580 | 96.1k | int width = best_nodes.size(); |
581 | 317k | while (t < width) { |
582 | 221k | double certainty = 0.0; |
583 | 221k | double rating = 0.0; |
584 | 1.24M | while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) { |
585 | 1.02M | double cert = best_nodes[t++]->certainty; |
586 | 1.02M | if (cert < certainty) { |
587 | 386k | certainty = cert; |
588 | 386k | } |
589 | 1.02M | rating -= cert; |
590 | 1.02M | } |
591 | 221k | starts.push_back(t); |
592 | 221k | if (t < width) { |
593 | 141k | int unichar_id = best_nodes[t]->unichar_id; |
594 | 141k | if (unichar_id == UNICHAR_SPACE && !certs->empty() && |
595 | 141k | best_nodes[t]->permuter != NO_PERM) { |
596 | | // All the rating and certainty go on the previous character except |
597 | | // for the space itself. |
598 | 12.4k | if (certainty < certs->back()) { |
599 | 3.47k | certs->back() = certainty; |
600 | 3.47k | } |
601 | 12.4k | ratings->back() += rating; |
602 | 12.4k | certainty = 0.0; |
603 | 12.4k | rating = 0.0; |
604 | 12.4k | } |
605 | 141k | unichar_ids->push_back(unichar_id); |
606 | 141k | xcoords->push_back(t); |
607 | 312k | do { |
608 | 312k | double cert = best_nodes[t++]->certainty; |
609 | | // Special-case NO-PERM space to forget the certainty of the previous |
610 | | // nulls. See long comment in ContinueContext. |
611 | 312k | if (cert < certainty || (unichar_id == UNICHAR_SPACE && |
612 | 173k | best_nodes[t - 1]->permuter == NO_PERM)) { |
613 | 173k | certainty = cert; |
614 | 173k | } |
615 | 312k | rating -= cert; |
616 | 312k | } while (t < width && best_nodes[t]->duplicate); |
617 | 141k | ends.push_back(t); |
618 | 141k | certs->push_back(certainty); |
619 | 141k | ratings->push_back(rating); |
620 | 141k | } else if (!certs->empty()) { |
621 | 62.8k | if (certainty < certs->back()) { |
622 | 13.0k | certs->back() = certainty; |
623 | 13.0k | } |
624 | 62.8k | ratings->back() += rating; |
625 | 62.8k | } |
626 | 221k | } |
627 | 96.1k | starts.push_back(width); |
628 | 96.1k | if (character_boundaries != nullptr) { |
629 | 96.1k | calculateCharBoundaries(&starts, &ends, character_boundaries, width); |
630 | 96.1k | } |
631 | 96.1k | xcoords->push_back(width); |
632 | 96.1k | } |
633 | | |
634 | | // Sets up a word with the ratings matrix and fake blobs with boxes in the |
635 | | // right places. |
636 | | WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space, |
637 | | const TBOX &line_box, int word_start, |
638 | | int word_end, float space_certainty, |
639 | | const UNICHARSET *unicharset, |
640 | | const std::vector<int> &xcoords, |
641 | 90.7k | float scale_factor) { |
642 | | // Make a fake blob for each non-zero label. |
643 | 90.7k | C_BLOB_LIST blobs; |
644 | 90.7k | C_BLOB_IT b_it(&blobs); |
645 | 218k | for (int i = word_start; i < word_end; ++i) { |
646 | 127k | if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) { |
647 | 127k | TBOX box(static_cast<int16_t>( |
648 | 127k | std::floor(character_boundaries_[i] * scale_factor)) + |
649 | 127k | line_box.left(), |
650 | 127k | line_box.bottom(), |
651 | 127k | static_cast<int16_t>( |
652 | 127k | std::ceil(character_boundaries_[i + 1] * scale_factor)) + |
653 | 127k | line_box.left(), |
654 | 127k | line_box.top()); |
655 | 127k | b_it.add_after_then_move(C_BLOB::FakeBlob(box)); |
656 | 127k | } |
657 | 127k | } |
658 | | // Make a fake word from the blobs. |
659 | 90.7k | WERD *word = new WERD(&blobs, leading_space, nullptr); |
660 | | // Make a WERD_RES from the word. |
661 | 90.7k | auto *word_res = new WERD_RES(word); |
662 | 90.7k | word_res->end = word_end - word_start + leading_space; |
663 | 90.7k | word_res->uch_set = unicharset; |
664 | 90.7k | word_res->combination = true; // Give it ownership of the word. |
665 | 90.7k | word_res->space_certainty = space_certainty; |
666 | 90.7k | word_res->ratings = new MATRIX(word_end - word_start, 1); |
667 | 90.7k | return word_res; |
668 | 90.7k | } |
669 | | |
670 | | // Fills top_n_flags_ with bools that are true iff the corresponding output |
671 | | // is one of the top_n. |
672 | | void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs, |
673 | 1.33M | int top_n) { |
674 | 1.33M | top_n_flags_.clear(); |
675 | 1.33M | top_n_flags_.resize(num_outputs, TN_ALSO_RAN); |
676 | 1.33M | top_code_ = -1; |
677 | 1.33M | second_code_ = -1; |
678 | 1.33M | top_heap_.clear(); |
679 | 149M | for (int i = 0; i < num_outputs; ++i) { |
680 | 148M | if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) { |
681 | 23.3M | TopPair entry(outputs[i], i); |
682 | 23.3M | top_heap_.Push(&entry); |
683 | 23.3M | if (top_heap_.size() > top_n) { |
684 | 16.6M | top_heap_.Pop(&entry); |
685 | 16.6M | } |
686 | 23.3M | } |
687 | 148M | } |
688 | 8.02M | while (!top_heap_.empty()) { |
689 | 6.68M | TopPair entry; |
690 | 6.68M | top_heap_.Pop(&entry); |
691 | 6.68M | if (top_heap_.size() > 1) { |
692 | 4.01M | top_n_flags_[entry.data()] = TN_TOPN; |
693 | 4.01M | } else { |
694 | 2.67M | top_n_flags_[entry.data()] = TN_TOP2; |
695 | 2.67M | if (top_heap_.empty()) { |
696 | 1.33M | top_code_ = entry.data(); |
697 | 1.33M | } else { |
698 | 1.33M | second_code_ = entry.data(); |
699 | 1.33M | } |
700 | 2.67M | } |
701 | 6.68M | } |
702 | 1.33M | top_n_flags_[null_char_] = TN_TOP2; |
703 | 1.33M | } |
704 | | |
705 | | void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList, |
706 | | const float *outputs, int num_outputs, |
707 | 0 | int top_n) { |
708 | 0 | top_n_flags_.clear(); |
709 | 0 | top_n_flags_.resize(num_outputs, TN_ALSO_RAN); |
710 | 0 | top_code_ = -1; |
711 | 0 | second_code_ = -1; |
712 | 0 | top_heap_.clear(); |
713 | 0 | for (int i = 0; i < num_outputs; ++i) { |
714 | 0 | if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) && |
715 | 0 | !exList->count(i)) { |
716 | 0 | TopPair entry(outputs[i], i); |
717 | 0 | top_heap_.Push(&entry); |
718 | 0 | if (top_heap_.size() > top_n) { |
719 | 0 | top_heap_.Pop(&entry); |
720 | 0 | } |
721 | 0 | } |
722 | 0 | } |
723 | 0 | while (!top_heap_.empty()) { |
724 | 0 | TopPair entry; |
725 | 0 | top_heap_.Pop(&entry); |
726 | 0 | if (top_heap_.size() > 1) { |
727 | 0 | top_n_flags_[entry.data()] = TN_TOPN; |
728 | 0 | } else { |
729 | 0 | top_n_flags_[entry.data()] = TN_TOP2; |
730 | 0 | if (top_heap_.empty()) { |
731 | 0 | top_code_ = entry.data(); |
732 | 0 | } else { |
733 | 0 | second_code_ = entry.data(); |
734 | 0 | } |
735 | 0 | } |
736 | 0 | } |
737 | 0 | top_n_flags_[null_char_] = TN_TOP2; |
738 | 0 | } |
739 | | |
740 | | // Adds the computation for the current time-step to the beam. Call at each |
741 | | // time-step in sequence from left to right. outputs is the activation vector |
742 | | // for the current timestep. |
743 | | void RecodeBeamSearch::DecodeStep(const float *outputs, int t, |
744 | | double dict_ratio, double cert_offset, |
745 | | double worst_dict_cert, |
746 | 1.33M | const UNICHARSET *charset, bool debug) { |
747 | 1.33M | if (t == static_cast<int>(beam_.size())) { |
748 | 677 | beam_.push_back(new RecodeBeam); |
749 | 677 | } |
750 | 1.33M | RecodeBeam *step = beam_[t]; |
751 | 1.33M | beam_size_ = t + 1; |
752 | 1.33M | step->Clear(); |
753 | 1.33M | if (t == 0) { |
754 | | // The first step can only use singles and initials. |
755 | 96.1k | ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2, |
756 | 96.1k | charset, dict_ratio, cert_offset, worst_dict_cert, step); |
757 | 96.1k | if (dict_ != nullptr) { |
758 | 96.1k | ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, |
759 | 96.1k | TN_TOP2, charset, dict_ratio, cert_offset, |
760 | 96.1k | worst_dict_cert, step); |
761 | 96.1k | } |
762 | 1.24M | } else { |
763 | 1.24M | RecodeBeam *prev = beam_[t - 1]; |
764 | 1.24M | if (debug) { |
765 | 0 | int beam_index = BeamIndex(true, NC_ANYTHING, 0); |
766 | 0 | for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { |
767 | 0 | std::vector<const RecodeNode *> path; |
768 | 0 | ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); |
769 | 0 | tprintf("Step %d: Dawg beam %d:\n", t, i); |
770 | 0 | DebugPath(charset, path); |
771 | 0 | } |
772 | 0 | beam_index = BeamIndex(false, NC_ANYTHING, 0); |
773 | 0 | for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { |
774 | 0 | std::vector<const RecodeNode *> path; |
775 | 0 | ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); |
776 | 0 | tprintf("Step %d: Non-Dawg beam %d:\n", t, i); |
777 | 0 | DebugPath(charset, path); |
778 | 0 | } |
779 | 0 | } |
780 | 1.24M | int total_beam = 0; |
781 | | // Work through the scores by group (top-2, top-n, the rest) while the beam |
782 | | // is empty. This enables extending the context using only the top-n results |
783 | | // first, which may have an empty intersection with the valid codes, so we |
784 | | // fall back to the rest if the beam is empty. |
785 | 2.48M | for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) { |
786 | 1.24M | auto top_n = static_cast<TopNState>(tn); |
787 | 75.7M | for (int index = 0; index < kNumBeams; ++index) { |
788 | | // Working backwards through the heaps doesn't guarantee that we see the |
789 | | // best first, but it comes before a lot of the worst, so it is slightly |
790 | | // more efficient than going forwards. |
791 | 91.2M | for (int i = prev->beams_[index].size() - 1; i >= 0; --i) { |
792 | 16.8M | ContinueContext(&prev->beams_[index].get(i).data(), index, outputs, |
793 | 16.8M | top_n, charset, dict_ratio, cert_offset, |
794 | 16.8M | worst_dict_cert, step); |
795 | 16.8M | } |
796 | 74.4M | } |
797 | 75.7M | for (int index = 0; index < kNumBeams; ++index) { |
798 | 74.4M | if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) { |
799 | 24.8M | total_beam += step->beams_[index].size(); |
800 | 24.8M | } |
801 | 74.4M | } |
802 | 1.24M | } |
803 | | // Special case for the best initial dawg. Push it on the heap if good |
804 | | // enough, but there is only one, so it doesn't blow up the beam. |
805 | 4.96M | for (int c = 0; c < NC_COUNT; ++c) { |
806 | 3.72M | if (step->best_initial_dawgs_[c].code >= 0) { |
807 | 332k | int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0); |
808 | 332k | RecodeHeap *dawg_heap = &step->beams_[index]; |
809 | 332k | PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c], |
810 | 332k | dawg_heap); |
811 | 332k | } |
812 | 3.72M | } |
813 | 1.24M | } |
814 | 1.33M | } |
815 | | |
816 | | void RecodeBeamSearch::DecodeSecondaryStep( |
817 | | const float *outputs, int t, double dict_ratio, double cert_offset, |
818 | 0 | double worst_dict_cert, const UNICHARSET *charset, bool debug) { |
819 | 0 | if (t == static_cast<int>(secondary_beam_.size())) { |
820 | 0 | secondary_beam_.push_back(new RecodeBeam); |
821 | 0 | } |
822 | 0 | RecodeBeam *step = secondary_beam_[t]; |
823 | 0 | step->Clear(); |
824 | 0 | if (t == 0) { |
825 | | // The first step can only use singles and initials. |
826 | 0 | ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2, |
827 | 0 | charset, dict_ratio, cert_offset, worst_dict_cert, step); |
828 | 0 | if (dict_ != nullptr) { |
829 | 0 | ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, |
830 | 0 | TN_TOP2, charset, dict_ratio, cert_offset, |
831 | 0 | worst_dict_cert, step); |
832 | 0 | } |
833 | 0 | } else { |
834 | 0 | RecodeBeam *prev = secondary_beam_[t - 1]; |
835 | 0 | if (debug) { |
836 | 0 | int beam_index = BeamIndex(true, NC_ANYTHING, 0); |
837 | 0 | for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { |
838 | 0 | std::vector<const RecodeNode *> path; |
839 | 0 | ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); |
840 | 0 | tprintf("Step %d: Dawg beam %d:\n", t, i); |
841 | 0 | DebugPath(charset, path); |
842 | 0 | } |
843 | 0 | beam_index = BeamIndex(false, NC_ANYTHING, 0); |
844 | 0 | for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { |
845 | 0 | std::vector<const RecodeNode *> path; |
846 | 0 | ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); |
847 | 0 | tprintf("Step %d: Non-Dawg beam %d:\n", t, i); |
848 | 0 | DebugPath(charset, path); |
849 | 0 | } |
850 | 0 | } |
851 | 0 | int total_beam = 0; |
852 | | // Work through the scores by group (top-2, top-n, the rest) while the beam |
853 | | // is empty. This enables extending the context using only the top-n results |
854 | | // first, which may have an empty intersection with the valid codes, so we |
855 | | // fall back to the rest if the beam is empty. |
856 | 0 | for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) { |
857 | 0 | auto top_n = static_cast<TopNState>(tn); |
858 | 0 | for (int index = 0; index < kNumBeams; ++index) { |
859 | | // Working backwards through the heaps doesn't guarantee that we see the |
860 | | // best first, but it comes before a lot of the worst, so it is slightly |
861 | | // more efficient than going forwards. |
862 | 0 | for (int i = prev->beams_[index].size() - 1; i >= 0; --i) { |
863 | 0 | ContinueContext(&prev->beams_[index].get(i).data(), index, outputs, |
864 | 0 | top_n, charset, dict_ratio, cert_offset, |
865 | 0 | worst_dict_cert, step); |
866 | 0 | } |
867 | 0 | } |
868 | 0 | for (int index = 0; index < kNumBeams; ++index) { |
869 | 0 | if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) { |
870 | 0 | total_beam += step->beams_[index].size(); |
871 | 0 | } |
872 | 0 | } |
873 | 0 | } |
874 | | // Special case for the best initial dawg. Push it on the heap if good |
875 | | // enough, but there is only one, so it doesn't blow up the beam. |
876 | 0 | for (int c = 0; c < NC_COUNT; ++c) { |
877 | 0 | if (step->best_initial_dawgs_[c].code >= 0) { |
878 | 0 | int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0); |
879 | 0 | RecodeHeap *dawg_heap = &step->beams_[index]; |
880 | 0 | PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c], |
881 | 0 | dawg_heap); |
882 | 0 | } |
883 | 0 | } |
884 | 0 | } |
885 | 0 | } |
886 | | |
887 | | // Adds to the appropriate beams the legal (according to recoder) |
888 | | // continuations of context prev, which is of the given length, using the |
889 | | // given network outputs to provide scores to the choices. Uses only those |
890 | | // choices for which top_n_flags[index] == top_n_flag. |
891 | | void RecodeBeamSearch::ContinueContext( |
892 | | const RecodeNode *prev, int index, const float *outputs, |
893 | | TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio, |
894 | 17.0M | double cert_offset, double worst_dict_cert, RecodeBeam *step) { |
895 | 17.0M | RecodedCharID prefix; |
896 | 17.0M | RecodedCharID full_code; |
897 | 17.0M | const RecodeNode *previous = prev; |
898 | 17.0M | int length = LengthFromBeamsIndex(index); |
899 | 17.0M | bool use_dawgs = IsDawgFromBeamsIndex(index); |
900 | 17.0M | NodeContinuation prev_cont = ContinuationFromBeamsIndex(index); |
901 | 17.0M | for (int p = length - 1; p >= 0 && previous != nullptr; --p) { |
902 | 0 | while (previous->duplicate || previous->code == null_char_) { |
903 | 0 | previous = previous->prev; |
904 | 0 | } |
905 | 0 | prefix.Set(p, previous->code); |
906 | 0 | full_code.Set(p, previous->code); |
907 | 0 | previous = previous->prev; |
908 | 0 | } |
909 | 17.0M | if (prev != nullptr && !is_simple_text_) { |
910 | 16.8M | if (top_n_flags_[prev->code] == top_n_flag) { |
911 | 12.2M | if (prev_cont != NC_NO_DUP) { |
912 | 11.4M | float cert = |
913 | 11.4M | NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset; |
914 | 11.4M | PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id, |
915 | 11.4M | cert, worst_dict_cert, dict_ratio, use_dawgs, |
916 | 11.4M | NC_ANYTHING, prev, step); |
917 | 11.4M | } |
918 | 12.2M | if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 && |
919 | 12.2M | prev->code != null_char_) { |
920 | 1.46M | float cert = NetworkIO::ProbToCertainty(outputs[prev->code] + |
921 | 1.46M | outputs[null_char_]) + |
922 | 1.46M | cert_offset; |
923 | 1.46M | PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id, |
924 | 1.46M | cert, worst_dict_cert, dict_ratio, use_dawgs, |
925 | 1.46M | NC_NO_DUP, prev, step); |
926 | 1.46M | } |
927 | 12.2M | } |
928 | 16.8M | if (prev_cont == NC_ONLY_DUP) { |
929 | 5.76M | return; |
930 | 5.76M | } |
931 | 11.0M | if (prev->code != null_char_ && length > 0 && |
932 | 11.0M | top_n_flags_[null_char_] == top_n_flag) { |
933 | | // Allow nulls within multi code sequences, as the nulls within are not |
934 | | // explicitly included in the code sequence. |
935 | 0 | float cert = |
936 | 0 | NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset; |
937 | 0 | PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID, |
938 | 0 | cert, worst_dict_cert, dict_ratio, use_dawgs, |
939 | 0 | NC_ANYTHING, prev, step); |
940 | 0 | } |
941 | 11.0M | } |
942 | 11.2M | const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix); |
943 | 11.2M | if (final_codes != nullptr) { |
944 | 1.24G | for (int code : *final_codes) { |
945 | 1.24G | if (top_n_flags_[code] != top_n_flag) { |
946 | 1.22G | continue; |
947 | 1.22G | } |
948 | 23.1M | if (prev != nullptr && prev->code == code && !is_simple_text_) { |
949 | 9.36M | continue; |
950 | 9.36M | } |
951 | 13.7M | float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; |
952 | 13.7M | if (cert < kMinCertainty && code != null_char_) { |
953 | 228 | continue; |
954 | 228 | } |
955 | 13.7M | full_code.Set(length, code); |
956 | 13.7M | int unichar_id = recoder_.DecodeUnichar(full_code); |
957 | | // Map the null char to INVALID. |
958 | 13.7M | if (length == 0 && code == null_char_) { |
959 | 4.12M | unichar_id = INVALID_UNICHAR_ID; |
960 | 4.12M | } |
961 | 13.7M | if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr && |
962 | 13.7M | !charset->get_enabled(unichar_id)) { |
963 | 0 | continue; // disabled by whitelist/blacklist |
964 | 0 | } |
965 | 13.7M | ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio, |
966 | 13.7M | use_dawgs, NC_ANYTHING, prev, step); |
967 | 13.7M | if (top_n_flag == TN_TOP2 && code != null_char_) { |
968 | 9.61M | float prob = outputs[code] + outputs[null_char_]; |
969 | 9.61M | if (prev != nullptr && prev_cont == NC_ANYTHING && |
970 | 9.61M | prev->code != null_char_ && |
971 | 9.61M | ((prev->code == top_code_ && code == second_code_) || |
972 | 1.27M | (code == top_code_ && prev->code == second_code_))) { |
973 | 95.3k | prob += outputs[prev->code]; |
974 | 95.3k | } |
975 | 9.61M | cert = NetworkIO::ProbToCertainty(prob) + cert_offset; |
976 | 9.61M | ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio, |
977 | 9.61M | use_dawgs, NC_ONLY_DUP, prev, step); |
978 | 9.61M | } |
979 | 13.7M | } |
980 | 11.2M | } |
981 | 11.2M | const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix); |
982 | 11.2M | if (next_codes != nullptr) { |
983 | 0 | for (int code : *next_codes) { |
984 | 0 | if (top_n_flags_[code] != top_n_flag) { |
985 | 0 | continue; |
986 | 0 | } |
987 | 0 | if (prev != nullptr && prev->code == code && !is_simple_text_) { |
988 | 0 | continue; |
989 | 0 | } |
990 | 0 | float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; |
991 | 0 | PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert, |
992 | 0 | worst_dict_cert, dict_ratio, use_dawgs, |
993 | 0 | NC_ANYTHING, prev, step); |
994 | 0 | if (top_n_flag == TN_TOP2 && code != null_char_) { |
995 | 0 | float prob = outputs[code] + outputs[null_char_]; |
996 | 0 | if (prev != nullptr && prev_cont == NC_ANYTHING && |
997 | 0 | prev->code != null_char_ && |
998 | 0 | ((prev->code == top_code_ && code == second_code_) || |
999 | 0 | (code == top_code_ && prev->code == second_code_))) { |
1000 | 0 | prob += outputs[prev->code]; |
1001 | 0 | } |
1002 | 0 | cert = NetworkIO::ProbToCertainty(prob) + cert_offset; |
1003 | 0 | PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, |
1004 | 0 | cert, worst_dict_cert, dict_ratio, use_dawgs, |
1005 | 0 | NC_ONLY_DUP, prev, step); |
1006 | 0 | } |
1007 | 0 | } |
1008 | 0 | } |
1009 | 11.2M | } |
1010 | | |
1011 | | // Continues for a new unichar, using dawg or non-dawg as per flag. |
1012 | | void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert, |
1013 | | float worst_dict_cert, float dict_ratio, |
1014 | | bool use_dawgs, NodeContinuation cont, |
1015 | | const RecodeNode *prev, |
1016 | 23.3M | RecodeBeam *step) { |
1017 | 23.3M | if (use_dawgs) { |
1018 | 9.02M | if (cert > worst_dict_cert) { |
1019 | 8.19M | ContinueDawg(code, unichar_id, cert, cont, prev, step); |
1020 | 8.19M | } |
1021 | 14.3M | } else { |
1022 | 14.3M | RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)]; |
1023 | 14.3M | PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false, |
1024 | 14.3M | false, false, false, cert * dict_ratio, prev, nullptr, |
1025 | 14.3M | nodawg_heap); |
1026 | 14.3M | if (dict_ != nullptr && |
1027 | 14.3M | ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) || |
1028 | 14.3M | !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) { |
1029 | | // Any top choice position that can start a new word, ie a space or |
1030 | | // any non-space-delimited character, should also be considered |
1031 | | // by the dawg search, so push initial dawg to the dawg heap. |
1032 | 1.20M | float dawg_cert = cert; |
1033 | 1.20M | PermuterType permuter = TOP_CHOICE_PERM; |
1034 | | // Since we use the space either side of a dictionary word in the |
1035 | | // certainty of the word, (to properly handle weak spaces) and the |
1036 | | // space is coming from a non-dict word, we need special conditions |
1037 | | // to avoid degrading the certainty of the dict word that follows. |
1038 | | // With a space we don't multiply the certainty by dict_ratio, and we |
1039 | | // flag the space with NO_PERM to indicate that we should not use the |
1040 | | // predecessor nulls to generate the confidence for the space, as they |
1041 | | // have already been multiplied by dict_ratio, and we can't go back to |
1042 | | // insert more entries in any previous heaps. |
1043 | 1.20M | if (unichar_id == UNICHAR_SPACE) { |
1044 | 1.20M | permuter = NO_PERM; |
1045 | 1.20M | } else { |
1046 | 0 | dawg_cert *= dict_ratio; |
1047 | 0 | } |
1048 | 1.20M | PushInitialDawgIfBetter(code, unichar_id, permuter, false, false, |
1049 | 1.20M | dawg_cert, cont, prev, step); |
1050 | 1.20M | } |
1051 | 14.3M | } |
1052 | 23.3M | } |
1053 | | |
1054 | | // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev, |
1055 | | // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id |
1056 | | // is a valid continuation of whatever is in prev. |
1057 | | void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert, |
1058 | | NodeContinuation cont, |
1059 | 8.19M | const RecodeNode *prev, RecodeBeam *step) { |
1060 | 8.19M | RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)]; |
1061 | 8.19M | RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)]; |
1062 | 8.19M | if (unichar_id == INVALID_UNICHAR_ID) { |
1063 | 1.09M | PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false, |
1064 | 1.09M | false, false, cert, prev, nullptr, dawg_heap); |
1065 | 1.09M | return; |
1066 | 1.09M | } |
1067 | | // Avoid dictionary probe if score a total loss. |
1068 | 7.10M | float score = cert; |
1069 | 7.10M | if (prev != nullptr) { |
1070 | 6.99M | score += prev->score; |
1071 | 6.99M | } |
1072 | 7.10M | if (dawg_heap->size() >= kBeamWidths[0] && |
1073 | 7.10M | score <= dawg_heap->PeekTop().data().score && |
1074 | 7.10M | nodawg_heap->size() >= kBeamWidths[0] && |
1075 | 7.10M | score <= nodawg_heap->PeekTop().data().score) { |
1076 | 10.5k | return; |
1077 | 10.5k | } |
1078 | 7.08M | const RecodeNode *uni_prev = prev; |
1079 | | // Prev may be a partial code, null_char, or duplicate, so scan back to the |
1080 | | // last valid unichar_id. |
1081 | 105M | while (uni_prev != nullptr && |
1082 | 105M | (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) { |
1083 | 97.9M | uni_prev = uni_prev->prev; |
1084 | 97.9M | } |
1085 | 7.08M | if (unichar_id == UNICHAR_SPACE) { |
1086 | 1.18M | if (uni_prev != nullptr && uni_prev->end_of_word) { |
1087 | | // Space is good. Push initial state, to the dawg beam and a regular |
1088 | | // space to the top choice beam. |
1089 | 625k | PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false, |
1090 | 625k | false, cert, cont, prev, step); |
1091 | 625k | PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter, |
1092 | 625k | false, false, false, false, cert, prev, nullptr, |
1093 | 625k | nodawg_heap); |
1094 | 625k | } |
1095 | 1.18M | return; |
1096 | 5.90M | } else if (uni_prev != nullptr && uni_prev->start_of_dawg && |
1097 | 5.90M | uni_prev->unichar_id != UNICHAR_SPACE && |
1098 | 5.90M | dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) && |
1099 | 5.90M | dict_->getUnicharset().IsSpaceDelimited(unichar_id)) { |
1100 | 0 | return; // Can't break words between space delimited chars. |
1101 | 0 | } |
1102 | 5.90M | DawgPositionVector initial_dawgs; |
1103 | 5.90M | auto *updated_dawgs = new DawgPositionVector; |
1104 | 5.90M | DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM); |
1105 | 5.90M | bool word_start = false; |
1106 | 5.90M | if (uni_prev == nullptr) { |
1107 | | // Starting from beginning of line. |
1108 | 1.32M | dict_->default_dawgs(&initial_dawgs, false); |
1109 | 1.32M | word_start = true; |
1110 | 4.57M | } else if (uni_prev->dawgs != nullptr) { |
1111 | | // Continuing a previous dict word. |
1112 | 4.57M | dawg_args.active_dawgs = uni_prev->dawgs; |
1113 | 4.57M | word_start = uni_prev->start_of_dawg; |
1114 | 4.57M | } else { |
1115 | 0 | return; // Can't continue if not a dict word. |
1116 | 0 | } |
1117 | 5.90M | auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay( |
1118 | 5.90M | &dawg_args, dict_->getUnicharset(), unichar_id, false)); |
1119 | 5.90M | if (permuter != NO_PERM) { |
1120 | 1.99M | PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false, |
1121 | 1.99M | word_start, dawg_args.valid_end, false, cert, prev, |
1122 | 1.99M | dawg_args.updated_dawgs, dawg_heap); |
1123 | 1.99M | if (dawg_args.valid_end && !space_delimited_) { |
1124 | | // We can start another word right away, so push initial state as well, |
1125 | | // to the dawg beam, and the regular character to the top choice beam, |
1126 | | // since non-dict words can start here too. |
1127 | 0 | PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true, |
1128 | 0 | cert, cont, prev, step); |
1129 | 0 | PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false, |
1130 | 0 | word_start, true, false, cert, prev, nullptr, |
1131 | 0 | nodawg_heap); |
1132 | 0 | } |
1133 | 3.90M | } else { |
1134 | 3.90M | delete updated_dawgs; |
1135 | 3.90M | } |
1136 | 5.90M | } |
1137 | | |
1138 | | // Adds a RecodeNode composed of the tuple (code, unichar_id, |
1139 | | // initial-dawg-state, prev, cert) to the given heap if/ there is room or if |
1140 | | // better than the current worst element if already full. |
1141 | | void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id, |
1142 | | PermuterType permuter, |
1143 | | bool start, bool end, float cert, |
1144 | | NodeContinuation cont, |
1145 | | const RecodeNode *prev, |
1146 | 1.83M | RecodeBeam *step) { |
1147 | 1.83M | RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont]; |
1148 | 1.83M | float score = cert; |
1149 | 1.83M | if (prev != nullptr) { |
1150 | 1.82M | score += prev->score; |
1151 | 1.82M | } |
1152 | 1.83M | if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) { |
1153 | 806k | auto *initial_dawgs = new DawgPositionVector; |
1154 | 806k | dict_->default_dawgs(initial_dawgs, false); |
1155 | 806k | RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert, |
1156 | 806k | score, prev, initial_dawgs, |
1157 | 806k | ComputeCodeHash(code, false, prev)); |
1158 | 806k | *best_initial_dawg = node; |
1159 | 806k | } |
1160 | 1.83M | } |
1161 | | |
1162 | | // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, |
1163 | | // false, false, false, false, cert, prev, nullptr) to heap if there is room |
1164 | | // or if better than the current worst element if already full. |
1165 | | /* static */ |
1166 | | void RecodeBeamSearch::PushDupOrNoDawgIfBetter( |
1167 | | int length, bool dup, int code, int unichar_id, float cert, |
1168 | | float worst_dict_cert, float dict_ratio, bool use_dawgs, |
1169 | 12.9M | NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) { |
1170 | 12.9M | int index = BeamIndex(use_dawgs, cont, length); |
1171 | 12.9M | if (use_dawgs) { |
1172 | 4.45M | if (cert > worst_dict_cert) { |
1173 | 4.26M | PushHeapIfBetter(kBeamWidths[length], code, unichar_id, |
1174 | 4.26M | prev ? prev->permuter : NO_PERM, false, false, false, |
1175 | 4.26M | dup, cert, prev, nullptr, &step->beams_[index]); |
1176 | 4.26M | } |
1177 | 8.49M | } else { |
1178 | 8.49M | cert *= dict_ratio; |
1179 | 8.49M | if (cert >= kMinCertainty || code == null_char_) { |
1180 | 8.43M | PushHeapIfBetter(kBeamWidths[length], code, unichar_id, |
1181 | 8.43M | prev ? prev->permuter : TOP_CHOICE_PERM, false, false, |
1182 | 8.43M | false, dup, cert, prev, nullptr, &step->beams_[index]); |
1183 | 8.43M | } |
1184 | 8.49M | } |
1185 | 12.9M | } |
1186 | | |
1187 | | // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, |
1188 | | // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room |
1189 | | // or if better than the current worst element if already full. |
1190 | | void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id, |
1191 | | PermuterType permuter, bool dawg_start, |
1192 | | bool word_start, bool end, bool dup, |
1193 | | float cert, const RecodeNode *prev, |
1194 | | DawgPositionVector *d, |
1195 | 30.7M | RecodeHeap *heap) { |
1196 | 30.7M | float score = cert; |
1197 | 30.7M | if (prev != nullptr) { |
1198 | 30.2M | score += prev->score; |
1199 | 30.2M | } |
1200 | 30.7M | if (heap->size() < max_size || score > heap->PeekTop().data().score) { |
1201 | 24.3M | uint64_t hash = ComputeCodeHash(code, dup, prev); |
1202 | 24.3M | RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end, |
1203 | 24.3M | dup, cert, score, prev, d, hash); |
1204 | 24.3M | if (UpdateHeapIfMatched(&node, heap)) { |
1205 | 4.22M | return; |
1206 | 4.22M | } |
1207 | 20.1M | RecodePair entry(score, node); |
1208 | 20.1M | heap->Push(&entry); |
1209 | 20.1M | ASSERT_HOST(entry.data().dawgs == nullptr); |
1210 | 20.1M | if (heap->size() > max_size) { |
1211 | 2.30M | heap->Pop(&entry); |
1212 | 2.30M | } |
1213 | 20.1M | } else { |
1214 | 6.42M | delete d; |
1215 | 6.42M | } |
1216 | 30.7M | } |
1217 | | |
1218 | | // Adds a RecodeNode to heap if there is room |
1219 | | // or if better than the current worst element if already full. |
1220 | | void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node, |
1221 | 332k | RecodeHeap *heap) { |
1222 | 332k | if (heap->size() < max_size || node->score > heap->PeekTop().data().score) { |
1223 | 275k | if (UpdateHeapIfMatched(node, heap)) { |
1224 | 0 | return; |
1225 | 0 | } |
1226 | 275k | RecodePair entry(node->score, *node); |
1227 | 275k | heap->Push(&entry); |
1228 | 275k | ASSERT_HOST(entry.data().dawgs == nullptr); |
1229 | 275k | if (heap->size() > max_size) { |
1230 | 26.5k | heap->Pop(&entry); |
1231 | 26.5k | } |
1232 | 275k | } |
1233 | 332k | } |
1234 | | |
1235 | | // Searches the heap for a matching entry, and updates the score with |
1236 | | // reshuffle if needed. Returns true if there was a match. |
1237 | | bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node, |
1238 | 24.6M | RecodeHeap *heap) { |
1239 | | // TODO(rays) consider hash map instead of linear search. |
1240 | | // It might not be faster because the hash map would have to be updated |
1241 | | // every time a heap reshuffle happens, and that would be a lot of overhead. |
1242 | 24.6M | std::vector<RecodePair> &nodes = heap->heap(); |
1243 | 50.6M | for (auto &i : nodes) { |
1244 | 50.6M | RecodeNode &node = i.data(); |
1245 | 50.6M | if (node.code == new_node->code && node.code_hash == new_node->code_hash && |
1246 | 50.6M | node.permuter == new_node->permuter && |
1247 | 50.6M | node.start_of_dawg == new_node->start_of_dawg) { |
1248 | 4.22M | if (new_node->score > node.score) { |
1249 | | // The new one is better. Update the entire node in the heap and |
1250 | | // reshuffle. |
1251 | 2.31M | node = *new_node; |
1252 | 2.31M | i.key() = node.score; |
1253 | 2.31M | heap->Reshuffle(&i); |
1254 | 2.31M | } |
1255 | 4.22M | return true; |
1256 | 4.22M | } |
1257 | 50.6M | } |
1258 | 20.3M | return false; |
1259 | 24.6M | } |
1260 | | |
1261 | | // Computes and returns the code-hash for the given code and prev. |
1262 | | uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup, |
1263 | 25.1M | const RecodeNode *prev) const { |
1264 | 25.1M | uint64_t hash = prev == nullptr ? 0 : prev->code_hash; |
1265 | 25.1M | if (!dup && code != null_char_) { |
1266 | 11.6M | int num_classes = recoder_.code_range(); |
1267 | 11.6M | uint64_t carry = (((hash >> 32) * num_classes) >> 32); |
1268 | 11.6M | hash *= num_classes; |
1269 | 11.6M | hash += carry; |
1270 | 11.6M | hash += code; |
1271 | 11.6M | } |
1272 | 25.1M | return hash; |
1273 | 25.1M | } |
1274 | | |
1275 | | // Backtracks to extract the best path through the lattice that was built |
1276 | | // during Decode. On return the best_nodes vector essentially contains the set |
1277 | | // of code, score pairs that make the optimal path with the constraint that |
1278 | | // the recoder can decode the code sequence back to a sequence of unichar-ids. |
1279 | | void RecodeBeamSearch::ExtractBestPaths( |
1280 | | std::vector<const RecodeNode *> *best_nodes, |
1281 | 96.1k | std::vector<const RecodeNode *> *second_nodes) const { |
1282 | | // Scan both beams to extract the best and second best paths. |
1283 | 96.1k | const RecodeNode *best_node = nullptr; |
1284 | 96.1k | const RecodeNode *second_best_node = nullptr; |
1285 | 96.1k | const RecodeBeam *last_beam = beam_[beam_size_ - 1]; |
1286 | 384k | for (int c = 0; c < NC_COUNT; ++c) { |
1287 | 288k | if (c == NC_ONLY_DUP) { |
1288 | 96.1k | continue; |
1289 | 96.1k | } |
1290 | 192k | auto cont = static_cast<NodeContinuation>(c); |
1291 | 576k | for (int is_dawg = 0; is_dawg < 2; ++is_dawg) { |
1292 | 384k | int beam_index = BeamIndex(is_dawg, cont, 0); |
1293 | 384k | int heap_size = last_beam->beams_[beam_index].size(); |
1294 | 1.17M | for (int h = 0; h < heap_size; ++h) { |
1295 | 791k | const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data(); |
1296 | 791k | if (is_dawg) { |
1297 | | // dawg_node may be a null_char, or duplicate, so scan back to the |
1298 | | // last valid unichar_id. |
1299 | 291k | const RecodeNode *dawg_node = node; |
1300 | 2.22M | while (dawg_node != nullptr && |
1301 | 2.22M | (dawg_node->unichar_id == INVALID_UNICHAR_ID || |
1302 | 2.16M | dawg_node->duplicate)) { |
1303 | 1.93M | dawg_node = dawg_node->prev; |
1304 | 1.93M | } |
1305 | 291k | if (dawg_node == nullptr || |
1306 | 291k | (!dawg_node->end_of_word && |
1307 | 230k | dawg_node->unichar_id != UNICHAR_SPACE)) { |
1308 | | // Dawg node is not valid. |
1309 | 85.8k | continue; |
1310 | 85.8k | } |
1311 | 291k | } |
1312 | 706k | if (best_node == nullptr || node->score > best_node->score) { |
1313 | 493k | second_best_node = best_node; |
1314 | 493k | best_node = node; |
1315 | 493k | } else if (second_best_node == nullptr || |
1316 | 212k | node->score > second_best_node->score) { |
1317 | 137k | second_best_node = node; |
1318 | 137k | } |
1319 | 706k | } |
1320 | 384k | } |
1321 | 192k | } |
1322 | 96.1k | if (second_nodes != nullptr) { |
1323 | 96.1k | ExtractPath(second_best_node, second_nodes); |
1324 | 96.1k | } |
1325 | 96.1k | ExtractPath(best_node, best_nodes); |
1326 | 96.1k | } |
1327 | | |
1328 | | // Helper backtracks through the lattice from the given node, storing the |
1329 | | // path and reversing it. |
1330 | | void RecodeBeamSearch::ExtractPath( |
1331 | 192k | const RecodeNode *node, std::vector<const RecodeNode *> *path) const { |
1332 | 192k | path->clear(); |
1333 | 2.86M | while (node != nullptr) { |
1334 | 2.67M | path->push_back(node); |
1335 | 2.67M | node = node->prev; |
1336 | 2.67M | } |
1337 | 192k | std::reverse(path->begin(), path->end()); |
1338 | 192k | } |
1339 | | |
1340 | | void RecodeBeamSearch::ExtractPath(const RecodeNode *node, |
1341 | | std::vector<const RecodeNode *> *path, |
1342 | 0 | int limiter) const { |
1343 | 0 | int pathcounter = 0; |
1344 | 0 | path->clear(); |
1345 | 0 | while (node != nullptr && pathcounter < limiter) { |
1346 | 0 | path->push_back(node); |
1347 | 0 | node = node->prev; |
1348 | 0 | ++pathcounter; |
1349 | 0 | } |
1350 | 0 | std::reverse(path->begin(), path->end()); |
1351 | 0 | } |
1352 | | |
1353 | | // Helper prints debug information on the given lattice path. |
1354 | | void RecodeBeamSearch::DebugPath( |
1355 | | const UNICHARSET *unicharset, |
1356 | 0 | const std::vector<const RecodeNode *> &path) const { |
1357 | 0 | for (unsigned c = 0; c < path.size(); ++c) { |
1358 | 0 | const RecodeNode &node = *path[c]; |
1359 | 0 | tprintf("%u ", c); |
1360 | 0 | node.Print(null_char_, *unicharset, 1); |
1361 | 0 | } |
1362 | 0 | } |
1363 | | |
1364 | | // Helper prints debug information on the given unichar path. |
1365 | | void RecodeBeamSearch::DebugUnicharPath( |
1366 | | const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path, |
1367 | | const std::vector<int> &unichar_ids, const std::vector<float> &certs, |
1368 | 0 | const std::vector<float> &ratings, const std::vector<int> &xcoords) const { |
1369 | 0 | auto num_ids = unichar_ids.size(); |
1370 | 0 | double total_rating = 0.0; |
1371 | 0 | for (unsigned c = 0; c < num_ids; ++c) { |
1372 | 0 | int coord = xcoords[c]; |
1373 | 0 | tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c], |
1374 | 0 | unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c], |
1375 | 0 | path[coord]->start_of_word, path[coord]->end_of_word, |
1376 | 0 | path[coord]->permuter); |
1377 | 0 | total_rating += ratings[c]; |
1378 | 0 | } |
1379 | 0 | tprintf("Path total rating = %g\n", total_rating); |
1380 | 0 | } |
1381 | | |
1382 | | } // namespace tesseract. |