Coverage Report

Created: 2026-06-07 07:04

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/model_interface.cc
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#include "model_interface.h"
16
17
#include <algorithm>
18
19
#include "sentencepiece_model.pb.h"
20
#include "third_party/absl/strings/str_format.h"
21
#include "util.h"
22
23
namespace sentencepiece {
24
25
ModelInterface::ModelInterface(const ModelProto &model_proto)
26
0
    : model_proto_(&model_proto), status_(util::OkStatus()) {}
27
3.17k
ModelInterface::~ModelInterface() {}
28
29
#define RETURN_PIECE(name, default_value)                                \
30
8.13k
  if (model_proto_->trainer_spec().name().empty()) return default_value; \
31
8.12k
  return model_proto_->trainer_spec().name();
32
33
2.03k
absl::string_view ModelInterface::unk_piece() const {
34
2.03k
  RETURN_PIECE(unk_piece, "<unk>");
35
0
}
36
37
2.03k
absl::string_view ModelInterface::bos_piece() const {
38
2.03k
  RETURN_PIECE(bos_piece, "<s>");
39
0
}
40
41
2.03k
absl::string_view ModelInterface::eos_piece() const {
42
2.03k
  RETURN_PIECE(eos_piece, "</s>");
43
0
}
44
45
2.03k
absl::string_view ModelInterface::pad_piece() const {
46
2.03k
  RETURN_PIECE(pad_piece, "<pad>");
47
0
}
48
49
#undef RETURN_PIECE
50
51
4.42M
int ModelInterface::PieceToId(absl::string_view piece) const {
52
4.42M
  if (auto it = reserved_id_map_.find(piece); it != reserved_id_map_.end()) {
53
20.4k
    return it->second;
54
20.4k
  }
55
4.40M
  if (auto it = pieces_.find(piece); it != pieces_.end()) {
56
109k
    return it->second;
57
109k
  }
58
4.29M
  return unk_id_;
59
4.40M
}
60
61
3.17k
void ModelInterface::InitializePieces() {
62
3.17k
  pieces_.clear();
63
3.17k
  reserved_id_map_.clear();
64
3.17k
  unk_id_ = -1;
65
66
3.17k
  std::set<absl::string_view> user_defined_symbols;
67
3.17k
  std::vector<bool> byte_found(256, false);
68
69
3.17k
  int pieces_size = 0;
70
3.17k
  int reserved_id_map_size = 0;
71
1.31M
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
72
1.31M
    const auto &sp = model_proto_->pieces(i);
73
1.31M
    static constexpr size_t kMaxPieceSize = 8192;
74
1.31M
    if (sp.piece().size() >= kMaxPieceSize) {
75
0
      status_ = util::InternalError("piece size must be less than 8k.");
76
0
      return;
77
0
    }
78
1.31M
    const bool is_normal_piece =
79
1.31M
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
80
26.3k
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
81
25.6k
         sp.type() == ModelProto::SentencePiece::UNUSED);
82
1.31M
    if (is_normal_piece) {
83
1.28M
      ++pieces_size;
84
1.28M
    } else {
85
25.6k
      ++reserved_id_map_size;
86
25.6k
    }
87
1.31M
  }
88
3.17k
  pieces_.reserve(pieces_size);
89
3.17k
  reserved_id_map_.reserve(reserved_id_map_size);
90
91
47.0k
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
92
44.8k
    const auto &sp = model_proto_->pieces(i);
93
44.8k
    if (sp.piece().empty()) {
94
48
      status_ = util::InternalError("piece must not be empty.");
95
48
      return;
96
48
    }
97
44.7k
    if (sp.piece().find('\0') != absl::string_view::npos) {
98
850
      status_ = util::InternalError("piece must not include null character.");
99
850
      return;
100
850
    }
101
43.9k
    const bool is_normal_piece =
102
43.9k
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
103
21.2k
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
104
20.6k
         sp.type() == ModelProto::SentencePiece::UNUSED);
105
43.9k
    if (!port::InsertIfNotPresent(
106
43.9k
            is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) {
107
10
      status_ = util::InternalError(sp.piece() + " is already defined.");
108
10
      return;
109
10
    }
110
111
43.9k
    if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
112
602
      user_defined_symbols.insert(sp.piece());
113
602
    }
114
115
43.9k
    if (sp.type() == ModelProto::SentencePiece::UNKNOWN) {
116
2.96k
      if (unk_id_ >= 0) {
117
2
        status_ = util::InternalError("unk is already defined.");
118
2
        return;
119
2
      }
120
2.96k
      unk_id_ = i;
121
2.96k
    }
