/src/icu/source/common/lstmbe.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // © 2021 and later: Unicode, Inc. and others. |
2 | | // License & terms of use: http://www.unicode.org/copyright.html |
3 | | |
4 | | #include <utility> |
5 | | #include <ctgmath> |
6 | | |
7 | | #include "unicode/utypes.h" |
8 | | |
9 | | #if !UCONFIG_NO_BREAK_ITERATION |
10 | | |
11 | | #include "brkeng.h" |
12 | | #include "charstr.h" |
13 | | #include "cmemory.h" |
14 | | #include "lstmbe.h" |
15 | | #include "putilimp.h" |
16 | | #include "uassert.h" |
17 | | #include "ubrkimpl.h" |
18 | | #include "uresimp.h" |
19 | | #include "uvectr32.h" |
20 | | #include "uvector.h" |
21 | | |
22 | | #include "unicode/brkiter.h" |
23 | | #include "unicode/resbund.h" |
24 | | #include "unicode/ubrk.h" |
25 | | #include "unicode/uniset.h" |
26 | | #include "unicode/ustring.h" |
27 | | #include "unicode/utf.h" |
28 | | |
29 | | U_NAMESPACE_BEGIN |
30 | | |
31 | | // Uncomment the following #define to debug. |
32 | | // #define LSTM_DEBUG 1 |
33 | | // #define LSTM_VECTORIZER_DEBUG 1 |
34 | | |
35 | | /** |
36 | | * Interface for reading 1D array. |
37 | | */ |
38 | | class ReadArray1D { |
39 | | public: |
40 | | virtual ~ReadArray1D(); |
41 | | virtual int32_t d1() const = 0; |
42 | | virtual float get(int32_t i) const = 0; |
43 | | |
44 | | #ifdef LSTM_DEBUG |
45 | | void print() const { |
46 | | printf("\n["); |
47 | | for (int32_t i = 0; i < d1(); i++) { |
48 | | printf("%0.8e ", get(i)); |
49 | | if (i % 4 == 3) printf("\n"); |
50 | | } |
51 | | printf("]\n"); |
52 | | } |
53 | | #endif |
54 | | }; |
55 | | |
56 | | ReadArray1D::~ReadArray1D() |
57 | 0 | { |
58 | 0 | } |
59 | | |
60 | | /** |
61 | | * Interface for reading 2D array. |
62 | | */ |
63 | | class ReadArray2D { |
64 | | public: |
65 | | virtual ~ReadArray2D(); |
66 | | virtual int32_t d1() const = 0; |
67 | | virtual int32_t d2() const = 0; |
68 | | virtual float get(int32_t i, int32_t j) const = 0; |
69 | | }; |
70 | | |
71 | | ReadArray2D::~ReadArray2D() |
72 | 0 | { |
73 | 0 | } |
74 | | |
75 | | /** |
76 | | * A class to index a float array as a 1D Array without owning the pointer or |
77 | | * copy the data. |
78 | | */ |
79 | | class ConstArray1D : public ReadArray1D { |
80 | | public: |
81 | 0 | ConstArray1D() : data_(nullptr), d1_(0) {} |
82 | | |
83 | 0 | ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {} |
84 | | |
85 | | virtual ~ConstArray1D(); |
86 | | |
87 | | // Init the object, the object does not own the data nor copy. |
88 | | // It is designed to directly use data from memory mapped resources. |
89 | 0 | void init(const int32_t* data, int32_t d1) { |
90 | 0 | U_ASSERT(IEEE_754 == 1); |
91 | 0 | data_ = reinterpret_cast<const float*>(data); |
92 | 0 | d1_ = d1; |
93 | 0 | } |
94 | | |
95 | | // ReadArray1D methods. |
96 | 0 | virtual int32_t d1() const { return d1_; } |
97 | 0 | virtual float get(int32_t i) const { |
98 | 0 | U_ASSERT(i < d1_); |
99 | 0 | return data_[i]; |
100 | 0 | } |
101 | | |
102 | | private: |
103 | | const float* data_; |
104 | | int32_t d1_; |
105 | | }; |
106 | | |
107 | | ConstArray1D::~ConstArray1D() |
108 | | { |
109 | | } |
110 | | |
111 | | /** |
112 | | * A class to index a float array as a 2D Array without owning the pointer or |
113 | | * copy the data. |
114 | | */ |
115 | | class ConstArray2D : public ReadArray2D { |
116 | | public: |
117 | 0 | ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {} |
118 | | |
119 | | ConstArray2D(const float* data, int32_t d1, int32_t d2) |
120 | 0 | : data_(data), d1_(d1), d2_(d2) {} |
121 | | |
122 | | virtual ~ConstArray2D(); |
123 | | |
124 | | // Init the object, the object does not own the data nor copy. |
125 | | // It is designed to directly use data from memory mapped resources. |
126 | 0 | void init(const int32_t* data, int32_t d1, int32_t d2) { |
127 | 0 | U_ASSERT(IEEE_754 == 1); |
128 | 0 | data_ = reinterpret_cast<const float*>(data); |
129 | 0 | d1_ = d1; |
130 | 0 | d2_ = d2; |
131 | 0 | } |
132 | | |
133 | | // ReadArray2D methods. |
134 | 0 | inline int32_t d1() const { return d1_; } |
135 | 0 | inline int32_t d2() const { return d2_; } |
136 | 0 | float get(int32_t i, int32_t j) const { |
137 | 0 | U_ASSERT(i < d1_); |
138 | 0 | U_ASSERT(j < d2_); |
139 | 0 | return data_[i * d2_ + j]; |
140 | 0 | } |
141 | | |
142 | | // Expose the ith row as a ConstArray1D |
143 | 0 | inline ConstArray1D row(int32_t i) const { |
144 | 0 | U_ASSERT(i < d1_); |
145 | 0 | return ConstArray1D(data_ + i * d2_, d2_); |
146 | 0 | } |
147 | | |
148 | | private: |
149 | | const float* data_; |
150 | | int32_t d1_; |
151 | | int32_t d2_; |
152 | | }; |
153 | | |
154 | | ConstArray2D::~ConstArray2D() |
155 | | { |
156 | | } |
157 | | |
158 | | /** |
159 | | * A class to allocate data as a writable 1D array. |
160 | | * This is the main class implement matrix operation. |
161 | | */ |
162 | | class Array1D : public ReadArray1D { |
163 | | public: |
164 | 0 | Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {} |
165 | | Array1D(int32_t d1, UErrorCode &status) |
166 | | : memory_(uprv_malloc(d1 * sizeof(float))), |
167 | 0 | data_((float*)memory_), d1_(d1) { |
168 | 0 | if (U_SUCCESS(status)) { |
169 | 0 | if (memory_ == nullptr) { |
170 | 0 | status = U_MEMORY_ALLOCATION_ERROR; |
171 | 0 | return; |
172 | 0 | } |
173 | 0 | clear(); |
174 | 0 | } |
175 | 0 | } |
176 | | |
177 | | virtual ~Array1D(); |
178 | | |
179 | | // A special constructor which does not own the memory but writeable |
180 | | // as a slice of an array. |
181 | | Array1D(float* data, int32_t d1) |
182 | 0 | : memory_(nullptr), data_(data), d1_(d1) {} |
183 | | |
184 | | // ReadArray1D methods. |
185 | 0 | virtual int32_t d1() const { return d1_; } |
186 | 0 | virtual float get(int32_t i) const { |
187 | 0 | U_ASSERT(i < d1_); |
188 | 0 | return data_[i]; |
189 | 0 | } |
190 | | |
191 | | // Return the index which point to the max data in the array. |
192 | 0 | inline int32_t maxIndex() const { |
193 | 0 | int32_t index = 0; |
194 | 0 | float max = data_[0]; |
195 | 0 | for (int32_t i = 1; i < d1_; i++) { |
196 | 0 | if (data_[i] > max) { |
197 | 0 | max = data_[i]; |
198 | 0 | index = i; |
199 | 0 | } |
200 | 0 | } |
201 | 0 | return index; |
202 | 0 | } |
203 | | |
204 | | // Slice part of the array to a new one. |
205 | 0 | inline Array1D slice(int32_t from, int32_t size) const { |
206 | 0 | U_ASSERT(from >= 0); |
207 | 0 | U_ASSERT(from < d1_); |
208 | 0 | U_ASSERT(from + size <= d1_); |
209 | 0 | return Array1D(data_ + from, size); |
210 | 0 | } |
211 | | |
212 | | // Add dot product of a 1D array and a 2D array into this one. |
213 | 0 | inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) { |
214 | 0 | U_ASSERT(a.d1() == b.d1()); |
215 | 0 | U_ASSERT(b.d2() == d1()); |
216 | 0 | for (int32_t i = 0; i < d1(); i++) { |
217 | 0 | for (int32_t j = 0; j < a.d1(); j++) { |
218 | 0 | data_[i] += a.get(j) * b.get(j, i); |
219 | 0 | } |
220 | 0 | } |
221 | 0 | return *this; |
222 | 0 | } |
223 | | |
224 | | // Hadamard Product the values of another array of the same size into this one. |
225 | 0 | inline Array1D& hadamardProduct(const ReadArray1D& a) { |
226 | 0 | U_ASSERT(a.d1() == d1()); |
227 | 0 | for (int32_t i = 0; i < d1(); i++) { |
228 | 0 | data_[i] *= a.get(i); |
229 | 0 | } |
230 | 0 | return *this; |
231 | 0 | } |
232 | | |
233 | | // Add the Hadamard Product of two arrays of the same size into this one. |
234 | 0 | inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) { |
235 | 0 | U_ASSERT(a.d1() == d1()); |
236 | 0 | U_ASSERT(b.d1() == d1()); |
237 | 0 | for (int32_t i = 0; i < d1(); i++) { |
238 | 0 | data_[i] += a.get(i) * b.get(i); |
239 | 0 | } |
240 | 0 | return *this; |
241 | 0 | } |
242 | | |
243 | | // Add the values of another array of the same size into this one. |
244 | 0 | inline Array1D& add(const ReadArray1D& a) { |
245 | 0 | U_ASSERT(a.d1() == d1()); |
246 | 0 | for (int32_t i = 0; i < d1(); i++) { |
247 | 0 | data_[i] += a.get(i); |
248 | 0 | } |
249 | 0 | return *this; |
250 | 0 | } |
251 | | |
252 | | // Assign the values of another array of the same size into this one. |
253 | 0 | inline Array1D& assign(const ReadArray1D& a) { |
254 | 0 | U_ASSERT(a.d1() == d1()); |
255 | 0 | for (int32_t i = 0; i < d1(); i++) { |
256 | 0 | data_[i] = a.get(i); |
257 | 0 | } |
258 | 0 | return *this; |
259 | 0 | } |
260 | | |
261 | | // Apply tanh to all the elements in the array. |
262 | 0 | inline Array1D& tanh() { |
263 | 0 | return tanh(*this); |
264 | 0 | } |
265 | | |
266 | | // Apply tanh of a and store into this array. |
267 | 0 | inline Array1D& tanh(const Array1D& a) { |
268 | 0 | U_ASSERT(a.d1() == d1()); |
269 | 0 | for (int32_t i = 0; i < d1_; i++) { |
270 | 0 | data_[i] = std::tanh(a.get(i)); |
271 | 0 | } |
272 | 0 | return *this; |
273 | 0 | } |
274 | | |
275 | | // Apply sigmoid to all the elements in the array. |
276 | 0 | inline Array1D& sigmoid() { |
277 | 0 | for (int32_t i = 0; i < d1_; i++) { |
278 | 0 | data_[i] = 1.0f/(1.0f + expf(-data_[i])); |
279 | 0 | } |
280 | 0 | return *this; |
281 | 0 | } |
282 | | |
283 | 0 | inline Array1D& clear() { |
284 | 0 | uprv_memset(data_, 0, d1_ * sizeof(float)); |
285 | 0 | return *this; |
286 | 0 | } |
287 | | |
288 | | private: |
289 | | void* memory_; |
290 | | float* data_; |
291 | | int32_t d1_; |
292 | | }; |
293 | | |
294 | | Array1D::~Array1D() |
295 | 0 | { |
296 | 0 | uprv_free(memory_); |
297 | 0 | } |
298 | | |
299 | | class Array2D : public ReadArray2D { |
300 | | public: |
301 | 0 | Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {} |
302 | | Array2D(int32_t d1, int32_t d2, UErrorCode &status) |
303 | | : memory_(uprv_malloc(d1 * d2 * sizeof(float))), |
304 | 0 | data_((float*)memory_), d1_(d1), d2_(d2) { |
305 | 0 | if (U_SUCCESS(status)) { |
306 | 0 | if (memory_ == nullptr) { |
307 | 0 | status = U_MEMORY_ALLOCATION_ERROR; |
308 | 0 | return; |
309 | 0 | } |
310 | 0 | clear(); |
311 | 0 | } |
312 | 0 | } |
313 | | virtual ~Array2D(); |
314 | | |
315 | | // ReadArray2D methods. |
316 | 0 | virtual int32_t d1() const { return d1_; } |
317 | 0 | virtual int32_t d2() const { return d2_; } |
318 | 0 | virtual float get(int32_t i, int32_t j) const { |
319 | 0 | U_ASSERT(i < d1_); |
320 | 0 | U_ASSERT(j < d2_); |
321 | 0 | return data_[i * d2_ + j]; |
322 | 0 | } |
323 | | |
324 | 0 | inline Array1D row(int32_t i) const { |
325 | 0 | U_ASSERT(i < d1_); |
326 | 0 | return Array1D(data_ + i * d2_, d2_); |
327 | 0 | } |
328 | | |
329 | 0 | inline Array2D& clear() { |
330 | 0 | uprv_memset(data_, 0, d1_ * d2_ * sizeof(float)); |
331 | 0 | return *this; |
332 | 0 | } |
333 | | |
334 | | private: |
335 | | void* memory_; |
336 | | float* data_; |
337 | | int32_t d1_; |
338 | | int32_t d2_; |
339 | | }; |
340 | | |
341 | | Array2D::~Array2D() |
342 | 0 | { |
343 | 0 | uprv_free(memory_); |
344 | 0 | } |
345 | | |
346 | | typedef enum { |
347 | | BEGIN, |
348 | | INSIDE, |
349 | | END, |
350 | | SINGLE |
351 | | } LSTMClass; |
352 | | |
353 | | typedef enum { |
354 | | UNKNOWN, |
355 | | CODE_POINTS, |
356 | | GRAPHEME_CLUSTER, |
357 | | } EmbeddingType; |
358 | | |
359 | | struct LSTMData : public UMemory { |
360 | | LSTMData(UResourceBundle* rb, UErrorCode &status); |
361 | | ~LSTMData(); |
362 | | UHashtable* fDict; |
363 | | EmbeddingType fType; |
364 | | const UChar* fName; |
365 | | ConstArray2D fEmbedding; |
366 | | ConstArray2D fForwardW; |
367 | | ConstArray2D fForwardU; |
368 | | ConstArray1D fForwardB; |
369 | | ConstArray2D fBackwardW; |
370 | | ConstArray2D fBackwardU; |
371 | | ConstArray1D fBackwardB; |
372 | | ConstArray2D fOutputW; |
373 | | ConstArray1D fOutputB; |
374 | | |
375 | | private: |
376 | | UResourceBundle* fBundle; |
377 | | }; |
378 | | |
379 | | LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status) |
380 | | : fDict(nullptr), fType(UNKNOWN), fName(nullptr), |
381 | | fBundle(rb) |
382 | 0 | { |
383 | 0 | if (U_FAILURE(status)) { |
384 | 0 | return; |
385 | 0 | } |
386 | 0 | if (IEEE_754 != 1) { |
387 | 0 | status = U_UNSUPPORTED_ERROR; |
388 | 0 | return; |
389 | 0 | } |
390 | 0 | LocalUResourceBundlePointer embeddings_res( |
391 | 0 | ures_getByKey(rb, "embeddings", nullptr, &status)); |
392 | 0 | int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status); |
393 | 0 | LocalUResourceBundlePointer hunits_res( |
394 | 0 | ures_getByKey(rb, "hunits", nullptr, &status)); |
395 | 0 | if (U_FAILURE(status)) return; |
396 | 0 | int32_t hunits = ures_getInt(hunits_res.getAlias(), &status); |
397 | 0 | const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status); |
398 | 0 | if (U_FAILURE(status)) return; |
399 | 0 | if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) { |
400 | 0 | fType = CODE_POINTS; |
401 | 0 | } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) { |
402 | 0 | fType = GRAPHEME_CLUSTER; |
403 | 0 | } |
404 | 0 | fName = ures_getStringByKey(rb, "model", nullptr, &status); |
405 | 0 | LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status)); |
406 | 0 | if (U_FAILURE(status)) return; |
407 | 0 | int32_t data_len = 0; |
408 | 0 | const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status); |
409 | 0 | fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status); |
410 | |
|
411 | 0 | StackUResourceBundle stackTempBundle; |
412 | 0 | ResourceDataValue value; |
413 | 0 | ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status); |
414 | 0 | ResourceArray stringArray = value.getArray(status); |
415 | 0 | int32_t num_index = stringArray.getSize(); |
416 | 0 | if (U_FAILURE(status)) { return; } |
417 | | |
418 | | // put dict into hash |
419 | 0 | int32_t stringLength; |
420 | 0 | for (int32_t idx = 0; idx < num_index; idx++) { |
421 | 0 | stringArray.getValue(idx, value); |
422 | 0 | const UChar* str = value.getString(stringLength, status); |
423 | 0 | uhash_putiAllowZero(fDict, (void*)str, idx, &status); |
424 | 0 | if (U_FAILURE(status)) return; |
425 | | #ifdef LSTM_VECTORIZER_DEBUG |
426 | | printf("Assign ["); |
427 | | while (*str != 0x0000) { |
428 | | printf("U+%04x ", *str); |
429 | | str++; |
430 | | } |
431 | | printf("] map to %d\n", idx-1); |
432 | | #endif |
433 | 0 | } |
434 | 0 | int32_t mat1_size = (num_index + 1) * embedding_size; |
435 | 0 | int32_t mat2_size = embedding_size * 4 * hunits; |
436 | 0 | int32_t mat3_size = hunits * 4 * hunits; |
437 | 0 | int32_t mat4_size = 4 * hunits; |
438 | 0 | int32_t mat5_size = mat2_size; |
439 | 0 | int32_t mat6_size = mat3_size; |
440 | 0 | int32_t mat7_size = mat4_size; |
441 | 0 | int32_t mat8_size = 2 * hunits * 4; |
442 | | #if U_DEBUG |
443 | | int32_t mat9_size = 4; |
444 | | U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size + |
445 | | mat6_size + mat7_size + mat8_size + mat9_size); |
446 | | #endif |
447 | |
|
448 | 0 | fEmbedding.init(data, (num_index + 1), embedding_size); |
449 | 0 | data += mat1_size; |
450 | 0 | fForwardW.init(data, embedding_size, 4 * hunits); |
451 | 0 | data += mat2_size; |
452 | 0 | fForwardU.init(data, hunits, 4 * hunits); |
453 | 0 | data += mat3_size; |
454 | 0 | fForwardB.init(data, 4 * hunits); |
455 | 0 | data += mat4_size; |
456 | 0 | fBackwardW.init(data, embedding_size, 4 * hunits); |
457 | 0 | data += mat5_size; |
458 | 0 | fBackwardU.init(data, hunits, 4 * hunits); |
459 | 0 | data += mat6_size; |
460 | 0 | fBackwardB.init(data, 4 * hunits); |
461 | 0 | data += mat7_size; |
462 | 0 | fOutputW.init(data, 2 * hunits, 4); |
463 | 0 | data += mat8_size; |
464 | 0 | fOutputB.init(data, 4); |
465 | 0 | } |
466 | | |
467 | 0 | LSTMData::~LSTMData() { |
468 | 0 | uhash_close(fDict); |
469 | 0 | ures_close(fBundle); |
470 | 0 | } |
471 | | |
472 | | class Vectorizer : public UMemory { |
473 | | public: |
474 | 0 | Vectorizer(UHashtable* dict) : fDict(dict) {} |
475 | | virtual ~Vectorizer(); |
476 | | virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, |
477 | | UVector32 &offsets, UVector32 &indices, |
478 | | UErrorCode &status) const = 0; |
479 | | protected: |
480 | 0 | int32_t stringToIndex(const UChar* str) const { |
481 | 0 | UBool found = false; |
482 | 0 | int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found); |
483 | 0 | if (!found) { |
484 | 0 | ret = fDict->count; |
485 | 0 | } |
486 | | #ifdef LSTM_VECTORIZER_DEBUG |
487 | | printf("["); |
488 | | while (*str != 0x0000) { |
489 | | printf("U+%04x ", *str); |
490 | | str++; |
491 | | } |
492 | | printf("] map to %d\n", ret); |
493 | | #endif |
494 | 0 | return ret; |
495 | 0 | } |
496 | | |
497 | | private: |
498 | | UHashtable* fDict; |
499 | | }; |
500 | | |
501 | | Vectorizer::~Vectorizer() |
502 | 0 | { |
503 | 0 | } |
504 | | |
505 | | class CodePointsVectorizer : public Vectorizer { |
506 | | public: |
507 | 0 | CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {} |
508 | | virtual ~CodePointsVectorizer(); |
509 | | virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, |
510 | | UVector32 &offsets, UVector32 &indices, |
511 | | UErrorCode &status) const; |
512 | | }; |
513 | | |
514 | | CodePointsVectorizer::~CodePointsVectorizer() |
515 | | { |
516 | | } |
517 | | |
518 | | void CodePointsVectorizer::vectorize( |
519 | | UText *text, int32_t startPos, int32_t endPos, |
520 | | UVector32 &offsets, UVector32 &indices, UErrorCode &status) const |
521 | 0 | { |
522 | 0 | if (offsets.ensureCapacity(endPos - startPos, status) && |
523 | 0 | indices.ensureCapacity(endPos - startPos, status)) { |
524 | 0 | if (U_FAILURE(status)) return; |
525 | 0 | utext_setNativeIndex(text, startPos); |
526 | 0 | int32_t current; |
527 | 0 | UChar str[2] = {0, 0}; |
528 | 0 | while (U_SUCCESS(status) && |
529 | 0 | (current = (int32_t)utext_getNativeIndex(text)) < endPos) { |
530 | | // Since the LSTMBreakEngine is currently only accept chars in BMP, |
531 | | // we can ignore the possibility of hitting supplementary code |
532 | | // point. |
533 | 0 | str[0] = (UChar) utext_next32(text); |
534 | 0 | U_ASSERT(!U_IS_SURROGATE(str[0])); |
535 | 0 | offsets.addElement(current, status); |
536 | 0 | indices.addElement(stringToIndex(str), status); |
537 | 0 | } |
538 | 0 | } |
539 | 0 | } |
540 | | |
541 | | class GraphemeClusterVectorizer : public Vectorizer { |
542 | | public: |
543 | | GraphemeClusterVectorizer(UHashtable* dict) |
544 | | : Vectorizer(dict) |
545 | 0 | { |
546 | 0 | } |
547 | | virtual ~GraphemeClusterVectorizer(); |
548 | | virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, |
549 | | UVector32 &offsets, UVector32 &indices, |
550 | | UErrorCode &status) const; |
551 | | }; |
552 | | |
553 | | GraphemeClusterVectorizer::~GraphemeClusterVectorizer() |
554 | | { |
555 | | } |
556 | | |
557 | | constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10; |
558 | | |
559 | | void GraphemeClusterVectorizer::vectorize( |
560 | | UText *text, int32_t startPos, int32_t endPos, |
561 | | UVector32 &offsets, UVector32 &indices, UErrorCode &status) const |
562 | 0 | { |
563 | 0 | if (U_FAILURE(status)) return; |
564 | 0 | if (!offsets.ensureCapacity(endPos - startPos, status) || |
565 | 0 | !indices.ensureCapacity(endPos - startPos, status)) { |
566 | 0 | return; |
567 | 0 | } |
568 | 0 | if (U_FAILURE(status)) return; |
569 | 0 | LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status)); |
570 | 0 | if (U_FAILURE(status)) return; |
571 | 0 | graphemeIter->setText(text, status); |
572 | 0 | if (U_FAILURE(status)) return; |
573 | | |
574 | 0 | if (startPos != 0) { |
575 | 0 | graphemeIter->preceding(startPos); |
576 | 0 | } |
577 | 0 | int32_t last = startPos; |
578 | 0 | int32_t current = startPos; |
579 | 0 | UChar str[MAX_GRAPHEME_CLSTER_LENGTH]; |
580 | 0 | while ((current = graphemeIter->next()) != BreakIterator::DONE) { |
581 | 0 | if (current >= endPos) { |
582 | 0 | break; |
583 | 0 | } |
584 | 0 | if (current > startPos) { |
585 | 0 | utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status); |
586 | 0 | if (U_FAILURE(status)) return; |
587 | 0 | offsets.addElement(last, status); |
588 | 0 | indices.addElement(stringToIndex(str), status); |
589 | 0 | if (U_FAILURE(status)) return; |
590 | 0 | } |
591 | 0 | last = current; |
592 | 0 | } |
593 | 0 | if (U_FAILURE(status) || last >= endPos) { |
594 | 0 | return; |
595 | 0 | } |
596 | 0 | utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status); |
597 | 0 | if (U_SUCCESS(status)) { |
598 | 0 | offsets.addElement(last, status); |
599 | 0 | indices.addElement(stringToIndex(str), status); |
600 | 0 | } |
601 | 0 | } |
602 | | |
603 | | // Computing LSTM as stated in |
604 | | // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate |
605 | | // ifco is temp array allocate outside which does not need to be |
606 | | // input/output value but could avoid unnecessary memory alloc/free if passing |
607 | | // in. |
608 | | void compute( |
609 | | int32_t hunits, |
610 | | const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b, |
611 | | const ReadArray1D& x, Array1D& h, Array1D& c, |
612 | | Array1D& ifco) |
613 | 0 | { |
614 | | // ifco = x * W + h * U + b |
615 | 0 | ifco.assign(b) |
616 | 0 | .addDotProduct(x, W) |
617 | 0 | .addDotProduct(h, U); |
618 | |
|
619 | 0 | ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod |
620 | 0 | ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid |
621 | 0 | ifco.slice(2*hunits, hunits).tanh(); // c_: tanh |
622 | 0 | ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod |
623 | |
|
624 | 0 | c.hadamardProduct(ifco.slice(hunits, hunits)) |
625 | 0 | .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits)); |
626 | |
|
627 | 0 | h.tanh(c) |
628 | 0 | .hadamardProduct(ifco.slice(3*hunits, hunits)); |
629 | 0 | } |
630 | | |
631 | | // Minimum word size |
632 | | static const int32_t MIN_WORD = 2; |
633 | | |
634 | | // Minimum number of characters for two words |
635 | | static const int32_t MIN_WORD_SPAN = MIN_WORD * 2; |
636 | | |
637 | | int32_t |
638 | | LSTMBreakEngine::divideUpDictionaryRange( UText *text, |
639 | | int32_t startPos, |
640 | | int32_t endPos, |
641 | | UVector32 &foundBreaks, |
642 | 0 | UErrorCode& status) const { |
643 | 0 | if (U_FAILURE(status)) return 0; |
644 | 0 | int32_t beginFoundBreakSize = foundBreaks.size(); |
645 | 0 | utext_setNativeIndex(text, startPos); |
646 | 0 | utext_moveIndex32(text, MIN_WORD_SPAN); |
647 | 0 | if (utext_getNativeIndex(text) >= endPos) { |
648 | 0 | return 0; // Not enough characters for two words |
649 | 0 | } |
650 | 0 | utext_setNativeIndex(text, startPos); |
651 | |
|
652 | 0 | UVector32 offsets(status); |
653 | 0 | UVector32 indices(status); |
654 | 0 | if (U_FAILURE(status)) return 0; |
655 | 0 | fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status); |
656 | 0 | if (U_FAILURE(status)) return 0; |
657 | 0 | int32_t* offsetsBuf = offsets.getBuffer(); |
658 | 0 | int32_t* indicesBuf = indices.getBuffer(); |
659 | |
|
660 | 0 | int32_t input_seq_len = indices.size(); |
661 | 0 | int32_t hunits = fData->fForwardU.d1(); |
662 | | |
663 | | // ----- Begin of all the Array memory allocation needed for this function |
664 | | // Allocate temp array used inside compute() |
665 | 0 | Array1D ifco(4 * hunits, status); |
666 | |
|
667 | 0 | Array1D c(hunits, status); |
668 | 0 | Array1D logp(4, status); |
669 | | |
670 | | // TODO: limit size of hBackward. If input_seq_len is too big, we could |
671 | | // run out of memory. |
672 | | // Backward LSTM |
673 | 0 | Array2D hBackward(input_seq_len, hunits, status); |
674 | | |
675 | | // Allocate fbRow and slice the internal array in two. |
676 | 0 | Array1D fbRow(2 * hunits, status); |
677 | | |
678 | | // ----- End of all the Array memory allocation needed for this function |
679 | 0 | if (U_FAILURE(status)) return 0; |
680 | | |
681 | | // To save the needed memory usage, the following is different from the |
682 | | // Python or ICU4X implementation. We first perform the Backward LSTM |
683 | | // and then merge the iteration of the forward LSTM and the output layer |
684 | | // together because we only neetdto remember the h[t-1] for Forward LSTM. |
685 | 0 | for (int32_t i = input_seq_len - 1; i >= 0; i--) { |
686 | 0 | Array1D hRow = hBackward.row(i); |
687 | 0 | if (i != input_seq_len - 1) { |
688 | 0 | hRow.assign(hBackward.row(i+1)); |
689 | 0 | } |
690 | | #ifdef LSTM_DEBUG |
691 | | printf("hRow %d\n", i); |
692 | | hRow.print(); |
693 | | printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]); |
694 | | printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i); |
695 | | fData->fEmbedding.row(indicesBuf[i]).print(); |
696 | | #endif // LSTM_DEBUG |
697 | 0 | compute(hunits, |
698 | 0 | fData->fBackwardW, fData->fBackwardU, fData->fBackwardB, |
699 | 0 | fData->fEmbedding.row(indicesBuf[i]), |
700 | 0 | hRow, c, ifco); |
701 | 0 | } |
702 | | |
703 | |
|
704 | 0 | Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow. |
705 | 0 | Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow. |
706 | | |
707 | | // The following iteration merge the forward LSTM and the output layer |
708 | | // together. |
709 | 0 | c.clear(); // reuse c since it is the same size. |
710 | 0 | for (int32_t i = 0; i < input_seq_len; i++) { |
711 | | #ifdef LSTM_DEBUG |
712 | | printf("forwardRow %d\n", i); |
713 | | forwardRow.print(); |
714 | | #endif // LSTM_DEBUG |
715 | | // Forward LSTM |
716 | | // Calculate the result into forwardRow, which point to the data in the first half |
717 | | // of fbRow. |
718 | 0 | compute(hunits, |
719 | 0 | fData->fForwardW, fData->fForwardU, fData->fForwardB, |
720 | 0 | fData->fEmbedding.row(indicesBuf[i]), |
721 | 0 | forwardRow, c, ifco); |
722 | | |
723 | | // assign the data from hBackward.row(i) to second half of fbRowa. |
724 | 0 | backwardRow.assign(hBackward.row(i)); |
725 | |
|
726 | 0 | logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW); |
727 | | #ifdef LSTM_DEBUG |
728 | | printf("backwardRow %d\n", i); |
729 | | backwardRow.print(); |
730 | | printf("logp %d\n", i); |
731 | | logp.print(); |
732 | | #endif // LSTM_DEBUG |
733 | | |
734 | | // current = argmax(logp) |
735 | 0 | LSTMClass current = (LSTMClass)logp.maxIndex(); |
736 | | // BIES logic. |
737 | 0 | if (current == BEGIN || current == SINGLE) { |
738 | 0 | if (i != 0) { |
739 | 0 | foundBreaks.addElement(offsetsBuf[i], status); |
740 | 0 | if (U_FAILURE(status)) return 0; |
741 | 0 | } |
742 | 0 | } |
743 | 0 | } |
744 | 0 | return foundBreaks.size() - beginFoundBreakSize; |
745 | 0 | } |
746 | | |
747 | 0 | Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) { |
748 | 0 | if (U_FAILURE(status)) { |
749 | 0 | return nullptr; |
750 | 0 | } |
751 | 0 | switch (data->fType) { |
752 | 0 | case CODE_POINTS: |
753 | 0 | return new CodePointsVectorizer(data->fDict); |
754 | 0 | break; |
755 | 0 | case GRAPHEME_CLUSTER: |
756 | 0 | return new GraphemeClusterVectorizer(data->fDict); |
757 | 0 | break; |
758 | 0 | default: |
759 | 0 | break; |
760 | 0 | } |
761 | 0 | UPRV_UNREACHABLE; |
762 | 0 | } |
763 | | |
764 | | LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status) |
765 | | : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status)) |
766 | 0 | { |
767 | 0 | if (U_FAILURE(status)) { |
768 | 0 | fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so. |
769 | 0 | return; |
770 | 0 | } |
771 | 0 | setCharacters(set); |
772 | 0 | } |
773 | | |
774 | 0 | LSTMBreakEngine::~LSTMBreakEngine() { |
775 | 0 | delete fData; |
776 | 0 | delete fVectorizer; |
777 | 0 | } |
778 | | |
779 | 0 | const UChar* LSTMBreakEngine::name() const { |
780 | 0 | return fData->fName; |
781 | 0 | } |
782 | | |
783 | 0 | UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) { |
784 | | // open root from brkitr tree. |
785 | 0 | UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status); |
786 | 0 | b = ures_getByKeyWithFallback(b, "lstm", b, &status); |
787 | 0 | UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status); |
788 | 0 | ures_close(b); |
789 | 0 | return result; |
790 | 0 | } |
791 | | |
792 | | U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status) |
793 | 0 | { |
794 | 0 | if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) { |
795 | 0 | return nullptr; |
796 | 0 | } |
797 | 0 | UnicodeString name = defaultLSTM(script, status); |
798 | 0 | if (U_FAILURE(status)) return nullptr; |
799 | 0 | CharString namebuf; |
800 | 0 | namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.')); |
801 | |
|
802 | 0 | LocalUResourceBundlePointer rb( |
803 | 0 | ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status)); |
804 | 0 | if (U_FAILURE(status)) return nullptr; |
805 | | |
806 | 0 | return CreateLSTMData(rb.orphan(), status); |
807 | 0 | } |
808 | | |
809 | | U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status) |
810 | 0 | { |
811 | 0 | return new LSTMData(rb, status); |
812 | 0 | } |
813 | | |
814 | | U_CAPI const LanguageBreakEngine* U_EXPORT2 |
815 | | CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status) |
816 | 0 | { |
817 | 0 | UnicodeString unicodeSetString; |
818 | 0 | switch(script) { |
819 | 0 | case USCRIPT_THAI: |
820 | 0 | unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]"); |
821 | 0 | break; |
822 | 0 | case USCRIPT_MYANMAR: |
823 | 0 | unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]"); |
824 | 0 | break; |
825 | 0 | default: |
826 | 0 | delete data; |
827 | 0 | return nullptr; |
828 | 0 | } |
829 | 0 | UnicodeSet unicodeSet; |
830 | 0 | unicodeSet.applyPattern(unicodeSetString, status); |
831 | 0 | const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status); |
832 | 0 | if (U_FAILURE(status) || engine == nullptr) { |
833 | 0 | if (engine != nullptr) { |
834 | 0 | delete engine; |
835 | 0 | } else { |
836 | 0 | status = U_MEMORY_ALLOCATION_ERROR; |
837 | 0 | } |
838 | 0 | return nullptr; |
839 | 0 | } |
840 | 0 | return engine; |
841 | 0 | } |
842 | | |
843 | | U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data) |
844 | 0 | { |
845 | 0 | delete data; |
846 | 0 | } |
847 | | |
848 | | U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data) |
849 | 0 | { |
850 | 0 | return data->fName; |
851 | 0 | } |
852 | | |
853 | | U_NAMESPACE_END |
854 | | |
855 | | #endif /* #if !UCONFIG_NO_BREAK_ITERATION */ |