Coverage Report

Created: 2026-01-17 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/sentencepiece/src/model_interface.cc
Line
Count
Source
1
// Copyright 2016 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.!
14
15
#include "model_interface.h"
16
17
#include <algorithm>
18
19
#include "sentencepiece_model.pb.h"
20
#include "third_party/absl/strings/str_format.h"
21
#include "util.h"
22
23
namespace sentencepiece {
24
25
ModelInterface::ModelInterface(const ModelProto &model_proto)
26
0
    : model_proto_(&model_proto), status_(util::OkStatus()) {}
27
0
ModelInterface::~ModelInterface() {}
28
29
#define RETURN_PIECE(name, default_value)                                \
30
0
  if (model_proto_->trainer_spec().name().empty()) return default_value; \
31
0
  return model_proto_->trainer_spec().name();
32
33
0
absl::string_view ModelInterface::unk_piece() const {
34
0
  RETURN_PIECE(unk_piece, "<unk>");
35
0
}
36
37
0
absl::string_view ModelInterface::bos_piece() const {
38
0
  RETURN_PIECE(bos_piece, "<s>");
39
0
}
40
41
0
absl::string_view ModelInterface::eos_piece() const {
42
0
  RETURN_PIECE(eos_piece, "</s>");
43
0
}
44
45
0
absl::string_view ModelInterface::pad_piece() const {
46
0
  RETURN_PIECE(pad_piece, "<pad>");
47
0
}
48
49
#undef RETURN_PIECE
50
51
0
int ModelInterface::PieceToId(absl::string_view piece) const {
52
0
  auto it = reserved_id_map_.find(piece);
53
0
  if (it != reserved_id_map_.end()) {
54
0
    return it->second;
55
0
  }
56
0
  auto it2 = pieces_.find(piece);
57
0
  if (it2 != pieces_.end()) {
58
0
    return it2->second;
59
0
  }
60
0
  return unk_id_;
61
0
}
62
63
0
void ModelInterface::InitializePieces() {
64
0
  pieces_.clear();
65
0
  reserved_id_map_.clear();
66
0
  unk_id_ = -1;
67
68
0
  std::set<absl::string_view> user_defined_symbols;
69
0
  std::vector<bool> byte_found(256, false);
70
71
0
  int pieces_size = 0;
72
0
  int reserved_id_map_size = 0;
73
0
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
74
0
    const auto &sp = model_proto_->pieces(i);
75
0
    const bool is_normal_piece =
76
0
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
77
0
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
78
0
         sp.type() == ModelProto::SentencePiece::UNUSED);
79
0
    if (is_normal_piece) {
80
0
      ++pieces_size;
81
0
    } else {
82
0
      ++reserved_id_map_size;
83
0
    }
84
0
  }
85
0
  pieces_.reserve(pieces_size);
86
0
  reserved_id_map_.reserve(reserved_id_map_size);
87
88
0
  for (int i = 0; i < model_proto_->pieces_size(); ++i) {
89
0
    const auto &sp = model_proto_->pieces(i);
90
0
    if (sp.piece().empty()) {
91
0
      status_ = util::InternalError("piece must not be empty.");
92
0
      return;
93
0
    }
94
95
0
    const bool is_normal_piece =
96
0
        (sp.type() == ModelProto::SentencePiece::NORMAL ||
97
0
         sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
98
0
         sp.type() == ModelProto::SentencePiece::UNUSED);
99
0
    if (!port::InsertIfNotPresent(
100
0
            is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) {
101
0
      status_ = util::InternalError(sp.piece() + " is already defined.");
102
0
      return;
103
0
    }
104
105
0
    if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
106
0
      user_defined_symbols.insert(sp.piece());
107
0
    }
108
109
0
    if (sp.type() == ModelProto::SentencePiece::UNKNOWN) {
110
0
      if (unk_id_ >= 0) {
111
0
        status_ = util::InternalError("unk is already defined.");
112
0
        return;
113
0
      }
114
0
      unk_id_ = i;
115
0
    }
116
117
0
    if (sp.type() == ModelProto::SentencePiece::BYTE) {
118
0
      if (!model_proto_->trainer_spec().byte_fallback()) {
119
0
        status_ =
120
0
            util::InternalError("byte piece " + sp.piece() +
121
0
                                " is found although `byte_fallback` is false.");
122
0
        return;
123
0
      }
124
0
      const int byte = PieceToByte(sp.piece());
125
0
      if (0 <= byte && byte < 256) {
126
0
        byte_found[byte] = true;
127
0
      } else {
128
0
        status_ =
129
0
            util::InternalError("byte piece " + sp.piece() + " is invalid.");
130
0
        return;
131
0
      }
132
0
    }
133
0
  }
