Coverage Report

Created: 2026-05-30 06:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/model_interface.cc
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#include "model_interface.h"
16
17
#include <algorithm>
18
19
#include "sentencepiece_model.pb.h"
20
#include "third_party/absl/strings/str_format.h"
21
#include "util.h"
22
23
namespace sentencepiece {
24
25
ModelInterface::ModelInterface(const ModelProto &model_proto)
26
0
    : model_proto_(&model_proto), status_(util::OkStatus()) {}
27
3.17k
ModelInterface::~ModelInterface() {}
28
29
#define RETURN_PIECE(name, default_value)                                \
30
8.13k
  if (model_proto_->trainer_spec().name().empty()) return default_value; \
31
8.12k
  return model_proto_->trainer_spec().name();
32
33
2.03k
absl::string_view ModelInterface::unk_piece() const {
34
2.03k
  RETURN_PIECE(unk_piece, "<unk>");
35
0
}
36
37
2.03k
absl::string_view ModelInterface::bos_piece() const {
38
2.03k
  RETURN_PIECE(bos_piece, "<s>");
39
0
}
40
41
2.03k
absl::string_view ModelInterface::eos_piece() const {
42
2.03k
  RETURN_PIECE(eos_piece, "</s>");
43
0
}
44
45
2.03k
absl::string_view ModelInterface::pad_piece() const {
46
2.03k
  RETURN_PIECE(pad_piece, "<pad>");
47
0
}
48
49
#undef RETURN_PIECE
50
51
4.42M
int ModelInterface::PieceToId(absl::string_view piece) const {
52
4.42M
  if (auto it = reserved_id_map_.find(piece); it != reserved_id_map_.end()) {
53
20.4k
    return it->second;
54
20.4k
  }
55
4.40M
  if (auto it = pieces_.find(piece); it != pieces_.end()) {
56
109k
    return it->second;
57
109k
  }
58
4.29M
  return unk_id_;
59
4.40M
}
60
61
3.17k
void ModelInterface::InitializePieces() {
62
3.17k
  pieces_.clear();
63
3.17k
  reserved_id_map_.clear();
64
3.17k
  unk_id_ = -1;
65
66
3.17k
  std::set<absl::string_view> user_defined_symbols;
67
3.17k
  std::vector<bool> byte_found(256, false);
68
69
3.17k
  int pieces_size = 0;
70
3.17k
  int reserved_id_map_size = 0;
71
1.31M
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
72
1.31M
    const auto &sp = model_proto_->pieces(i);
73
1.31M
    const bool is_normal_piece =
74
1.31M
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
75
26.3k
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
76
25.6k
         sp.type() == ModelProto::SentencePiece::UNUSED);
77
1.31M
    if (is_normal_piece) {
78
1.28M
      ++pieces_size;
79
1.28M
    } else {
80
25.6k
      ++reserved_id_map_size;
81
25.6k
    }
82
1.31M
  }
83
3.17k
  pieces_.reserve(pieces_size);
84
3.17k
  reserved_id_map_.reserve(reserved_id_map_size);
85
86
47.0k
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
87
44.8k
    const auto &sp = model_proto_->pieces(i);
88
44.8k
    if (sp.piece().empty()) {
89
48
      status_ = util::InternalError("piece must not be empty.");
90
48
      return;
91
48
    }
92
44.7k
    if (sp.piece().find('\0') != absl::string_view::npos) {
93
850
      status_ = util::InternalError("piece must not include null character.");
94
850
      return;
95
850
    }
96
43.9k
    const bool is_normal_piece =
97
43.9k
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
98
21.2k
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
99
20.6k
         sp.type() == ModelProto::SentencePiece::UNUSED);
100
43.9k
    if (!port::InsertIfNotPresent(
101
43.9k
            is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) {
102
10
      status_ = util::InternalError(sp.piece() + " is already defined.");
103
10
      return;
104
10
    }
105
106
43.9k
    if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
107
602
      user_defined_symbols.insert(sp.piece());
108
602
    }
109
110
43.9k
    if (sp.type() == ModelProto::SentencePiece::UNKNOWN) {
111
2.96k
      if (unk_id_ >= 0) {
112
2
        status_ = util::InternalError("unk is already defined.");
113
2
        return;
114
2
      }
115
2.96k
      unk_id_ = i;
116
2.96k
    }
