Coverage Report

Created: 2025-06-13 07:15

/src/tesseract/src/lstm/recodebeam.cpp
Line
Count
Source (jump to first uncovered line)
1
///////////////////////////////////////////////////////////////////////
2
// File:        recodebeam.cpp
3
// Description: Beam search to decode from the re-encoded CJK as a sequence of
4
//              smaller numbers in place of a single large code.
5
// Author:      Ray Smith
6
//
7
// (C) Copyright 2015, Google Inc.
8
// Licensed under the Apache License, Version 2.0 (the "License");
9
// you may not use this file except in compliance with the License.
10
// You may obtain a copy of the License at
11
// http://www.apache.org/licenses/LICENSE-2.0
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
17
//
18
///////////////////////////////////////////////////////////////////////
19
20
#include "recodebeam.h"
21
22
#include "networkio.h"
23
#include "pageres.h"
24
#include "unicharcompress.h"
25
26
#include <algorithm> // for std::reverse
27
28
namespace tesseract {
29
30
// The beam width at each code position.
31
const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
32
    5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
33
};
34
35
static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};
36
37
// Prints debug details of the node.
38
void RecodeNode::Print(int null_char, const UNICHARSET &unicharset,
39
0
                       int depth) const {
40
0
  if (code == null_char) {
41
0
    tprintf("null_char");
42
0
  } else {
43
0
    tprintf("label=%d, uid=%d=%s", code, unichar_id,
44
0
            unicharset.debug_str(unichar_id).c_str());
45
0
  }
46
0
  tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty,
47
0
          start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "",
48
0
          end_of_word ? " End" : "", permuter, code_hash);
49
0
  if (depth > 0 && prev != nullptr) {
50
0
    tprintf(" prev:");
51
0
    prev->Print(null_char, unicharset, depth - 1);
52
0
  } else {
53
0
    tprintf("\n");
54
0
  }
55
0
}
56
57
// Borrows the pointer, which is expected to survive until *this is deleted.
58
RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder,
59
                                   int null_char, bool simple_text, Dict *dict)
60
1
    : recoder_(recoder),
61
1
      beam_size_(0),
62
1
      top_code_(-1),
63
1
      second_code_(-1),
64
1
      dict_(dict),
65
1
      space_delimited_(true),
66
1
      is_simple_text_(simple_text),
67
1
      null_char_(null_char) {
68
1
  if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) {
69
0
    space_delimited_ = false;
70
0
  }
71
1
}
72
73
0
RecodeBeamSearch::~RecodeBeamSearch() {
74
0
  for (auto data : beam_) {
75
0
    delete data;
76
0
  }
77
0
  for (auto data : secondary_beam_) {
78
0
    delete data;
79
0
  }
80
0
}
81
82
// Decodes the set of network outputs, storing the lattice internally.
83
void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio,
84
                              double cert_offset, double worst_dict_cert,
85
96.1k
                              const UNICHARSET *charset, int lstm_choice_mode) {
86
96.1k
  beam_size_ = 0;
87
96.1k
  int width = output.Width();
88
96.1k
  if (lstm_choice_mode) {
89
0
    timesteps.clear();
90
0
  }
91
1.43M
  for (int t = 0; t < width; ++t) {
92
1.33M
    ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
93
1.33M
    DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
94
1.33M
               charset);
95
1.33M
    if (lstm_choice_mode) {
96
0
      SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
97
0
    }
98
1.33M
  }
99
96.1k
}
100
void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float> &output,
101
                              double dict_ratio, double cert_offset,
102
                              double worst_dict_cert,
103
0
                              const UNICHARSET *charset) {
104
0
  beam_size_ = 0;
105
0
  int width = output.dim1();
106
0
  for (int t = 0; t < width; ++t) {
107
0
    ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
108
0
    DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
109
0
  }
110
0
}
111
112
void RecodeBeamSearch::DecodeSecondaryBeams(
113
    const NetworkIO &output, double dict_ratio, double cert_offset,
114
0
    double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) {
115
0
  for (auto data : secondary_beam_) {
116
0
    delete data;
117
0
  }
118
0
  secondary_beam_.clear();
119
0
  if (character_boundaries_.size() < 2) {
120
0
    return;
121
0
  }
122
0
  int width = output.Width();
123
0
  unsigned bucketNumber = 0;
124
0
  for (int t = 0; t < width; ++t) {
125
0
    while ((bucketNumber + 1) < character_boundaries_.size() &&
126
0
           t >= character_boundaries_[bucketNumber + 1]) {
127
0
      ++bucketNumber;
128
0
    }
129
0
    ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t),
130
0
                   output.NumFeatures(), kBeamWidths[0]);
131
0
    DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset,
132
0
                        worst_dict_cert, charset);
133
0
  }
134
0
}
135
136
void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs,
137
                                              int num_outputs,
138
                                              const UNICHARSET *charset,
139
0
                                              int xCoord) {
140
0
  std::vector<std::pair<const char *, float>> choices;
141
0
  for (int i = 0; i < num_outputs; ++i) {
142
0
    if (outputs[i] >= 0.01f) {
143
0
      const char *character;
144
0
      if (i + 2 >= num_outputs) {
145
0
        character = "";
146
0
      } else if (i > 0) {
147
0
        character = charset->id_to_unichar_ext(i + 2);
148
0
      } else {
149
0
        character = charset->id_to_unichar_ext(i);
150
0
      }
151
0
      size_t pos = 0;
152
      // order the possible choices within one timestep
153
      // beginning with the most likely
154
0
      while (choices.size() > pos && choices[pos].second > outputs[i]) {
155
0
        pos++;
156
0
      }
157
0
      choices.insert(choices.begin() + pos,
158
0
                     std::pair<const char *, float>(character, outputs[i]));
159
0
    }
160
0
  }
161
0
  timesteps.push_back(choices);
162
0
}
163
164
0
void RecodeBeamSearch::segmentTimestepsByCharacters() {
165
0
  for (unsigned i = 1; i < character_boundaries_.size(); ++i) {
166
0
    std::vector<std::vector<std::pair<const char *, float>>> segment;
167
0
    for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i];
168
0
         ++j) {
169
0
      segment.push_back(timesteps[j]);
170
0
    }
171
0
    segmentedTimesteps.push_back(segment);
172
0
  }
173
0
}
174
std::vector<std::vector<std::pair<const char *, float>>>
175
RecodeBeamSearch::combineSegmentedTimesteps(
176
    std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
177
0
        *segmentedTimesteps) {
178
0
  std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
179
0
  for (auto &segmentedTimestep : *segmentedTimesteps) {
180
0
    for (auto &j : segmentedTimestep) {
181
0
      combined_timesteps.push_back(j);
182
0
    }
183
0
  }
184
0
  return combined_timesteps;
185
0
}
186
187
void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
188
                                               std::vector<int> *ends,
189
                                               std::vector<int> *char_bounds_,
190
96.1k
                                               int maxWidth) {
191
96.1k
  char_bounds_->push_back(0);
192
238k
  for (unsigned i = 0; i < ends->size(); ++i) {
193
141k
    int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
194
141k
    char_bounds_->push_back((*ends)[i] + middle);
195
141k
  }
196
96.1k
  char_bounds_->pop_back();
197
96.1k
  char_bounds_->push_back(maxWidth);
198
96.1k
}
199
200
// Returns the best path as labels/scores/xcoords similar to simple CTC.
201
void RecodeBeamSearch::ExtractBestPathAsLabels(
202
0
    std::vector<int> *labels, std::vector<int> *xcoords) const {
203
0
  labels->clear();
204
0
  xcoords->clear();
205
0
  std::vector<const RecodeNode *> best_nodes;
206
0
  ExtractBestPaths(&best_nodes, nullptr);
207
  // Now just run CTC on the best nodes.
208
0
  int t = 0;
209
0
  int width = best_nodes.size();
210
0
  while (t < width) {
211
0
    int label = best_nodes[t]->code;
212
0
    if (label != null_char_) {
213
0
      labels->push_back(label);
214
0
      xcoords->push_back(t);
215
0
    }
216
0
    while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
217
0
    }
218
0
  }
