Coverage Report

Created: 2026-05-30 06:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/bpe_model.cc
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#include "bpe_model.h"
16
17
#include <cstdint>
18
#include <functional>
19
#include <memory>
20
#include <queue>
21
#include <random>
22
#include <utility>
23
#include <vector>
24
25
#include "freelist.h"
26
#include "model_interface.h"
27
#include "sentencepiece_model.pb.h"
28
#include "third_party/absl/base/attributes.h"
29
#include "third_party/absl/container/flat_hash_map.h"
30
#include "third_party/absl/random/random.h"
31
#include "third_party/absl/strings/string_view.h"
32
#include "util.h"
33
34
namespace sentencepiece {
35
namespace bpe {
36
37
287
Model::Model(const ModelProto &model_proto) {
38
287
  model_proto_ = &model_proto;
39
287
  InitializePieces();
40
287
}
41
42
287
Model::~Model() {}
43
44
std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
45
379
    absl::string_view normalized, float alpha) const {
46
379
  if (!status().ok() || normalized.empty()) {
47
3
    return {};
48
3
  }
49
50
376
  struct SymbolPair {
51
376
    union {
52
376
      float score;  // score of this pair. large is better.
53
376
      int32_t int_score;
54
376
    };
55
376
    uint32_t left;      // left index of this pair
56
376
    int right;          // right index of this pair
57
376
    unsigned int size;  // length of this piece
58
376
  };
59
60
376
  class SymbolPairComparator {
61
376
   public:
62
376
    ABSL_ATTRIBUTE_ALWAYS_INLINE inline bool operator()(const SymbolPair &h1,
63
376
                                                        const SymbolPair &h2) {
64
0
      const int32_t i1 = h1.int_score;
65
0
      const int32_t i2 = h2.int_score;
66
67
      // Fast path for the common case where both scores are negative because
68
      // they are log-probabilities.
69
      // Note: we use the fact that IEEE 754 floating point format enables
70
      // to compare the integer representation of negative floats which is
71
      // cheaper than using float comparison. And it works the same way for
72
      // little endian and big endian machines because the IEEE 754 format is
73
      // aligned with the endianness.
74
      // `(i1 & i2) < 0` is an efficient way to check `i1 < 0 && i2 < 0`.
75
0
      if ((i1 & i2) < 0) {
76
        // For negative floats, their integer representation order is the
77
        // reverse of the float order. That is, for two negative floats f1, f2,
78
        // f1 < f2 iff i1 > i2.
79
0
        return (i1 > i2) || (i1 == i2 && h1.left > h2.left);
80
0
      }
81
82
      // Slow path for uncommon cases (mixed signs or both positive).
83
      // Note: the comparison between NaN and +0 and +1 can be different than
84
      // if we used float numbers but it should not influence the result.
85
0
      bool score_less;
86
      // If signs are different ((i1 ^ i2) < 0), the negative score is smaller.
87
0
      if ((i1 ^ i2) < 0) {
88
0
        score_less = i1 < 0;
89
0
      } else {
90
        // If signs are the same (and not both negative), they must both be
91
        // non-negative. For non-negative floats, integer order is the same as
92
        // float order.
93
0
        score_less = i1 < i2;
94
0
      }
95
96
0
      return score_less || (i1 == i2 && h1.left > h2.left);
97
0
    }
98
376
  };
99
100
376
  struct Symbol {
101
376
    int prev;     // prev index of this symbol. -1 for BOS.
102
376
    int next;     // next index of tihs symbol. -1 for EOS.
103
376
    bool freeze;  // this symbol is never be merged.
104
376
    absl::string_view piece;
105
376
  };
106
107
376
  std::vector<Symbol> symbols;
108
376
  symbols.reserve(normalized.size());
109
110
  // Splits the input into Symbols doing longest prefix match of the input
111
  // from pieces(type:UNUSED) in the vocabulary.
112
  // Does character splitting as a fallback of longest prefix match.
113
376
  int index = 0;
114
820k
  while (!normalized.empty()) {
115
820k
    Symbol s;
116
820k
    const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
117
820k
    s.piece = absl::string_view(normalized.data(), mblen);
118
820k
    s.prev = index == 0 ? -1 : index - 1;
119
820k
    normalized.remove_prefix(mblen);
120
820k
    s.next = normalized.empty() ? -1 : index + 1;
121
820k
    ++index;
122
820k
    symbols.emplace_back(s);
123
820k
  }
124
125
376
  if (symbols.empty()) {
126
0
    return {};
127
0
  }
128
129
376
  std::vector<SymbolPair> agenda_vec;
130
376
  agenda_vec.reserve(symbols.size());
131
132
  // Reverse merge rules.
133
  // key: merged symbol, value: pair of original symbols.
134
376
  absl::flat_hash_map<absl::string_view,
135
376
                      std::pair<absl::string_view, absl::string_view>>
136
376
      rev_merge;
137
138
  // Lookup all bigrams.
139
376
  if (symbols.size() > 1) {
140
376
    int left = 0;
141
376
    int right = 1;
142
376
    Symbol *symbol_left = &symbols[left];
143
376
    Symbol *symbol_right = &symbols[right];
144
820k
    for (; right < symbols.size();
145
819k
         left = right, symbol_left = symbol_right, ++right, ++symbol_right) {
146
819k
      if (symbol_left->freeze || symbol_right->freeze) continue;
147
819k
      const absl::string_view piece(
148
819k
          symbol_left->piece.data(),
149
819k
          symbol_left->piece.size() + symbol_right->piece.size());
150
819k
      const auto it = pieces_.find(piece);
151
819k
      if (it == pieces_.end()) continue;
152
0
      SymbolPair &h = agenda_vec.emplace_back();
153
0
      h.left = left;
154
0
      h.right = right;
155
0
      h.score = GetScore(it->second);
156
0
      h.size = piece.size();
157
158
      // Makes `rev_merge` for resegmentation.
159
0
      if (IsUnusedInlined(it->second))
160
0
        rev_merge[piece] =
161
0
            std::make_pair(symbol_left->piece, symbol_right->piece);
162
0
    }
163
376
  }
164
165
376
  using Agenda = std::priority_queue<SymbolPair, std::vector<SymbolPair>,
166
376
                                     SymbolPairComparator>;
167
376
  Agenda agenda(SymbolPairComparator(), std::move(agenda_vec));
168
  // Lookup new symbol pair at [left, right] and inserts it to agenda.
169
376
  auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge](
170
376
                                   int left, int right) {
171
0
    if (left == -1 || right == -1) return;
172
0
    const Symbol &left_symbol = symbols[left];
173
0
    const Symbol &right_symbol = symbols[right];
174
0
    if (left_symbol.freeze || right_symbol.freeze) return;
175
0
    const absl::string_view piece(
176
0
        left_symbol.piece.data(),
177
0
        left_symbol.piece.size() + right_symbol.piece.size());
178
0
    const auto it = pieces_.find(piece);
179
0
    if (it == pieces_.end()) {
180
0
      return;
181
0
    }
182
0
    const int id = it->second;
183
0
    SymbolPair h;
184
0
    h.left = left;
185
0
    h.right = right;
186
0
    h.score = GetScore(id);
187
0
    h.size = piece.size();
188
0
    agenda.push(h);
189
190
    // Makes `rev_merge` for resegmentation.
191
0
    if (IsUnusedInlined(id))
192
0
      rev_merge[piece] = std::make_pair(left_symbol.piece, right_symbol.piece);
193
0
  };
194
195
376
  absl::BitGen *rand_gen = nullptr;
196
  // Main loop.
197
376
  while (!agenda.empty()) {
198
    // Pop the top pair if it is stale.
199
0
    const SymbolPair &top_ref = agenda.top();
200
0
    if (symbols[top_ref.left].piece.empty() ||
201
0
        symbols[top_ref.right].piece.empty() ||
202
0
        (symbols[top_ref.left].piece.size() +
203
0
             symbols[top_ref.right].piece.size() !=
204
0
         top_ref.size)) {
205
0
      agenda.pop();
206
0
      continue;
207
0
    }
208
209
0
    SymbolPair top = agenda.top();
210
0
    agenda.pop();
211
212
0
    Symbol &left_symbol = symbols[top.left];
213
0
    Symbol &right_symbol = symbols[top.right];
214
215
    // Note that original BPE-dropout paper assumes that all merged symbols are
216
    // pre computed, but here we randomly skip merge operation inside this loop.
217
    // This implementation is theoretically equivalent to the original one.
218
    // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
219
0
    if (alpha > 0.0) {
220
0
      if (alpha >= 1.0) continue;
221
0
      if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
222
0
      std::uniform_real_distribution<> gen(0.0, 1.0);
223
0
      if (gen(*rand_gen) < alpha) continue;
224
0
    }
225
226
    // Replaces symbols with `top` rule.
227
0
    left_symbol.piece =
228
0
        absl::string_view(left_symbol.piece.data(),
229
0
                          left_symbol.piece.size() + right_symbol.piece.size());
230
231
    // Updates prev/next pointers.
232
0
    left_symbol.next = right_symbol.next;
233
0
    if (right_symbol.next >= 0) {
234
0
      symbols[right_symbol.next].prev = top.left;
235
0
    }
236
0
    right_symbol.piece = absl::string_view("");
237
238
    // Adds new symbol pairs which are newly added after symbol replacement.
239
0
    MaybeAddNewSymbolPair(left_symbol.prev, top.left);
240
0
    MaybeAddNewSymbolPair(top.left, left_symbol.next);
241
0
  }
242
243
376
  std::function<void(absl::string_view, EncodeResult *)> resegment;
244
376
  resegment = [this, &resegment, &rev_merge](absl::string_view w,
245
820k
                                             EncodeResult *output) -> void {
246
820k
    const int id = PieceToId(w);
247
820k
    if (id == -1 || !IsUnusedInlined(id)) {
248
820k
      output->emplace_back(w, id);
249
820k
      return;
250
820k
    }
251
0
    const auto p = rev_merge.find(w);
252
0
    if (p == rev_merge.end()) {
253
      // This block will never be called, as `rev_merge` stores all the
254
      // resegmentation info for unused id.
255
0
      output->emplace_back(w, id);
256
0
      return;
257
0
    }
258
    // Recursively resegment left and right symbols.
259
0
    resegment(p->second.first, output);
260
0
    resegment(p->second.second, output);
261
0
  };
262
263
376
  EncodeResult output;
264
376
  output.reserve(symbols.size());
265
820k
  for (int index = 0; index != -1; index = symbols[index].next) {
266
820k
    if (index >= 0 && index < static_cast<int>(symbols.size())) {
267
820k
      resegment(symbols[index].piece, &output);
268
820k
    }
269
820k
  }
270
271
376
  return output;
272
376
}
273
}  // namespace bpe
274
}  // namespace sentencepiece