Coverage Report

Created: 2026-03-31 06:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/bpe_model.cc
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#include "bpe_model.h"
16
17
#include <functional>
18
#include <memory>
19
#include <queue>
20
#include <random>
21
#include <utility>
22
#include <vector>
23
24
#include "freelist.h"
25
#include "third_party/absl/container/flat_hash_map.h"
26
#include "util.h"
27
28
namespace sentencepiece {
29
namespace bpe {
30
31
0
Model::Model(const ModelProto &model_proto) {
32
0
  model_proto_ = &model_proto;
33
0
  InitializePieces();
34
0
}
35
36
0
Model::~Model() {}
37
38
std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
39
0
    absl::string_view normalized, float alpha) const {
40
0
  if (!status().ok() || normalized.empty()) {
41
0
    return {};
42
0
  }
43
44
0
  struct SymbolPair {
45
0
    int left;     // left index of this pair
46
0
    int right;    // right index of this pair
47
0
    float score;  // score of this pair. large is better.
48
0
    size_t size;  // length of this piece
49
0
  };
50
51
0
  class SymbolPairComparator {
52
0
   public:
53
0
    const bool operator()(SymbolPair *h1, SymbolPair *h2) {
54
0
      return (h1->score < h2->score ||
55
0
              (h1->score == h2->score && h1->left > h2->left));
56
0
    }
57
0
  };
58
59
0
  struct Symbol {
60
0
    int prev;     // prev index of this symbol. -1 for BOS.
61
0
    int next;     // next index of tihs symbol. -1 for EOS.
62
0
    bool freeze;  // this symbol is never be merged.
63
0
    absl::string_view piece;
64
0
  };
65
66
0
  using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
67
0
                                     SymbolPairComparator>;
68
0
  Agenda agenda;
69
0
  std::vector<Symbol> symbols;
70
0
  symbols.reserve(normalized.size());
71
72
  // Reverse merge rules.
73
  // key: merged symbol, value: pair of original symbols.
74
0
  absl::flat_hash_map<absl::string_view,
75
0
                      std::pair<absl::string_view, absl::string_view>>
76
0
      rev_merge;
77
78
  // Pre-allocates SymbolPair for efficiency.
79
0
  constexpr size_t kPreallocateSymbolPairSize = 256;
80
0
  model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize);
81
82
  // Lookup new symbol pair at [left, right] and inserts it to agenda.
83
0
  auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda,
84
0
                                &rev_merge](int left, int right) {
85
0
    if (left == -1 || right == -1 || symbols[left].freeze ||
86
0
        symbols[right].freeze)
87
0
      return;
88
0
    const absl::string_view piece(
89
0
        symbols[left].piece.data(),
90
0
        symbols[left].piece.size() + symbols[right].piece.size());
91
0
    const auto it = pieces_.find(piece);
92
0
    if (it == pieces_.end()) {
93
0
      return;
94
0
    }
95
0
    auto *h = symbol_pair_allocator.Allocate();
96
0
    h->left = left;
97
0
    h->right = right;
98
0
    h->score = GetScore(it->second);
99
0
    h->size = piece.size();
100
0
    agenda.push(h);
101
102
    // Makes `rev_merge` for resegmentation.
103
0
    if (IsUnusedInlined(it->second)) {
104
0
      rev_merge[piece] =
105
0
          std::make_pair(symbols[left].piece, symbols[right].piece);
106
0
    }
107
0
  };
108
109
  // Splits the input into character sequence
110
0
  int index = 0;
111
0
  while (!normalized.empty()) {
112
0
    Symbol s;
113
0
    const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
114
0
    s.piece = absl::string_view(normalized.data(), mblen);
115
0
    s.prev = index == 0 ? -1 : index - 1;
116
0
    normalized.remove_prefix(mblen);
117
0
    s.next = normalized.empty() ? -1 : index + 1;
118
0
    ++index;
119
0
    symbols.emplace_back(s);
120
0
  }
121
122
0
  if (symbols.empty()) {
123
0
    return {};
124
0
  }
125
126
  // Lookup all bigrams.
127
0
  for (size_t i = 1; i < symbols.size(); ++i) {
128
0
    MaybeAddNewSymbolPair(i - 1, i);
129
0
  }
130
131
  // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
132
0
  std::mt19937 *rand_gen = nullptr;
133
0
  auto skip_merge = [&]() {
134
0
    if (alpha <= 0.0) return false;
135
0
    if (alpha >= 1.0) return true;
136
0
    if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
137
0
    std::uniform_real_distribution<> gen(0.0, 1.0);
138
0
    return gen(*rand_gen) < alpha;
139
0
  };
140
141
  // Main loop.
142
0
  while (!agenda.empty()) {
143
0
    SymbolPair *top = agenda.top();
144
0
    agenda.pop();
145
146
    // `top` is no longer available.
147
0
    if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
148
0
        symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
149
0
            top->size) {
150
0
      continue;
151
0
    }
152
153
    // Note that orignal BPE-dropout paper assumes that all merged symbols are
154
    // pre computed, but here we randomly skip merge opration inside this loop.
155
    // This implemenation is theoretically equivalent to the original one.
156
0
    if (skip_merge()) continue;
157
158
    // Replaces symbols with `top` rule.
159
0
    symbols[top->left].piece = absl::string_view(
160
0
        symbols[top->left].piece.data(),
161
0
        symbols[top->left].piece.size() + symbols[top->right].piece.size());
162
163
    // Updates prev/next pointers.
164
0
    symbols[top->left].next = symbols[top->right].next;
165
0
    if (symbols[top->right].next >= 0) {
166
0
      symbols[symbols[top->right].next].prev = top->left;
167
0
    }
168
0
    symbols[top->right].piece = absl::string_view("");
169
170
    // Adds new symbol pairs which are newly added after symbol replacement.
171
0
    MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
172
0
    MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
173
0
  }
174
175
0
  std::function<void(absl::string_view, EncodeResult *)> resegment;
176
0
  resegment = [this, &resegment, &rev_merge](absl::string_view w,
177
0
                                             EncodeResult *output) -> void {
178
0
    const int id = PieceToId(w);
179
0
    if (id == -1 || !IsUnusedInlined(id)) {
180
0
      output->emplace_back(w, id);
181
0
      return;
182
0
    }
183
0
    const auto p = rev_merge.find(w);
184
0
    if (p == rev_merge.end()) {
185
      // This block will never be called, as `rev_merge` stores all the
186
      // resegmentation info for unused id.
187
0
      output->emplace_back(w, id);
188
0
      return;
189
0
    }
190
    // Recursively resegment left and right symbols.
191
0
    resegment(p->second.first, output);
192
0
    resegment(p->second.second, output);
193
0
  };
194
195
0
  EncodeResult output;
196
0
  for (int index = 0; index != -1; index = symbols[index].next) {
197
0
    if (index >= 0 && index < static_cast<int>(symbols.size())) {
198
0
      resegment(symbols[index].piece, &output);
199
0
    }
200
0
  }
201
202
0
  return output;
203
0
}
204
}  // namespace bpe
205
}  // namespace sentencepiece