Coverage Report

Created: 2026-05-04 07:01

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/build/root/include/sentencepiece_trainer.h
Line
Count
Source
1
// Copyright 2018 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#ifndef SENTENCEPIECE_TRAINER_H_
16
#define SENTENCEPIECE_TRAINER_H_
17
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
21
22
#include "sentencepiece_processor.h"
23
24
namespace sentencepiece {
25
26
class TrainerSpec;
27
class NormalizerSpec;
28
class ModelProto;
29
30
namespace pretokenizer {
31
class PretokenizerForTrainingInterface;
32
}  // namespace pretokenizer
33
34
namespace normalizer {
35
class Normalizer;
36
}  // namespace normalizer
37
38
// Iterator over the training sentences.
39
// Training sentences are loaded sequentially as follows:
40
//
41
// for (; !it.done(); it.Next()) {
42
//    const std::string &s = it.value();
43
// }
44
// RETURN_IF_ERROR(it.status());
45
//
46
class SentenceIterator {
47
 public:
48
  virtual ~SentenceIterator() {}
49
  // Returns true if iteration finishes (including error case).
50
  // Uses SentenceIterator::status() method to know whether
51
  // all sentences are loaded successfully.
52
  virtual bool done() const = 0;
53
  virtual void Next() = 0;
54
  virtual const std::string &value() const = 0;
55
  virtual util::Status status() const = 0;
56
};
57
58
class SentencePieceTrainer {
59
 public:
60
  // Trains SentencePiece model with `trainer_spec`.
61
  // Default `normalizer_spec` is used.
62
  // When `sentence_iterator` is passed, load sentences from the iterator.
63
  static util::Status Train(const TrainerSpec &trainer_spec,
64
                            SentenceIterator *sentence_iterator = nullptr,
65
                            std::string *serialized_model_proto = nullptr);
66
67
  // Trains SentencePiece model with `trainer_spec` and
68
  // `normalizer_spec`.
69
  // When `sentence_iterator` is passed, load sentences from the iterator.
70
  static util::Status Train(const TrainerSpec &trainer_spec,
71
                            const NormalizerSpec &normalizer_spec,
72
                            SentenceIterator *sentence_iterator = nullptr,
73
                            std::string *serialized_model_proto = nullptr);
74
75
  // Trains SentencePiece model with `trainer_spec`, `normalizer_spec`
76
  // and `denormalizer_spec`.
77
  // When `sentence_iterator` is passed, load sentences from the iterator.
78
  static util::Status Train(const TrainerSpec &trainer_spec,
79
                            const NormalizerSpec &normalizer_spec,
80
                            const NormalizerSpec &denormalizer_spec,
81
                            SentenceIterator *sentence_iterator = nullptr,
82
                            std::string *serialized_model_proto = nullptr);
83
  // Trains SentencePiece model with command-line string in `args`,
84
  // e.g.,
85
  // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
86
  // When `sentence_iterator` is passed, load sentences from the iterator.
87
  static util::Status Train(absl::string_view args,
88
                            SentenceIterator *sentence_iterator = nullptr,
89
                            std::string *serialized_model_proto = nullptr);
90
91
  // Trains SentencePiece model with mapin `kwargs`.
92
  // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...}
93
  static util::Status Train(
94
      const std::unordered_map<std::string, std::string> &kwargs,
95
      SentenceIterator *sentence_iterator = nullptr,
96
      std::string *serialized_model_proto = nullptr);
97
98
  // The same as above, but passes the list of sentences.
99
  static util::Status Train(absl::string_view args,
100
                            const std::vector<std::string> &sentences,
101
                            std::string *serialized_model_proto = nullptr);
102
103
  // The same as above, but passes the list of sentences.
104
  static util::Status Train(
105
      const std::unordered_map<std::string, std::string> &kwargs,
106
      const std::vector<std::string> &sentences,
107
      std::string *serialized_model_proto = nullptr);
108
109
  // Handy function to make a normalizer spec from the pre-compiled
110
  // normalization name. Do not use this method in production as it crashes
111
  // When `name` is invalid. Useful for unittesting.
112
  static NormalizerSpec GetNormalizerSpec(absl::string_view name);
113
114
  // Populates necessary fields (precompiled_charmap) from
115
  // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
116
  static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec,
117
                                             bool is_denormalizer = false);
118
119
  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
120
  // std::unordered_map in `kargs`.
121
  static util::Status MergeSpecsFromArgs(
122
      const std::unordered_map<std::string, std::string> &kwargs,
123
      TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec,
124
      NormalizerSpec *denormalizer_spec);
125
126
  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
127
  // command line flags in `args`.
128
  static util::Status MergeSpecsFromArgs(absl::string_view args,
129
                                         TrainerSpec *trainer_spec,
130
                                         NormalizerSpec *normalizer_spec,
131
                                         NormalizerSpec *denormalizer_spec);
132
133
  // Injects global pre-tokenizer that are applied in training time.
134
  // Pretokenizer is only used for extracting pieces.
135
  // TODO(taku): It would be better to inject per `trainer_spec`.
136
  static util::Status SetPretokenizerForTraining(
137
      const pretokenizer::PretokenizerForTrainingInterface *pretokenizer);
138
139
  // Returns the current pretokenizer. if no pretokenizer is defined, returns
140
  // nullptr.
141
  static const pretokenizer::PretokenizerForTrainingInterface *
142
  GetPretokenizerForTraining();
143
144
  // Helper function to set `field_name=value` in `message`.
145
  // When `field_name` is repeated, multiple values can be passed
146
  // with comma-separated values. `field_name` must not be a nested message.
147
  // The body of these functions are automatically generated with
148
  // data/gen_spec_parser.pl
149
  static util::Status SetProtoField(absl::string_view name,
150
                                    absl::string_view value,
151
                                    TrainerSpec *message);
152
153
  static util::Status SetProtoField(absl::string_view name,
154
                                    absl::string_view value,
155
                                    NormalizerSpec *message);
156
157
  // Populates model type from string representation, e.g., "bpe".
158
  // Supported model: "unigram", "bpe", "word", "char".
159
  static util::Status PopulateModelTypeFromString(absl::string_view type,
160
                                                  TrainerSpec *trainer_spec);
161
162
 private:
163
0
  SentencePieceTrainer() {}
164
0
  ~SentencePieceTrainer() {}
165
};
166
167
class SentencePieceNormalizer {
168
 public:
169
  SentencePieceNormalizer();
170
  virtual ~SentencePieceNormalizer();
171
172
  virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
173
174
  virtual util::Status Load(absl::string_view filename);
175
176
  virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
177
178
  virtual util::Status LoadFromRuleTSV(absl::string_view filename);
179
180
  virtual util::Status LoadFromRuleName(absl::string_view name);
181
182
  virtual util::Status Normalize(absl::string_view input,
183
                                 std::string *normalized) const;
184
185
  virtual util::Status Normalize(absl::string_view input,
186
                                 std::string *normalized,
187
                                 std::vector<size_t> *norm_to_orig) const;
188
189
  virtual std::string Normalize(absl::string_view input) const;
190
191
  virtual NormalizerSpec *mutable_normalizer_spec() const;
192
193
  virtual std::string serialized_model_proto() const;
194
195
 private:
196
  std::unique_ptr<normalizer::Normalizer> normalizer_;
197
  std::unique_ptr<ModelProto> model_proto_;
198
};
199
200
// Converts the utf8 byte spans into Unicode char span.
201
void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm,
202
                               std::vector<size_t> *norm_to_orig);
203
204
// Sets data dir including the pre-compiled normalization data.
205
// The implementation is found in util.cc
206
void SetDataDir(absl::string_view data_dir);
207
208
}  // namespace sentencepiece
209
210
#endif  // SENTENCEPIECE_TRAINER_H_