219
0
  xcoords->push_back(width);
220
0
}
221
222
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
223
// duplicates, nulls and intermediate parts.
224
void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
225
    bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
226
    std::vector<float> *certs, std::vector<float> *ratings,
227
0
    std::vector<int> *xcoords) const {
228
0
  std::vector<const RecodeNode *> best_nodes;
229
0
  ExtractBestPaths(&best_nodes, nullptr);
230
0
  ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
231
0
  if (debug) {
232
0
    DebugPath(unicharset, best_nodes);
233
0
    DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
234
0
                     *xcoords);
235
0
  }
236
0
}
237
238
// Returns the best path as a set of WERD_RES.
239
void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box,
240
                                              float scale_factor, bool debug,
241
                                              const UNICHARSET *unicharset,
242
                                              PointerVector<WERD_RES> *words,
243
96.1k
                                              int lstm_choice_mode) {
244
96.1k
  words->truncate(0);
245
96.1k
  std::vector<int> unichar_ids;
246
96.1k
  std::vector<float> certs;
247
96.1k
  std::vector<float> ratings;
248
96.1k
  std::vector<int> xcoords;
249
96.1k
  std::vector<const RecodeNode *> best_nodes;
250
96.1k
  std::vector<const RecodeNode *> second_nodes;
251
96.1k
  character_boundaries_.clear();
252
96.1k
  ExtractBestPaths(&best_nodes, &second_nodes);
253
96.1k
  if (debug) {
254
0
    DebugPath(unicharset, best_nodes);
255
0
    ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
256
0
                            &xcoords);
257
0
    tprintf("\nSecond choice path:\n");
258
0
    DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
259
0
                     xcoords);
260
0
  }
261
  // If lstm choice mode is required in granularity level 2, it stores the x
262
  // Coordinates of every chosen character, to match the alternative choices to
263
  // it.
264
96.1k
  ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
265
96.1k
                          &character_boundaries_);
266
96.1k
  int num_ids = unichar_ids.size();
267
96.1k
  if (debug) {
268
0
    DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
269
0
                     xcoords);
270
0
  }
271
  // Convert labels to unichar-ids.
272
96.1k
  int word_end = 0;
273
96.1k
  float prev_space_cert = 0.0f;
274
186k
  for (int word_start = 0; word_start < num_ids; word_start = word_end) {
275
127k
    for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
276
      // A word is terminated when a space character or start_of_word flag is
277
      // hit. We also want to force a separate word for every non
278
      // space-delimited character when not in a dictionary context.
279
51.1k
      if (unichar_ids[word_end] == UNICHAR_SPACE) {
280
14.5k
        break;
281
14.5k
      }
282
36.5k
      int index = xcoords[word_end];
283
36.5k
      if (best_nodes[index]->start_of_word) {
284
44
        break;
285
44
      }
286
36.5k
      if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
287
36.5k
          (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
288
21.5k
           !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) {
289
0
        break;
290
0
      }
291
36.5k
    }
292
90.7k
    float space_cert = 0.0f;
293
90.7k
    if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
294
14.5k
      space_cert = certs[word_end];
295
14.5k
    }
296
90.7k
    bool leading_space =
297
90.7k
        word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
298
    // Create a WERD_RES for the output word.
299
90.7k
    WERD_RES *word_res =
300
90.7k
        InitializeWord(leading_space, line_box, word_start, word_end,
301
90.7k
                       std::min(space_cert, prev_space_cert), unicharset,
302
90.7k
                       xcoords, scale_factor);
303
218k
    for (int i = word_start; i < word_end; ++i) {
304
127k
      auto *choices = new BLOB_CHOICE_LIST;
305
127k
      BLOB_CHOICE_IT bc_it(choices);
306
127k
      auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1,
307
127k
                                     1.0f, static_cast<float>(INT16_MAX), 0.0f,
308
127k
                                     BCC_STATIC_CLASSIFIER);
309
127k
      int col = i - word_start;
310
127k
      choice->set_matrix_cell(col, col);
311
127k
      bc_it.add_after_then_move(choice);
312
127k
      word_res->ratings->put(col, col, choices);
313
127k
    }
314
90.7k
    int index = xcoords[word_end - 1];
315
90.7k
    word_res->FakeWordFromRatings(best_nodes[index]->permuter);
316
90.7k
    words->push_back(word_res);
317
90.7k
    prev_space_cert = space_cert;
318
90.7k
    if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
319
14.5k
      ++word_end;
320
14.5k
    }
321
90.7k
  }
322
96.1k
}
323
324
struct greater_than {
325
0
  inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const {
326
0
    return (node1->score > node2->score);
327
0
  }
328
};
329
330
void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs,
331
                                  const UNICHARSET *charset,
332
0
                                  bool secondary) const {
333
0
  std::vector<std::vector<const RecodeNode *>> topology;
334
0
  std::unordered_set<const RecodeNode *> visited;
335
0
  const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
336
  // create the topology
337
0
  for (int step = beam.size() - 1; step >= 0; --step) {
338
0
    std::vector<const RecodeNode *> layer;
339
0
    topology.push_back(layer);
340
0
  }
341
  // fill the topology with depths first
342
0
  for (int step = beam.size() - 1; step >= 0; --step) {
343
0
    std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
344
0
    for (auto &&node : heaps) {
345
0
      int backtracker = 0;
346
0
      const RecodeNode *curr = &node.data();
347
0
      while (curr != nullptr && !visited.count(curr)) {
348
0
        visited.insert(curr);
349
0
        topology[step - backtracker].push_back(curr);
350
0
        curr = curr->prev;
351
0
        ++backtracker;
352
0
      }
353
0
    }
354
0
  }
355
0
  int ct = 0;
356
0
  unsigned cb = 1;
357
0
  for (const std::vector<const RecodeNode *> &layer : topology) {
358
0
    if (cb >= character_boundaries_.size()) {
359
0
      break;
360
0
    }
361
0
    if (ct == character_boundaries_[cb]) {
362
0
      tprintf("***\n");
363
0
      ++cb;
364
0
    }
365
0
    for (const RecodeNode *node : layer) {
366
0
      const char *code;
367
0
      int intCode;
368
0
      if (node->unichar_id != INVALID_UNICHAR_ID) {
369
0
        code = charset->id_to_unichar(node->unichar_id);
370
0
        intCode = node->unichar_id;
371
0
      } else if (node->code == null_char_) {
372
0
        intCode = 0;
373
0
        code = " ";
374
0
      } else {
375
0
        intCode = 666;
376
0
        code = "*";
377
0
      }
378
0
      int intPrevCode = 0;
379
0
      const char *prevCode;
380
0
      float prevScore = 0;
381
0
      if (node->prev != nullptr) {
382
0
        prevScore = node->prev->score;
383
0
        if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
384
0
          prevCode = charset->id_to_unichar(node->prev->unichar_id);
385
0
          intPrevCode = node->prev->unichar_id;
386
0
        } else if (node->code == null_char_) {
387
0
          intPrevCode = 0;
388
0
          prevCode = " ";
389
0
        } else {
390
0
          prevCode = "*";
391
0
          intPrevCode = 666;
392
0
        }
393
0
      } else {
394
0
        prevCode = " ";
395
0
      }
396
0
      if (uids) {
397
0
        tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
398
0
                node->score);
399
0
      } else {
400
0
        tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
401
0
      }
402
0
    }
