/src/myanmar-tools/clients/cpp/zawgyi_detector.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2017 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include <cmath> |
16 | | #include <cstddef> |
17 | | #include <cstdint> |
18 | | #include <cstring> |
19 | | #include <limits> |
20 | | #include <glog/logging.h> |
21 | | #include <unicode/utf8.h> |
22 | | |
23 | | #include "public/myanmartools.h" |
24 | | #include "zawgyi_detector-impl.h" |
25 | | |
26 | | namespace { |
27 | | const uint8_t kModelData[] = { |
28 | | #include "zawgyi_model_data.inc" |
29 | | }; |
30 | | constexpr size_t kModelSize = sizeof kModelData; |
31 | | } // namespace |
32 | | |
33 | | using namespace google_myanmar_tools; |
34 | | |
35 | | #if __BYTE_ORDER == __LITTLE_ENDIAN |
36 | 4.23M | # define BSWAP(dest, bits) __builtin_bswap##bits(dest); |
37 | | #else |
38 | | # define BSWAP(dest, bits) dest; |
39 | | #endif |
40 | | |
41 | | /** |
42 | | * Loads a big-endian type from ptr to dest. Static-asserts that the number of |
43 | | * bytes matches the expected size. |
44 | | */ |
45 | | #define BIG_ENDIAN_LOAD(ptr, dest, bits) \ |
46 | 4.23M | static_assert(sizeof(dest) == bits / 8, \ |
47 | 4.23M | "Expected type to be " #bits " bits"); \ |
48 | 4.23M | uint##bits##_t u; \ |
49 | 4.23M | memcpy(&u, ptr, bits / 8); \ |
50 | 4.23M | u = BSWAP(u, bits); \ |
51 | 4.23M | memcpy(&dest, &u, bits / 8); |
52 | | |
53 | 904 | int64_t BigEndian::loadInt64(const void* ptr) { |
54 | 904 | int64_t dest; |
55 | 904 | BIG_ENDIAN_LOAD(ptr, dest, 64); |
56 | 904 | return dest; |
57 | 904 | } |
58 | | |
59 | 1.35k | int32_t BigEndian::loadInt32(const void* ptr) { |
60 | 1.35k | int32_t dest; |
61 | 1.35k | BIG_ENDIAN_LOAD(ptr, dest, 32); |
62 | 1.35k | return dest; |
63 | 1.35k | } |
64 | | |
65 | 2.13M | int16_t BigEndian::loadInt16(const void* ptr) { |
66 | 2.13M | int16_t dest; |
67 | 2.13M | BIG_ENDIAN_LOAD(ptr, dest, 16); |
68 | 2.13M | return dest; |
69 | 2.13M | } |
70 | | |
71 | 2.09M | float BigEndian::loadFloat(const void* ptr) { |
72 | 2.09M | float dest; |
73 | 2.09M | BIG_ENDIAN_LOAD(ptr, dest, 32); |
74 | 2.09M | return dest; |
75 | 2.09M | } |
76 | | |
77 | | // Implement Markov Chain processing. |
78 | 452 | BinaryMarkovClassifier::BinaryMarkovClassifier(const uint8_t* binary_ptr) { |
79 | | // Binary formatted file: |
80 | | // magic number: int64 |
81 | | // version: int32 |
82 | | // int16 size of model N |
83 | | // N entries of form: |
84 | | // int16 entry_count |
85 | | // float default_log_value for row unless entry_count is zero |
86 | | // entry count items of: |
87 | | // byte: index |
88 | | // float: log_value |
89 | | |
90 | 452 | const uint8_t* data_ptr = binary_ptr; |
91 | | |
92 | 452 | model_size_ = 0; |
93 | 452 | model_array_ = nullptr; |
94 | | |
95 | 452 | int64_t magic_number; |
96 | 452 | int32_t version; |
97 | | |
98 | 452 | magic_number = BigEndian::loadInt64(data_ptr); |
99 | 452 | data_ptr += sizeof(magic_number); |
100 | | |
101 | 452 | CHECK_EQ(BINARY_TAG, magic_number); |
102 | | |
103 | 452 | version = BigEndian::loadInt32(data_ptr); |
104 | 452 | data_ptr += sizeof(version); |
105 | | |
106 | 452 | CHECK_EQ(VERSION, version); |
107 | | |
108 | 452 | model_size_ = BigEndian::loadInt16(data_ptr); |
109 | 452 | data_ptr += sizeof(model_size_); |
110 | 452 | VLOG(2) << "BinaryMarkovClassifier size = " << model_size_; |
111 | | |
112 | 452 | model_array_ = new float[model_size_ * model_size_]; |
113 | | |
114 | 452 | float row_default_value; |
115 | | // Read each "row". |
116 | 103k | for (int row = 0; row < model_size_; ++row) { |
117 | 102k | int16_t row_entry_count; |
118 | 102k | row_entry_count = BigEndian::loadInt16(data_ptr); |
119 | 102k | data_ptr += sizeof(row_entry_count); |
120 | | |
121 | 102k | if (row_entry_count != 0) { |
122 | 55.5k | row_default_value = BigEndian::loadFloat(data_ptr); |
123 | 55.5k | data_ptr += sizeof(row_default_value); |
124 | 55.5k | } else { |
125 | 47.0k | row_default_value = 0.0f; |
126 | 47.0k | } |
127 | | |
128 | 102k | int index; |
129 | | // Set all the entries in the row to default. |
130 | 23.3M | for (int col = 0; col < model_size_; ++col) { |
131 | 23.2M | index = row * model_size_ + col; |
132 | 23.2M | model_array_[index] = row_default_value; |
133 | 23.2M | } |
134 | | |
135 | 102k | int16_t column; |
136 | | // Set non-default values. |
137 | 2.13M | for (int entry = 0; entry < row_entry_count; ++entry) { |
138 | 2.03M | column = BigEndian::loadInt16(data_ptr); |
139 | 2.03M | data_ptr += sizeof(column); |
140 | | |
141 | 2.03M | index = row * model_size_ + column; |
142 | | |
143 | 2.03M | model_array_[index] = BigEndian::loadFloat(data_ptr); |
144 | 2.03M | data_ptr += sizeof(float); |
145 | 2.03M | } |
146 | 102k | } |
147 | 452 | } |
148 | | |
149 | 452 | BinaryMarkovClassifier::~BinaryMarkovClassifier() { |
150 | 452 | delete[] model_array_; |
151 | 452 | } |
152 | | |
153 | 120k | float BinaryMarkovClassifier::GetLogProbabilityDifference(int i1, int i2) { |
154 | 120k | return model_array_[i1 * model_size_ + i2]; |
155 | 120k | } |
156 | | |
157 | | //---------------------------------------------------------------------------- |
158 | | |
159 | | // Initialize ZawgyiUnicode models from the stream |
160 | 452 | ZawgyiUnicodeMarkovModel::ZawgyiUnicodeMarkovModel(const uint8_t* data_models) { |
161 | 452 | int64_t magic_number; |
162 | 452 | const uint8_t* input_ptr = data_models; |
163 | | |
164 | 452 | magic_number = BigEndian::loadInt64(input_ptr); |
165 | 452 | input_ptr += sizeof(magic_number); |
166 | | |
167 | 452 | CHECK_EQ(BINARY_TAG, magic_number); |
168 | | |
169 | 452 | int32_t version = BigEndian::loadInt32(input_ptr); |
170 | 452 | input_ptr += sizeof(version); |
171 | | |
172 | 452 | if (version == 1) { |
173 | | // No SSV field |
174 | 0 | ssv_ = 0; |
175 | 452 | } else { |
176 | 452 | CHECK_EQ(2, version); |
177 | 452 | ssv_ = BigEndian::loadInt32(input_ptr); |
178 | 452 | input_ptr += sizeof(ssv_); |
179 | 452 | CHECK_GE(ssv_, 0); |
180 | 452 | CHECK_LT(ssv_, SSV_COUNT); |
181 | 452 | } |
182 | | |
183 | 452 | classifier_ = new BinaryMarkovClassifier(input_ptr); |
184 | 452 | } |
185 | | |
186 | 452 | ZawgyiUnicodeMarkovModel::~ZawgyiUnicodeMarkovModel() { |
187 | 452 | delete classifier_; |
188 | 452 | } |
189 | | |
190 | | double |
191 | | ZawgyiUnicodeMarkovModel::Predict(const char* input_utf8, |
192 | 452 | int32_t length) const { |
193 | 452 | if (length < 0) { |
194 | 0 | size_t length_size = strlen(input_utf8); |
195 | 0 | if (length_size > __INT32_MAX__) { |
196 | 0 | return -std::numeric_limits<double>::infinity(); |
197 | 0 | } |
198 | 0 | length = static_cast<int32_t>(length_size); |
199 | 0 | } |
200 | | |
201 | | // Start at the base state |
202 | 452 | int prevState = 0; |
203 | | |
204 | 452 | double totalDelta = 0.0; |
205 | 452 | bool seenTransition = false; |
206 | 13.9M | for (int32_t i = 0; i <= length;) { |
207 | 13.9M | int currState; |
208 | 13.9M | if (i >= length) { |
209 | 452 | currState = 0; |
210 | 452 | i++; |
211 | 13.9M | } else { |
212 | 13.9M | char32_t cp; |
213 | 13.9M | U8_NEXT(input_utf8, i, length, cp); |
214 | 13.9M | currState = GetIndexForCodePoint(cp); |
215 | 13.9M | } |
216 | | // Ignore 0-to-0 transitions |
217 | 13.9M | if (prevState != 0 || currState != 0) { |
218 | 120k | float delta = |
219 | 120k | classifier_->GetLogProbabilityDifference(prevState, currState); |
220 | 120k | totalDelta += delta; |
221 | 120k | seenTransition = true; |
222 | 120k | } |
223 | 13.9M | prevState = currState; |
224 | 13.9M | } |
225 | | |
226 | | // Special case: if there is no signal (both log probabilities are zero), |
227 | | // return -Infinity, which will get interpreted by users as strong Unicode. |
228 | | // This happens when the input string contains no Myanmar-range code points. |
229 | 452 | if (!seenTransition) { |
230 | 318 | return -std::numeric_limits<double>::infinity(); |
231 | 318 | } |
232 | | |
233 | | // result = Pz/(Pu+Pz) |
234 | | // = exp(logPz)/(exp(logPu)+exp(logPz)) |
235 | | // = 1/(1+exp(logPu-logPz)) |
236 | 134 | return 1.0 / (1.0 + exp(totalDelta)); |
237 | 452 | } |
238 | | |
239 | 13.9M | int16_t ZawgyiUnicodeMarkovModel::GetIndexForCodePoint(char32_t cp) const { |
240 | 13.9M | if (STD_CP0 <= cp && cp <= STD_CP1) { |
241 | 41.6k | return cp - STD_CP0 + STD_OFFSET; |
242 | 41.6k | } |
243 | 13.9M | if (AFT_CP0 <= cp && cp <= AFT_CP1) { |
244 | 9.10k | return cp - AFT_CP0 + AFT_OFFSET; |
245 | 9.10k | } |
246 | 13.9M | if (EXA_CP0 <= cp && cp <= EXA_CP1) { |
247 | 18.8k | return cp - EXA_CP0 + EXA_OFFSET; |
248 | 18.8k | } |
249 | 13.9M | if (EXB_CP0 <= cp && cp <= EXB_CP1) { |
250 | 1.86k | return cp - EXB_CP0 + EXB_OFFSET; |
251 | 1.86k | } |
252 | 13.9M | if (ssv_ == SSV_STD_EXA_EXB_SPC && SPC_CP0 <= cp && cp <= SPC_CP1) { |
253 | 3.08k | return cp - SPC_CP0 + SPC_OFFSET; |
254 | 3.08k | } |
255 | 13.9M | return 0; |
256 | 13.9M | } |
257 | | |
258 | | |
259 | | //---------------------------------------------------------------------------- |
260 | | |
261 | | // Reads standard detection modes from embedded data. |
262 | 452 | ZawgyiDetector::ZawgyiDetector() { |
263 | 452 | CHECK(kModelData) << " null model_data loaded"; |
264 | 452 | CHECK(kModelSize > 0) << " model size = " << kModelSize; |
265 | 452 | VLOG(2) << "model_data size = " << kModelSize; |
266 | | // TODO: Check kModelSize when reading the model? |
267 | 452 | model_ = new ZawgyiUnicodeMarkovModel(kModelData); |
268 | 452 | } |
269 | | |
270 | 452 | ZawgyiDetector::~ZawgyiDetector() { |
271 | 452 | delete model_; |
272 | 452 | } |
273 | | |
274 | | double ZawgyiDetector::GetZawgyiProbability(const char* input_utf8, |
275 | 452 | int32_t length) const { |
276 | 452 | return model_->Predict(input_utf8, length); |
277 | 452 | } |
278 | | |
279 | | // C bindings (declared with extern "C"). |
280 | 0 | GMTZawgyiDetector* GMTOpenZawgyiDetector(void) { |
281 | 0 | return reinterpret_cast<GMTZawgyiDetector*>(new ZawgyiDetector()); |
282 | 0 | } |
283 | | |
284 | 0 | void GMTCloseZawgyiDetector(GMTZawgyiDetector* detector) { |
285 | 0 | ZawgyiDetector* cppDetector = reinterpret_cast<ZawgyiDetector*>(detector); |
286 | 0 | delete cppDetector; |
287 | 0 | } |
288 | | |
289 | 0 | double GMTGetZawgyiProbability(GMTZawgyiDetector* detector, const char* input_utf8) { |
290 | 0 | return GMTGetZawgyiProbabilityWithLength(detector, input_utf8, -1); |
291 | 0 | } |
292 | | |
293 | 0 | double GMTGetZawgyiProbabilityWithLength(GMTZawgyiDetector* detector, const char* input_utf8, int32_t length) { |
294 | 0 | ZawgyiDetector* cppDetector = reinterpret_cast<ZawgyiDetector*>(detector); |
295 | 0 | return cppDetector->GetZawgyiProbability(input_utf8, length); |
296 | 0 | } |