Coverage Report

Created: 2026-01-10 06:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/sentencepiece_processor.h
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#ifndef SENTENCEPIECE_PROCESSOR_H_
16
#define SENTENCEPIECE_PROCESSOR_H_
17
18
#include <cstdint>
19
#include <cstring>
20
#include <memory>
21
#include <string>
22
#include <string_view>
23
#include <utility>
24
#include <vector>
25
26
#ifdef _USE_EXTERNAL_ABSL
27
#include "third_party/absl/strings/string_view.h"
28
#else  // _USE_EXTERNAL_ABSL
29
#ifndef SWIG
30
namespace absl {
31
using std::string_view;
32
}  // namespace absl
33
#endif  // SWIG
34
#endif  // _USE_EXTERNAL_ABSL
35
36
namespace sentencepiece {
37
namespace util {
38
39
enum class StatusCode : int {
40
  kOk = 0,
41
  kCancelled = 1,
42
  kUnknown = 2,
43
  kInvalidArgument = 3,
44
  kDeadlineExceeded = 4,
45
  kNotFound = 5,
46
  kAlreadyExists = 6,
47
  kPermissionDenied = 7,
48
  kResourceExhausted = 8,
49
  kFailedPrecondition = 9,
50
  kAborted = 10,
51
  kOutOfRange = 11,
52
  kUnimplemented = 12,
53
  kInternal = 13,
54
  kUnavailable = 14,
55
  kDataLoss = 15,
56
  kUnauthenticated = 16,
57
};
58
59
class Status {
60
 public:
61
  Status();
62
  ~Status();
63
  Status(StatusCode code, absl::string_view error_message);
64
  Status(const Status &s);
65
  void operator=(const Status &s);
66
  bool operator==(const Status &s) const;
67
  bool operator!=(const Status &s) const;
68
  inline bool ok() const { return rep_ == nullptr; }
69
70
  void set_error_message(const char *str);
71
  const char *error_message() const;
72
0
  const char *message() const { return error_message(); }
73
  StatusCode code() const;
74
  std::string ToString() const;
75
76
  void IgnoreError();
77
78
 private:
79
  struct Rep;
80
  std::unique_ptr<Rep> rep_;
81
};
82
}  // namespace util
83
84
// SentencePieceProcessor:
85
// Simple and language independent tokenizer and de-tokenizer for
86
// Neural Network Machine Translation.
87
//
88
// SentencePieceProcessor provides Encode() and Decode() methods,
89
// which correspond to tokenization and de-tokenization respectively.
90
//
91
// - Encode:
92
//   Given a raw source sentence, encode it into a sequence
93
//   of pieces or vocabulary ids.
94
//
95
// - Decode:
96
//   Given a sequence of pieces or vocabulary ids, decode it
97
//   into a de-tokenized raw sentence.
98
//
99
// SentencePieceProcessor provides a lossless data conversion
100
// that allows the original raw sentence to be perfectly reconstructed
101
// from the encoded data, i.e., Decode(Encode(input)) == input.
102
// This characteristics is useful, as we can make the de-tokenization
103
// completely language independent.
104
//
105
// Usage:
106
//   SentencePieceProcessor sp;
107
//   sp.Load("//path/to/model");
108
//
109
//   vector<string> sps;
110
//   sp.Encode("hello world.", &sps).IgnoreError();
111
//
112
//   vector<int> ids;
113
//   sp.Encode("hello world.", &ids).IgnoreError();
114
//
115
//   string detok;
116
//   sp.Decode(sps, &detok);
117
//   CHECK_EQ("hello world.", detok).IgnoreError();
118
//
119
//   sp.Decode(ids, &detok);
120
//   CHECK_EQ("hello world.", detok).IgnoreError();
121
//
122
//  We can also use SentencePieceText which manages the byte-offsets
123
//  between user input (output) and internal sentence pieces.
124
//
125
//   SentencePieceText spt;
126
//   sp.Encode("hello world.", &spt);
127
//   // Emits the byte range of each piece.
128
//   for (const auto &piece : spt.pieces()) {
129
//      LOG(INFO) << piece.begin() << " " << piece.end();
130
//   }
131
//
132
//   sp.Decode({0, 1, 2, 3..}, &spt);
133
//   for (const auto &piece : spt.pieces()) {
134
//      LOG(INFO) << piece.begin() << " " << piece.end();
135
//   }
136
//
137
138
class NBestSentencePieceText;
139
class ModelInterface;
140
class SentencePieceText;
141
class ModelProto;
142
class NormalizerSpec;
143
144
namespace normalizer {
145
class Normalizer;
146
}  // namespace normalizer
147
148
#ifndef SWIGGO
149
namespace util {
150
// Redefine std::string for serialized_proto interface as Python's string is
151
// a Unicode string. We can enforce the return value to be raw byte sequence
152
// with SWIG's typemap.
153
using bytes = std::string;
154
}  // namespace util
155
#endif  // SWIGGO
156
157
class NBestSentencePieceText;
158
class ModelInterface;
159
class SentencePieceText;
160
class SentencePieceText_SentencePiece;
161
162
// Wrapper class of SentencePieceText
163
// This wrapper only allows an immutable access to the proto and
164
// hides the actual implementation of protobuf.
165
// See sentencepiece.proto for the details of this class.
166
class ImmutableSentencePieceText_ImmutableSentencePiece {
167
 public:
168
  ImmutableSentencePieceText_ImmutableSentencePiece();
169
  ~ImmutableSentencePieceText_ImmutableSentencePiece() = default;
170
171
  const std::string &piece() const;
172
  const std::string &surface() const;
173
  uint32_t id() const;
174
  uint32_t begin() const;
175
  uint32_t end() const;
176
177
  friend class ImmutableSentencePieceText;
178
179
 private:
180
  explicit ImmutableSentencePieceText_ImmutableSentencePiece(
181
      const SentencePieceText_SentencePiece &sp);
182
  const SentencePieceText_SentencePiece *sp_ = nullptr;
183
};
184
185
class ImmutableSentencePieceText {
186
 public:
187
  ImmutableSentencePieceText();
188
  virtual ~ImmutableSentencePieceText();
189
190
  std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const;
191
192
  size_t pieces_size() const;
193
  ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const;
194
195
  const std::string &text() const;
196
  float score() const;
197
198
  util::bytes SerializeAsString() const;
199
200
  // Returns the actual mutable proto.
201
  // Do not use this outside of SentencePieceProcessor, as
202
  // it returns the raw pointer managed by the shared_ptr.
203
  SentencePieceText *mutable_proto();
204
205
  // Converts the utf8 byte spans into Unicode char span.
206
  void ConvertToUnicodeSpans();
207
208
  friend class ImmutableNBestSentencePieceText;
209
210
 private:
211
  explicit ImmutableSentencePieceText(const SentencePieceText &spt);
212
  const SentencePieceText *spt_ = nullptr;
213
  std::shared_ptr<SentencePieceText> rep_;
214
};
215
216
// Wrapper class of SentencePieceText
217
// This wrapper only allows an immutable access to the proto and
218
// hides the actual implementation of protobuf.
219
// See sentencepiece.proto for the details of this class.
220
class ImmutableNBestSentencePieceText {
221
 public:
222
  ImmutableNBestSentencePieceText();
223
  virtual ~ImmutableNBestSentencePieceText();
224
225
  std::vector<ImmutableSentencePieceText> nbests() const;
226
227
  size_t nbests_size() const;
228
  ImmutableSentencePieceText nbests(int index) const;
229
230
  util::bytes SerializeAsString() const;
231
232
  // Returns the actual mutable proto.
233
  // Do not use this outside of SentencePieceProcessor, as
234
  // it returns the raw pointer managed by the shared_ptr.
235
  NBestSentencePieceText *mutable_proto();
236
237
  void ConvertToUnicodeSpans();
238
239
 private:
240
  std::shared_ptr<NBestSentencePieceText> rep_;
241
};
242
243
class SentencePieceProcessor {
244
 public:
245
  SentencePieceProcessor();
246
  virtual ~SentencePieceProcessor();
247
248
  // Loads model from `filename`.
249
  // Returns false if `filename` cannot be loaded.
250
  virtual util::Status Load(absl::string_view filename);
251
252
  // Loads model from `filename`.
253
  // Crash if `filename` cannot be loaded.
254
  virtual void LoadOrDie(absl::string_view filename);
255
256
  // Loads model from `model_proto`.
257
  // `model_proto` is copied.
258
  virtual util::Status Load(const ModelProto &model_proto);
259
260
  // Loads model from `model_proto`.
261
  // `model_proto` is moved.
262
  virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
263
264
  // Loads model from `serialized`, which is a string-serialized model proto.
265
  // Useful to load the model from a platform independent blob object.
266
  virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
267
268
  // Returns the status. Encode/Decode methods are valid when status is OK.
269
  virtual util::Status status() const;
270
271
  // Sets encode extra_option sequence.
272
  virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option);
273
274
  // Sets decode extra_option sequence.
275
  virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option);
