/src/sentencepiece/src/unigram_model.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License.! |
14 | | |
15 | | #ifndef UNIGRAM_MODEL_H_ |
16 | | #define UNIGRAM_MODEL_H_ |
17 | | |
18 | | #include <memory> |
19 | | #include <string> |
20 | | #include <utility> |
21 | | #include <vector> |
22 | | |
23 | | #include "common.h" |
24 | | #include "freelist.h" |
25 | | #include "model_interface.h" |
26 | | #include "sentencepiece_model.pb.h" |
27 | | #include "third_party/darts_clone/darts.h" |
28 | | |
29 | | namespace sentencepiece { |
30 | | namespace unigram { |
31 | | |
32 | | // Lattice represents a search space of sentence piece segmentation. |
33 | | class Lattice { |
34 | | public: |
35 | | Lattice(); |
36 | | virtual ~Lattice(); |
37 | | |
38 | | struct Node { |
39 | | absl::string_view piece; // Sentence piece representation. |
40 | | uint32 pos; // Unicode position in the sentence. |
41 | | uint32 length; // Unicode length, not UT8 byte. |
42 | | uint32 node_id; // unique id in the current lattice. |
43 | | int id; // vocab id. (maybe -1 for UNK) |
44 | | float score; // logprob of this sentencepiece. |
45 | | float backtrace_score; // backtrace info used in Viterbi. |
46 | | Node *prev; // best previous node on Viterbi path. |
47 | | |
48 | | std::string DebugString() const; |
49 | | }; |
50 | | |
51 | | // Returns bos node. |
52 | | Node *bos_node() const; |
53 | | |
54 | | // Returns eos node. |
55 | | Node *eos_node() const; |
56 | | |
57 | | // Returns nodes starting at |pos|. |
58 | | const std::vector<Node *> &begin_nodes(int pos) const; |
59 | | |
60 | | // Returns nodes ending at |pos|. |
61 | | const std::vector<Node *> &end_nodes(int pos) const; |
62 | | |
63 | | // Returns Unicode character length. |
64 | | int size() const; |
65 | | |
66 | | // Returns multi-byte (utf8) length. |
67 | | int utf8_size() const; |
68 | | |
69 | | // Returns the substring of sentence. sentence[pos:] |
70 | | const char *surface(int pos) const; |
71 | | |
72 | | // Returns immutable sentence. The same as surface(0) |
73 | | const char *sentence() const; |
74 | | |
75 | | // Clears the lattice. |
76 | | void Clear(); |
77 | | |
78 | | // Sets new sentence. |
79 | | void SetSentence(absl::string_view sentence); |
80 | | |
81 | | // Inserts a new node at [pos, pos + length - 1]. |
82 | | // After calling this method, The caller must set Node::score and Node::id. |
83 | | Node *Insert(int pos, int length); |
84 | | |
85 | | using LatticePathWithScore = std::pair<std::vector<Node *>, float>; |
86 | | |
87 | | // Returns Viterbi path. All nodes must be populated in advance. |
88 | | LatticePathWithScore Viterbi(); |
89 | | |
90 | | // Runs forwards/backwards algorithm, returns vector with normalised |
91 | | // transition probs. |
92 | | std::vector<float> ForwardAlgorithm(float theta) const; |
93 | | std::vector<float> BackwardAlgorithm(float theta) const; |
94 | | |
95 | | // Returns n-best results. |
96 | | std::vector<LatticePathWithScore> NBest(size_t nbest_size, bool sample, |
97 | | float theta); |
98 | | |
99 | | // Samples one path from the lattice according to the |
100 | | // generation probability (Product of piece probabilities). |
101 | | // `theta` is a smoothing parameter. |
102 | | std::vector<Node *> Sample(float theta); |
103 | | |
104 | | // Calculates the entropy of the lattice. |
105 | | float CalculateEntropy(float theta) const; |
106 | | |
107 | | // Populates marginal probability of every node in this lattice. |
108 | | // |freq| is the frequency of the sentence. |
109 | | // for (auto *node : all_nodes_) { |
110 | | // (*expected)[node->id] += marginal_prob_of_node * freq; |
111 | | // } |
112 | | // Returns the log-likelihood of this sentence. |
113 | | float PopulateMarginal(float freq, std::vector<float> *expected) const; |
114 | | |
115 | | private: |
116 | | // Returns new node. |
117 | | // Lattice class has the ownership of the returned value. |
118 | | Node *NewNode(); |
119 | | |
120 | | absl::string_view sentence_; |
121 | | std::vector<const char *> surface_; |
122 | | std::vector<std::vector<Node *>> begin_nodes_; |
123 | | std::vector<std::vector<Node *>> end_nodes_; |
124 | | model::FreeList<Node> node_allocator_; |
125 | | }; |
126 | | |
127 | | class Model : public ModelInterface { |
128 | | public: |
129 | | explicit Model(const ModelProto &model_proto); |
130 | 0 | Model() {} |
131 | | ~Model() override; |
132 | | |
133 | | EncodeResult Encode(absl::string_view normalized) const override; |
134 | | |
135 | | NBestEncodeResult NBestEncode(absl::string_view normalized, |
136 | | int nbest_size) const override; |
137 | | |
138 | | EncodeResult SampleEncode(absl::string_view normalized, |
139 | | float theta) const override; |
140 | | |
141 | | NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, |
142 | | float theta, int samples, bool wor, |
143 | | bool include_best) const override; |
144 | | |
145 | | float CalculateEntropy(absl::string_view normalized, |
146 | | float theta) const override; |
147 | | |
148 | 0 | bool IsSampleEncodeAvailable() const override { return true; } |
149 | | |
150 | 0 | bool IsSampleEncodeAndScoreAvailable() const override { return true; } |
151 | | |
152 | 0 | bool IsCalculateEntropyAvailable() const override { return true; } |
153 | | |
154 | 0 | bool IsNBestEncodeAvailable() const override { return true; } |
155 | | |
156 | | // Returns the minimum score in sentence pieces. |
157 | | // min_score() - 10 is used for the cost of unknown sentence. |
158 | 0 | float min_score() const { return min_score_; } |
159 | | |
160 | | // Returns the maximum score in sentence pieces. |
161 | | // max_score() is used for the cost of user defined symbols. |
162 | 0 | float max_score() const { return max_score_; } |
163 | | |
164 | | // Populates all sentence pieces to the |lattice|. |
165 | | // After calling this function, lattice.Viterbi() returns the |
166 | | // best segmentation. |
167 | | void PopulateNodes(Lattice *lattice) const; |
168 | | |
169 | | // Returns a vocab id of |piece|. |
170 | | int PieceToId(absl::string_view piece) const override; |
171 | | |
172 | | // Verifies if two outputs are equivalent by comparing their scores. |
173 | | bool VerifyOutputsEquivalent(absl::string_view expected, |
174 | | absl::string_view actual) const override; |
175 | | |
176 | | enum EncoderVersion { |
177 | | kOptimized, // The optimized encoder. |
178 | | kOriginal // The original encoder. |
179 | | }; |
180 | | |
181 | 0 | void SetEncoderVersion(EncoderVersion encoder_version) { |
182 | 0 | encoder_version_ = encoder_version; |
183 | 0 | } |
184 | | |
185 | | // Returns the current encoder version in use. |
186 | 0 | EncoderVersion GetEncoderVersion() const { return encoder_version_; } |
187 | | |
188 | | protected: |
189 | | // Builds a Trie index. |
190 | | void BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces); |
191 | | |
192 | | // The optimized Viterbi encode. |
193 | | // Main differences from the original function: |
194 | | // 1. Memorizes the best path at each postion so far, |
195 | | // 2. No need to store the Lattice nodes, |
196 | | // 3. Works in utf-8 directly, |
197 | | // 4. Defines a new struct with fewer fields than Lattice, |
198 | | // 5. Does not depend on `class Lattice` nor call `SetSentence()`, |
199 | | // `PopulateNodes()`, or `Viterbi()`. It does everything in one function. |
200 | | // For detailed explanations please see the comments inside the function body. |
201 | | EncodeResult EncodeOptimized(absl::string_view normalized) const; |
202 | | |
203 | | float min_score_ = 0.0; |
204 | | float max_score_ = 0.0; |
205 | | std::unique_ptr<Darts::DoubleArray> trie_; |
206 | | |
207 | | // Maximum size of the return value of Trie, which corresponds |
208 | | // to the maximum size of shared common prefix in the sentence pieces. |
209 | | int trie_results_size_; |
210 | | |
211 | | // encoder version. |
212 | | EncoderVersion encoder_version_ = kOptimized; |
213 | | }; |
214 | | |
215 | | } // namespace unigram |
216 | | } // namespace sentencepiece |
217 | | #endif // UNIGRAM_MODEL_H_ |