/src/sentencepiece/src/model_interface.cc
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #include "model_interface.h" |
16 | | |
17 | | #include <algorithm> |
18 | | |
19 | | #include "sentencepiece_model.pb.h" |
20 | | #include "third_party/absl/strings/str_format.h" |
21 | | #include "util.h" |
22 | | |
23 | | namespace sentencepiece { |
24 | | |
25 | | ModelInterface::ModelInterface(const ModelProto &model_proto) |
26 | 0 | : model_proto_(&model_proto), status_(util::OkStatus()) {} |
27 | 3.17k | ModelInterface::~ModelInterface() {} |
28 | | |
29 | | #define RETURN_PIECE(name, default_value) \ |
30 | 8.13k | if (model_proto_->trainer_spec().name().empty()) return default_value; \ |
31 | 8.12k | return model_proto_->trainer_spec().name(); |
32 | | |
33 | 2.03k | absl::string_view ModelInterface::unk_piece() const { |
34 | 2.03k | RETURN_PIECE(unk_piece, "<unk>"); |
35 | 0 | } |
36 | | |
37 | 2.03k | absl::string_view ModelInterface::bos_piece() const { |
38 | 2.03k | RETURN_PIECE(bos_piece, "<s>"); |
39 | 0 | } |
40 | | |
41 | 2.03k | absl::string_view ModelInterface::eos_piece() const { |
42 | 2.03k | RETURN_PIECE(eos_piece, "</s>"); |
43 | 0 | } |
44 | | |
45 | 2.03k | absl::string_view ModelInterface::pad_piece() const { |
46 | 2.03k | RETURN_PIECE(pad_piece, "<pad>"); |
47 | 0 | } |
48 | | |
49 | | #undef RETURN_PIECE |
50 | | |
51 | 4.42M | int ModelInterface::PieceToId(absl::string_view piece) const { |
52 | 4.42M | if (auto it = reserved_id_map_.find(piece); it != reserved_id_map_.end()) { |
53 | 20.4k | return it->second; |
54 | 20.4k | } |
55 | 4.40M | if (auto it = pieces_.find(piece); it != pieces_.end()) { |
56 | 109k | return it->second; |
57 | 109k | } |
58 | 4.29M | return unk_id_; |
59 | 4.40M | } |
60 | | |
61 | 3.17k | void ModelInterface::InitializePieces() { |
62 | 3.17k | pieces_.clear(); |
63 | 3.17k | reserved_id_map_.clear(); |
64 | 3.17k | unk_id_ = -1; |
65 | | |
66 | 3.17k | std::set<absl::string_view> user_defined_symbols; |
67 | 3.17k | std::vector<bool> byte_found(256, false); |
68 | | |
69 | 3.17k | int pieces_size = 0; |
70 | 3.17k | int reserved_id_map_size = 0; |
71 | 1.31M | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
72 | 1.31M | const auto &sp = model_proto_->pieces(i); |
73 | 1.31M | static constexpr size_t kMaxPieceSize = 8192; |
74 | 1.31M | if (sp.piece().size() >= kMaxPieceSize) { |
75 | 0 | status_ = util::InternalError("piece size must be less than 8k."); |
76 | 0 | return; |
77 | 0 | } |
78 | 1.31M | const bool is_normal_piece = |
79 | 1.31M | (sp.type() == ModelProto::SentencePiece::NORMAL || |
80 | 26.3k | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
81 | 25.6k | sp.type() == ModelProto::SentencePiece::UNUSED); |
82 | 1.31M | if (is_normal_piece) { |
83 | 1.28M | ++pieces_size; |
84 | 1.28M | } else { |
85 | 25.6k | ++reserved_id_map_size; |
86 | 25.6k | } |
87 | 1.31M | } |
88 | 3.17k | pieces_.reserve(pieces_size); |
89 | 3.17k | reserved_id_map_.reserve(reserved_id_map_size); |
90 | | |
91 | 47.0k | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
92 | 44.8k | const auto &sp = model_proto_->pieces(i); |
93 | 44.8k | if (sp.piece().empty()) { |
94 | 48 | status_ = util::InternalError("piece must not be empty."); |
95 | 48 | return; |
96 | 48 | } |
97 | 44.7k | if (sp.piece().find('\0') != absl::string_view::npos) { |
98 | 850 | status_ = util::InternalError("piece must not include null character."); |
99 | 850 | return; |
100 | 850 | } |
101 | 43.9k | const bool is_normal_piece = |
102 | 43.9k | (sp.type() == ModelProto::SentencePiece::NORMAL || |
103 | 21.2k | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
104 | 20.6k | sp.type() == ModelProto::SentencePiece::UNUSED); |
105 | 43.9k | if (!port::InsertIfNotPresent( |
106 | 43.9k | is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) { |
107 | 10 | status_ = util::InternalError(sp.piece() + " is already defined."); |
108 | 10 | return; |
109 | 10 | } |
110 | | |
111 | 43.9k | if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) { |
112 | 602 | user_defined_symbols.insert(sp.piece()); |
113 | 602 | } |
114 | | |
115 | 43.9k | if (sp.type() == ModelProto::SentencePiece::UNKNOWN) { |
116 | 2.96k | if (unk_id_ >= 0) { |
117 | 2 | status_ = util::InternalError("unk is already defined."); |
118 | 2 | return; |
119 | 2 | } |
120 | 2.96k | unk_id_ = i; |
121 | 2.96k | } |
122 | | |
123 | 43.9k | if (sp.type() == ModelProto::SentencePiece::BYTE) { |
124 | 14.7k | if (!model_proto_->trainer_spec().byte_fallback()) { |
125 | 6 | status_ = |
126 | 6 | util::InternalError("byte piece " + sp.piece() + |
127 | 6 | " is found although `byte_fallback` is false."); |
128 | 6 | return; |
129 | 6 | } |
130 | 14.7k | const int byte = PieceToByte(sp.piece()); |
131 | 14.7k | if (0 <= byte && byte < 256) { |
132 | 14.7k | byte_found[byte] = true; |
133 | 14.7k | } else { |
134 | 32 | status_ = |
135 | 32 | util::InternalError("byte piece " + sp.piece() + " is invalid."); |
136 | 32 | return; |
137 | 32 | } |
138 | 14.7k | } |
139 | 43.9k | } |
140 | | |
141 | 2.23k | if (unk_id_ == -1) { |
142 | 107 | status_ = util::InternalError("unk is not defined."); |
143 | 107 | return; |
144 | 107 | } |
145 | | |
146 | 2.12k | if (model_proto_->trainer_spec().byte_fallback()) { |
147 | | // Checks that there are 256 byte pieces. |
148 | 67 | if (std::find(byte_found.begin(), byte_found.end(), false) != |
149 | 67 | byte_found.end()) { |
150 | 39 | status_ = util::InternalError( |
151 | 39 | "there are not 256 byte pieces although `byte_fallback` is true."); |
152 | 39 | return; |
153 | 39 | } |
154 | 67 | } |
155 | | |
156 | 2.08k | matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols); |
157 | 2.08k | } |
158 | | |
159 | | std::vector<absl::string_view> SplitIntoWords(absl::string_view text, |
160 | | bool treat_ws_as_suffix, |
161 | 761 | bool allow_ws_only_pieces) { |
162 | 761 | const char *begin = text.data(); |
163 | 761 | const char *end = text.data() + text.size(); |
164 | | |
165 | | // Space symbol (U+2581) |
166 | 761 | constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81"; |
167 | 761 | bool in_ws_sequence = false; |
168 | | |
169 | 761 | std::vector<absl::string_view> result; |
170 | 761 | if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences. |
171 | 0 | if (begin < end) result.emplace_back(begin, 0); |
172 | 0 | while (begin < end) { |
173 | 0 | const int mblen = |
174 | 0 | std::min<int>(string_util::OneCharLen(begin), end - begin); |
175 | 0 | const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
176 | |
|
177 | 0 | if (is_ws) { // keep track of sequences consecutive ws tokens. |
178 | 0 | in_ws_sequence = true; |
179 | 0 | } else if (in_ws_sequence) { |
180 | 0 | if (allow_ws_only_pieces) result.emplace_back(begin, 0); |
181 | |
|
182 | 0 | in_ws_sequence = false; |
183 | 0 | } |
184 | |
|
185 | 0 | result.back() = |
186 | 0 | absl::string_view(result.back().data(), result.back().size() + mblen); |
187 | 0 | begin += mblen; |
188 | |
|
189 | 0 | if (begin < end && is_ws && !allow_ws_only_pieces) |
190 | 0 | result.emplace_back(begin, 0); |
191 | 0 | } |
192 | 761 | } else { |
193 | 19.0M | while (begin < end) { |
194 | 19.0M | const int mblen = |
195 | 19.0M | std::min<int>(string_util::OneCharLen(begin), end - begin); |
196 | 19.0M | bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
197 | | |
198 | | // if is whitespace (and not in sequence if allow_ws_only_pieces is True) |
199 | 19.0M | if (begin == text.data() || |
200 | 19.0M | (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) { |
201 | 107k | result.emplace_back(begin, 0); // add empty string piece. |
202 | 107k | in_ws_sequence = true; |
203 | 107k | } |
204 | | |
205 | 19.0M | if (in_ws_sequence && !is_ws) in_ws_sequence = false; |
206 | | |
207 | 19.0M | result.back() = |
208 | 19.0M | absl::string_view(result.back().data(), result.back().size() + mblen); |
209 | 19.0M | begin += mblen; |
210 | 19.0M | } |
211 | 761 | } |
212 | | |
213 | 761 | return result; |
214 | 761 | } |
215 | | |
216 | 1.62k | const std::string &ByteToPiece(unsigned char c) { |
217 | 1.62k | static const std::vector<std::string> *const kBytePieces = []() { |
218 | 1 | auto *v = new std::vector<std::string>(256); |
219 | 257 | for (int i = 0; i < 256; ++i) { |
220 | 256 | (*v)[i] = absl::StrFormat("<0x%02X>", i); |
221 | 256 | } |
222 | 1 | return v; |
223 | 1 | }(); |
224 | 1.62k | return (*kBytePieces)[c]; |
225 | 1.62k | } |
226 | | |
227 | 15.3k | int PieceToByte(absl::string_view piece) { |
228 | 15.3k | using PieceToByteMap = absl::flat_hash_map<absl::string_view, unsigned char>; |
229 | 15.3k | static const auto *const kMap = []() -> PieceToByteMap * { |
230 | 1 | auto *m = new PieceToByteMap(); |
231 | 257 | for (int i = 0; i < 256; ++i) { |
232 | 256 | (*m)[ByteToPiece(i)] = i; |
233 | 256 | } |
234 | 1 | return m; |
235 | 1 | }(); |
236 | | |
237 | 15.3k | if (const auto it = kMap->find(piece); it != kMap->end()) { |
238 | 15.3k | return it->second; |
239 | 15.3k | } |
240 | | |
241 | 32 | return -1; |
242 | 15.3k | } |
243 | | |
244 | | } // namespace sentencepiece |