276
277
  //////////////////////////////////////////////////////////////
278
  // Vocabulary restriction.
279
  // Background:
280
  // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
281
282
  // Restricts the vocabulary set.
283
  // The input sentences are encoded into the tokens in `valid_vocab`.
284
  virtual util::Status SetVocabulary(
285
      const std::vector<absl::string_view> &valid_vocab);
286
287
  // Reverts the vocabulary restriction.
288
  virtual util::Status ResetVocabulary();
289
290
  // Loads the valid vocabulary set from `filename` in TSV format.
291
  // Format:  <token> <tab> <freq>.
292
  // Any token with frequency < threshold will be treated as OOV.
293
  virtual util::Status LoadVocabulary(absl::string_view filename,
294
                                      int threshold);
295
296
  //////////////////////////////////////////////////////////////
297
  // Simple Encode and Decode API.
298
  //
299
  // Given a UTF8 input, encodes it into a sequence of sentence pieces.
300
  virtual util::Status Encode(absl::string_view input,
301
                              std::vector<std::string> *pieces) const;
302
303
  // Given a UTF8 input, encodes it into a sequence of ids.
304
  virtual util::Status Encode(absl::string_view input,
305
                              std::vector<int> *ids) const;
306
307
  // Given a sequence of pieces, decodes it into a detokenized output.
