Coverage Report

Created: 2025-08-26 06:02

/src/sentencepiece/src/model_interface.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#ifndef MODEL_INTERFACE_H_
16
#define MODEL_INTERFACE_H_
17
18
#include <memory>
19
#include <set>
20
#include <string>
21
#include <utility>
22
#include <vector>
23
24
#include "common.h"
25
#include "normalizer.h"
26
#include "sentencepiece_model.pb.h"
27
#include "sentencepiece_processor.h"
28
#include "third_party/absl/container/flat_hash_map.h"
29
#include "third_party/absl/strings/string_view.h"
30
#include "third_party/darts_clone/darts.h"
31
#include "util.h"
32
33
namespace sentencepiece {
34
35
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
36
std::vector<absl::string_view> SplitIntoWords(
37
    absl::string_view text, bool treat_ws_as_suffix = false,
38
    bool allow_ws_only_pieces = false);
39
40
// Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>").
41
std::string ByteToPiece(unsigned char c);
42
43
// Converts piece to byte (e.g., "<0x3A>" -> 58). Returns -1 if `piece` is not
44
// a valid byte piece.
45
int PieceToByte(absl::string_view piece);
46
47
using EncodeResult = std::vector<std::pair<absl::string_view, int>>;
48
using NBestEncodeResult = std::vector<std::pair<EncodeResult, float>>;
49
50
class ModelProto;
51
52
// Underlying model interface.
53
// Given a normalized string, returns a sequence of sentence pieces with ids.
54
class ModelInterface {
55
 public:
56
  using PieceToIdMap = absl::flat_hash_map<absl::string_view, int>;
57
  //                                           string_util::string_view_hash>;
58
59
  absl::string_view unk_piece() const;
60
  absl::string_view bos_piece() const;
61
  absl::string_view eos_piece() const;
62
  absl::string_view pad_piece() const;
63
64
  // `model_proto` should not be deleted until ModelInterface is destroyed.
65
  explicit ModelInterface(const ModelProto &model_proto);
66
0
  ModelInterface() {}
67
68
  virtual ~ModelInterface();
69
70
  // Returns Status.
71
  // Encode/Decode functions are valid only when status is OK.
72
0
  virtual util::Status status() const { return status_; }
73
74
0
  virtual const ModelProto &model_proto() const { return *model_proto_; }
75
76
0
  virtual const normalizer::PrefixMatcher *prefix_matcher() const {
77
0
    return matcher_.get();
78
0
  }
79
80
  // Given a normalized string, returns a sequence of sentence pieces with ids.
81
  // The concatenation of pieces must be the same as `normalized`.
82
  virtual EncodeResult Encode(absl::string_view normalized) const = 0;
83
84
  // The same as above, but returns nbest result with score.
85
  virtual NBestEncodeResult NBestEncode(absl::string_view normalized,
86
0
                                        int nbest_size) const {
87
0
    LOG(ERROR) << "Not implemented.";
88
0
    return NBestEncodeResult();
89
0
  }
90
91
  virtual EncodeResult SampleEncode(absl::string_view normalized,
92
0
                                    float alpha) const {
93
0
    LOG(ERROR) << "Not implemented.";
94
0
    return EncodeResult();
95
0
  }
96
97
  // Sample `samples` many tokenisations from the segmentation lattice
98
  // If `wor` is true, the samples are taken without replacement, and the scores
99
  // are the inclusion probabilities of the elements in the sample; otherwise
100
  // the samples are taken with replacement and the scores are the log-probs of
101
  // sample elements
102
  // If `include_best` is true, the best tokenisation is always included in the
103
  // sample, and the remaining elements are sampled excluding the best.
104
  virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
105
                                                 float alpha, int samples,
106
                                                 bool wor,
107
0
                                                 bool include_best) const {
108
0
    LOG(ERROR) << "Not implemented.";
109
0
    return {{EncodeResult(), 0.0}};
110
0
  }
111
112
  // Calculates the entropy of the segmentation lattice with inverse temperature
113
  // `alpha`. Uses a novel dynamic program to calculate the entropy.
114
  virtual float CalculateEntropy(absl::string_view normalized,
115
0
                                 float alpha) const {
116
0
    LOG(ERROR) << "Not implemented.";
117
0
    return 0.0;
118
0
  }
119
120
  // Return true if SampleEncode returns a valid result.
121
0
  virtual bool IsSampleEncodeAvailable() const { return false; }
122
123
  // Return true if NBestEncode returns a valid result.
124
0
  virtual bool IsNBestEncodeAvailable() const { return false; }
125
126
  // Return true if SampleEncodeAndScore returns a valid result.
127
0
  virtual bool IsSampleEncodeAndScoreAvailable() const { return false; }
128
129
  // Return true if CalculateEntropy returns a valid result.
130
0
  virtual bool IsCalculateEntropyAvailable() const { return false; }
131
132
  // Returns the vocab id of `piece`.
133
  // Returns UNK(0) if `piece` is unknown
134
  virtual int PieceToId(absl::string_view piece) const;
135
136
  // Returns the string representation of vocab with `id`.
137
  // id must be 0 <= id < GetPieceSize().
138
0
  virtual const std::string &IdToPiece(int id) const {
139
0
    return model_proto_->pieces(id).piece();
140
0
  }
141
142
  // Returns the size of sentence pieces, which is the same
143
  // as the size of vocabulary for NMT.
144
0
  virtual int GetPieceSize() const {
145
0
    if (!model_proto_) return 0;
146
0
    return model_proto_->pieces_size();
147
0
  }
148
149
  // Returns the score of `id`.
150
  // Score represents a log probability of the piece.
151
  // We can roughly estimate the unigram frequency of the piece.
152
0
  virtual float GetScore(int id) const {
153
0
    return model_proto_->pieces(id).score();
154
0
  }
155
156
  // Returns true if `id` is unknown symbol.
157
0
  virtual bool IsUnknown(int id) const {
158
0
    return (model_proto_->pieces(id).type() ==
159
0
            ModelProto::SentencePiece::UNKNOWN);
160
0
  }
161
162
  // Returns true if `id` is control symbol.
163
0
  virtual bool IsControl(int id) const {
164
0
    return (model_proto_->pieces(id).type() ==
165
0
            ModelProto::SentencePiece::CONTROL);
166
0
  }
167
168
  // Returns true if `id` is unused symbol.
169
0
  virtual bool IsUnused(int id) const {
170
0
    return (model_proto_->pieces(id).type() ==
171
0
            ModelProto::SentencePiece::UNUSED);
172
0
  }
173
174
  // Returns true if `id` is user defined symbol.
175
0
  virtual bool IsUserDefined(int id) const {
176
0
    return (model_proto_->pieces(id).type() ==
177
0
            ModelProto::SentencePiece::USER_DEFINED);
178
0
  }
179
180
  // Returns true if `id` is byte symbol.
181
0
  virtual bool IsByte(int id) const {
182
0
    return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE);
183
0
  }