403
0
    tprintf("-\n");
404
0
    ++ct;
405
0
  }
406
0
  tprintf("***\n");
407
0
}
408
409
0
void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET *unicharset) {
410
0
  if (character_boundaries_.size() < 2) {
411
0
    return;
412
0
  }
413
  // For the first iteration the original beam is analyzed. After that a
414
  // new beam is calculated based on the results from the original beam.
415
0
  std::vector<RecodeBeam *> &currentBeam =
416
0
      secondary_beam_.empty() ? beam_ : secondary_beam_;
417
0
  character_boundaries_[0] = 0;
418
0
  for (unsigned j = 1; j < character_boundaries_.size(); ++j) {
419
0
    std::vector<int> unichar_ids;
420
0
    std::vector<float> certs;
421
0
    std::vector<float> ratings;
422
0
    std::vector<int> xcoords;
423
0
    int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
424
0
    std::vector<tesseract::RecodePair> &heaps =
425
0
        currentBeam.at(character_boundaries_[j] - 1)->beams_->heap();
426
0
    std::vector<const RecodeNode *> best_nodes;
427
0
    std::vector<const RecodeNode *> best;
428
    // Scan the segmented node chain for valid unichar ids.
429
0
    for (auto &&entry : heaps) {
430
0
      bool validChar = false;
431
0
      int backcounter = 0;
432
0
      const RecodeNode *node = &entry.data();
433
0
      while (node != nullptr && backcounter < backpath) {
434
0
        if (node->code != null_char_ &&
435
0
            node->unichar_id != INVALID_UNICHAR_ID) {
436
0
          validChar = true;
437
0
          break;
438
0
        }
439
0
        node = node->prev;
440
0
        ++backcounter;
441
0
      }
442
0
      if (validChar) {
443
0
        best.push_back(&entry.data());
444
0
      }
445
0
    }
446
    // find the best rated segmented node chain and extract the unichar id.
447
0
    if (!best.empty()) {
448
0
      std::sort(best.begin(), best.end(), greater_than());
449
0
      ExtractPath(best[0], &best_nodes, backpath);
450
0
      ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
451
0
                              &xcoords);
452
0
    }
453
0
    if (!unichar_ids.empty()) {
454
0
      int bestPos = 0;
455
0
      for (unsigned i = 1; i < unichar_ids.size(); ++i) {
456
0
        if (ratings[i] < ratings[bestPos]) {
457
0
          bestPos = i;
458
0
        }
459
0
      }
460
#if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60).
461
      int bestCode = -10;
462
      for (auto &node : best_nodes) {
463
        if (node->unichar_id == unichar_ids[bestPos]) {
464
          bestCode = node->code;
465
        }
466
      }
467
#endif
468
      // Exclude the best choice for the followup decoding.
469
0
      std::unordered_set<int> excludeCodeList;
470
0
      for (auto &best_node : best_nodes) {
471
0
        if (best_node->code != null_char_) {
472
0
          excludeCodeList.insert(best_node->code);
473
0
        }
474
0
      }
475
0
      if (j - 1 < excludedUnichars.size()) {
476
0
        for (auto elem : excludeCodeList) {
477
0
          excludedUnichars[j - 1].insert(elem);
478
0
        }
479
0
      } else {
480
0
        excludedUnichars.push_back(excludeCodeList);
481
0
      }
482
      // Save the best choice for the choice iterator.
483
0
      if (j - 1 < ctc_choices.size()) {
484
0
        int id = unichar_ids[bestPos];
485
0
        const char *result = unicharset->id_to_unichar_ext(id);
486
0
        float rating = ratings[bestPos];
487
0
        ctc_choices[j - 1].push_back(
488
0
            std::pair<const char *, float>(result, rating));
489
0
      } else {
490
0
        std::vector<std::pair<const char *, float>> choice;
491
0
        int id = unichar_ids[bestPos];
492
0
        const char *result = unicharset->id_to_unichar_ext(id);
493
0
        float rating = ratings[bestPos];
494
0
        choice.emplace_back(result, rating);
495
0
        ctc_choices.push_back(choice);
496
0
      }
497
      // fill the blank spot with an empty array
498
0
    } else {
499
0
      if (j - 1 >= excludedUnichars.size()) {
500
0
        std::unordered_set<int> excludeCodeList;
501
0
        excludedUnichars.push_back(excludeCodeList);
502
0
      }
503
0
      if (j - 1 >= ctc_choices.size()) {
504
0
        std::vector<std::pair<const char *, float>> choice;
505
0
        ctc_choices.push_back(choice);
506
0
      }
507
0
    }
508
0
  }
509
0
  for (auto data : secondary_beam_) {
510
0
    delete data;
511
0
  }
512
0
  secondary_beam_.clear();
513
0
}
514
515
// Generates debug output of the content of the beams after a Decode.
516
0
void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const {
517
0
  for (int p = 0; p < beam_size_; ++p) {
518
0
    for (int d = 0; d < 2; ++d) {
519
0
      for (int c = 0; c < NC_COUNT; ++c) {
520
0
        auto cont = static_cast<NodeContinuation>(c);
521
0
        int index = BeamIndex(d, cont, 0);
522
0
        if (beam_[p]->beams_[index].empty()) {
523
0
          continue;
524
0
        }
525
        // Print all the best scoring nodes for each unichar found.
526
0
        tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict",
527
0
                kNodeContNames[c]);
528
0
        DebugBeamPos(unicharset, beam_[p]->beams_[index]);
529
0
      }
530
0
    }
531
0
  }