308
  virtual util::Status Decode(const std::vector<std::string> &pieces,
309
                              std::string *detokenized) const;
310
311
  // Given a sequence of pieces, decodes it into a detokenized output.
312
  virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
313
                              std::string *detokenized) const;
314
315
  // Given a sequence of ids, decodes it into a detokenized output.
316
  virtual util::Status Decode(const std::vector<int> &ids,
317
                              std::string *detokenized) const;
318
319
  //////////////////////////////////////////////////////////////
320
  // NBest API.
321
  //
322
  // Same as Encode, but returns nbest results.
323
  virtual util::Status NBestEncode(
324
      absl::string_view input, int nbest_size,
325
      std::vector<std::vector<std::string>> *pieces) const;
326
327
  // Same as Encode, but returns nbest results.
328
  virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
329
                                   std::vector<std::vector<int>> *ids) const;
330
331
  //////////////////////////////////////////////////////////////
332
  // Sampling API.
333
  //
334
  // Unigram and BPE support sampling mode.
335
  // - Unigram (--model_type=unigram):
336
  // `nbest_size`: When `nbest_size` is positive value, approximately samples
337
  // one segmentation from nbest candidates. When `nbest_size` is negative
338
  // value, samples one segmentation from the hypotheses (Lattice) according to
339
  // the generation probabilities using forward-filtering and backward-sampling
340
  // algorithm.
341
  // `alpha`: Smoothing parameter (inverse temperature). The best segmentation
342
  // (Viterbi segmentation) is more likely sampled when setting larger alpha.
343
  // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
344
  // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
345
  // in https://arxiv.org/abs/1804.10959  (nbest_size < 0 means l = infinity)
346
  //
347
  // - BPE (--model_type=bpe):
348
  // `alpha`: The dropout probability `p` of bpe merge operations in
349
  // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
350
  // nbest_size parameter is ignored in BPE.
351
  virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
352
                                    float alpha,
353
                                    std::vector<std::string> *pieces) const;
354
355
  // Same as above, but returns a sequence of ids.
356
  virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
357
                                    float alpha, std::vector<int> *ids) const;
358
359
  //////////////////////////////////////////////////////////////
360
  // SampleEncodeAndScore API.
361
  //
362
  // Sample `samples` many tokenisations from the segmentation lattice.
363
  // These methods are only available in model_type=unigram.
364
  //
365
  // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
366
  // `Sample` method.
367
  // 'wor`: If `wor` is true, the samples are taken without replacement, and the
368
  // scores are the inclusion probabilities of the elements in the sample;
369
  // otherwise the samples are taken with replacement and the scores are the
370
  // log-probs of sample elements
371
  // `include_best`: If `include_best` is true, the best tokenisation is always
372
  // included in the sample, and the remaining elements are sampled excluding
373
  // the best.
