/src/sentencepiece/src/sentencepiece_processor.h
Line | Count | Source |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #ifndef SENTENCEPIECE_PROCESSOR_H_ |
16 | | #define SENTENCEPIECE_PROCESSOR_H_ |
17 | | |
18 | | #include <cstdint> |
19 | | #include <cstring> |
20 | | #include <memory> |
21 | | #include <string> |
22 | | #include <string_view> |
23 | | #include <utility> |
24 | | #include <vector> |
25 | | |
26 | | #ifdef _USE_EXTERNAL_ABSL |
27 | | #include "third_party/absl/strings/string_view.h" |
28 | | #else // _USE_EXTERNAL_ABSL |
29 | | #ifndef SWIG |
30 | | namespace absl { |
31 | | using std::string_view; |
32 | | } // namespace absl |
33 | | #endif // SWIG |
34 | | #endif // _USE_EXTERNAL_ABSL |
35 | | |
36 | | namespace sentencepiece { |
37 | | namespace util { |
38 | | |
39 | | enum class StatusCode : int { |
40 | | kOk = 0, |
41 | | kCancelled = 1, |
42 | | kUnknown = 2, |
43 | | kInvalidArgument = 3, |
44 | | kDeadlineExceeded = 4, |
45 | | kNotFound = 5, |
46 | | kAlreadyExists = 6, |
47 | | kPermissionDenied = 7, |
48 | | kResourceExhausted = 8, |
49 | | kFailedPrecondition = 9, |
50 | | kAborted = 10, |
51 | | kOutOfRange = 11, |
52 | | kUnimplemented = 12, |
53 | | kInternal = 13, |
54 | | kUnavailable = 14, |
55 | | kDataLoss = 15, |
56 | | kUnauthenticated = 16, |
57 | | }; |
58 | | |
59 | | class Status { |
60 | | public: |
61 | | Status(); |
62 | | ~Status(); |
63 | | Status(StatusCode code, absl::string_view error_message); |
64 | | Status(const Status &s); |
65 | | void operator=(const Status &s); |
66 | | bool operator==(const Status &s) const; |
67 | | bool operator!=(const Status &s) const; |
68 | | inline bool ok() const { return rep_ == nullptr; } |
69 | | |
70 | | void set_error_message(const char *str); |
71 | | const char *error_message() const; |
72 | 0 | const char *message() const { return error_message(); } |
73 | | StatusCode code() const; |
74 | | std::string ToString() const; |
75 | | |
76 | | void IgnoreError(); |
77 | | |
78 | | private: |
79 | | struct Rep; |
80 | | std::unique_ptr<Rep> rep_; |
81 | | }; |
82 | | } // namespace util |
83 | | |
84 | | // SentencePieceProcessor: |
85 | | // Simple and language independent tokenizer and de-tokenizer for |
86 | | // Neural Network Machine Translation. |
87 | | // |
88 | | // SentencePieceProcessor provides Encode() and Decode() methods, |
89 | | // which correspond to tokenization and de-tokenization respectively. |
90 | | // |
91 | | // - Encode: |
92 | | // Given a raw source sentence, encode it into a sequence |
93 | | // of pieces or vocabulary ids. |
94 | | // |
95 | | // - Decode: |
96 | | // Given a sequence of pieces or vocabulary ids, decode it |
97 | | // into a de-tokenized raw sentence. |
98 | | // |
99 | | // SentencePieceProcessor provides a lossless data conversion |
100 | | // that allows the original raw sentence to be perfectly reconstructed |
101 | | // from the encoded data, i.e., Decode(Encode(input)) == input. |
102 | | // This characteristics is useful, as we can make the de-tokenization |
103 | | // completely language independent. |
104 | | // |
105 | | // Usage: |
106 | | // SentencePieceProcessor sp; |
107 | | // sp.Load("//path/to/model"); |
108 | | // |
109 | | // vector<string> sps; |
110 | | // sp.Encode("hello world.", &sps).IgnoreError(); |
111 | | // |
112 | | // vector<int> ids; |
113 | | // sp.Encode("hello world.", &ids).IgnoreError(); |
114 | | // |
115 | | // string detok; |
116 | | // sp.Decode(sps, &detok); |
117 | | // CHECK_EQ("hello world.", detok).IgnoreError(); |
118 | | // |
119 | | // sp.Decode(ids, &detok); |
120 | | // CHECK_EQ("hello world.", detok).IgnoreError(); |
121 | | // |
122 | | // We can also use SentencePieceText which manages the byte-offsets |
123 | | // between user input (output) and internal sentence pieces. |
124 | | // |
125 | | // SentencePieceText spt; |
126 | | // sp.Encode("hello world.", &spt); |
127 | | // // Emits the byte range of each piece. |
128 | | // for (const auto &piece : spt.pieces()) { |
129 | | // LOG(INFO) << piece.begin() << " " << piece.end(); |
130 | | // } |
131 | | // |
132 | | // sp.Decode({0, 1, 2, 3..}, &spt); |
133 | | // for (const auto &piece : spt.pieces()) { |
134 | | // LOG(INFO) << piece.begin() << " " << piece.end(); |
135 | | // } |
136 | | // |
137 | | |
138 | | class NBestSentencePieceText; |
139 | | class ModelInterface; |
140 | | class SentencePieceText; |
141 | | class ModelProto; |
142 | | class NormalizerSpec; |
143 | | |
144 | | namespace normalizer { |
145 | | class Normalizer; |
146 | | } // namespace normalizer |
147 | | |
148 | | #ifndef SWIGGO |
149 | | namespace util { |
150 | | // Redefine std::string for serialized_proto interface as Python's string is |
151 | | // a Unicode string. We can enforce the return value to be raw byte sequence |
152 | | // with SWIG's typemap. |
153 | | using bytes = std::string; |
154 | | } // namespace util |
155 | | #endif // SWIGGO |
156 | | |
157 | | class NBestSentencePieceText; |
158 | | class ModelInterface; |
159 | | class SentencePieceText; |
160 | | class SentencePieceText_SentencePiece; |
161 | | |
162 | | // Wrapper class of SentencePieceText |
163 | | // This wrapper only allows an immutable access to the proto and |
164 | | // hides the actual implementation of protobuf. |
165 | | // See sentencepiece.proto for the details of this class. |
166 | | class ImmutableSentencePieceText_ImmutableSentencePiece { |
167 | | public: |
168 | | ImmutableSentencePieceText_ImmutableSentencePiece(); |
169 | | ~ImmutableSentencePieceText_ImmutableSentencePiece() = default; |
170 | | |
171 | | const std::string &piece() const; |
172 | | const std::string &surface() const; |
173 | | uint32_t id() const; |
174 | | uint32_t begin() const; |
175 | | uint32_t end() const; |
176 | | |
177 | | friend class ImmutableSentencePieceText; |
178 | | |
179 | | private: |
180 | | explicit ImmutableSentencePieceText_ImmutableSentencePiece( |
181 | | const SentencePieceText_SentencePiece &sp); |
182 | | const SentencePieceText_SentencePiece *sp_ = nullptr; |
183 | | }; |
184 | | |
185 | | class ImmutableSentencePieceText { |
186 | | public: |
187 | | ImmutableSentencePieceText(); |
188 | | virtual ~ImmutableSentencePieceText(); |
189 | | |
190 | | std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const; |
191 | | |
192 | | size_t pieces_size() const; |
193 | | ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const; |
194 | | |
195 | | const std::string &text() const; |
196 | | float score() const; |
197 | | |
198 | | util::bytes SerializeAsString() const; |
199 | | |
200 | | // Returns the actual mutable proto. |
201 | | // Do not use this outside of SentencePieceProcessor, as |
202 | | // it returns the raw pointer managed by the shared_ptr. |
203 | | SentencePieceText *mutable_proto(); |
204 | | |
205 | | // Converts the utf8 byte spans into Unicode char span. |
206 | | void ConvertToUnicodeSpans(); |
207 | | |
208 | | friend class ImmutableNBestSentencePieceText; |
209 | | |
210 | | private: |
211 | | explicit ImmutableSentencePieceText(const SentencePieceText &spt); |
212 | | const SentencePieceText *spt_ = nullptr; |
213 | | std::shared_ptr<SentencePieceText> rep_; |
214 | | }; |
215 | | |
216 | | // Wrapper class of SentencePieceText |
217 | | // This wrapper only allows an immutable access to the proto and |
218 | | // hides the actual implementation of protobuf. |
219 | | // See sentencepiece.proto for the details of this class. |
220 | | class ImmutableNBestSentencePieceText { |
221 | | public: |
222 | | ImmutableNBestSentencePieceText(); |
223 | | virtual ~ImmutableNBestSentencePieceText(); |
224 | | |
225 | | std::vector<ImmutableSentencePieceText> nbests() const; |
226 | | |
227 | | size_t nbests_size() const; |
228 | | ImmutableSentencePieceText nbests(int index) const; |
229 | | |
230 | | util::bytes SerializeAsString() const; |
231 | | |
232 | | // Returns the actual mutable proto. |
233 | | // Do not use this outside of SentencePieceProcessor, as |
234 | | // it returns the raw pointer managed by the shared_ptr. |
235 | | NBestSentencePieceText *mutable_proto(); |
236 | | |
237 | | void ConvertToUnicodeSpans(); |
238 | | |
239 | | private: |
240 | | std::shared_ptr<NBestSentencePieceText> rep_; |
241 | | }; |
242 | | |
243 | | class SentencePieceProcessor { |
244 | | public: |
245 | | SentencePieceProcessor(); |
246 | | virtual ~SentencePieceProcessor(); |
247 | | |
248 | | // Loads model from `filename`. |
249 | | // Returns false if `filename` cannot be loaded. |
250 | | virtual util::Status Load(absl::string_view filename); |
251 | | |
252 | | // Loads model from `filename`. |
253 | | // Crash if `filename` cannot be loaded. |
254 | | virtual void LoadOrDie(absl::string_view filename); |
255 | | |
256 | | // Loads model from `model_proto`. |
257 | | // `model_proto` is copied. |
258 | | virtual util::Status Load(const ModelProto &model_proto); |
259 | | |
260 | | // Loads model from `model_proto`. |
261 | | // `model_proto` is moved. |
262 | | virtual util::Status Load(std::unique_ptr<ModelProto> model_proto); |
263 | | |
264 | | // Loads model from `serialized`, which is a string-serialized model proto. |
265 | | // Useful to load the model from a platform independent blob object. |
266 | | virtual util::Status LoadFromSerializedProto(absl::string_view serialized); |
267 | | |
268 | | // Returns the status. Encode/Decode methods are valid when status is OK. |
269 | | virtual util::Status status() const; |
270 | | |
271 | | // Sets encode extra_option sequence. |
272 | | virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option); |
273 | | |
274 | | // Sets decode extra_option sequence. |
275 | | virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option); |
276 | | |
277 | | ////////////////////////////////////////////////////////////// |
278 | | // Vocabulary restriction. |
279 | | // Background: |
280 | | // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt |
281 | | |
282 | | // Restricts the vocabulary set. |
283 | | // The input sentences are encoded into the tokens in `valid_vocab`. |
284 | | virtual util::Status SetVocabulary( |
285 | | const std::vector<absl::string_view> &valid_vocab); |
286 | | |
287 | | // Reverts the vocabulary restriction. |
288 | | virtual util::Status ResetVocabulary(); |
289 | | |
290 | | // Loads the valid vocabulary set from `filename` in TSV format. |
291 | | // Format: <token> <tab> <freq>. |
292 | | // Any token with frequency < threshold will be treated as OOV. |
293 | | virtual util::Status LoadVocabulary(absl::string_view filename, |
294 | | int threshold); |
295 | | |
296 | | ////////////////////////////////////////////////////////////// |
297 | | // Simple Encode and Decode API. |
298 | | // |
299 | | // Given a UTF8 input, encodes it into a sequence of sentence pieces. |
300 | | virtual util::Status Encode(absl::string_view input, |
301 | | std::vector<std::string> *pieces) const; |
302 | | |
303 | | // Given a UTF8 input, encodes it into a sequence of ids. |
304 | | virtual util::Status Encode(absl::string_view input, |
305 | | std::vector<int> *ids) const; |
306 | | |
307 | | // Given a sequence of pieces, decodes it into a detokenized output. |
308 | | virtual util::Status Decode(const std::vector<std::string> &pieces, |
309 | | std::string *detokenized) const; |
310 | | |
311 | | // Given a sequence of pieces, decodes it into a detokenized output. |
312 | | virtual util::Status Decode(const std::vector<absl::string_view> &pieces, |
313 | | std::string *detokenized) const; |
314 | | |
315 | | // Given a sequence of ids, decodes it into a detokenized output. |
316 | | virtual util::Status Decode(const std::vector<int> &ids, |
317 | | std::string *detokenized) const; |
318 | | |
319 | | ////////////////////////////////////////////////////////////// |
320 | | // NBest API. |
321 | | // |
322 | | // Same as Encode, but returns nbest results. |
323 | | virtual util::Status NBestEncode( |
324 | | absl::string_view input, int nbest_size, |
325 | | std::vector<std::vector<std::string>> *pieces) const; |
326 | | |
327 | | // Same as Encode, but returns nbest results. |
328 | | virtual util::Status NBestEncode(absl::string_view input, int nbest_size, |
329 | | std::vector<std::vector<int>> *ids) const; |
330 | | |
331 | | ////////////////////////////////////////////////////////////// |
332 | | // Sampling API. |
333 | | // |
334 | | // Unigram and BPE support sampling mode. |
335 | | // - Unigram (--model_type=unigram): |
336 | | // `nbest_size`: When `nbest_size` is positive value, approximately samples |
337 | | // one segmentation from nbest candidates. When `nbest_size` is negative |
338 | | // value, samples one segmentation from the hypotheses (Lattice) according to |
339 | | // the generation probabilities using forward-filtering and backward-sampling |
340 | | // algorithm. |
341 | | // `alpha`: Smoothing parameter (inverse temperature). The best segmentation |
342 | | // (Viterbi segmentation) is more likely sampled when setting larger alpha. |
343 | | // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or |
344 | | // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha` |
345 | | // in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity) |
346 | | // |
347 | | // - BPE (--model_type=bpe): |
348 | | // `alpha`: The dropout probability `p` of bpe merge operations in |
349 | | // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so |
350 | | // nbest_size parameter is ignored in BPE. |
351 | | virtual util::Status SampleEncode(absl::string_view input, int nbest_size, |
352 | | float alpha, |
353 | | std::vector<std::string> *pieces) const; |
354 | | |
355 | | // Same as above, but returns a sequence of ids. |
356 | | virtual util::Status SampleEncode(absl::string_view input, int nbest_size, |
357 | | float alpha, std::vector<int> *ids) const; |
358 | | |
359 | | ////////////////////////////////////////////////////////////// |
360 | | // SampleEncodeAndScore API. |
361 | | // |
362 | | // Sample `samples` many tokenisations from the segmentation lattice. |
363 | | // These methods are only available in model_type=unigram. |
364 | | // |
365 | | // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in |
366 | | // `Sample` method. |
367 | | // 'wor`: If `wor` is true, the samples are taken without replacement, and the |
368 | | // scores are the inclusion probabilities of the elements in the sample; |
369 | | // otherwise the samples are taken with replacement and the scores are the |
370 | | // log-probs of sample elements |
371 | | // `include_best`: If `include_best` is true, the best tokenisation is always |
372 | | // included in the sample, and the remaining elements are sampled excluding |
373 | | // the best. |
374 | | virtual util::Status SampleEncodeAndScore( |
375 | | absl::string_view input, int num_samples, float alpha, bool wor, |
376 | | bool include_best, |
377 | | std::vector<std::pair<std::vector<std::string>, float>> *pieces) const; |
378 | | |
379 | | // Same as above, but returns a sequence of ids. |
380 | | virtual util::Status SampleEncodeAndScore( |
381 | | absl::string_view input, int num_samples, float alpha, bool wor, |
382 | | bool include_best, |
383 | | std::vector<std::pair<std::vector<int>, float>> *ids) const; |
384 | | |
385 | | ////////////////////////////////////////////////////////////// |
386 | | // Entropy API. |
387 | | // |
388 | | // This only available in model_type=unigram. |
389 | | // Calculate entropy of possible tokenisations |
390 | | virtual util::Status CalculateEntropy(absl::string_view input, float alpha, |
391 | | float *entropy) const; |
392 | | |
393 | | ////////////////////////////////////////////////////////////// |
394 | | // Advanced API returning SentencePieceText, which manages |
395 | | // utf8-byte alignments between user-input/detokenized text |
396 | | // and internal sentencepiece sequence. |
397 | | // |
398 | | // Given a UTF8 input, encodes it into SentencePieceText. |
399 | | // |
400 | | // When using these APIs, sentencepiece.pb.h header files must be included. |
401 | | // We can also use ImutableSentencePieceText as follows. |
402 | | // |
403 | | // ImmutableSentencePieceText spt; |
404 | | // Encode("hello", spt.mutable_proto()).IgnoreError(); |
405 | | // std::cout << spt.pieces_size() << std::endl; |
406 | | virtual util::Status Encode(absl::string_view input, |
407 | | SentencePieceText *spt) const; |
408 | | |
409 | | virtual util::Status NBestEncode(absl::string_view input, int nbest_size, |
410 | | NBestSentencePieceText *nbest_spt) const; |
411 | | |
412 | | virtual util::Status SampleEncode(absl::string_view input, int nbest_size, |
413 | | float alpha, SentencePieceText *spt) const; |
414 | | |
415 | | virtual util::Status SampleEncodeAndScore( |
416 | | absl::string_view input, int num_samples, float alpha, bool wor, |
417 | | bool include_best, NBestSentencePieceText *samples_spt) const; |
418 | | |
419 | | // DEPRECATED: Remove this API and use std::vector<std::string_view> |
420 | | virtual util::Status Decode(const std::vector<std::string> &pieces, |
421 | | SentencePieceText *spt) const; |
422 | | |
423 | | virtual util::Status Decode(const std::vector<absl::string_view> &pieces, |
424 | | SentencePieceText *spt) const; |
425 | | |
426 | | virtual util::Status Decode(const std::vector<int> &ids, |
427 | | SentencePieceText *spt) const; |
428 | | #ifdef SWIG |
429 | | #define SPP_SWIG_CHECK_AND_THROW \ |
430 | | if (!status.ok()) throw status; |
431 | | #else |
432 | | #define SPP_SWIG_CHECK_AND_THROW \ |
433 | 0 | if (!status.ok()) { \ |
434 | 0 | } |
435 | | #endif // SWIG |
436 | | |
437 | | #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ |
438 | 0 | OutType output; \ |
439 | 0 | const auto status = FuncName(__VA_ARGS__, &output); \ |
440 | 0 | SPP_SWIG_CHECK_AND_THROW; \ |
441 | 0 | return output; |
442 | | |
443 | | #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \ |
444 | 0 | OutType output; \ |
445 | 0 | const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ |
446 | 0 | SPP_SWIG_CHECK_AND_THROW; \ |
447 | 0 | return output.SerializeAsString(); |
448 | | |
449 | | #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ |
450 | 0 | OutType output; \ |
451 | 0 | const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ |
452 | 0 | SPP_SWIG_CHECK_AND_THROW; \ |
453 | 0 | return output; |
454 | | |
455 | | ////////////////////////////////////////////////////////////// |
456 | | // Handy methods that return the result directly. |
457 | | // These functions ignore internal errors. |
458 | | virtual std::vector<std::string> EncodeAsPieces( |
459 | 0 | absl::string_view input) const { |
460 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input); |
461 | 0 | } |
462 | | |
463 | 0 | virtual std::vector<int> EncodeAsIds(absl::string_view input) const { |
464 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input); |
465 | 0 | } |
466 | | |
467 | | virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces( |
468 | 0 | absl::string_view input, int nbest_size) const { |
469 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL( |
470 | 0 | NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size); |
471 | 0 | } |
472 | | |
473 | | virtual std::vector<std::vector<int>> NBestEncodeAsIds( |
474 | 0 | absl::string_view input, int nbest_size) const { |
475 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>, |
476 | 0 | input, nbest_size); |
477 | 0 | } |
478 | | |
479 | | virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input, |
480 | | int nbest_size, |
481 | 0 | float alpha) const { |
482 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input, |
483 | 0 | nbest_size, alpha); |
484 | 0 | } |
485 | | |
486 | | virtual std::vector<int> SampleEncodeAsIds(absl::string_view input, |
487 | | int nbest_size, |
488 | 0 | float alpha) const { |
489 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input, |
490 | 0 | nbest_size, alpha); |
491 | 0 | } |
492 | | |
493 | | virtual std::vector<std::pair<std::vector<std::string>, float>> |
494 | | SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples, |
495 | 0 | float alpha, bool wor, bool include_best) const { |
496 | 0 | using _T = std::vector<std::pair<std::vector<std::string>, float>>; |
497 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples, |
498 | 0 | alpha, wor, include_best); |
499 | 0 | } |
500 | | |
501 | | virtual std::vector<std::pair<std::vector<int>, float>> |
502 | | SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples, |
503 | 0 | float alpha, bool wor, bool include_best) const { |
504 | 0 | using _T = std::vector<std::pair<std::vector<int>, float>>; |
505 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples, |
506 | 0 | alpha, wor, include_best); |
507 | 0 | } |
508 | | |
509 | | // DEPRECATED: Remove this API and use std::vector<std::string_view> |
510 | | virtual std::string DecodePieces( |
511 | 0 | const std::vector<std::string> &pieces) const { |
512 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces); |
513 | 0 | } |
514 | | |
515 | | virtual std::string DecodePieces( |
516 | 0 | const std::vector<absl::string_view> &pieces) const { |
517 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces); |
518 | 0 | } |
519 | | |
520 | 0 | virtual std::string DecodeIds(const std::vector<int> &ids) const { |
521 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids); |
522 | 0 | } |
523 | | |
524 | 0 | virtual float CalculateEntropy(absl::string_view text, float alpha) const { |
525 | 0 | DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha); |
526 | 0 | } |
527 | | |
528 | | ////////////////////////////////////////////////////////////// |
529 | | // SerializedProto API. (DEPRECATED). Use ImmutableProto API. |
530 | | // They are used in Python interface. Returns serialized proto. |
531 | | // In python module, we can get access to the full Proto after |
532 | | // deserialzing the returned byte sequence. |
533 | 0 | virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const { |
534 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input); |
535 | 0 | } |
536 | | |
537 | | virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input, |
538 | | int nbest_size, |
539 | | float alpha) const { |
540 | | DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText, |
541 | | input, nbest_size, alpha); |
542 | | } |
543 | | |
544 | | virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input, |
545 | 0 | int nbest_size) const { |
546 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL( |
547 | 0 | NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size); |
548 | 0 | } |
549 | | |
550 | | virtual util::bytes SampleEncodeAndScoreAsSerializedProto( |
551 | | absl::string_view input, int num_samples, float alpha, bool wor, |
552 | 0 | bool include_best) const { |
553 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore, |
554 | 0 | ImmutableNBestSentencePieceText, input, |
555 | 0 | num_samples, alpha, wor, include_best); |
556 | 0 | } |
557 | | |
558 | | // TODO(taku): Remove this API and use std::vector<std::string_view> |
559 | | virtual util::bytes DecodePiecesAsSerializedProto( |
560 | 0 | const std::vector<std::string> &pieces) const { |
561 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, |
562 | 0 | pieces); |
563 | 0 | } |
564 | | |
565 | | virtual util::bytes DecodePiecesAsSerializedProto( |
566 | 0 | const std::vector<absl::string_view> &pieces) const { |
567 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, |
568 | 0 | pieces); |
569 | 0 | } |
570 | | |
571 | | virtual util::bytes DecodeIdsAsSerializedProto( |
572 | 0 | const std::vector<int> &ids) const { |
573 | 0 | DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids); |
574 | 0 | } |
575 | | |
576 | | ////////////////////////////////////////////////////////////// |
577 | | // ImmutableProto API. |
578 | | virtual ImmutableSentencePieceText EncodeAsImmutableProto( |
579 | 0 | absl::string_view input) const { |
580 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input); |
581 | 0 | } |
582 | | |
583 | | virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto( |
584 | 0 | absl::string_view input, int nbest_size, float alpha) const { |
585 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText, |
586 | 0 | input, nbest_size, alpha); |
587 | 0 | } |
588 | | |
589 | | virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto( |
590 | 0 | absl::string_view input, int nbest_size) const { |
591 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL( |
592 | 0 | NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size); |
593 | 0 | } |
594 | | |
595 | | virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto( |
596 | | absl::string_view input, int num_samples, float alpha, bool wor, |
597 | 0 | bool include_best) const { |
598 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore, |
599 | 0 | ImmutableNBestSentencePieceText, input, |
600 | 0 | num_samples, alpha, wor, include_best); |
601 | 0 | } |
602 | | |
603 | | // TODO(taku): Remove this API and use std::vector<std::string_view> |
604 | | virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto( |
605 | 0 | const std::vector<std::string> &pieces) const { |
606 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces); |
607 | 0 | } |
608 | | |
609 | | virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto( |
610 | 0 | const std::vector<absl::string_view> &pieces) const { |
611 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces); |
612 | 0 | } |
613 | | |
614 | | virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto( |
615 | 0 | const std::vector<int> &ids) const { |
616 | 0 | DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids); |
617 | 0 | } |
618 | | |
619 | | #undef DEFINE_SPP_DIRECT_FUNC_IMPL |
620 | | #undef DEFINE_SPP_SERIALIZED_PROTO_IMPL |
621 | | #undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL |
622 | | |
623 | | ////////////////////////////////////////////////////////////// |
624 | | // Normalization methods. |
625 | | |
626 | | // Normalize `input`. |
627 | | virtual util::Status Normalize(absl::string_view input, |
628 | | std::string *normalized) const; |
629 | | |
630 | | // Normalize `input`. Stores the utf8-byte offset from |
631 | | // the normalized string to the original input. |
632 | | virtual util::Status Normalize(absl::string_view input, |
633 | | std::string *normalized, |
634 | | std::vector<size_t> *norm_to_orig) const; |
635 | | |
636 | | virtual std::string Normalize(absl::string_view input) const; |
637 | | |
638 | | ////////////////////////////////////////////////////////////// |
639 | | // Vocabulary management methods. |
640 | | // |
641 | | // Returns the size of sentence pieces, which is the same as |
642 | | // the size of vocabulary for NMT. |
643 | | virtual int GetPieceSize() const; |
644 | | |
645 | | // Returns the vocab id of `piece`. |
646 | | // Returns UNK(0) if `piece` is unknown. |
647 | | virtual int PieceToId(absl::string_view piece) const; |
648 | | |
649 | | // Returns the string representation of vocab with `id`. |
650 | | virtual const std::string &IdToPiece(int id) const; |
651 | | |
652 | | // Returns the score of `id`. |
653 | | // Usually score is an emission log probability of unigram language |
654 | | // model. |
655 | | virtual float GetScore(int id) const; |
656 | | |
657 | | // Returns true if `id` is unknown symbol. |
658 | | virtual bool IsUnknown(int id) const; |
659 | | |
660 | | // Returns true if `id` is control symbol. |
661 | | virtual bool IsControl(int id) const; |
662 | | |
663 | | // Returns true if `id` is unused symbol. |
664 | | virtual bool IsUnused(int id) const; |
665 | | |
666 | | // Returns true if `id` is byte symbol. |
667 | | virtual bool IsByte(int id) const; |
668 | | |
669 | | // Returns the reserved id. |
670 | | // Returns -1 if not defined. |
671 | | |
672 | | // Returns unknown (<unk>) id. |
673 | | virtual int unk_id() const; |
674 | | |
675 | | // Returns BOS (<s>) id. |
676 | | virtual int bos_id() const; |
677 | | |
678 | | // Returns EOS (</s>) id. |
679 | | virtual int eos_id() const; |
680 | | |
681 | | // Returns PAD (<pad>) id. |
682 | | virtual int pad_id() const; |
683 | | |
684 | | ////////////////////////////////////////////////////////////// |
685 | | // Model management. |
686 | | // |
687 | | // Allows injection of a mock model instance. `model` is moved. |
688 | | void SetModel(std::unique_ptr<ModelInterface> &&model); |
689 | | |
690 | | // Allows injection of a normalizer instance. `normalizer` is moved. |
691 | | void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer); |
692 | | |
693 | | // Returns immutable model proto. Useful to obtain extended |
694 | | // or experimental parameters encoded in model_proto. |
695 | | const ModelProto &model_proto() const; |
696 | | |
697 | | // returns immutable model proto as std::string. |
698 | | // Useful to save the state of this instance via Python's pickle object. |
699 | | util::bytes serialized_model_proto() const; |
700 | | |
701 | | // Returns mutable normalizer_spec. |
702 | | // Updating the intenral normalization during the encoding/decoding are not |
703 | | // recommended and may result in unexpected behavior. Use at your own risk. |
704 | | NormalizerSpec *mutable_normalizer_spec() const; |
705 | | |
706 | | private: |
707 | | enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE }; |
708 | | |
709 | | util::Status ParseExtraOptions(absl::string_view extra_option, |
710 | | std::vector<ExtraOption> *extra_options) const; |
711 | | |
712 | | util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options, |
713 | | SentencePieceText *spt) const; |
714 | | |
715 | | util::Status PopulateSentencePieceText( |
716 | | absl::string_view input, absl::string_view normalized, |
717 | | const std::vector<size_t> &norm_to_orig, |
718 | | const std::vector<std::pair<absl::string_view, int>> &result, |
719 | | SentencePieceText *spt) const; |
720 | | |
721 | | std::unique_ptr<ModelInterface> model_; |
722 | | std::unique_ptr<normalizer::Normalizer> normalizer_; |
723 | | std::unique_ptr<normalizer::Normalizer> denormalizer_; |
724 | | |
725 | | // Underlying model protocol buffer. The same lifetime as model_. |
726 | | std::unique_ptr<ModelProto> model_proto_; |
727 | | |
728 | | std::vector<ExtraOption> encode_extra_options_; |
729 | | std::vector<ExtraOption> decode_extra_options_; |
730 | | }; |
731 | | |
732 | | // Set seed value of random generator. |
733 | | // Do not set static_cast<unique_int>(-1), |
734 | | // as this seed is reserved for initializing from |
735 | | // std::random_device. |
736 | | void SetRandomGeneratorSeed(unsigned int seed); |
737 | | |
738 | | // Set the global log level. The default loglevel is 0. |
739 | | // The log is emitted only when min_log_level >= output_log_level. |
740 | | void SetMinLogLevel(int v); |
741 | | |
742 | | // IO related functions to absorb model formats. |
743 | | namespace io { |
744 | | // Loads `model_proto` from `filename`. |
745 | | // We can instantiate SentencePieceProcessor as follows: |
746 | | // |
747 | | // auto model_proto = absl::make_unique<ModelProto>(); |
748 | | // io::LoadModelProto("//path/spm.model", model_proto.get()); |
749 | | // SentencePieceProcessor sp; |
750 | | // CHECK_OK(sp.Load(std::move(model_proto))); |
751 | | util::Status LoadModelProto(absl::string_view, ModelProto *model_proto); |
752 | | |
753 | | // Saves `model_proto` as `filename`. |
754 | | util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto); |
755 | | } // namespace io |
756 | | } // namespace sentencepiece |
757 | | #endif // SENTENCEPIECE_PROCESSOR_H_ |