532
0
}
533
534
// Generates debug output of the content of a single beam position.
535
void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset,
536
0
                                    const RecodeHeap &heap) const {
537
0
  std::vector<const RecodeNode *> unichar_bests(unicharset.size());
538
0
  const RecodeNode *null_best = nullptr;
539
0
  int heap_size = heap.size();
540
0
  for (int i = 0; i < heap_size; ++i) {
541
0
    const RecodeNode *node = &heap.get(i).data();
542
0
    if (node->unichar_id == INVALID_UNICHAR_ID) {
543
0
      if (null_best == nullptr || null_best->score < node->score) {
544
0
        null_best = node;
545
0
      }
546
0
    } else {
547
0
      if (unichar_bests[node->unichar_id] == nullptr ||
548
0
          unichar_bests[node->unichar_id]->score < node->score) {
549
0
        unichar_bests[node->unichar_id] = node;
550
0
      }
551
0
    }
552
0
  }
553
0
  for (auto &unichar_best : unichar_bests) {
554
0
    if (unichar_best != nullptr) {
555
0
      const RecodeNode &node = *unichar_best;
556
0
      node.Print(null_char_, unicharset, 1);
557
0
    }
558
0
  }
559
0
  if (null_best != nullptr) {
560
0
    null_best->Print(null_char_, unicharset, 1);
561
0
  }
562
0
}
563
564
// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
565
// duplicates, nulls and intermediate parts.
566
/* static */
567
void RecodeBeamSearch::ExtractPathAsUnicharIds(
568
    const std::vector<const RecodeNode *> &best_nodes,
569
    std::vector<int> *unichar_ids, std::vector<float> *certs,
570
    std::vector<float> *ratings, std::vector<int> *xcoords,
571
96.1k
    std::vector<int> *character_boundaries) {
572
96.1k
  unichar_ids->clear();
573
96.1k
  certs->clear();
574
96.1k
  ratings->clear();
575
96.1k
  xcoords->clear();
576
96.1k
  std::vector<int> starts;
577
96.1k
  std::vector<int> ends;
578
  // Backtrack extracting only valid, non-duplicate unichar-ids.
579
96.1k
  int t = 0;
580
96.1k
  int width = best_nodes.size();
581
317k
  while (t < width) {
582
221k
    double certainty = 0.0;
583
221k
    double rating = 0.0;
584
1.24M
    while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
585
1.02M
      double cert = best_nodes[t++]->certainty;
586
1.02M
      if (cert < certainty) {
587
386k
        certainty = cert;
588
386k
      }
589
1.02M
      rating -= cert;
590
1.02M
    }
591
221k
    starts.push_back(t);
592
221k
    if (t < width) {
593
141k
      int unichar_id = best_nodes[t]->unichar_id;
594
141k
      if (unichar_id == UNICHAR_SPACE && !certs->empty() &&
595
141k
          best_nodes[t]->permuter != NO_PERM) {
596
        // All the rating and certainty go on the previous character except
597
        // for the space itself.
598
12.4k
        if (certainty < certs->back()) {
599
3.47k
          certs->back() = certainty;
600
3.47k
        }
601
12.4k
        ratings->back() += rating;
602
12.4k
        certainty = 0.0;
603
12.4k
        rating = 0.0;
604
12.4k
      }
605
141k
      unichar_ids->push_back(unichar_id);
606
141k
      xcoords->push_back(t);
607
312k
      do {
608
312k
        double cert = best_nodes[t++]->certainty;
609
        // Special-case NO-PERM space to forget the certainty of the previous
610
        // nulls. See long comment in ContinueContext.
611
312k
        if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
612
173k
                                 best_nodes[t - 1]->permuter == NO_PERM)) {
613
173k
          certainty = cert;
614
173k
        }
615
312k
        rating -= cert;
616
312k
      } while (t < width && best_nodes[t]->duplicate);
617
141k
      ends.push_back(t);
618
141k
      certs->push_back(certainty);
619
141k
      ratings->push_back(rating);
620
141k
    } else if (!certs->empty()) {
621
62.8k
      if (certainty < certs->back()) {
622
13.0k
        certs->back() = certainty;
623
13.0k
      }
624
62.8k
      ratings->back() += rating;
625
62.8k
    }
626
221k
  }
627
96.1k
  starts.push_back(width);
628
96.1k
  if (character_boundaries != nullptr) {
629
96.1k
    calculateCharBoundaries(&starts, &ends, character_boundaries, width);
630
96.1k
  }
631
96.1k
  xcoords->push_back(width);
632
96.1k
}
633
634
// Sets up a word with the ratings matrix and fake blobs with boxes in the
635
// right places.
636
WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space,
637
                                           const TBOX &line_box, int word_start,
638
                                           int word_end, float space_certainty,
639
                                           const UNICHARSET *unicharset,
640
                                           const std::vector<int> &xcoords,
641
90.7k
                                           float scale_factor) {
642
  // Make a fake blob for each non-zero label.
643
90.7k
  C_BLOB_LIST blobs;
644
90.7k
  C_BLOB_IT b_it(&blobs);
645
218k
  for (int i = word_start; i < word_end; ++i) {
646
127k
    if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) {
647
127k
      TBOX box(static_cast<int16_t>(
648
127k
                   std::floor(character_boundaries_[i] * scale_factor)) +
649
127k
                   line_box.left(),
650
127k
               line_box.bottom(),
651
127k
               static_cast<int16_t>(
652
127k
                   std::ceil(character_boundaries_[i + 1] * scale_factor)) +
653
127k
                   line_box.left(),
654
127k
               line_box.top());
655
127k
      b_it.add_after_then_move(C_BLOB::FakeBlob(box));
656
127k
    }
657
127k
  }
658
  // Make a fake word from the blobs.
659
90.7k
  WERD *word = new WERD(&blobs, leading_space, nullptr);
660
  // Make a WERD_RES from the word.
661
90.7k
  auto *word_res = new WERD_RES(word);
662
90.7k
  word_res->end = word_end - word_start + leading_space;
663
90.7k
  word_res->uch_set = unicharset;
664
90.7k
  word_res->combination = true; // Give it ownership of the word.
665
90.7k
  word_res->space_certainty = space_certainty;
666
90.7k
  word_res->ratings = new MATRIX(word_end - word_start, 1);
667
90.7k
  return word_res;
668
90.7k
}
669
670
// Fills top_n_flags_ with bools that are true iff the corresponding output
671
// is one of the top_n.
672
void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs,
673
1.33M
                                   int top_n) {
674
1.33M
  top_n_flags_.clear();
675
1.33M
  top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
676
1.33M
  top_code_ = -1;
677
1.33M
  second_code_ = -1;
678
1.33M
  top_heap_.clear();
679
149M
  for (int i = 0; i < num_outputs; ++i) {
680
148M
    if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) {
681
23.3M
      TopPair entry(outputs[i], i);
682
23.3M
      top_heap_.Push(&entry);
683
23.3M
      if (top_heap_.size() > top_n) {
684
16.6M
        top_heap_.Pop(&entry);
685
16.6M
      }
686
23.3M
    }
687
148M
  }
688
8.02M
  while (!top_heap_.empty()) {
689
6.68M
    TopPair entry;
690
6.68M
    top_heap_.Pop(&entry);
691
6.68M
    if (top_heap_.size() > 1) {
692
4.01M
      top_n_flags_[entry.data()] = TN_TOPN;
693
4.01M
    } else {
694
2.67M
      top_n_flags_[entry.data()] = TN_TOP2;
695
2.67M
      if (top_heap_.empty()) {
696
1.33M
        top_code_ = entry.data();
697
1.33M
      } else {
698
1.33M
        second_code_ = entry.data();
699
1.33M
      }
700
2.67M
    }
701
6.68M
  }
702
1.33M
  top_n_flags_[null_char_] = TN_TOP2;
703
1.33M
}
704
705
void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
706
                                      const float *outputs, int num_outputs,
707
0
                                      int top_n) {
708
0
  top_n_flags_.clear();
709
0
  top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
710
0
  top_code_ = -1;
711
0
  second_code_ = -1;
712
0
  top_heap_.clear();
713
0
  for (int i = 0; i < num_outputs; ++i) {
714
0
    if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) &&
715
0
        !exList->count(i)) {
716
0
      TopPair entry(outputs[i], i);
717
0
      top_heap_.Push(&entry);
718
0
      if (top_heap_.size() > top_n) {
719
0
        top_heap_.Pop(&entry);
720
0
      }
721
0
    }
722
0
  }
723
0
  while (!top_heap_.empty()) {
724
0
    TopPair entry;
725
0
    top_heap_.Pop(&entry);
726
0
    if (top_heap_.size() > 1) {
727
0
      top_n_flags_[entry.data()] = TN_TOPN;
728
0
    } else {
729
0
      top_n_flags_[entry.data()] = TN_TOP2;
730
0
      if (top_heap_.empty()) {
731
0
        top_code_ = entry.data();
732
0
      } else {
733
0
        second_code_ = entry.data();
734
0
      }
735
0
    }
736
0
  }
737
0
  top_n_flags_[null_char_] = TN_TOP2;