374
  virtual util::Status SampleEncodeAndScore(
375
      absl::string_view input, int num_samples, float alpha, bool wor,
376
      bool include_best,
377
      std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
378
379
  // Same as above, but returns a sequence of ids.
380
  virtual util::Status SampleEncodeAndScore(
381
      absl::string_view input, int num_samples, float alpha, bool wor,
382
      bool include_best,
383
      std::vector<std::pair<std::vector<int>, float>> *ids) const;
384
385
  //////////////////////////////////////////////////////////////
386
  // Entropy API.
387
  //
388
  // This only available in model_type=unigram.
389
  // Calculate entropy of possible tokenisations
390
  virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
391
                                        float *entropy) const;
392
393
  //////////////////////////////////////////////////////////////
394
  // Advanced API returning SentencePieceText, which manages
395
  // utf8-byte alignments between user-input/detokenized text
396
  // and internal sentencepiece sequence.
397
  //
398
  // Given a UTF8 input, encodes it into SentencePieceText.
399
  //
400
  // When using these APIs, sentencepiece.pb.h header files must be included.
401
  // We can also use ImutableSentencePieceText as follows.
402
  //
403
  // ImmutableSentencePieceText spt;
404
  // Encode("hello", spt.mutable_proto()).IgnoreError();
405
  // std::cout << spt.pieces_size() << std::endl;
406
  virtual util::Status Encode(absl::string_view input,
407
                              SentencePieceText *spt) const;
408
409
  virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
410
                                   NBestSentencePieceText *nbest_spt) const;
411
412
  virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
413
                                    float alpha, SentencePieceText *spt) const;
414
415
  virtual util::Status SampleEncodeAndScore(
416
      absl::string_view input, int num_samples, float alpha, bool wor,
417
      bool include_best, NBestSentencePieceText *samples_spt) const;
418
419
  // DEPRECATED: Remove this API and use std::vector<std::string_view>
420
  virtual util::Status Decode(const std::vector<std::string> &pieces,
421
                              SentencePieceText *spt) const;
422
423
  virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
424
                              SentencePieceText *spt) const;
425
426
  virtual util::Status Decode(const std::vector<int> &ids,
427
                              SentencePieceText *spt) const;
428
#ifdef SWIG
429
#define SPP_SWIG_CHECK_AND_THROW \
430
  if (!status.ok()) throw status;
431
#else
432
#define SPP_SWIG_CHECK_AND_THROW \
433
0
  if (!status.ok()) {            \
434
0
  }
435
#endif  // SWIG
436
437
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
438
0
  OutType output;                                           \
439
0
  const auto status = FuncName(__VA_ARGS__, &output);       \
440
0
  SPP_SWIG_CHECK_AND_THROW;                                 \
441
0
  return output;
442
443
#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...)     \
444
0
  OutType output;                                                    \
445
0
  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
446
0
  SPP_SWIG_CHECK_AND_THROW;                                          \
447
0
  return output.SerializeAsString();
448
449
#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...)      \
450
0
  OutType output;                                                    \
451
0
  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
452
0
  SPP_SWIG_CHECK_AND_THROW;                                          \
453
0
  return output;
454
455
  //////////////////////////////////////////////////////////////
456
  // Handy methods that return the result directly.
457
  // These functions ignore internal errors.
458
  virtual std::vector<std::string> EncodeAsPieces(
459
0
      absl::string_view input) const {
460
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
461
0
  }
462
463
0
  virtual std::vector<int> EncodeAsIds(absl::string_view input) const {
464
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
465
0
  }
466
467
  virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
468
0
      absl::string_view input, int nbest_size) const {
469
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(
470
0
        NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
471
0
  }
472
473
  virtual std::vector<std::vector<int>> NBestEncodeAsIds(
474
0
      absl::string_view input, int nbest_size) const {
475
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
476
0
                                input, nbest_size);
477
0
  }
478
479
  virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input,
480
                                                        int nbest_size,
481
0
                                                        float alpha) const {
482
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
483
0
                                nbest_size, alpha);
484
0
  }
485
486
  virtual std::vector<int> SampleEncodeAsIds(absl::string_view input,
487
                                             int nbest_size,
488
0
                                             float alpha) const {
489
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
490
0
                                nbest_size, alpha);
491
0
  }
492
493
  virtual std::vector<std::pair<std::vector<std::string>, float>>
494
  SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
495
0
                               float alpha, bool wor, bool include_best) const {
496
0
    using _T = std::vector<std::pair<std::vector<std::string>, float>>;
497
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
498
0
                                alpha, wor, include_best);