184
185
0
  virtual bool ByteFallbackEnabled() const {
186
0
    return model_proto_ && model_proto_->trainer_spec().byte_fallback();
187
0
  }
188
189
  // Verifies if the `expected` and `actual` outputs are equivalent. `expected`
190
  // and `actual` are sentence pieces joined by space (` `). Normally it means
191
  // that the two strings are identical. In some model, due to float rounding
192
  // errors, the strings may not be identical, but they may be still equivalent
193
  // provided their scores are close enough (by some espilon).
194
  virtual bool VerifyOutputsEquivalent(absl::string_view expected,
195
0
                                       absl::string_view actual) const {
196
0
    return expected == actual;
197
0
  }
198
199
 protected:
200
  void InitializePieces();
201
202
  // Non-virtual (inlined) implementation for faster execution.
203
0
  inline float GetScoreInlined(int id) const {
204
0
    return model_proto_->pieces(id).score();
205
0
  }
206
207
0
  inline bool IsUnknownInlined(int id) const {
208
0
    return (model_proto_->pieces(id).type() ==
209
0
            ModelProto::SentencePiece::UNKNOWN);
210
0
  }
211
212
0
  inline bool IsControlInlined(int id) const {
213
0
    return (model_proto_->pieces(id).type() ==
214
0
            ModelProto::SentencePiece::CONTROL);
215
0
  }
216
217
0
  inline bool IsUnusedInlined(int id) const {
218
0
    return (model_proto_->pieces(id).type() ==
219
0
            ModelProto::SentencePiece::UNUSED);
220
0
  }
221
222
0
  inline bool IsUserDefinedInlined(int id) const {
223
0
    return (model_proto_->pieces(id).type() ==
224
0
            ModelProto::SentencePiece::USER_DEFINED);
225
0
  }
226
227
0
  inline bool IsByteInlined(int id) const {
228
0
    return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE);
229
0
  }
230
231
  const ModelProto *model_proto_ = nullptr;
232
233
  // PrefixMatcher for user defined symbols.
234
  std::unique_ptr<normalizer::PrefixMatcher> matcher_;
235
236
  // piece -> id map for normal pieces
237
  PieceToIdMap pieces_;
238
239
  // piece -> id map for control, unknown, and byte pieces
240
  PieceToIdMap reserved_id_map_;
241
242
  // unknown id.
243
  int unk_id_ = 0;
244
245
  // status.
246
  util::Status status_;
247
};
248
}  // namespace sentencepiece
249
#endif  // MODEL_INTERFACE_H_