/src/sentencepiece/src/bpe_model.cc
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #include "bpe_model.h" |
16 | | |
17 | | #include <functional> |
18 | | #include <memory> |
19 | | #include <queue> |
20 | | #include <random> |
21 | | #include <utility> |
22 | | #include <vector> |
23 | | |
24 | | #include "freelist.h" |
25 | | #include "third_party/absl/container/flat_hash_map.h" |
26 | | #include "util.h" |
27 | | |
28 | | namespace sentencepiece { |
29 | | namespace bpe { |
30 | | |
31 | 0 | Model::Model(const ModelProto &model_proto) { |
32 | 0 | model_proto_ = &model_proto; |
33 | 0 | InitializePieces(); |
34 | 0 | } |
35 | | |
36 | 0 | Model::~Model() {} |
37 | | |
38 | | std::vector<std::pair<absl::string_view, int>> Model::SampleEncode( |
39 | 0 | absl::string_view normalized, float alpha) const { |
40 | 0 | if (!status().ok() || normalized.empty()) { |
41 | 0 | return {}; |
42 | 0 | } |
43 | | |
44 | 0 | struct SymbolPair { |
45 | 0 | int left; // left index of this pair |
46 | 0 | int right; // right index of this pair |
47 | 0 | float score; // score of this pair. large is better. |
48 | 0 | size_t size; // length of this piece |
49 | 0 | }; |
50 | |
|
51 | 0 | class SymbolPairComparator { |
52 | 0 | public: |
53 | 0 | const bool operator()(SymbolPair *h1, SymbolPair *h2) { |
54 | 0 | return (h1->score < h2->score || |
55 | 0 | (h1->score == h2->score && h1->left > h2->left)); |
56 | 0 | } |
57 | 0 | }; |
58 | |
|
59 | 0 | struct Symbol { |
60 | 0 | int prev; // prev index of this symbol. -1 for BOS. |
61 | 0 | int next; // next index of tihs symbol. -1 for EOS. |
62 | 0 | bool freeze; // this symbol is never be merged. |
63 | 0 | absl::string_view piece; |
64 | 0 | }; |
65 | |
|
66 | 0 | using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>, |
67 | 0 | SymbolPairComparator>; |
68 | 0 | Agenda agenda; |
69 | 0 | std::vector<Symbol> symbols; |
70 | 0 | symbols.reserve(normalized.size()); |
71 | | |
72 | | // Reverse merge rules. |
73 | | // key: merged symbol, value: pair of original symbols. |
74 | 0 | absl::flat_hash_map<absl::string_view, |
75 | 0 | std::pair<absl::string_view, absl::string_view>> |
76 | 0 | rev_merge; |
77 | | |
78 | | // Pre-allocates SymbolPair for efficiency. |
79 | 0 | constexpr size_t kPreallocateSymbolPairSize = 256; |
80 | 0 | model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize); |
81 | | |
82 | | // Lookup new symbol pair at [left, right] and inserts it to agenda. |
83 | 0 | auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda, |
84 | 0 | &rev_merge](int left, int right) { |
85 | 0 | if (left == -1 || right == -1 || symbols[left].freeze || |
86 | 0 | symbols[right].freeze) |
87 | 0 | return; |
88 | 0 | const absl::string_view piece( |
89 | 0 | symbols[left].piece.data(), |
90 | 0 | symbols[left].piece.size() + symbols[right].piece.size()); |
91 | 0 | const auto it = pieces_.find(piece); |
92 | 0 | if (it == pieces_.end()) { |
93 | 0 | return; |
94 | 0 | } |
95 | 0 | auto *h = symbol_pair_allocator.Allocate(); |
96 | 0 | h->left = left; |
97 | 0 | h->right = right; |
98 | 0 | h->score = GetScore(it->second); |
99 | 0 | h->size = piece.size(); |
100 | 0 | agenda.push(h); |
101 | | |
102 | | // Makes `rev_merge` for resegmentation. |
103 | 0 | if (IsUnusedInlined(it->second)) { |
104 | 0 | rev_merge[piece] = |
105 | 0 | std::make_pair(symbols[left].piece, symbols[right].piece); |
106 | 0 | } |
107 | 0 | }; |
108 | | |
109 | | // Splits the input into character sequence |
110 | 0 | int index = 0; |
111 | 0 | while (!normalized.empty()) { |
112 | 0 | Symbol s; |
113 | 0 | const int mblen = matcher_->PrefixMatch(normalized, &s.freeze); |
114 | 0 | s.piece = absl::string_view(normalized.data(), mblen); |
115 | 0 | s.prev = index == 0 ? -1 : index - 1; |
116 | 0 | normalized.remove_prefix(mblen); |
117 | 0 | s.next = normalized.empty() ? -1 : index + 1; |
118 | 0 | ++index; |
119 | 0 | symbols.emplace_back(s); |
120 | 0 | } |
121 | |
|
122 | 0 | if (symbols.empty()) { |
123 | 0 | return {}; |
124 | 0 | } |
125 | | |
126 | | // Lookup all bigrams. |
127 | 0 | for (size_t i = 1; i < symbols.size(); ++i) { |
128 | 0 | MaybeAddNewSymbolPair(i - 1, i); |
129 | 0 | } |
130 | | |
131 | | // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf |
132 | 0 | std::mt19937 *rand_gen = nullptr; |
133 | 0 | auto skip_merge = [&]() { |
134 | 0 | if (alpha <= 0.0) return false; |
135 | 0 | if (alpha >= 1.0) return true; |
136 | 0 | if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator(); |
137 | 0 | std::uniform_real_distribution<> gen(0.0, 1.0); |
138 | 0 | return gen(*rand_gen) < alpha; |
139 | 0 | }; |
140 | | |
141 | | // Main loop. |
142 | 0 | while (!agenda.empty()) { |
143 | 0 | SymbolPair *top = agenda.top(); |
144 | 0 | agenda.pop(); |
145 | | |
146 | | // `top` is no longer available. |
147 | 0 | if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() || |
148 | 0 | symbols[top->left].piece.size() + symbols[top->right].piece.size() != |
149 | 0 | top->size) { |
150 | 0 | continue; |
151 | 0 | } |
152 | | |
153 | | // Note that orignal BPE-dropout paper assumes that all merged symbols are |
154 | | // pre computed, but here we randomly skip merge opration inside this loop. |
155 | | // This implemenation is theoretically equivalent to the original one. |
156 | 0 | if (skip_merge()) continue; |
157 | | |
158 | | // Replaces symbols with `top` rule. |
159 | 0 | symbols[top->left].piece = absl::string_view( |
160 | 0 | symbols[top->left].piece.data(), |
161 | 0 | symbols[top->left].piece.size() + symbols[top->right].piece.size()); |
162 | | |
163 | | // Updates prev/next pointers. |
164 | 0 | symbols[top->left].next = symbols[top->right].next; |
165 | 0 | if (symbols[top->right].next >= 0) { |
166 | 0 | symbols[symbols[top->right].next].prev = top->left; |
167 | 0 | } |
168 | 0 | symbols[top->right].piece = absl::string_view(""); |
169 | | |
170 | | // Adds new symbol pairs which are newly added after symbol replacement. |
171 | 0 | MaybeAddNewSymbolPair(symbols[top->left].prev, top->left); |
172 | 0 | MaybeAddNewSymbolPair(top->left, symbols[top->left].next); |
173 | 0 | } |
174 | |
|
175 | 0 | std::function<void(absl::string_view, EncodeResult *)> resegment; |
176 | 0 | resegment = [this, &resegment, &rev_merge](absl::string_view w, |
177 | 0 | EncodeResult *output) -> void { |
178 | 0 | const int id = PieceToId(w); |
179 | 0 | if (id == -1 || !IsUnusedInlined(id)) { |
180 | 0 | output->emplace_back(w, id); |
181 | 0 | return; |
182 | 0 | } |
183 | 0 | const auto p = rev_merge.find(w); |
184 | 0 | if (p == rev_merge.end()) { |
185 | | // This block will never be called, as `rev_merge` stores all the |
186 | | // resegmentation info for unused id. |
187 | 0 | output->emplace_back(w, id); |
188 | 0 | return; |
189 | 0 | } |
190 | | // Recursively resegment left and right symbols. |
191 | 0 | resegment(p->second.first, output); |
192 | 0 | resegment(p->second.second, output); |
193 | 0 | }; |
194 | |
|
195 | 0 | EncodeResult output; |
196 | 0 | for (int index = 0; index != -1; index = symbols[index].next) { |
197 | 0 | if (index >= 0 && index < static_cast<int>(symbols.size())) { |
198 | 0 | resegment(symbols[index].piece, &output); |
199 | 0 | } |
200 | 0 | } |
201 | |
|
202 | 0 | return output; |
203 | 0 | } |
204 | | } // namespace bpe |
205 | | } // namespace sentencepiece |