Coverage Report

Created: 2025-07-11 06:49

/src/sentencepiece/src/unigram_model.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#ifndef UNIGRAM_MODEL_H_
16
#define UNIGRAM_MODEL_H_
17
18
#include <memory>
19
#include <string>
20
#include <utility>
21
#include <vector>
22
23
#include "common.h"
24
#include "freelist.h"
25
#include "model_interface.h"
26
#include "sentencepiece_model.pb.h"
27
#include "third_party/darts_clone/darts.h"
28
29
namespace sentencepiece {
30
namespace unigram {
31
32
// Lattice represents a search space of sentence piece segmentation.
33
class Lattice {
34
 public:
35
  Lattice();
36
  virtual ~Lattice();
37
38
  struct Node {
39
    absl::string_view piece;  // Sentence piece representation.
40
    uint32 pos;               // Unicode position in the sentence.
41
    uint32 length;            // Unicode length, not UT8 byte.
42
    uint32 node_id;           // unique id in the current lattice.
43
    int id;                   // vocab id. (maybe -1 for UNK)
44
    float score;              // logprob of this sentencepiece.
45
    float backtrace_score;    // backtrace info used in Viterbi.
46
    Node *prev;               // best previous node on Viterbi path.
47
48
    std::string DebugString() const;
49
  };
50
51
  // Returns bos node.
52
  Node *bos_node() const;
53
54
  // Returns eos node.
55
  Node *eos_node() const;
56
57
  // Returns nodes starting at |pos|.
58
  const std::vector<Node *> &begin_nodes(int pos) const;
59
60
  // Returns nodes ending at |pos|.
61
  const std::vector<Node *> &end_nodes(int pos) const;
62
63
  // Returns Unicode character length.
64
  int size() const;
65
66
  // Returns multi-byte (utf8) length.
67
  int utf8_size() const;
68
69
  // Returns the substring of sentence. sentence[pos:]
70
  const char *surface(int pos) const;
71
72
  // Returns immutable sentence. The same as surface(0)
73
  const char *sentence() const;
74
75
  // Clears the lattice.
76
  void Clear();
77
78
  // Sets new sentence.
79
  void SetSentence(absl::string_view sentence);
80
81
  // Inserts a new node at [pos, pos + length - 1].
82
  // After calling this method, The caller must set Node::score and Node::id.
83
  Node *Insert(int pos, int length);
84
85
  using LatticePathWithScore = std::pair<std::vector<Node *>, float>;
86
87
  // Returns Viterbi path. All nodes must be populated in advance.
88
  LatticePathWithScore Viterbi();
89
90
  // Runs forwards/backwards algorithm, returns vector with normalised
91
  // transition probs.
92
  std::vector<float> ForwardAlgorithm(float theta) const;
93
  std::vector<float> BackwardAlgorithm(float theta) const;
94
95
  // Returns n-best results.
96
  std::vector<LatticePathWithScore> NBest(size_t nbest_size, bool sample,
97
                                          float theta);
98
99
  // Samples one path from the lattice according to the
100
  // generation probability (Product of piece probabilities).
101
  // `theta` is a smoothing parameter.
102
  std::vector<Node *> Sample(float theta);
103
104
  // Calculates the entropy of the lattice.
105
  float CalculateEntropy(float theta) const;
106
107
  // Populates marginal probability of every node in this lattice.
108
  // |freq| is the frequency of the sentence.
109
  //  for (auto *node : all_nodes_) {
110
  //    (*expected)[node->id] += marginal_prob_of_node * freq;
111
  //  }
112
  // Returns the log-likelihood of this sentence.
113
  float PopulateMarginal(float freq, std::vector<float> *expected) const;
114
115
 private:
116
  // Returns new node.
117
  // Lattice class has the ownership of the returned value.
118
  Node *NewNode();
119
120
  absl::string_view sentence_;
121
  std::vector<const char *> surface_;
122
  std::vector<std::vector<Node *>> begin_nodes_;
123
  std::vector<std::vector<Node *>> end_nodes_;
124
  model::FreeList<Node> node_allocator_;
125
};
126
127
class Model : public ModelInterface {
128
 public:
129
  explicit Model(const ModelProto &model_proto);
130
0
  Model() {}
131
  ~Model() override;
132
133
  EncodeResult Encode(absl::string_view normalized) const override;
134
135
  NBestEncodeResult NBestEncode(absl::string_view normalized,
136
                                int nbest_size) const override;
137
138
  EncodeResult SampleEncode(absl::string_view normalized,
139
                            float theta) const override;
140
141
  NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
142
                                         float theta, int samples, bool wor,
143
                                         bool include_best) const override;
144
145
  float CalculateEntropy(absl::string_view normalized,
146
                         float theta) const override;
147
148
0
  bool IsSampleEncodeAvailable() const override { return true; }
149
150
0
  bool IsSampleEncodeAndScoreAvailable() const override { return true; }
151
152
0
  bool IsCalculateEntropyAvailable() const override { return true; }
153
154
0
  bool IsNBestEncodeAvailable() const override { return true; }
155
156
  // Returns the minimum score in sentence pieces.
157
  // min_score() - 10 is used for the cost of unknown sentence.
158
0
  float min_score() const { return min_score_; }
159
160
  // Returns the maximum score in sentence pieces.
161
  // max_score() is used for the cost of user defined symbols.
162
0
  float max_score() const { return max_score_; }
163
164
  // Populates all sentence pieces to the |lattice|.
165
  // After calling this function, lattice.Viterbi() returns the
166
  // best segmentation.
167
  void PopulateNodes(Lattice *lattice) const;
168
169
  // Returns a vocab id of |piece|.
170
  int PieceToId(absl::string_view piece) const override;
171
172
  // Verifies if two outputs are equivalent by comparing their scores.
173
  bool VerifyOutputsEquivalent(absl::string_view expected,
174
                               absl::string_view actual) const override;
175
176
  enum EncoderVersion {
177
    kOptimized,  // The optimized encoder.
178
    kOriginal    // The original encoder.
179
  };
180
181
0
  void SetEncoderVersion(EncoderVersion encoder_version) {
182
0
    encoder_version_ = encoder_version;
183
0
  }
184
185
  // Returns the current encoder version in use.
186
0
  EncoderVersion GetEncoderVersion() const { return encoder_version_; }
187
188
 protected:
189
  // Builds a Trie index.
190
  void BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces);
191
192
  // The optimized Viterbi encode.
193
  // Main differences from the original function:
194
  // 1. Memorizes the best path at each postion so far,
195
  // 2. No need to store the Lattice nodes,
196
  // 3. Works in utf-8 directly,
197
  // 4. Defines a new struct with fewer fields than Lattice,
198
  // 5. Does not depend on `class Lattice` nor call `SetSentence()`,
199
  // `PopulateNodes()`, or `Viterbi()`. It does everything in one function.
200
  // For detailed explanations please see the comments inside the function body.
201
  EncodeResult EncodeOptimized(absl::string_view normalized) const;
202
203
  float min_score_ = 0.0;
204
  float max_score_ = 0.0;
205
  std::unique_ptr<Darts::DoubleArray> trie_;
206
207
  // Maximum size of the return value of Trie, which corresponds
208
  // to the maximum size of shared common prefix in the sentence pieces.
209
  int trie_results_size_;
210
211
  // encoder version.
212
  EncoderVersion encoder_version_ = kOptimized;
213
};
214
215
}  // namespace unigram
216
}  // namespace sentencepiece
217
#endif  // UNIGRAM_MODEL_H_