/src/sentencepiece/src/model_interface.cc
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #include "model_interface.h" |
16 | | |
17 | | #include <algorithm> |
18 | | |
19 | | #include "sentencepiece_model.pb.h" |
20 | | #include "third_party/absl/strings/str_format.h" |
21 | | #include "util.h" |
22 | | |
23 | | namespace sentencepiece { |
24 | | |
25 | | ModelInterface::ModelInterface(const ModelProto &model_proto) |
26 | 0 | : model_proto_(&model_proto), status_(util::OkStatus()) {} |
27 | 0 | ModelInterface::~ModelInterface() {} |
28 | | |
29 | | #define RETURN_PIECE(name, default_value) \ |
30 | 0 | if (model_proto_->trainer_spec().name().empty()) return default_value; \ |
31 | 0 | return model_proto_->trainer_spec().name(); |
32 | | |
33 | 0 | absl::string_view ModelInterface::unk_piece() const { |
34 | 0 | RETURN_PIECE(unk_piece, "<unk>"); |
35 | 0 | } |
36 | | |
37 | 0 | absl::string_view ModelInterface::bos_piece() const { |
38 | 0 | RETURN_PIECE(bos_piece, "<s>"); |
39 | 0 | } |
40 | | |
41 | 0 | absl::string_view ModelInterface::eos_piece() const { |
42 | 0 | RETURN_PIECE(eos_piece, "</s>"); |
43 | 0 | } |
44 | | |
45 | 0 | absl::string_view ModelInterface::pad_piece() const { |
46 | 0 | RETURN_PIECE(pad_piece, "<pad>"); |
47 | 0 | } |
48 | | |
49 | | #undef RETURN_PIECE |
50 | | |
51 | 0 | int ModelInterface::PieceToId(absl::string_view piece) const { |
52 | 0 | auto it = reserved_id_map_.find(piece); |
53 | 0 | if (it != reserved_id_map_.end()) { |
54 | 0 | return it->second; |
55 | 0 | } |
56 | 0 | auto it2 = pieces_.find(piece); |
57 | 0 | if (it2 != pieces_.end()) { |
58 | 0 | return it2->second; |
59 | 0 | } |
60 | 0 | return unk_id_; |
61 | 0 | } |
62 | | |
63 | 0 | void ModelInterface::InitializePieces() { |
64 | 0 | pieces_.clear(); |
65 | 0 | reserved_id_map_.clear(); |
66 | 0 | unk_id_ = -1; |
67 | |
|
68 | 0 | std::set<absl::string_view> user_defined_symbols; |
69 | 0 | std::vector<bool> byte_found(256, false); |
70 | |
|
71 | 0 | int pieces_size = 0; |
72 | 0 | int reserved_id_map_size = 0; |
73 | 0 | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
74 | 0 | const auto &sp = model_proto_->pieces(i); |
75 | 0 | const bool is_normal_piece = |
76 | 0 | (sp.type() == ModelProto::SentencePiece::NORMAL || |
77 | 0 | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
78 | 0 | sp.type() == ModelProto::SentencePiece::UNUSED); |
79 | 0 | if (is_normal_piece) { |
80 | 0 | ++pieces_size; |
81 | 0 | } else { |
82 | 0 | ++reserved_id_map_size; |
83 | 0 | } |
84 | 0 | } |
85 | 0 | pieces_.reserve(pieces_size); |
86 | 0 | reserved_id_map_.reserve(reserved_id_map_size); |
87 | |
|
88 | 0 | for (int i = 0; i < model_proto_->pieces_size(); ++i) { |
89 | 0 | const auto &sp = model_proto_->pieces(i); |
90 | 0 | if (sp.piece().empty()) { |
91 | 0 | status_ = util::InternalError("piece must not be empty."); |
92 | 0 | return; |
93 | 0 | } |
94 | | |
95 | 0 | const bool is_normal_piece = |
96 | 0 | (sp.type() == ModelProto::SentencePiece::NORMAL || |
97 | 0 | sp.type() == ModelProto::SentencePiece::USER_DEFINED || |
98 | 0 | sp.type() == ModelProto::SentencePiece::UNUSED); |
99 | 0 | if (!port::InsertIfNotPresent( |
100 | 0 | is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) { |
101 | 0 | status_ = util::InternalError(sp.piece() + " is already defined."); |
102 | 0 | return; |
103 | 0 | } |
104 | | |
105 | 0 | if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) { |
106 | 0 | user_defined_symbols.insert(sp.piece()); |
107 | 0 | } |
108 | |
|
109 | 0 | if (sp.type() == ModelProto::SentencePiece::UNKNOWN) { |
110 | 0 | if (unk_id_ >= 0) { |
111 | 0 | status_ = util::InternalError("unk is already defined."); |
112 | 0 | return; |
113 | 0 | } |
114 | 0 | unk_id_ = i; |
115 | 0 | } |
116 | | |
117 | 0 | if (sp.type() == ModelProto::SentencePiece::BYTE) { |
118 | 0 | if (!model_proto_->trainer_spec().byte_fallback()) { |
119 | 0 | status_ = |
120 | 0 | util::InternalError("byte piece " + sp.piece() + |
121 | 0 | " is found although `byte_fallback` is false."); |
122 | 0 | return; |
123 | 0 | } |
124 | 0 | const int byte = PieceToByte(sp.piece()); |
125 | 0 | if (0 <= byte && byte < 256) { |
126 | 0 | byte_found[byte] = true; |
127 | 0 | } else { |
128 | 0 | status_ = |
129 | 0 | util::InternalError("byte piece " + sp.piece() + " is invalid."); |
130 | 0 | return; |
131 | 0 | } |
132 | 0 | } |
133 | 0 | } |
134 | | |
135 | 0 | if (unk_id_ == -1) { |
136 | 0 | status_ = util::InternalError("unk is not defined."); |
137 | 0 | return; |
138 | 0 | } |
139 | | |
140 | 0 | if (model_proto_->trainer_spec().byte_fallback()) { |
141 | | // Checks that there are 256 byte pieces. |
142 | 0 | if (std::find(byte_found.begin(), byte_found.end(), false) != |
143 | 0 | byte_found.end()) { |
144 | 0 | status_ = util::InternalError( |
145 | 0 | "there are not 256 byte pieces although `byte_fallback` is true."); |
146 | 0 | return; |
147 | 0 | } |
148 | 0 | } |
149 | | |
150 | 0 | matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols); |
151 | 0 | } |
152 | | |
153 | | std::vector<absl::string_view> SplitIntoWords(absl::string_view text, |
154 | | bool treat_ws_as_suffix, |
155 | 0 | bool allow_ws_only_pieces) { |
156 | 0 | const char *begin = text.data(); |
157 | 0 | const char *end = text.data() + text.size(); |
158 | | |
159 | | // Space symbol (U+2581) |
160 | 0 | const absl::string_view kSpaceSymbol = "\xe2\x96\x81"; |
161 | 0 | bool in_ws_sequence = false; |
162 | |
|
163 | 0 | std::vector<absl::string_view> result; |
164 | 0 | if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences. |
165 | 0 | if (begin < end) result.emplace_back(begin, 0); |
166 | 0 | while (begin < end) { |
167 | 0 | const int mblen = |
168 | 0 | std::min<int>(string_util::OneCharLen(begin), end - begin); |
169 | 0 | const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
170 | |
|
171 | 0 | if (is_ws) { // keep track of sequences consecutive ws tokens. |
172 | 0 | in_ws_sequence = true; |
173 | 0 | } else if (in_ws_sequence) { |
174 | 0 | if (allow_ws_only_pieces) result.emplace_back(begin, 0); |
175 | |
|
176 | 0 | in_ws_sequence = false; |
177 | 0 | } |
178 | |
|
179 | 0 | result.back() = |
180 | 0 | absl::string_view(result.back().data(), result.back().size() + mblen); |
181 | 0 | begin += mblen; |
182 | |
|
183 | 0 | if (begin < end && is_ws && !allow_ws_only_pieces) |
184 | 0 | result.emplace_back(begin, 0); |
185 | 0 | } |
186 | 0 | } else { |
187 | 0 | while (begin < end) { |
188 | 0 | const int mblen = |
189 | 0 | std::min<int>(string_util::OneCharLen(begin), end - begin); |
190 | 0 | bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; |
191 | | |
192 | | // if is whitespace (and not in sequence if allow_ws_only_pieces is True) |
193 | 0 | if (begin == text.data() || |
194 | 0 | (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) { |
195 | 0 | result.emplace_back(begin, 0); // add empty string piece. |
196 | 0 | in_ws_sequence = true; |
197 | 0 | } |
198 | |
|
199 | 0 | if (in_ws_sequence && !is_ws) in_ws_sequence = false; |
200 | |
|
201 | 0 | result.back() = |
202 | 0 | absl::string_view(result.back().data(), result.back().size() + mblen); |
203 | 0 | begin += mblen; |
204 | 0 | } |
205 | 0 | } |
206 | |
|
207 | 0 | return result; |
208 | 0 | } |
209 | | |
210 | 0 | std::string ByteToPiece(unsigned char c) { |
211 | 0 | return absl::StrFormat("<0x%02X>", c); |
212 | 0 | } |
213 | | |
214 | 0 | int PieceToByte(absl::string_view piece) { |
215 | 0 | using PieceToByteMap = absl::flat_hash_map<std::string, unsigned char>; |
216 | 0 | static const auto *const kMap = []() -> PieceToByteMap * { |
217 | 0 | auto *m = new PieceToByteMap(); |
218 | 0 | for (int i = 0; i < 256; ++i) { |
219 | 0 | (*m)[ByteToPiece(i)] = i; |
220 | 0 | } |
221 | 0 | return m; |
222 | 0 | }(); |
223 | 0 | const auto it = kMap->find(std::string(piece)); |
224 | 0 | if (it == kMap->end()) { |
225 | 0 | return -1; |
226 | 0 | } else { |
227 | 0 | return it->second; |
228 | 0 | } |
229 | 0 | } |
230 | | |
231 | | } // namespace sentencepiece |