738
0
}
739
740
// Adds the computation for the current time-step to the beam. Call at each
741
// time-step in sequence from left to right. outputs is the activation vector
742
// for the current timestep.
743
void RecodeBeamSearch::DecodeStep(const float *outputs, int t,
744
                                  double dict_ratio, double cert_offset,
745
                                  double worst_dict_cert,
746
1.33M
                                  const UNICHARSET *charset, bool debug) {
747
1.33M
  if (t == static_cast<int>(beam_.size())) {
748
677
    beam_.push_back(new RecodeBeam);
749
677
  }
750
1.33M
  RecodeBeam *step = beam_[t];
751
1.33M
  beam_size_ = t + 1;
752
1.33M
  step->Clear();
753
1.33M
  if (t == 0) {
754
    // The first step can only use singles and initials.
755
96.1k
    ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
756
96.1k
                    charset, dict_ratio, cert_offset, worst_dict_cert, step);
757
96.1k
    if (dict_ != nullptr) {
758
96.1k
      ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
759
96.1k
                      TN_TOP2, charset, dict_ratio, cert_offset,
760
96.1k
                      worst_dict_cert, step);
761
96.1k
    }
762
1.24M
  } else {
763
1.24M
    RecodeBeam *prev = beam_[t - 1];
764
1.24M
    if (debug) {
765
0
      int beam_index = BeamIndex(true, NC_ANYTHING, 0);
766
0
      for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
767
0
        std::vector<const RecodeNode *> path;
768
0
        ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
769
0
        tprintf("Step %d: Dawg beam %d:\n", t, i);
770
0
        DebugPath(charset, path);
771
0
      }
772
0
      beam_index = BeamIndex(false, NC_ANYTHING, 0);
773
0
      for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
774
0
        std::vector<const RecodeNode *> path;
775
0
        ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
776
0
        tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
777
0
        DebugPath(charset, path);
778
0
      }
779
0
    }
780
1.24M
    int total_beam = 0;
781
    // Work through the scores by group (top-2, top-n, the rest) while the beam
782
    // is empty. This enables extending the context using only the top-n results
783
    // first, which may have an empty intersection with the valid codes, so we
784
    // fall back to the rest if the beam is empty.
785
2.48M
    for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
786
1.24M
      auto top_n = static_cast<TopNState>(tn);
787
75.7M
      for (int index = 0; index < kNumBeams; ++index) {
788
        // Working backwards through the heaps doesn't guarantee that we see the
789
        // best first, but it comes before a lot of the worst, so it is slightly
790
        // more efficient than going forwards.
791
91.2M
        for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
792
16.8M
          ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
793
16.8M
                          top_n, charset, dict_ratio, cert_offset,
794
16.8M
                          worst_dict_cert, step);
795
16.8M
        }
796
74.4M
      }
797
75.7M
      for (int index = 0; index < kNumBeams; ++index) {
798
74.4M
        if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
799
24.8M
          total_beam += step->beams_[index].size();
800
24.8M
        }
801
74.4M
      }
802
1.24M
    }
803
    // Special case for the best initial dawg. Push it on the heap if good
804
    // enough, but there is only one, so it doesn't blow up the beam.
805
4.96M
    for (int c = 0; c < NC_COUNT; ++c) {
806
3.72M
      if (step->best_initial_dawgs_[c].code >= 0) {
807
332k
        int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
808
332k
        RecodeHeap *dawg_heap = &step->beams_[index];
809
332k
        PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
810
332k
                         dawg_heap);
811
332k
      }
812
3.72M
    }
813
1.24M
  }
814
1.33M
}
815
816
void RecodeBeamSearch::DecodeSecondaryStep(
817
    const float *outputs, int t, double dict_ratio, double cert_offset,
818
0
    double worst_dict_cert, const UNICHARSET *charset, bool debug) {
819
0
  if (t == static_cast<int>(secondary_beam_.size())) {
820
0
    secondary_beam_.push_back(new RecodeBeam);
821
0
  }
822
0
  RecodeBeam *step = secondary_beam_[t];
823
0
  step->Clear();
824
0
  if (t == 0) {
825
    // The first step can only use singles and initials.
826
0
    ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
827
0
                    charset, dict_ratio, cert_offset, worst_dict_cert, step);
828
0
    if (dict_ != nullptr) {
829
0
      ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
830
0
                      TN_TOP2, charset, dict_ratio, cert_offset,
831
0
                      worst_dict_cert, step);
832
0
    }
833
0
  } else {
834
0
    RecodeBeam *prev = secondary_beam_[t - 1];
835
0
    if (debug) {
836
0
      int beam_index = BeamIndex(true, NC_ANYTHING, 0);
837
0
      for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
838
0
        std::vector<const RecodeNode *> path;
839
0
        ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
840
0
        tprintf("Step %d: Dawg beam %d:\n", t, i);
841
0
        DebugPath(charset, path);
842
0
      }
843
0
      beam_index = BeamIndex(false, NC_ANYTHING, 0);
844
0
      for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
845
0
        std::vector<const RecodeNode *> path;
846
0
        ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
847
0
        tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
848
0
        DebugPath(charset, path);
849
0
      }
850
0
    }
851
0
    int total_beam = 0;
852
    // Work through the scores by group (top-2, top-n, the rest) while the beam
853
    // is empty. This enables extending the context using only the top-n results
854
    // first, which may have an empty intersection with the valid codes, so we
855
    // fall back to the rest if the beam is empty.
856
0
    for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
857
0
      auto top_n = static_cast<TopNState>(tn);
858
0
      for (int index = 0; index < kNumBeams; ++index) {
859
        // Working backwards through the heaps doesn't guarantee that we see the
860
        // best first, but it comes before a lot of the worst, so it is slightly
861
        // more efficient than going forwards.
862
0
        for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
863
0
          ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
864
0
                          top_n, charset, dict_ratio, cert_offset,
865
0
                          worst_dict_cert, step);
866
0
        }
867
0
      }
868
0
      for (int index = 0; index < kNumBeams; ++index) {
869
0
        if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
870
0
          total_beam += step->beams_[index].size();
871
0
        }
872
0
      }
873
0
    }
874
    // Special case for the best initial dawg. Push it on the heap if good
875
    // enough, but there is only one, so it doesn't blow up the beam.
876
0
    for (int c = 0; c < NC_COUNT; ++c) {
877
0
      if (step->best_initial_dawgs_[c].code >= 0) {
878
0
        int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
879
0
        RecodeHeap *dawg_heap = &step->beams_[index];
880
0
        PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
881
0
                         dawg_heap);
882
0
      }
883
0
    }
884
0
  }
