/src/sentencepiece/src/model_interface.cc
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #include "model_interface.h" |
16 | | |
17 | | #include <algorithm> |
18 | | |
19 | | #include "sentencepiece_model.pb.h" |
20 | | #include "third_party/absl/strings/str_format.h" |
21 | | #include "util.h" |
22 | | |
23 | | namespace sentencepiece { |
24 | | |
25 | | ModelInterface::ModelInterface(const ModelProto &model_proto) |
26 | 0 | : model_proto_(&model_proto), status_(util::OkStatus()) {} |
27 | 3.17k | ModelInterface::~ModelInterface() {} |
28 | | |
29 | | #define RETURN_PIECE(name, default_value) \ |
30 | 8.13k | if (model_proto_->trainer_spec().name().empty()) return default_value; \ |
31 | 8.12k | return model_proto_->trainer_spec().name(); |
32 | | |
33 | 2.03k | absl::string_view ModelInterface::unk_piece() const { |
34 | 2.03k | RETURN_PIECE(unk_piece, "<unk>"); |
35 | 0 | } |
36 | | |
37 | 2.03k | absl::string_view ModelInterface::bos_piece() const { |
38 | 2.03k | RETURN_PIECE(bos_piece, "<s>"); |
39 | 0 | } |
40 | | |
41 | 2.03k | absl::string_view ModelInterface::eos_piece() const { |
42 | 2.03k | RETURN_PIECE(eos_piece, "</s>"); |
43 | 0 | } |
44 | | |
45 | 2.03k | absl::string_view ModelInterface::pad_piece() const { |
46 | 2.03k | RETURN_PIECE(pad_piece, "<pad>"); |
47 | 0 | } |
48 | | |
49 | | #undef RETURN_PIECE |
50 | | |
51 | 4.42M | int ModelInterface::PieceToId(absl::string_view piece) const { |
52 | 4.42M | if (auto it = reserved_id_map_.find(piece); it != reserved_id_map_.end()) { |
53 | 20.4k | return it->second; |
54 | 20.4k | } |
55 | 4.40M | if (auto it = pieces_.find(piece); it != pieces_.end()) { |
56 | 109k | return it->second; |
57 | 109k | } |
58 | 4.29M | return unk_id_; |
59 | 4.40M | } |
60 | | |
61 | 3.17k | void ModelInterface::InitializePieces() { |
62 | 3.17k | pieces_.clear(); |
63 | 3.17k | reserved_id_map_.clear(); |
64 | 3.17k | unk_id_ = -1; |
65 | | |
66 | 3.17k | std::set<absl::string_view> user_defined_symbols; |
67 | 3.17k | std::vector<bool> byte_found(256, false); |
68 | | |
69 | 3.17k | int pieces_size = 0; |
70 | 3.17k | int reserved_id_map_size = 0; |
71 | 1.31M | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
72 | 1.31M | const auto &sp = model_proto_->pieces(i); |
73 | 1.31M | const bool is_normal_piece = |
74 | 1.31M | (sp.type() == ModelProto::SentencePiece::NORMAL || |
75 | 26.3k | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
76 | 25.6k | sp.type() == ModelProto::SentencePiece::UNUSED); |
77 | 1.31M | if (is_normal_piece) { |
78 | 1.28M | ++pieces_size; |
79 | 1.28M | } else { |
80 | 25.6k | ++reserved_id_map_size; |
81 | 25.6k | } |
82 | 1.31M | } |
83 | 3.17k | pieces_.reserve(pieces_size); |
84 | 3.17k | reserved_id_map_.reserve(reserved_id_map_size); |
85 | | |
86 | 47.0k | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
87 | 44.8k | const auto &sp = model_proto_->pieces(i); |
88 | 44.8k | if (sp.piece().empty()) { |
89 | 48 | status_ = util::InternalError("piece must not be empty."); |
90 | 48 | return; |
91 | 48 | } |
92 | 44.7k | if (sp.piece().find('\0') != absl::string_view::npos) { |
93 | 850 | status_ = util::InternalError("piece must not include null character."); |
94 | 850 | return; |
95 | 850 | } |
96 | 43.9k | const bool is_normal_piece = |
97 | 43.9k | (sp.type() == ModelProto::SentencePiece::NORMAL || |
98 | 21.2k | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
99 | 20.6k | sp.type() == ModelProto::SentencePiece::UNUSED); |
100 | 43.9k | if (!port::InsertIfNotPresent( |
101 | 43.9k | is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) { |
102 | 10 | status_ = util::InternalError(sp.piece() + " is already defined."); |
103 | 10 | return; |
104 | 10 | } |
105 | | |
106 | 43.9k | if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) { |
107 | 602 | user_defined_symbols.insert(sp.piece()); |
108 | 602 | } |
109 | | |
110 | 43.9k | if (sp.type() == ModelProto::SentencePiece::UNKNOWN) { |
111 | 2.96k | if (unk_id_ >= 0) { |
112 | 2 | status_ = util::InternalError("unk is already defined."); |
113 | 2 | return; |
114 | 2 | } |
115 | 2.96k | unk_id_ = i; |
116 | 2.96k | } |
117 | | |
118 | 43.9k | if (sp.type() == ModelProto::SentencePiece::BYTE) { |
119 | 14.7k | if (!model_proto_->trainer_spec().byte_fallback()) { |
120 | 6 | status_ = |
121 | 6 | util::InternalError("byte piece " + sp.piece() + |
122 | 6 | " is found although `byte_fallback` is false."); |
123 | 6 | return; |
124 | 6 | } |
125 | 14.7k | const int byte = PieceToByte(sp.piece()); |
126 | 14.7k | if (0 <= byte && byte < 256) { |
127 | 14.7k | byte_found[byte] = true; |
128 | 14.7k | } else { |
129 | 32 | status_ = |
130 | 32 | util::InternalError("byte piece " + sp.piece() + " is invalid."); |
131 | 32 | return; |
132 | 32 | } |
133 | 14.7k | } |
134 | 43.9k | } |
135 | | |
136 | 2.23k | if (unk_id_ == -1) { |
137 | 107 | status_ = util::InternalError("unk is not defined."); |
138 | 107 | return; |
139 | 107 | } |
140 | | |
141 | 2.12k | if (model_proto_->trainer_spec().byte_fallback()) { |
142 | | // Checks that there are 256 byte pieces. |
143 | 67 | if (std::find(byte_found.begin(), byte_found.end(), false) != |
144 | 67 | byte_found.end()) { |
145 | 39 | status_ = util::InternalError( |
146 | 39 | "there are not 256 byte pieces although `byte_fallback` is true."); |
147 | 39 | return; |
148 | 39 | } |
149 | 67 | } |
150 | | |
151 | 2.08k | matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols); |
152 | 2.08k | } |
153 | | |
154 | | std::vector<absl::string_view> SplitIntoWords(absl::string_view text, |
155 | | bool treat_ws_as_suffix, |
156 | 761 | bool allow_ws_only_pieces) { |
157 | 761 | const char *begin = text.data(); |
158 | 761 | const char *end = text.data() + text.size(); |
159 | | |
160 | | // Space symbol (U+2581) |
161 | 761 | constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81"; |
162 | 761 | bool in_ws_sequence = false; |
163 | | |
164 | 761 | std::vector<absl::string_view> result; |
165 | 761 | if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences. |
166 | 0 | if (begin < end) result.emplace_back(begin, 0); |
167 | 0 | while (begin < end) { |
168 | 0 | const int mblen = |
169 | 0 | std::min<int>(string_util::OneCharLen(begin), end - begin); |
170 | 0 | const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
171 | |
|
172 | 0 | if (is_ws) { // keep track of sequences consecutive ws tokens. |
173 | 0 | in_ws_sequence = true; |
174 | 0 | } else if (in_ws_sequence) { |
175 | 0 | if (allow_ws_only_pieces) result.emplace_back(begin, 0); |
176 | |
|
177 | 0 | in_ws_sequence = false; |
178 | 0 | } |
179 | |
|
180 | 0 | result.back() = |
181 | 0 | absl::string_view(result.back().data(), result.back().size() + mblen); |
182 | 0 | begin += mblen; |
183 | |
|
184 | 0 | if (begin < end && is_ws && !allow_ws_only_pieces) |
185 | 0 | result.emplace_back(begin, 0); |
186 | 0 | } |
187 | 761 | } else { |
188 | 19.0M | while (begin < end) { |
189 | 19.0M | const int mblen = |
190 | 19.0M | std::min<int>(string_util::OneCharLen(begin), end - begin); |
191 | 19.0M | bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
192 | | |
193 | | // if is whitespace (and not in sequence if allow_ws_only_pieces is True) |
194 | 19.0M | if (begin == text.data() || |
195 | 19.0M | (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) { |
196 | 107k | result.emplace_back(begin, 0); // add empty string piece. |
197 | 107k | in_ws_sequence = true; |
198 | 107k | } |
199 | | |
200 | 19.0M | if (in_ws_sequence && !is_ws) in_ws_sequence = false; |
201 | | |
202 | 19.0M | result.back() = |
203 | 19.0M | absl::string_view(result.back().data(), result.back().size() + mblen); |
204 | 19.0M | begin += mblen; |
205 | 19.0M | } |
206 | 761 | } |
207 | | |
208 | 761 | return result; |
209 | 761 | } |
210 | | |
211 | 1.62k | std::string ByteToPiece(unsigned char c) { |
212 | 1.62k | return absl::StrFormat("<0x%02X>", c); |
213 | 1.62k | } |
214 | | |
215 | 15.3k | int PieceToByte(absl::string_view piece) { |
216 | 15.3k | using PieceToByteMap = absl::flat_hash_map<std::string, unsigned char>; |
217 | 15.3k | static const auto *const kMap = []() -> PieceToByteMap * { |
218 | 1 | auto *m = new PieceToByteMap(); |
219 | 257 | for (int i = 0; i < 256; ++i) { |
220 | 256 | (*m)[ByteToPiece(i)] = i; |
221 | 256 | } |
222 | 1 | return m; |
223 | 1 | }(); |
224 | | |
225 | 15.3k | if (const auto it = kMap->find(piece); it != kMap->end()) { |
226 | 15.3k | return it->second; |
227 | 15.3k | } |
228 | | |
229 | 32 | return -1; |
230 | 15.3k | } |
231 | | |
232 | | } // namespace sentencepiece |