/src/sentencepiece/src/bpe_model.cc
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #include "bpe_model.h" |
16 | | |
17 | | #include <cstdint> |
18 | | #include <functional> |
19 | | #include <memory> |
20 | | #include <queue> |
21 | | #include <random> |
22 | | #include <utility> |
23 | | #include <vector> |
24 | | |
25 | | #include "freelist.h" |
26 | | #include "model_interface.h" |
27 | | #include "sentencepiece_model.pb.h" |
28 | | #include "third_party/absl/base/attributes.h" |
29 | | #include "third_party/absl/container/flat_hash_map.h" |
30 | | #include "third_party/absl/random/random.h" |
31 | | #include "third_party/absl/strings/string_view.h" |
32 | | #include "util.h" |
33 | | |
34 | | namespace sentencepiece { |
35 | | namespace bpe { |
36 | | |
37 | 287 | Model::Model(const ModelProto &model_proto) { |
38 | 287 | model_proto_ = &model_proto; |
39 | 287 | InitializePieces(); |
40 | 287 | } |
41 | | |
42 | 287 | Model::~Model() {} |
43 | | |
44 | | std::vector<std::pair<absl::string_view, int>> Model::SampleEncode( |
45 | 379 | absl::string_view normalized, float alpha) const { |
46 | 379 | if (!status().ok() || normalized.empty()) { |
47 | 3 | return {}; |
48 | 3 | } |
49 | | |
50 | 376 | struct SymbolPair { |
51 | 376 | union { |
52 | 376 | float score; // score of this pair. large is better. |
53 | 376 | int32_t int_score; |
54 | 376 | }; |
55 | 376 | uint32_t left; // left index of this pair |
56 | 376 | int right; // right index of this pair |
57 | 376 | unsigned int size; // length of this piece |
58 | 376 | }; |
59 | | |
60 | 376 | class SymbolPairComparator { |
61 | 376 | public: |
62 | 376 | ABSL_ATTRIBUTE_ALWAYS_INLINE inline bool operator()(const SymbolPair &h1, |
63 | 376 | const SymbolPair &h2) { |
64 | 0 | const int32_t i1 = h1.int_score; |
65 | 0 | const int32_t i2 = h2.int_score; |
66 | | |
67 | | // Fast path for the common case where both scores are negative because |
68 | | // they are log-probabilities. |
69 | | // Note: we use the fact that IEEE 754 floating point format enables |
70 | | // to compare the integer representation of negative floats which is |
71 | | // cheaper than using float comparison. And it works the same way for |
72 | | // little endian and big endian machines because the IEEE 754 format is |
73 | | // aligned with the endianness. |
74 | | // `(i1 & i2) < 0` is an efficient way to check `i1 < 0 && i2 < 0`. |
75 | 0 | if ((i1 & i2) < 0) { |
76 | | // For negative floats, their integer representation order is the |
77 | | // reverse of the float order. That is, for two negative floats f1, f2, |
78 | | // f1 < f2 iff i1 > i2. |
79 | 0 | return (i1 > i2) || (i1 == i2 && h1.left > h2.left); |
80 | 0 | } |
81 | | |
82 | | // Slow path for uncommon cases (mixed signs or both positive). |
83 | | // Note: the comparison between NaN and +0 and +1 can be different than |
84 | | // if we used float numbers but it should not influence the result. |
85 | 0 | bool score_less; |
86 | | // If signs are different ((i1 ^ i2) < 0), the negative score is smaller. |
87 | 0 | if ((i1 ^ i2) < 0) { |
88 | 0 | score_less = i1 < 0; |
89 | 0 | } else { |
90 | | // If signs are the same (and not both negative), they must both be |
91 | | // non-negative. For non-negative floats, integer order is the same as |
92 | | // float order. |
93 | 0 | score_less = i1 < i2; |
94 | 0 | } |
95 | |
|
96 | 0 | return score_less || (i1 == i2 && h1.left > h2.left); |
97 | 0 | } |
98 | 376 | }; |
99 | | |
100 | 376 | struct Symbol { |
101 | 376 | int prev; // prev index of this symbol. -1 for BOS. |
102 | 376 | int next; // next index of tihs symbol. -1 for EOS. |
103 | 376 | bool freeze; // this symbol is never be merged. |
104 | 376 | absl::string_view piece; |
105 | 376 | }; |
106 | | |
107 | 376 | std::vector<Symbol> symbols; |
108 | 376 | symbols.reserve(normalized.size()); |
109 | | |
110 | | // Splits the input into Symbols doing longest prefix match of the input |
111 | | // from pieces(type:UNUSED) in the vocabulary. |
112 | | // Does character splitting as a fallback of longest prefix match. |
113 | 376 | int index = 0; |
114 | 820k | while (!normalized.empty()) { |
115 | 820k | Symbol s; |
116 | 820k | const int mblen = matcher_->PrefixMatch(normalized, &s.freeze); |
117 | 820k | s.piece = absl::string_view(normalized.data(), mblen); |
118 | 820k | s.prev = index == 0 ? -1 : index - 1; |
119 | 820k | normalized.remove_prefix(mblen); |
120 | 820k | s.next = normalized.empty() ? -1 : index + 1; |
121 | 820k | ++index; |
122 | 820k | symbols.emplace_back(s); |
123 | 820k | } |
124 | | |
125 | 376 | if (symbols.empty()) { |
126 | 0 | return {}; |
127 | 0 | } |
128 | | |
129 | 376 | std::vector<SymbolPair> agenda_vec; |
130 | 376 | agenda_vec.reserve(symbols.size()); |
131 | | |
132 | | // Reverse merge rules. |
133 | | // key: merged symbol, value: pair of original symbols. |
134 | 376 | absl::flat_hash_map<absl::string_view, |
135 | 376 | std::pair<absl::string_view, absl::string_view>> |
136 | 376 | rev_merge; |
137 | | |
138 | | // Lookup all bigrams. |
139 | 376 | if (symbols.size() > 1) { |
140 | 376 | int left = 0; |
141 | 376 | int right = 1; |
142 | 376 | Symbol *symbol_left = &symbols[left]; |
143 | 376 | Symbol *symbol_right = &symbols[right]; |
144 | 820k | for (; right < symbols.size(); |
145 | 819k | left = right, symbol_left = symbol_right, ++right, ++symbol_right) { |
146 | 819k | if (symbol_left->freeze || symbol_right->freeze) continue; |
147 | 819k | const absl::string_view piece( |
148 | 819k | symbol_left->piece.data(), |
149 | 819k | symbol_left->piece.size() + symbol_right->piece.size()); |
150 | 819k | const auto it = pieces_.find(piece); |
151 | 819k | if (it == pieces_.end()) continue; |
152 | 0 | SymbolPair &h = agenda_vec.emplace_back(); |
153 | 0 | h.left = left; |
154 | 0 | h.right = right; |
155 | 0 | h.score = GetScore(it->second); |
156 | 0 | h.size = piece.size(); |
157 | | |
158 | | // Makes `rev_merge` for resegmentation. |
159 | 0 | if (IsUnusedInlined(it->second)) |
160 | 0 | rev_merge[piece] = |
161 | 0 | std::make_pair(symbol_left->piece, symbol_right->piece); |
162 | 0 | } |
163 | 376 | } |
164 | | |
165 | 376 | using Agenda = std::priority_queue<SymbolPair, std::vector<SymbolPair>, |
166 | 376 | SymbolPairComparator>; |
167 | 376 | Agenda agenda(SymbolPairComparator(), std::move(agenda_vec)); |
168 | | // Lookup new symbol pair at [left, right] and inserts it to agenda. |
169 | 376 | auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge]( |
170 | 376 | int left, int right) { |
171 | 0 | if (left == -1 || right == -1) return; |
172 | 0 | const Symbol &left_symbol = symbols[left]; |
173 | 0 | const Symbol &right_symbol = symbols[right]; |
174 | 0 | if (left_symbol.freeze || right_symbol.freeze) return; |
175 | 0 | const absl::string_view piece( |
176 | 0 | left_symbol.piece.data(), |
177 | 0 | left_symbol.piece.size() + right_symbol.piece.size()); |
178 | 0 | const auto it = pieces_.find(piece); |
179 | 0 | if (it == pieces_.end()) { |
180 | 0 | return; |
181 | 0 | } |
182 | 0 | const int id = it->second; |
183 | 0 | SymbolPair h; |
184 | 0 | h.left = left; |
185 | 0 | h.right = right; |
186 | 0 | h.score = GetScore(id); |
187 | 0 | h.size = piece.size(); |
188 | 0 | agenda.push(h); |
189 | | |
190 | | // Makes `rev_merge` for resegmentation. |
191 | 0 | if (IsUnusedInlined(id)) |
192 | 0 | rev_merge[piece] = std::make_pair(left_symbol.piece, right_symbol.piece); |
193 | 0 | }; |
194 | | |
195 | 376 | absl::BitGen *rand_gen = nullptr; |
196 | | // Main loop. |
197 | 376 | while (!agenda.empty()) { |
198 | | // Pop the top pair if it is stale. |
199 | 0 | const SymbolPair &top_ref = agenda.top(); |
200 | 0 | if (symbols[top_ref.left].piece.empty() || |
201 | 0 | symbols[top_ref.right].piece.empty() || |
202 | 0 | (symbols[top_ref.left].piece.size() + |
203 | 0 | symbols[top_ref.right].piece.size() != |
204 | 0 | top_ref.size)) { |
205 | 0 | agenda.pop(); |
206 | 0 | continue; |
207 | 0 | } |
208 | | |
209 | 0 | SymbolPair top = agenda.top(); |
210 | 0 | agenda.pop(); |
211 | |
|
212 | 0 | Symbol &left_symbol = symbols[top.left]; |
213 | 0 | Symbol &right_symbol = symbols[top.right]; |
214 | | |
215 | | // Note that original BPE-dropout paper assumes that all merged symbols are |
216 | | // pre computed, but here we randomly skip merge operation inside this loop. |
217 | | // This implementation is theoretically equivalent to the original one. |
218 | | // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf |
219 | 0 | if (alpha > 0.0) { |
220 | 0 | if (alpha >= 1.0) continue; |
221 | 0 | if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator(); |
222 | 0 | std::uniform_real_distribution<> gen(0.0, 1.0); |
223 | 0 | if (gen(*rand_gen) < alpha) continue; |
224 | 0 | } |
225 | | |
226 | | // Replaces symbols with `top` rule. |
227 | 0 | left_symbol.piece = |
228 | 0 | absl::string_view(left_symbol.piece.data(), |
229 | 0 | left_symbol.piece.size() + right_symbol.piece.size()); |
230 | | |
231 | | // Updates prev/next pointers. |
232 | 0 | left_symbol.next = right_symbol.next; |
233 | 0 | if (right_symbol.next >= 0) { |
234 | 0 | symbols[right_symbol.next].prev = top.left; |
235 | 0 | } |
236 | 0 | right_symbol.piece = absl::string_view(""); |
237 | | |
238 | | // Adds new symbol pairs which are newly added after symbol replacement. |
239 | 0 | MaybeAddNewSymbolPair(left_symbol.prev, top.left); |
240 | 0 | MaybeAddNewSymbolPair(top.left, left_symbol.next); |
241 | 0 | } |
242 | | |
243 | 376 | std::function<void(absl::string_view, EncodeResult *)> resegment; |
244 | 376 | resegment = [this, &resegment, &rev_merge](absl::string_view w, |
245 | 820k | EncodeResult *output) -> void { |
246 | 820k | const int id = PieceToId(w); |
247 | 820k | if (id == -1 || !IsUnusedInlined(id)) { |
248 | 820k | output->emplace_back(w, id); |
249 | 820k | return; |
250 | 820k | } |
251 | 0 | const auto p = rev_merge.find(w); |
252 | 0 | if (p == rev_merge.end()) { |
253 | | // This block will never be called, as `rev_merge` stores all the |
254 | | // resegmentation info for unused id. |
255 | 0 | output->emplace_back(w, id); |
256 | 0 | return; |
257 | 0 | } |
258 | | // Recursively resegment left and right symbols. |
259 | 0 | resegment(p->second.first, output); |
260 | 0 | resegment(p->second.second, output); |
261 | 0 | }; |
262 | | |
263 | 376 | EncodeResult output; |
264 | 376 | output.reserve(symbols.size()); |
265 | 820k | for (int index = 0; index != -1; index = symbols[index].next) { |
266 | 820k | if (index >= 0 && index < static_cast<int>(symbols.size())) { |
267 | 820k | resegment(symbols[index].piece, &output); |
268 | 820k | } |
269 | 820k | } |
270 | | |
271 | 376 | return output; |
272 | 376 | } |
273 | | } // namespace bpe |
274 | | } // namespace sentencepiece |