885
0
}
886
887
// Adds to the appropriate beams the legal (according to recoder)
888
// continuations of context prev, which is of the given length, using the
889
// given network outputs to provide scores to the choices. Uses only those
890
// choices for which top_n_flags[index] == top_n_flag.
891
void RecodeBeamSearch::ContinueContext(
892
    const RecodeNode *prev, int index, const float *outputs,
893
    TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio,
894
17.0M
    double cert_offset, double worst_dict_cert, RecodeBeam *step) {
895
17.0M
  RecodedCharID prefix;
896
17.0M
  RecodedCharID full_code;
897
17.0M
  const RecodeNode *previous = prev;
898
17.0M
  int length = LengthFromBeamsIndex(index);
899
17.0M
  bool use_dawgs = IsDawgFromBeamsIndex(index);
900
17.0M
  NodeContinuation prev_cont = ContinuationFromBeamsIndex(index);
901
17.0M
  for (int p = length - 1; p >= 0 && previous != nullptr; --p) {
902
0
    while (previous->duplicate || previous->code == null_char_) {
903
0
      previous = previous->prev;
904
0
    }
905
0
    prefix.Set(p, previous->code);
906
0
    full_code.Set(p, previous->code);
907
0
    previous = previous->prev;
908
0
  }
909
17.0M
  if (prev != nullptr && !is_simple_text_) {
910
16.8M
    if (top_n_flags_[prev->code] == top_n_flag) {
911
12.2M
      if (prev_cont != NC_NO_DUP) {
912
11.4M
        float cert =
913
11.4M
            NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
914
11.4M
        PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
915
11.4M
                                cert, worst_dict_cert, dict_ratio, use_dawgs,
916
11.4M
                                NC_ANYTHING, prev, step);
917
11.4M
      }
918
12.2M
      if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 &&
919
12.2M
          prev->code != null_char_) {
920
1.46M
        float cert = NetworkIO::ProbToCertainty(outputs[prev->code] +
921
1.46M
                                                outputs[null_char_]) +
922
1.46M
                     cert_offset;
923
1.46M
        PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
924
1.46M
                                cert, worst_dict_cert, dict_ratio, use_dawgs,
925
1.46M
                                NC_NO_DUP, prev, step);
926
1.46M
      }
927
12.2M
    }
928
16.8M
    if (prev_cont == NC_ONLY_DUP) {
929
5.76M
      return;
930
5.76M
    }
931
11.0M
    if (prev->code != null_char_ && length > 0 &&
932
11.0M
        top_n_flags_[null_char_] == top_n_flag) {
933
      // Allow nulls within multi code sequences, as the nulls within are not
934
      // explicitly included in the code sequence.
935
0
      float cert =
936
0
          NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
937
0
      PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID,
938
0
                              cert, worst_dict_cert, dict_ratio, use_dawgs,
939
0
                              NC_ANYTHING, prev, step);
940
0
    }
941
11.0M
  }
942
11.2M
  const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix);
943
11.2M
  if (final_codes != nullptr) {
944
1.24G
    for (int code : *final_codes) {
945
1.24G
      if (top_n_flags_[code] != top_n_flag) {
946
1.22G
        continue;
947
1.22G
      }
948
23.1M
      if (prev != nullptr && prev->code == code && !is_simple_text_) {
949
9.36M
        continue;
950
9.36M
      }
951
13.7M
      float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
952
13.7M
      if (cert < kMinCertainty && code != null_char_) {
953
228
        continue;
954
228
      }
955
13.7M
      full_code.Set(length, code);
956
13.7M
      int unichar_id = recoder_.DecodeUnichar(full_code);
957
      // Map the null char to INVALID.
958
13.7M
      if (length == 0 && code == null_char_) {
959
4.12M
        unichar_id = INVALID_UNICHAR_ID;
960
4.12M
      }
961
13.7M
      if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr &&
962
13.7M
          !charset->get_enabled(unichar_id)) {
963
0
        continue; // disabled by whitelist/blacklist
964
0
      }
965
13.7M
      ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
966
13.7M
                      use_dawgs, NC_ANYTHING, prev, step);
967
13.7M
      if (top_n_flag == TN_TOP2 && code != null_char_) {
968
9.61M
        float prob = outputs[code] + outputs[null_char_];
969
9.61M
        if (prev != nullptr && prev_cont == NC_ANYTHING &&
970
9.61M
            prev->code != null_char_ &&
971
9.61M
            ((prev->code == top_code_ && code == second_code_) ||
972
1.27M
             (code == top_code_ && prev->code == second_code_))) {
973
95.3k
          prob += outputs[prev->code];
974
95.3k
        }
975
9.61M
        cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
976
9.61M
        ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
977
9.61M
                        use_dawgs, NC_ONLY_DUP, prev, step);
978
9.61M
      }
979
13.7M
    }
980
11.2M
  }
981
11.2M
  const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix);
982
11.2M
  if (next_codes != nullptr) {
983
0
    for (int code : *next_codes) {
984
0
      if (top_n_flags_[code] != top_n_flag) {
985
0
        continue;
986
0
      }
987
0
      if (prev != nullptr && prev->code == code && !is_simple_text_) {
988
0
        continue;
989
0
      }
990
0
      float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
991
0
      PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert,
992
0
                              worst_dict_cert, dict_ratio, use_dawgs,
993
0
                              NC_ANYTHING, prev, step);
994
0
      if (top_n_flag == TN_TOP2 && code != null_char_) {
995
0
        float prob = outputs[code] + outputs[null_char_];
996
0
        if (prev != nullptr && prev_cont == NC_ANYTHING &&
997
0
            prev->code != null_char_ &&
998
0
            ((prev->code == top_code_ && code == second_code_) ||
999
0
             (code == top_code_ && prev->code == second_code_))) {
1000
0
          prob += outputs[prev->code];
1001
0
        }
1002
0
        cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
1003
0
        PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID,
1004
0
                                cert, worst_dict_cert, dict_ratio, use_dawgs,
1005
0
                                NC_ONLY_DUP, prev, step);
1006
0
      }
1007
0
    }
1008
0
  }
1009
11.2M
}
1010
1011
// Continues for a new unichar, using dawg or non-dawg as per flag.
1012
void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert,
1013
                                       float worst_dict_cert, float dict_ratio,
1014
                                       bool use_dawgs, NodeContinuation cont,
1015
                                       const RecodeNode *prev,
1016
23.3M
                                       RecodeBeam *step) {
1017
23.3M
  if (use_dawgs) {
1018
9.02M
    if (cert > worst_dict_cert) {
1019
8.19M
      ContinueDawg(code, unichar_id, cert, cont, prev, step);
1020
8.19M
    }
1021
14.3M
  } else {
1022
14.3M
    RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1023
14.3M
    PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false,
1024
14.3M
                     false, false, false, cert * dict_ratio, prev, nullptr,
1025
14.3M
                     nodawg_heap);
1026
14.3M
    if (dict_ != nullptr &&
1027
14.3M
        ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
1028
14.3M
         !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
1029
      // Any top choice position that can start a new word, ie a space or
1030
      // any non-space-delimited character, should also be considered
1031
      // by the dawg search, so push initial dawg to the dawg heap.
1032
1.20M
      float dawg_cert = cert;
1033
1.20M
      PermuterType permuter = TOP_CHOICE_PERM;
1034
      // Since we use the space either side of a dictionary word in the
1035
      // certainty of the word, (to properly handle weak spaces) and the
1036
      // space is coming from a non-dict word, we need special conditions
1037
      // to avoid degrading the certainty of the dict word that follows.
1038
      // With a space we don't multiply the certainty by dict_ratio, and we
1039
      // flag the space with NO_PERM to indicate that we should not use the
1040
      // predecessor nulls to generate the confidence for the space, as they
1041
      // have already been multiplied by dict_ratio, and we can't go back to
1042
      // insert more entries in any previous heaps.
1043
1.20M
      if (unichar_id == UNICHAR_SPACE) {
1044
1.20M
        permuter = NO_PERM;
1045
1.20M
      } else {
1046
0
        dawg_cert *= dict_ratio;
1047
0
      }
1048
1.20M
      PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
1049
1.20M
                              dawg_cert, cont, prev, step);
1050
1.20M
    }
1051
14.3M
  }
1052
23.3M
}
1053
1054
// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
1055
// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
1056
// is a valid continuation of whatever is in prev.
1057
void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert,
1058
                                    NodeContinuation cont,