499
0
  }
500
501
  virtual std::vector<std::pair<std::vector<int>, float>>
502
  SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
503
0
                            float alpha, bool wor, bool include_best) const {
504
0
    using _T = std::vector<std::pair<std::vector<int>, float>>;
505
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
506
0
                                alpha, wor, include_best);
507
0
  }
508
509
  // DEPRECATED: Remove this API and use std::vector<std::string_view>
510
  virtual std::string DecodePieces(
511
0
      const std::vector<std::string> &pieces) const {
512
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
513
0
  }
514
515
  virtual std::string DecodePieces(
516
0
      const std::vector<absl::string_view> &pieces) const {
517
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
518
0
  }
519
520
0
  virtual std::string DecodeIds(const std::vector<int> &ids) const {
521
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
522
0
  }
523
524
0
  virtual float CalculateEntropy(absl::string_view text, float alpha) const {
525
0
    DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
526
0
  }
527
528
  //////////////////////////////////////////////////////////////
529
  // SerializedProto API. (DEPRECATED). Use ImmutableProto API.
530
  // They are used in Python interface. Returns serialized proto.
531
  // In python module, we can get access to the full Proto after
532
  // deserialzing the returned byte sequence.
533
0
  virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
534
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
535
0
  }
536
537
  virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
538
                                                    int nbest_size,
539
                                                    float alpha) const {
540
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
541
                                     input, nbest_size, alpha);
542
  }
543
544
  virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
545
0
                                                   int nbest_size) const {
546
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(
547
0
        NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
548
0
  }
549
550
  virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
551
      absl::string_view input, int num_samples, float alpha, bool wor,
552
0
      bool include_best) const {
553
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
554
0
                                     ImmutableNBestSentencePieceText, input,
555
0
                                     num_samples, alpha, wor, include_best);
556
0
  }
557
558
  // TODO(taku): Remove this API and use std::vector<std::string_view>
559
  virtual util::bytes DecodePiecesAsSerializedProto(
560
0
      const std::vector<std::string> &pieces) const {
561
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
562
0
                                     pieces);
563
0
  }
564
565
  virtual util::bytes DecodePiecesAsSerializedProto(
566
0
      const std::vector<absl::string_view> &pieces) const {
567
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
568
0
                                     pieces);
569
0
  }
570
571
  virtual util::bytes DecodeIdsAsSerializedProto(
572
0
      const std::vector<int> &ids) const {
573
0
    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
574
0
  }
575
576
  //////////////////////////////////////////////////////////////
577
  // ImmutableProto API.
578
  virtual ImmutableSentencePieceText EncodeAsImmutableProto(
579
0
      absl::string_view input) const {
580
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
581
0
  }
582
583
  virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
584
0
      absl::string_view input, int nbest_size, float alpha) const {
585
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
586
0
                                    input, nbest_size, alpha);
587
0
  }
588
589
  virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
590
0
      absl::string_view input, int nbest_size) const {
591
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
592
0
        NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
593
0
  }
594
595
  virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
596
      absl::string_view input, int num_samples, float alpha, bool wor,
597
0
      bool include_best) const {
598
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
599
0
                                    ImmutableNBestSentencePieceText, input,
600
0
                                    num_samples, alpha, wor, include_best);
601
0
  }
602
603
  // TODO(taku): Remove this API and use std::vector<std::string_view>
604
  virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
605
0
      const std::vector<std::string> &pieces) const {
606
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
607
0
  }
608
609
  virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
610
0
      const std::vector<absl::string_view> &pieces) const {
611
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
612
0
  }
613
614
  virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
615
0
      const std::vector<int> &ids) const {
616
0
    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
617
0
  }
618
619
#undef DEFINE_SPP_DIRECT_FUNC_IMPL
620
#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
621
#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
622
623
  //////////////////////////////////////////////////////////////
624
  // Normalization methods.
625
626
  // Normalize `input`.
627
  virtual util::Status Normalize(absl::string_view input,
628
                                 std::string *normalized) const;
629
630
  // Normalize `input`. Stores the utf8-byte offset from
631
  // the normalized string to the original input.
632
  virtual util::Status Normalize(absl::string_view input,
633
                                 std::string *normalized,
634
                                 std::vector<size_t> *norm_to_orig) const;
635
636
  virtual std::string Normalize(absl::string_view input) const;
637
638
  //////////////////////////////////////////////////////////////
639
  // Vocabulary management methods.