134
135
0
  if (unk_id_ == -1) {
136
0
    status_ = util::InternalError("unk is not defined.");
137
0
    return;
138
0
  }
139
140
0
  if (model_proto_->trainer_spec().byte_fallback()) {
141
    // Checks that there are 256 byte pieces.
142
0
    if (std::find(byte_found.begin(), byte_found.end(), false) !=
143
0
        byte_found.end()) {
144
0
      status_ = util::InternalError(
145
0
          "there are not 256 byte pieces although `byte_fallback` is true.");
146
0
      return;
147
0
    }
148
0
  }
149
150
0
  matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols);
151
0
}
152
153
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
154
                                              bool treat_ws_as_suffix,
155
0
                                              bool allow_ws_only_pieces) {
156
0
  const char *begin = text.data();
157
0
  const char *end = text.data() + text.size();
158
159
  // Space symbol (U+2581)
160
0
  const absl::string_view kSpaceSymbol = "\xe2\x96\x81";
161
0
  bool in_ws_sequence = false;
162
163
0
  std::vector<absl::string_view> result;
164
0
  if (treat_ws_as_suffix) {  // put ws tokens at the end of non-ws sequences.
165
0
    if (begin < end) result.emplace_back(begin, 0);
166
0
    while (begin < end) {
167
0
      const int mblen =
168
0
          std::min<int>(string_util::OneCharLen(begin), end - begin);
169
0
      const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
170
171
0
      if (is_ws) {  // keep track of sequences consecutive ws tokens.
172
0
        in_ws_sequence = true;
173
0
      } else if (in_ws_sequence) {
174
0
        if (allow_ws_only_pieces) result.emplace_back(begin, 0);
175
176
0
        in_ws_sequence = false;
177
0
      }
178
179
0
      result.back() =
180
0
          absl::string_view(result.back().data(), result.back().size() + mblen);
181
0
      begin += mblen;
182
183
0
      if (begin < end && is_ws && !allow_ws_only_pieces)
184
0
        result.emplace_back(begin, 0);
185
0
    }
186
0
  } else {
187
0
    while (begin < end) {
188
0
      const int mblen =
189
0
          std::min<int>(string_util::OneCharLen(begin), end - begin);
190
0
      bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
191
192
      // if is whitespace (and not in sequence if allow_ws_only_pieces is True)
193
0
      if (begin == text.data() ||
194
0
          (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
195
0
        result.emplace_back(begin, 0);  // add empty string piece.
196
0
        in_ws_sequence = true;
197
0
      }
198
199
0
      if (in_ws_sequence && !is_ws) in_ws_sequence = false;
200
201
0
      result.back() =
202
0
          absl::string_view(result.back().data(), result.back().size() + mblen);
203
0
      begin += mblen;
204
0
    }
205
0
  }
206
207
0
  return result;
208
0
}
209
210
0
std::string ByteToPiece(unsigned char c) {
211
0
  return absl::StrFormat("<0x%02X>", c);
212
0
}
213
214
0
int PieceToByte(absl::string_view piece) {
215
0
  using PieceToByteMap = absl::flat_hash_map<std::string, unsigned char>;
216
0
  static const auto *const kMap = []() -> PieceToByteMap * {
217
0
    auto *m = new PieceToByteMap();
218
0
    for (int i = 0; i < 256; ++i) {
219
0
      (*m)[ByteToPiece(i)] = i;
220
0
    }
221
0
    return m;
222
0
  }();
223
0
  const auto it = kMap->find(std::string(piece));
224
0
  if (it == kMap->end()) {
225
0
    return -1;
226
0
  } else {
227
0
    return it->second;
228
0
  }
229
0
}
230
231
}  // namespace sentencepiece