1059
8.19M
                                    const RecodeNode *prev, RecodeBeam *step) {
1060
8.19M
  RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)];
1061
8.19M
  RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1062
8.19M
  if (unichar_id == INVALID_UNICHAR_ID) {
1063
1.09M
    PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false,
1064
1.09M
                     false, false, cert, prev, nullptr, dawg_heap);
1065
1.09M
    return;
1066
1.09M
  }
1067
  // Avoid dictionary probe if score a total loss.
1068
7.10M
  float score = cert;
1069
7.10M
  if (prev != nullptr) {
1070
6.99M
    score += prev->score;
1071
6.99M
  }
1072
7.10M
  if (dawg_heap->size() >= kBeamWidths[0] &&
1073
7.10M
      score <= dawg_heap->PeekTop().data().score &&
1074
7.10M
      nodawg_heap->size() >= kBeamWidths[0] &&
1075
7.10M
      score <= nodawg_heap->PeekTop().data().score) {
1076
10.5k
    return;
1077
10.5k
  }
1078
7.08M
  const RecodeNode *uni_prev = prev;
1079
  // Prev may be a partial code, null_char, or duplicate, so scan back to the
1080
  // last valid unichar_id.
1081
105M
  while (uni_prev != nullptr &&
1082
105M
         (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1083
97.9M
    uni_prev = uni_prev->prev;
1084
97.9M
  }
1085
7.08M
  if (unichar_id == UNICHAR_SPACE) {
1086
1.18M
    if (uni_prev != nullptr && uni_prev->end_of_word) {
1087
      // Space is good. Push initial state, to the dawg beam and a regular
1088
      // space to the top choice beam.
1089
625k
      PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
1090
625k
                              false, cert, cont, prev, step);
1091
625k
      PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1092
625k
                       false, false, false, false, cert, prev, nullptr,
1093
625k
                       nodawg_heap);
1094
625k
    }
1095
1.18M
    return;
1096
5.90M
  } else if (uni_prev != nullptr && uni_prev->start_of_dawg &&
1097
5.90M
             uni_prev->unichar_id != UNICHAR_SPACE &&
1098
5.90M
             dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
1099
5.90M
             dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
1100
0
    return; // Can't break words between space delimited chars.
1101
0
  }
1102
5.90M
  DawgPositionVector initial_dawgs;
1103
5.90M
  auto *updated_dawgs = new DawgPositionVector;
1104
5.90M
  DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
1105
5.90M
  bool word_start = false;
1106
5.90M
  if (uni_prev == nullptr) {
1107
    // Starting from beginning of line.
1108
1.32M
    dict_->default_dawgs(&initial_dawgs, false);
1109
1.32M
    word_start = true;
1110
4.57M
  } else if (uni_prev->dawgs != nullptr) {
1111
    // Continuing a previous dict word.
1112
4.57M
    dawg_args.active_dawgs = uni_prev->dawgs;
1113
4.57M
    word_start = uni_prev->start_of_dawg;
1114
4.57M
  } else {
1115
0
    return; // Can't continue if not a dict word.
1116
0
  }
1117
5.90M
  auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay(
1118
5.90M
      &dawg_args, dict_->getUnicharset(), unichar_id, false));
1119
5.90M
  if (permuter != NO_PERM) {
1120
1.99M
    PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1121
1.99M
                     word_start, dawg_args.valid_end, false, cert, prev,
1122
1.99M
                     dawg_args.updated_dawgs, dawg_heap);
1123
1.99M
    if (dawg_args.valid_end && !space_delimited_) {
1124
      // We can start another word right away, so push initial state as well,
1125
      // to the dawg beam, and the regular character to the top choice beam,
1126
      // since non-dict words can start here too.
1127
0
      PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
1128
0
                              cert, cont, prev, step);
1129
0
      PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1130
0
                       word_start, true, false, cert, prev, nullptr,
1131
0
                       nodawg_heap);
1132
0
    }
1133
3.90M
  } else {
1134
3.90M
    delete updated_dawgs;
1135
3.90M
  }
1136
5.90M
}
1137
1138
// Adds a RecodeNode composed of the tuple (code, unichar_id,
1139
// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
1140
// better than the current worst element if already full.
1141
void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
1142
                                               PermuterType permuter,
1143
                                               bool start, bool end, float cert,
1144
                                               NodeContinuation cont,
1145
                                               const RecodeNode *prev,
1146
1.83M
                                               RecodeBeam *step) {
1147
1.83M
  RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1148
1.83M
  float score = cert;
1149
1.83M
  if (prev != nullptr) {
1150
1.82M
    score += prev->score;
1151
1.82M
  }
1152
1.83M
  if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1153
806k
    auto *initial_dawgs = new DawgPositionVector;
1154
806k
    dict_->default_dawgs(initial_dawgs, false);
1155
806k
    RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
1156
806k
                    score, prev, initial_dawgs,
1157
806k
                    ComputeCodeHash(code, false, prev));
1158
806k
    *best_initial_dawg = node;
1159
806k
  }
1160
1.83M
}
1161
1162
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1163
// false, false, false, false, cert, prev, nullptr) to heap if there is room
1164
// or if better than the current worst element if already full.
1165
/* static */
1166
void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1167
    int length, bool dup, int code, int unichar_id, float cert,
1168
    float worst_dict_cert, float dict_ratio, bool use_dawgs,
1169
12.9M
    NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) {
1170
12.9M
  int index = BeamIndex(use_dawgs, cont, length);
1171
12.9M
  if (use_dawgs) {
1172
4.45M
    if (cert > worst_dict_cert) {
1173
4.26M
      PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1174
4.26M
                       prev ? prev->permuter : NO_PERM, false, false, false,
1175
4.26M
                       dup, cert, prev, nullptr, &step->beams_[index]);
1176
4.26M
    }
1177
8.49M
  } else {
1178
8.49M
    cert *= dict_ratio;
1179
8.49M
    if (cert >= kMinCertainty || code == null_char_) {
1180
8.43M
      PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1181
8.43M
                       prev ? prev->permuter : TOP_CHOICE_PERM, false, false,
1182
8.43M
                       false, dup, cert, prev, nullptr, &step->beams_[index]);
1183
8.43M
    }
1184
8.49M
  }
1185
12.9M
}
1186
1187
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1188
// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
1189
// or if better than the current worst element if already full.
1190
void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
1191
                                        PermuterType permuter, bool dawg_start,
1192
                                        bool word_start, bool end, bool dup,
1193
                                        float cert, const RecodeNode *prev,
1194
                                        DawgPositionVector *d,
1195
30.7M
                                        RecodeHeap *heap) {
1196
30.7M
  float score = cert;
1197
30.7M
  if (prev != nullptr) {
1198
30.2M
    score += prev->score;
1199
30.2M
  }
1200
30.7M
  if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1201
24.3M
    uint64_t hash = ComputeCodeHash(code, dup, prev);
1202
24.3M
    RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1203
24.3M
                    dup, cert, score, prev, d, hash);
1204
24.3M
    if (UpdateHeapIfMatched(&node, heap)) {
1205
4.22M
      return;
1206
4.22M
    }
1207
20.1M
    RecodePair entry(score, node);
1208
20.1M
    heap->Push(&entry);
1209
20.1M
    ASSERT_HOST(entry.data().dawgs == nullptr);
1210
20.1M
    if (heap->size() > max_size) {
1211
2.30M
      heap->Pop(&entry);
1212
2.30M
    }
1213
20.1M
  } else {
1214
6.42M
    delete d;
1215
6.42M
  }