117
118
43.9k
    if (sp.type() == ModelProto::SentencePiece::BYTE) {
119
14.7k
      if (!model_proto_->trainer_spec().byte_fallback()) {
120
6
        status_ =
121
6
            util::InternalError("byte piece " + sp.piece() +
122
6
                                " is found although `byte_fallback` is false.");
123
6
        return;
124
6
      }
125
14.7k
      const int byte = PieceToByte(sp.piece());
126
14.7k
      if (0 <= byte && byte < 256) {
127
14.7k
        byte_found[byte] = true;
128
14.7k
      } else {
129
32
        status_ =
130
32
            util::InternalError("byte piece " + sp.piece() + " is invalid.");
131
32
        return;
132
32
      }
133
14.7k
    }
134
43.9k
  }
135
136
2.23k
  if (unk_id_ == -1) {
137
107
    status_ = util::InternalError("unk is not defined.");
138
107
    return;
139
107
  }
140
141
2.12k
  if (model_proto_->trainer_spec().byte_fallback()) {
142
    // Checks that there are 256 byte pieces.
143
67
    if (std::find(byte_found.begin(), byte_found.end(), false) !=
144
67
        byte_found.end()) {
145
39
      status_ = util::InternalError(
146
39
          "there are not 256 byte pieces although `byte_fallback` is true.");
147
39
      return;
148
39
    }
149
67
  }
150
151
2.08k
  matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols);
152
2.08k
}
153
154
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
155
                                              bool treat_ws_as_suffix,
156
761
                                              bool allow_ws_only_pieces) {
157
761
  const char *begin = text.data();
158
761
  const char *end = text.data() + text.size();
159
160
  // Space symbol (U+2581)
161
761
  constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81";
162
761
  bool in_ws_sequence = false;
163
164
761
  std::vector<absl::string_view> result;
165
761
  if (treat_ws_as_suffix) {  // put ws tokens at the end of non-ws sequences.
166
0
    if (begin < end) result.emplace_back(begin, 0);
167
0
    while (begin < end) {
168
0
      const int mblen =
169
0
          std::min<int>(string_util::OneCharLen(begin), end - begin);
170
0
      const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
171
172
0
      if (is_ws) {  // keep track of sequences consecutive ws tokens.
173
0
        in_ws_sequence = true;
174
0
      } else if (in_ws_sequence) {
175
0
        if (allow_ws_only_pieces) result.emplace_back(begin, 0);
176
177
0
        in_ws_sequence = false;
178
0
      }
179
180
0
      result.back() =
181
0
          absl::string_view(result.back().data(), result.back().size() + mblen);
182
0
      begin += mblen;
183
184
0
      if (begin < end && is_ws && !allow_ws_only_pieces)
185
0
        result.emplace_back(begin, 0);
186
0
    }
187
761
  } else {
188
19.0M
    while (begin < end) {
189
19.0M
      const int mblen =
190
19.0M
          std::min<int>(string_util::OneCharLen(begin), end - begin);
191
19.0M
      bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
192
193
      // if is whitespace (and not in sequence if allow_ws_only_pieces is True)
194
19.0M
      if (begin == text.data() ||
195
19.0M
          (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
196
107k
        result.emplace_back(begin, 0);  // add empty string piece.
197
107k
        in_ws_sequence = true;
198
107k
      }
199
200
19.0M
      if (in_ws_sequence && !is_ws) in_ws_sequence = false;
201
202
19.0M
      result.back() =
203
19.0M
          absl::string_view(result.back().data(), result.back().size() + mblen);
204
19.0M
      begin += mblen;
205
19.0M
    }
206
761
  }
207
208
761
  return result;
209
761
}
210
211
1.62k
std::string ByteToPiece(unsigned char c) {
212
1.62k
  return absl::StrFormat("<0x%02X>", c);
213
1.62k
}
214
215
15.3k
int PieceToByte(absl::string_view piece) {
216
15.3k
  using PieceToByteMap = absl::flat_hash_map<std::string, unsigned char>;
217
15.3k
  static const auto *const kMap = []() -> PieceToByteMap * {
218
1
    auto *m = new PieceToByteMap();
219
257
    for (int i = 0; i < 256; ++i) {
220
256
      (*m)[ByteToPiece(i)] = i;
221
256
    }
222
1
    return m;
223
1
  }();
224
225
15.3k
  if (const auto it = kMap->find(piece); it != kMap->end()) {
226
15.3k
    return it->second;
227
15.3k
  }
228
229
32
  return -1;
230
15.3k
}
231
232
}  // namespace sentencepiece