/src/sentencepiece/src/model_interface.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #ifndef MODEL_INTERFACE_H_ |
16 | | #define MODEL_INTERFACE_H_ |
17 | | |
18 | | #include <memory> |
19 | | #include <set> |
20 | | #include <string> |
21 | | #include <utility> |
22 | | #include <vector> |
23 | | |
24 | | #include "common.h" |
25 | | #include "normalizer.h" |
26 | | #include "sentencepiece_model.pb.h" |
27 | | #include "sentencepiece_processor.h" |
28 | | #include "third_party/absl/container/flat_hash_map.h" |
29 | | #include "third_party/absl/strings/string_view.h" |
30 | | #include "third_party/darts_clone/darts.h" |
31 | | #include "util.h" |
32 | | |
33 | | namespace sentencepiece { |
34 | | |
35 | | // "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"] |
36 | | std::vector<absl::string_view> SplitIntoWords( |
37 | | absl::string_view text, bool treat_ws_as_suffix = false, |
38 | | bool allow_ws_only_pieces = false); |
39 | | |
40 | | // Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>"). |
41 | | std::string ByteToPiece(unsigned char c); |
42 | | |
43 | | // Converts piece to byte (e.g., "<0x3A>" -> 58). Returns -1 if `piece` is not |
44 | | // a valid byte piece. |
45 | | int PieceToByte(absl::string_view piece); |
46 | | |
47 | | using EncodeResult = std::vector<std::pair<absl::string_view, int>>; |
48 | | using NBestEncodeResult = std::vector<std::pair<EncodeResult, float>>; |
49 | | |
50 | | class ModelProto; |
51 | | |
52 | | // Underlying model interface. |
53 | | // Given a normalized string, returns a sequence of sentence pieces with ids. |
54 | | class ModelInterface { |
55 | | public: |
56 | | using PieceToIdMap = absl::flat_hash_map<absl::string_view, int>; |
57 | | // string_util::string_view_hash>; |
58 | | |
59 | | absl::string_view unk_piece() const; |
60 | | absl::string_view bos_piece() const; |
61 | | absl::string_view eos_piece() const; |
62 | | absl::string_view pad_piece() const; |
63 | | |
64 | | // `model_proto` should not be deleted until ModelInterface is destroyed. |
65 | | explicit ModelInterface(const ModelProto &model_proto); |
66 | 0 | ModelInterface() {} |
67 | | |
68 | | virtual ~ModelInterface(); |
69 | | |
70 | | // Returns Status. |
71 | | // Encode/Decode functions are valid only when status is OK. |
72 | 0 | virtual util::Status status() const { return status_; } |
73 | | |
74 | 0 | virtual const ModelProto &model_proto() const { return *model_proto_; } |
75 | | |
76 | 0 | virtual const normalizer::PrefixMatcher *prefix_matcher() const { |
77 | 0 | return matcher_.get(); |
78 | 0 | } |
79 | | |
80 | | // Given a normalized string, returns a sequence of sentence pieces with ids. |
81 | | // The concatenation of pieces must be the same as `normalized`. |
82 | | virtual EncodeResult Encode(absl::string_view normalized) const = 0; |
83 | | |
84 | | // The same as above, but returns nbest result with score. |
85 | | virtual NBestEncodeResult NBestEncode(absl::string_view normalized, |
86 | 0 | int nbest_size) const { |
87 | 0 | LOG(ERROR) << "Not implemented."; |
88 | 0 | return NBestEncodeResult(); |
89 | 0 | } |
90 | | |
91 | | virtual EncodeResult SampleEncode(absl::string_view normalized, |
92 | 0 | float alpha) const { |
93 | 0 | LOG(ERROR) << "Not implemented."; |
94 | 0 | return EncodeResult(); |
95 | 0 | } |
96 | | |
97 | | // Sample `samples` many tokenisations from the segmentation lattice |
98 | | // If `wor` is true, the samples are taken without replacement, and the scores |
99 | | // are the inclusion probabilities of the elements in the sample; otherwise |
100 | | // the samples are taken with replacement and the scores are the log-probs of |
101 | | // sample elements |
102 | | // If `include_best` is true, the best tokenisation is always included in the |
103 | | // sample, and the remaining elements are sampled excluding the best. |
104 | | virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, |
105 | | float alpha, int samples, |
106 | | bool wor, |
107 | 0 | bool include_best) const { |
108 | 0 | LOG(ERROR) << "Not implemented."; |
109 | 0 | return {{EncodeResult(), 0.0}}; |
110 | 0 | } |
111 | | |
112 | | // Calculates the entropy of the segmentation lattice with inverse temperature |
113 | | // `alpha`. Uses a novel dynamic program to calculate the entropy. |
114 | | virtual float CalculateEntropy(absl::string_view normalized, |
115 | 0 | float alpha) const { |
116 | 0 | LOG(ERROR) << "Not implemented."; |
117 | 0 | return 0.0; |
118 | 0 | } |
119 | | |
120 | | // Return true if SampleEncode returns a valid result. |
121 | 0 | virtual bool IsSampleEncodeAvailable() const { return false; } |
122 | | |
123 | | // Return true if NBestEncode returns a valid result. |
124 | 0 | virtual bool IsNBestEncodeAvailable() const { return false; } |
125 | | |
126 | | // Return true if SampleEncodeAndScore returns a valid result. |
127 | 0 | virtual bool IsSampleEncodeAndScoreAvailable() const { return false; } |
128 | | |
129 | | // Return true if CalculateEntropy returns a valid result. |
130 | 0 | virtual bool IsCalculateEntropyAvailable() const { return false; } |
131 | | |
132 | | // Returns the vocab id of `piece`. |
133 | | // Returns UNK(0) if `piece` is unknown |
134 | | virtual int PieceToId(absl::string_view piece) const; |
135 | | |
136 | | // Returns the string representation of vocab with `id`. |
137 | | // id must be 0 <= id < GetPieceSize(). |
138 | 0 | virtual const std::string &IdToPiece(int id) const { |
139 | 0 | return model_proto_->pieces(id).piece(); |
140 | 0 | } |
141 | | |
142 | | // Returns the size of sentence pieces, which is the same |
143 | | // as the size of vocabulary for NMT. |
144 | 0 | virtual int GetPieceSize() const { |
145 | 0 | if (!model_proto_) return 0; |
146 | 0 | return model_proto_->pieces_size(); |
147 | 0 | } |
148 | | |
149 | | // Returns the score of `id`. |
150 | | // Score represents a log probability of the piece. |
151 | | // We can roughly estimate the unigram frequency of the piece. |
152 | 0 | virtual float GetScore(int id) const { |
153 | 0 | return model_proto_->pieces(id).score(); |
154 | 0 | } |
155 | | |
156 | | // Returns true if `id` is unknown symbol. |
157 | 0 | virtual bool IsUnknown(int id) const { |
158 | 0 | return (model_proto_->pieces(id).type() == |
159 | 0 | ModelProto::SentencePiece::UNKNOWN); |
160 | 0 | } |
161 | | |
162 | | // Returns true if `id` is control symbol. |
163 | 0 | virtual bool IsControl(int id) const { |
164 | 0 | return (model_proto_->pieces(id).type() == |
165 | 0 | ModelProto::SentencePiece::CONTROL); |
166 | 0 | } |
167 | | |
168 | | // Returns true if `id` is unused symbol. |
169 | 0 | virtual bool IsUnused(int id) const { |
170 | 0 | return (model_proto_->pieces(id).type() == |
171 | 0 | ModelProto::SentencePiece::UNUSED); |
172 | 0 | } |
173 | | |
174 | | // Returns true if `id` is user defined symbol. |
175 | 0 | virtual bool IsUserDefined(int id) const { |
176 | 0 | return (model_proto_->pieces(id).type() == |
177 | 0 | ModelProto::SentencePiece::USER_DEFINED); |
178 | 0 | } |
179 | | |
180 | | // Returns true if `id` is byte symbol. |
181 | 0 | virtual bool IsByte(int id) const { |
182 | 0 | return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE); |
183 | 0 | } |
184 | | |
185 | 0 | virtual bool ByteFallbackEnabled() const { |
186 | 0 | return model_proto_ && model_proto_->trainer_spec().byte_fallback(); |
187 | 0 | } |
188 | | |
189 | | // Verifies if the `expected` and `actual` outputs are equivalent. `expected` |
190 | | // and `actual` are sentence pieces joined by space (` `). Normally it means |
191 | | // that the two strings are identical. In some model, due to float rounding |
192 | | // errors, the strings may not be identical, but they may be still equivalent |
193 | | // provided their scores are close enough (by some espilon). |
194 | | virtual bool VerifyOutputsEquivalent(absl::string_view expected, |
195 | 0 | absl::string_view actual) const { |
196 | 0 | return expected == actual; |
197 | 0 | } |
198 | | |
199 | | protected: |
200 | | void InitializePieces(); |
201 | | |
202 | | // Non-virtual (inlined) implementation for faster execution. |
203 | 0 | inline float GetScoreInlined(int id) const { |
204 | 0 | return model_proto_->pieces(id).score(); |
205 | 0 | } |
206 | | |
207 | 0 | inline bool IsUnknownInlined(int id) const { |
208 | 0 | return (model_proto_->pieces(id).type() == |
209 | 0 | ModelProto::SentencePiece::UNKNOWN); |
210 | 0 | } |
211 | | |
212 | 0 | inline bool IsControlInlined(int id) const { |
213 | 0 | return (model_proto_->pieces(id).type() == |
214 | 0 | ModelProto::SentencePiece::CONTROL); |
215 | 0 | } |
216 | | |
217 | 0 | inline bool IsUnusedInlined(int id) const { |
218 | 0 | return (model_proto_->pieces(id).type() == |
219 | 0 | ModelProto::SentencePiece::UNUSED); |
220 | 0 | } |
221 | | |
222 | 0 | inline bool IsUserDefinedInlined(int id) const { |
223 | 0 | return (model_proto_->pieces(id).type() == |
224 | 0 | ModelProto::SentencePiece::USER_DEFINED); |
225 | 0 | } |
226 | | |
227 | 0 | inline bool IsByteInlined(int id) const { |
228 | 0 | return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE); |
229 | 0 | } |
230 | | |
231 | | const ModelProto *model_proto_ = nullptr; |
232 | | |
233 | | // PrefixMatcher for user defined symbols. |
234 | | std::unique_ptr<normalizer::PrefixMatcher> matcher_; |
235 | | |
236 | | // piece -> id map for normal pieces |
237 | | PieceToIdMap pieces_; |
238 | | |
239 | | // piece -> id map for control, unknown, and byte pieces |
240 | | PieceToIdMap reserved_id_map_; |
241 | | |
242 | | // unknown id. |
243 | | int unk_id_ = 0; |
244 | | |
245 | | // status. |
246 | | util::Status status_; |
247 | | }; |
248 | | } // namespace sentencepiece |
249 | | #endif // MODEL_INTERFACE_H_ |