1216
30.7M
}
1217
1218
// Adds a RecodeNode to heap if there is room
1219
// or if better than the current worst element if already full.
1220
void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node,
1221
332k
                                        RecodeHeap *heap) {
1222
332k
  if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1223
275k
    if (UpdateHeapIfMatched(node, heap)) {
1224
0
      return;
1225
0
    }
1226
275k
    RecodePair entry(node->score, *node);
1227
275k
    heap->Push(&entry);
1228
275k
    ASSERT_HOST(entry.data().dawgs == nullptr);
1229
275k
    if (heap->size() > max_size) {
1230
26.5k
      heap->Pop(&entry);
1231
26.5k
    }
1232
275k
  }
1233
332k
}
1234
1235
// Searches the heap for a matching entry, and updates the score with
1236
// reshuffle if needed. Returns true if there was a match.
1237
bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1238
24.6M
                                           RecodeHeap *heap) {
1239
  // TODO(rays) consider hash map instead of linear search.
1240
  // It might not be faster because the hash map would have to be updated
1241
  // every time a heap reshuffle happens, and that would be a lot of overhead.
1242
24.6M
  std::vector<RecodePair> &nodes = heap->heap();
1243
50.6M
  for (auto &i : nodes) {
1244
50.6M
    RecodeNode &node = i.data();
1245
50.6M
    if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1246
50.6M
        node.permuter == new_node->permuter &&
1247
50.6M
        node.start_of_dawg == new_node->start_of_dawg) {
1248
4.22M
      if (new_node->score > node.score) {
1249
        // The new one is better. Update the entire node in the heap and
1250
        // reshuffle.
1251
2.31M
        node = *new_node;
1252
2.31M
        i.key() = node.score;
1253
2.31M
        heap->Reshuffle(&i);
1254
2.31M
      }
1255
4.22M
      return true;
1256
4.22M
    }
1257
50.6M
  }
1258
20.3M
  return false;
1259
24.6M
}
1260
1261
// Computes and returns the code-hash for the given code and prev.
1262
uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup,
1263
25.1M
                                           const RecodeNode *prev) const {
1264
25.1M
  uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
1265
25.1M
  if (!dup && code != null_char_) {
1266
11.6M
    int num_classes = recoder_.code_range();
1267
11.6M
    uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1268
11.6M
    hash *= num_classes;
1269
11.6M
    hash += carry;
1270
11.6M
    hash += code;
1271
11.6M
  }
1272
25.1M
  return hash;
1273
25.1M
}
1274
1275
// Backtracks to extract the best path through the lattice that was built
1276
// during Decode. On return the best_nodes vector essentially contains the set
1277
// of code, score pairs that make the optimal path with the constraint that
1278
// the recoder can decode the code sequence back to a sequence of unichar-ids.
1279
void RecodeBeamSearch::ExtractBestPaths(
1280
    std::vector<const RecodeNode *> *best_nodes,
1281
96.1k
    std::vector<const RecodeNode *> *second_nodes) const {
1282
  // Scan both beams to extract the best and second best paths.
1283
96.1k
  const RecodeNode *best_node = nullptr;
1284
96.1k
  const RecodeNode *second_best_node = nullptr;
1285
96.1k
  const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1286
384k
  for (int c = 0; c < NC_COUNT; ++c) {
1287
288k
    if (c == NC_ONLY_DUP) {
1288
96.1k
      continue;
1289
96.1k
    }
1290
192k
    auto cont = static_cast<NodeContinuation>(c);
1291
576k
    for (int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1292
384k
      int beam_index = BeamIndex(is_dawg, cont, 0);
1293
384k
      int heap_size = last_beam->beams_[beam_index].size();
1294
1.17M
      for (int h = 0; h < heap_size; ++h) {
1295
791k
        const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1296
791k
        if (is_dawg) {
1297
          // dawg_node may be a null_char, or duplicate, so scan back to the
1298
          // last valid unichar_id.
1299
291k
          const RecodeNode *dawg_node = node;
1300
2.22M
          while (dawg_node != nullptr &&
1301
2.22M
                 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1302
2.16M
                  dawg_node->duplicate)) {
1303
1.93M
            dawg_node = dawg_node->prev;
1304
1.93M
          }
1305
291k
          if (dawg_node == nullptr ||
1306
291k
              (!dawg_node->end_of_word &&
1307
230k
               dawg_node->unichar_id != UNICHAR_SPACE)) {
1308
            // Dawg node is not valid.
1309
85.8k
            continue;
1310
85.8k
          }
1311
291k
        }
1312
706k
        if (best_node == nullptr || node->score > best_node->score) {
1313
493k
          second_best_node = best_node;
1314
493k
          best_node = node;
1315
493k
        } else if (second_best_node == nullptr ||
1316
212k
                   node->score > second_best_node->score) {
1317
137k
          second_best_node = node;
1318
137k
        }
1319
706k
      }
1320
384k
    }
1321
192k
  }
1322
96.1k
  if (second_nodes != nullptr) {
1323
96.1k
    ExtractPath(second_best_node, second_nodes);
1324
96.1k
  }
1325
96.1k
  ExtractPath(best_node, best_nodes);
1326
96.1k
}
1327
1328
// Helper backtracks through the lattice from the given node, storing the
1329
// path and reversing it.
1330
void RecodeBeamSearch::ExtractPath(
1331
192k
    const RecodeNode *node, std::vector<const RecodeNode *> *path) const {
1332
192k
  path->clear();
1333
2.86M
  while (node != nullptr) {
1334
2.67M
    path->push_back(node);
1335
2.67M
    node = node->prev;
1336
2.67M
  }
1337
192k
  std::reverse(path->begin(), path->end());
1338
192k
}
1339
1340
void RecodeBeamSearch::ExtractPath(const RecodeNode *node,
1341
                                   std::vector<const RecodeNode *> *path,
1342
0
                                   int limiter) const {
1343
0
  int pathcounter = 0;
1344
0
  path->clear();
1345
0
  while (node != nullptr && pathcounter < limiter) {
1346
0
    path->push_back(node);
1347
0
    node = node->prev;
1348
0
    ++pathcounter;
1349
0
  }
1350
0
  std::reverse(path->begin(), path->end());
1351
0
}
1352
1353
// Helper prints debug information on the given lattice path.
1354
void RecodeBeamSearch::DebugPath(
1355
    const UNICHARSET *unicharset,
1356
0
    const std::vector<const RecodeNode *> &path) const {
1357
0
  for (unsigned c = 0; c < path.size(); ++c) {
1358
0
    const RecodeNode &node = *path[c];
1359
0
    tprintf("%u ", c);
1360
0
    node.Print(null_char_, *unicharset, 1);
1361
0
  }
1362
0
}
1363
1364
// Helper prints debug information on the given unichar path.
1365
void RecodeBeamSearch::DebugUnicharPath(
1366
    const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path,
1367
    const std::vector<int> &unichar_ids, const std::vector<float> &certs,
1368
0
    const std::vector<float> &ratings, const std::vector<int> &xcoords) const {
1369
0
  auto num_ids = unichar_ids.size();
1370
0
  double total_rating = 0.0;
1371
0
  for (unsigned c = 0; c < num_ids; ++c) {
1372
0
    int coord = xcoords[c];
1373
0
    tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1374
0
            unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1375
0
            path[coord]->start_of_word, path[coord]->end_of_word,
1376
0
            path[coord]->permuter);
1377
0
    total_rating += ratings[c];
1378
0
  }
1379
0
  tprintf("Path total rating = %g\n", total_rating);
1380
0
}
1381
1382
} // namespace tesseract.