640
  //
641
  // Returns the size of sentence pieces, which is the same as
642
  // the size of vocabulary for NMT.
643
  virtual int GetPieceSize() const;
644
645
  // Returns the vocab id of `piece`.
646
  // Returns UNK(0) if `piece` is unknown.
647
  virtual int PieceToId(absl::string_view piece) const;
648
649
  // Returns the string representation of vocab with `id`.
650
  virtual const std::string &IdToPiece(int id) const;
651
652
  // Returns the score of `id`.
653
  // Usually score is an emission log probability of unigram language
654
  // model.
655
  virtual float GetScore(int id) const;
656
657
  // Returns true if `id` is unknown symbol.
658
  virtual bool IsUnknown(int id) const;
659
660
  // Returns true if `id` is control symbol.
661
  virtual bool IsControl(int id) const;
662
663
  // Returns true if `id` is unused symbol.
664
  virtual bool IsUnused(int id) const;
665
666
  // Returns true if `id` is byte symbol.
667
  virtual bool IsByte(int id) const;
668
669
  // Returns the reserved id.
670
  // Returns -1 if not defined.
671
672
  // Returns unknown (<unk>) id.
673
  virtual int unk_id() const;
674
675
  // Returns BOS (<s>) id.
676
  virtual int bos_id() const;
677
678
  // Returns EOS (</s>) id.
679
  virtual int eos_id() const;
680
681
  // Returns PAD (<pad>) id.
682
  virtual int pad_id() const;
683
684
  //////////////////////////////////////////////////////////////
685
  // Model management.
686
  //
687
  // Allows injection of a mock model instance. `model` is moved.
688
  void SetModel(std::unique_ptr<ModelInterface> &&model);
689
690
  // Allows injection of a normalizer instance. `normalizer` is moved.
691
  void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
692
693
  // Returns immutable model proto. Useful to obtain extended
694
  // or experimental parameters encoded in model_proto.
695
  const ModelProto &model_proto() const;
696
697
  // returns immutable model proto as std::string.
698
  // Useful to save the state of this instance via Python's pickle object.
699
  util::bytes serialized_model_proto() const;
700
701
  // Returns mutable normalizer_spec.
702
  // Updating the intenral normalization during the encoding/decoding are not
703
  // recommended and may result in unexpected behavior. Use at your own risk.
704
  NormalizerSpec *mutable_normalizer_spec() const;
705
706
 private:
707
  enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
708
709
  util::Status ParseExtraOptions(absl::string_view extra_option,
710
                                 std::vector<ExtraOption> *extra_options) const;
711
712
  util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
713
                                 SentencePieceText *spt) const;
714
715
  util::Status PopulateSentencePieceText(
716
      absl::string_view input, absl::string_view normalized,
717
      const std::vector<size_t> &norm_to_orig,
718
      const std::vector<std::pair<absl::string_view, int>> &result,
719
      SentencePieceText *spt) const;
720
721
  std::unique_ptr<ModelInterface> model_;
722
  std::unique_ptr<normalizer::Normalizer> normalizer_;
723
  std::unique_ptr<normalizer::Normalizer> denormalizer_;
724
725
  // Underlying model protocol buffer. The same lifetime as model_.
726
  std::unique_ptr<ModelProto> model_proto_;
727
728
  std::vector<ExtraOption> encode_extra_options_;
729
  std::vector<ExtraOption> decode_extra_options_;
730
};
731
732
// Set seed value of random generator.
733
// Do not set static_cast<unique_int>(-1),
734
// as this seed is reserved for initializing from
735
// std::random_device.
736
void SetRandomGeneratorSeed(unsigned int seed);
737
738
// Set the global log level. The default loglevel is 0.
739
// The log is emitted only when min_log_level >= output_log_level.
740
void SetMinLogLevel(int v);
741
742
// IO related functions to absorb model formats.
743
namespace io {
744
// Loads `model_proto` from `filename`.
745
// We can instantiate SentencePieceProcessor as follows:
746
//
747
//  auto model_proto = absl::make_unique<ModelProto>();
748
//  io::LoadModelProto("//path/spm.model", model_proto.get());
749
//  SentencePieceProcessor sp;
750
//  CHECK_OK(sp.Load(std::move(model_proto)));
751
util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
752
753
// Saves `model_proto` as `filename`.
754
util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
755
}  // namespace io
756
}  // namespace sentencepiece
757
#endif  // SENTENCEPIECE_PROCESSOR_H_