122
123
43.9k
    if (sp.type() == ModelProto::SentencePiece::BYTE) {
124
14.7k
      if (!model_proto_->trainer_spec().byte_fallback()) {
125
6
        status_ =
126
6
            util::InternalError("byte piece " + sp.piece() +
127
6
                                " is found although `byte_fallback` is false.");
128
6
        return;
129
6
      }
130
14.7k
      const int byte = PieceToByte(sp.piece());
131
14.7k
      if (0 <= byte && byte < 256) {
132
14.7k
        byte_found[byte] = true;
133
14.7k
      } else {
134
32
        status_ =
135
32
            util::InternalError("byte piece " + sp.piece() + " is invalid.");
136
32
        return;
137
32
      }
138
14.7k
    }
139
43.9k
  }
140
141
2.23k
  if (unk_id_ == -1) {
142
107
    status_ = util::InternalError("unk is not defined.");
143
107
    return;
144
107
  }
145
146
2.12k
  if (model_proto_->trainer_spec().byte_fallback()) {
147
    // Checks that there are 256 byte pieces.
148
67
    if (std::find(byte_found.begin(), byte_found.end(), false) !=
149
67
        byte_found.end()) {
150
39
      status_ = util::InternalError(
151
39
          "there are not 256 byte pieces although `byte_fallback` is true.");
152
39
      return;
153
39
    }
154
67
  }
155
156
2.08k
  matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols);
157
2.08k
}
158
159
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
160
                                              bool treat_ws_as_suffix,
161
761
                                              bool allow_ws_only_pieces) {
162
761
  const char *begin = text.data();
163
761
  const char *end = text.data() + text.size();
164
165
  // Space symbol (U+2581)
166
761
  constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81";
167
761
  bool in_ws_sequence = false;
168
169
761
  std::vector<absl::string_view> result;
170
761
  if (treat_ws_as_suffix) {  // put ws tokens at the end of non-ws sequences.
171
0
    if (begin < end) result.emplace_back(begin, 0);
172
0
    while (begin < end) {
173
0
      const int mblen =
174
0
          std::min<int>(string_util::OneCharLen(begin), end - begin);
175
0
      const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
176
177
0
      if (is_ws) {  // keep track of sequences consecutive ws tokens.
178
0
        in_ws_sequence = true;
179
0
      } else if (in_ws_sequence) {
180
0
        if (allow_ws_only_pieces) result.emplace_back(begin, 0);
181
182
0
        in_ws_sequence = false;
183
0
      }
184
185
0
      result.back() =
186
0
          absl::string_view(result.back().data(), result.back().size() + mblen);
187
0
      begin += mblen;
188
189
0
      if (begin < end && is_ws && !allow_ws_only_pieces)
190
0
        result.emplace_back(begin, 0);
191
0
    }
192
761
  } else {
193
19.0M
    while (begin < end) {
194
19.0M
      const int mblen =
195
19.0M
          std::min<int>(string_util::OneCharLen(begin), end - begin);
196
19.0M
      bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
197
198
      // if is whitespace (and not in sequence if allow_ws_only_pieces is True)
199
19.0M
      if (begin == text.data() ||
200
19.0M
          (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
201
107k
        result.emplace_back(begin, 0);  // add empty string piece.
202
107k
        in_ws_sequence = true;
203
107k
      }
204
205
19.0M
      if (in_ws_sequence && !is_ws) in_ws_sequence = false;
206
207
19.0M
      result.back() =
208
19.0M
          absl::string_view(result.back().data(), result.back().size() + mblen);
209
19.0M
      begin += mblen;
210
19.0M
    }
211
761
  }
212
213
761
  return result;
214
761
}
215
216
1.62k
const std::string &ByteToPiece(unsigned char c) {
217
1.62k
  static const std::vector<std::string> *const kBytePieces = []() {
218
1
    auto *v = new std::vector<std::string>(256);
219
257
    for (int i = 0; i < 256; ++i) {
220
256
      (*v)[i] = absl::StrFormat("<0x%02X>", i);
221
256
    }
222
1
    return v;
223
1
  }();
224
1.62k
  return (*kBytePieces)[c];
225
1.62k
}
226
227
15.3k
int PieceToByte(absl::string_view piece) {
228
15.3k
  using PieceToByteMap = absl::flat_hash_map<absl::string_view, unsigned char>;
229
15.3k
  static const auto *const kMap = []() -> PieceToByteMap * {
230
1
    auto *m = new PieceToByteMap();
231
257
    for (int i = 0; i < 256; ++i) {
232
256
      (*m)[ByteToPiece(i)] = i;
233
256
    }
234
1
    return m;
235
1
  }();
236
237
15.3k
  if (const auto it = kMap->find(piece); it != kMap->end()) {
238
15.3k
    return it->second;
239
15.3k
  }
240
241
32
  return -1;
242
15.3k
}
243
244
}  // namespace sentencepiece