/src/icu/source/common/lstmbe.cpp
Line  | Count  | Source (jump to first uncovered line)  | 
1  |  | // © 2021 and later: Unicode, Inc. and others.  | 
2  |  | // License & terms of use: http://www.unicode.org/copyright.html  | 
3  |  |  | 
4  |  | #include <utility>  | 
5  |  | #include <ctgmath>  | 
6  |  |  | 
7  |  | #include "unicode/utypes.h"  | 
8  |  |  | 
9  |  | #if !UCONFIG_NO_BREAK_ITERATION  | 
10  |  |  | 
11  |  | #include "brkeng.h"  | 
12  |  | #include "charstr.h"  | 
13  |  | #include "cmemory.h"  | 
14  |  | #include "lstmbe.h"  | 
15  |  | #include "putilimp.h"  | 
16  |  | #include "uassert.h"  | 
17  |  | #include "ubrkimpl.h"  | 
18  |  | #include "uresimp.h"  | 
19  |  | #include "uvectr32.h"  | 
20  |  | #include "uvector.h"  | 
21  |  |  | 
22  |  | #include "unicode/brkiter.h"  | 
23  |  | #include "unicode/resbund.h"  | 
24  |  | #include "unicode/ubrk.h"  | 
25  |  | #include "unicode/uniset.h"  | 
26  |  | #include "unicode/ustring.h"  | 
27  |  | #include "unicode/utf.h"  | 
28  |  |  | 
29  |  | U_NAMESPACE_BEGIN  | 
30  |  |  | 
31  |  | // Uncomment the following #define to debug.  | 
32  |  | // #define LSTM_DEBUG 1  | 
33  |  | // #define LSTM_VECTORIZER_DEBUG 1  | 
34  |  |  | 
35  |  | /**  | 
36  |  |  * Interface for reading 1D array.  | 
37  |  |  */  | 
38  |  | class ReadArray1D { | 
39  |  | public:  | 
40  |  |     virtual ~ReadArray1D();  | 
41  |  |     virtual int32_t d1() const = 0;  | 
42  |  |     virtual float get(int32_t i) const = 0;  | 
43  |  |  | 
44  |  | #ifdef LSTM_DEBUG  | 
45  |  |     void print() const { | 
46  |  |         printf("\n["); | 
47  |  |         for (int32_t i = 0; i < d1(); i++) { | 
48  |  |            printf("%0.8e ", get(i)); | 
49  |  |            if (i % 4 == 3) printf("\n"); | 
50  |  |         }  | 
51  |  |         printf("]\n"); | 
52  |  |     }  | 
53  |  | #endif  | 
54  |  | };  | 
55  |  |  | 
56  |  | ReadArray1D::~ReadArray1D()  | 
57  | 0  | { | 
58  | 0  | }  | 
59  |  |  | 
60  |  | /**  | 
61  |  |  * Interface for reading 2D array.  | 
62  |  |  */  | 
63  |  | class ReadArray2D { | 
64  |  | public:  | 
65  |  |     virtual ~ReadArray2D();  | 
66  |  |     virtual int32_t d1() const = 0;  | 
67  |  |     virtual int32_t d2() const = 0;  | 
68  |  |     virtual float get(int32_t i, int32_t j) const = 0;  | 
69  |  | };  | 
70  |  |  | 
71  |  | ReadArray2D::~ReadArray2D()  | 
72  | 0  | { | 
73  | 0  | }  | 
74  |  |  | 
75  |  | /**  | 
76  |  |  * A class to index a float array as a 1D Array without owning the pointer or  | 
77  |  |  * copy the data.  | 
78  |  |  */  | 
79  |  | class ConstArray1D : public ReadArray1D { | 
80  |  | public:  | 
81  | 0  |     ConstArray1D() : data_(nullptr), d1_(0) {} | 
82  |  |  | 
83  | 0  |     ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {} | 
84  |  |  | 
85  |  |     virtual ~ConstArray1D();  | 
86  |  |  | 
87  |  |     // Init the object, the object does not own the data nor copy.  | 
88  |  |     // It is designed to directly use data from memory mapped resources.  | 
89  | 0  |     void init(const int32_t* data, int32_t d1) { | 
90  | 0  |         U_ASSERT(IEEE_754 == 1);  | 
91  | 0  |         data_ = reinterpret_cast<const float*>(data);  | 
92  | 0  |         d1_ = d1;  | 
93  | 0  |     }  | 
94  |  |  | 
95  |  |     // ReadArray1D methods.  | 
96  | 0  |     virtual int32_t d1() const { return d1_; } | 
97  | 0  |     virtual float get(int32_t i) const { | 
98  | 0  |         U_ASSERT(i < d1_);  | 
99  | 0  |         return data_[i];  | 
100  | 0  |     }  | 
101  |  |  | 
102  |  | private:  | 
103  |  |     const float* data_;  | 
104  |  |     int32_t d1_;  | 
105  |  | };  | 
106  |  |  | 
107  |  | ConstArray1D::~ConstArray1D()  | 
108  |  | { | 
109  |  | }  | 
110  |  |  | 
111  |  | /**  | 
112  |  |  * A class to index a float array as a 2D Array without owning the pointer or  | 
113  |  |  * copy the data.  | 
114  |  |  */  | 
115  |  | class ConstArray2D : public ReadArray2D { | 
116  |  | public:  | 
117  | 0  |     ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {} | 
118  |  |  | 
119  |  |     ConstArray2D(const float* data, int32_t d1, int32_t d2)  | 
120  | 0  |         : data_(data), d1_(d1), d2_(d2) {} | 
121  |  |  | 
122  |  |     virtual ~ConstArray2D();  | 
123  |  |  | 
124  |  |     // Init the object, the object does not own the data nor copy.  | 
125  |  |     // It is designed to directly use data from memory mapped resources.  | 
126  | 0  |     void init(const int32_t* data, int32_t d1, int32_t d2) { | 
127  | 0  |         U_ASSERT(IEEE_754 == 1);  | 
128  | 0  |         data_ = reinterpret_cast<const float*>(data);  | 
129  | 0  |         d1_ = d1;  | 
130  | 0  |         d2_ = d2;  | 
131  | 0  |     }  | 
132  |  |  | 
133  |  |     // ReadArray2D methods.  | 
134  | 0  |     inline int32_t d1() const { return d1_; } | 
135  | 0  |     inline int32_t d2() const { return d2_; } | 
136  | 0  |     float get(int32_t i, int32_t j) const { | 
137  | 0  |         U_ASSERT(i < d1_);  | 
138  | 0  |         U_ASSERT(j < d2_);  | 
139  | 0  |         return data_[i * d2_ + j];  | 
140  | 0  |     }  | 
141  |  |  | 
142  |  |     // Expose the ith row as a ConstArray1D  | 
143  | 0  |     inline ConstArray1D row(int32_t i) const { | 
144  | 0  |         U_ASSERT(i < d1_);  | 
145  | 0  |         return ConstArray1D(data_ + i * d2_, d2_);  | 
146  | 0  |     }  | 
147  |  |  | 
148  |  | private:  | 
149  |  |     const float* data_;  | 
150  |  |     int32_t d1_;  | 
151  |  |     int32_t d2_;  | 
152  |  | };  | 
153  |  |  | 
154  |  | ConstArray2D::~ConstArray2D()  | 
155  |  | { | 
156  |  | }  | 
157  |  |  | 
158  |  | /**  | 
159  |  |  * A class to allocate data as a writable 1D array.  | 
160  |  |  * This is the main class implement matrix operation.  | 
161  |  |  */  | 
162  |  | class Array1D : public ReadArray1D { | 
163  |  | public:  | 
164  | 0  |     Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {} | 
165  |  |     Array1D(int32_t d1, UErrorCode &status)  | 
166  | 0  |         : memory_(uprv_malloc(d1 * sizeof(float))),  | 
167  | 0  |           data_((float*)memory_), d1_(d1) { | 
168  | 0  |         if (U_SUCCESS(status)) { | 
169  | 0  |             if (memory_ == nullptr) { | 
170  | 0  |                 status = U_MEMORY_ALLOCATION_ERROR;  | 
171  | 0  |                 return;  | 
172  | 0  |             }  | 
173  | 0  |             clear();  | 
174  | 0  |         }  | 
175  | 0  |     }  | 
176  |  |  | 
177  |  |     virtual ~Array1D();  | 
178  |  |  | 
179  |  |     // A special constructor which does not own the memory but writeable  | 
180  |  |     // as a slice of an array.  | 
181  |  |     Array1D(float* data, int32_t d1)  | 
182  | 0  |         : memory_(nullptr), data_(data), d1_(d1) {} | 
183  |  |  | 
184  |  |     // ReadArray1D methods.  | 
185  | 0  |     virtual int32_t d1() const { return d1_; } | 
186  | 0  |     virtual float get(int32_t i) const { | 
187  | 0  |         U_ASSERT(i < d1_);  | 
188  | 0  |         return data_[i];  | 
189  | 0  |     }  | 
190  |  |  | 
191  |  |     // Return the index which point to the max data in the array.  | 
192  | 0  |     inline int32_t maxIndex() const { | 
193  | 0  |         int32_t index = 0;  | 
194  | 0  |         float max = data_[0];  | 
195  | 0  |         for (int32_t i = 1; i < d1_; i++) { | 
196  | 0  |             if (data_[i] > max) { | 
197  | 0  |                 max = data_[i];  | 
198  | 0  |                 index = i;  | 
199  | 0  |             }  | 
200  | 0  |         }  | 
201  | 0  |         return index;  | 
202  | 0  |     }  | 
203  |  |  | 
204  |  |     // Slice part of the array to a new one.  | 
205  | 0  |     inline Array1D slice(int32_t from, int32_t size) const { | 
206  | 0  |         U_ASSERT(from >= 0);  | 
207  | 0  |         U_ASSERT(from < d1_);  | 
208  | 0  |         U_ASSERT(from + size <= d1_);  | 
209  | 0  |         return Array1D(data_ + from, size);  | 
210  | 0  |     }  | 
211  |  |  | 
212  |  |     // Add dot product of a 1D array and a 2D array into this one.  | 
213  | 0  |     inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) { | 
214  | 0  |         U_ASSERT(a.d1() == b.d1());  | 
215  | 0  |         U_ASSERT(b.d2() == d1());  | 
216  | 0  |         for (int32_t i = 0; i < d1(); i++) { | 
217  | 0  |             for (int32_t j = 0; j < a.d1(); j++) { | 
218  | 0  |                 data_[i] += a.get(j) * b.get(j, i);  | 
219  | 0  |             }  | 
220  | 0  |         }  | 
221  | 0  |         return *this;  | 
222  | 0  |     }  | 
223  |  |  | 
224  |  |     // Hadamard Product the values of another array of the same size into this one.  | 
225  | 0  |     inline Array1D& hadamardProduct(const ReadArray1D& a) { | 
226  | 0  |         U_ASSERT(a.d1() == d1());  | 
227  | 0  |         for (int32_t i = 0; i < d1(); i++) { | 
228  | 0  |             data_[i] *= a.get(i);  | 
229  | 0  |         }  | 
230  | 0  |         return *this;  | 
231  | 0  |     }  | 
232  |  |  | 
233  |  |     // Add the Hadamard Product of two arrays of the same size into this one.  | 
234  | 0  |     inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) { | 
235  | 0  |         U_ASSERT(a.d1() == d1());  | 
236  | 0  |         U_ASSERT(b.d1() == d1());  | 
237  | 0  |         for (int32_t i = 0; i < d1(); i++) { | 
238  | 0  |             data_[i] += a.get(i) * b.get(i);  | 
239  | 0  |         }  | 
240  | 0  |         return *this;  | 
241  | 0  |     }  | 
242  |  |  | 
243  |  |     // Add the values of another array of the same size into this one.  | 
244  | 0  |     inline Array1D& add(const ReadArray1D& a) { | 
245  | 0  |         U_ASSERT(a.d1() == d1());  | 
246  | 0  |         for (int32_t i = 0; i < d1(); i++) { | 
247  | 0  |             data_[i] += a.get(i);  | 
248  | 0  |         }  | 
249  | 0  |         return *this;  | 
250  | 0  |     }  | 
251  |  |  | 
252  |  |     // Assign the values of another array of the same size into this one.  | 
253  | 0  |     inline Array1D& assign(const ReadArray1D& a) { | 
254  | 0  |         U_ASSERT(a.d1() == d1());  | 
255  | 0  |         for (int32_t i = 0; i < d1(); i++) { | 
256  | 0  |             data_[i] = a.get(i);  | 
257  | 0  |         }  | 
258  | 0  |         return *this;  | 
259  | 0  |     }  | 
260  |  |  | 
261  |  |     // Apply tanh to all the elements in the array.  | 
262  | 0  |     inline Array1D& tanh() { | 
263  | 0  |         return tanh(*this);  | 
264  | 0  |     }  | 
265  |  |  | 
266  |  |     // Apply tanh of a and store into this array.  | 
267  | 0  |     inline Array1D& tanh(const Array1D& a) { | 
268  | 0  |         U_ASSERT(a.d1() == d1());  | 
269  | 0  |         for (int32_t i = 0; i < d1_; i++) { | 
270  | 0  |             data_[i] = std::tanh(a.get(i));  | 
271  | 0  |         }  | 
272  | 0  |         return *this;  | 
273  | 0  |     }  | 
274  |  |  | 
275  |  |     // Apply sigmoid to all the elements in the array.  | 
276  | 0  |     inline Array1D& sigmoid() { | 
277  | 0  |         for (int32_t i = 0; i < d1_; i++) { | 
278  | 0  |             data_[i] = 1.0f/(1.0f + expf(-data_[i]));  | 
279  | 0  |         }  | 
280  | 0  |         return *this;  | 
281  | 0  |     }  | 
282  |  |  | 
283  | 0  |     inline Array1D& clear() { | 
284  | 0  |         uprv_memset(data_, 0, d1_ * sizeof(float));  | 
285  | 0  |         return *this;  | 
286  | 0  |     }  | 
287  |  |  | 
288  |  | private:  | 
289  |  |     void* memory_;  | 
290  |  |     float* data_;  | 
291  |  |     int32_t d1_;  | 
292  |  | };  | 
293  |  |  | 
294  |  | Array1D::~Array1D()  | 
295  | 0  | { | 
296  | 0  |     uprv_free(memory_);  | 
297  | 0  | }  | 
298  |  |  | 
299  |  | class Array2D : public ReadArray2D { | 
300  |  | public:  | 
301  | 0  |     Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {} | 
302  |  |     Array2D(int32_t d1, int32_t d2, UErrorCode &status)  | 
303  | 0  |         : memory_(uprv_malloc(d1 * d2 * sizeof(float))),  | 
304  | 0  |           data_((float*)memory_), d1_(d1), d2_(d2) { | 
305  | 0  |         if (U_SUCCESS(status)) { | 
306  | 0  |             if (memory_ == nullptr) { | 
307  | 0  |                 status = U_MEMORY_ALLOCATION_ERROR;  | 
308  | 0  |                 return;  | 
309  | 0  |             }  | 
310  | 0  |             clear();  | 
311  | 0  |         }  | 
312  | 0  |     }  | 
313  |  |     virtual ~Array2D();  | 
314  |  |  | 
315  |  |     // ReadArray2D methods.  | 
316  | 0  |     virtual int32_t d1() const { return d1_; } | 
317  | 0  |     virtual int32_t d2() const { return d2_; } | 
318  | 0  |     virtual float get(int32_t i, int32_t j) const { | 
319  | 0  |         U_ASSERT(i < d1_);  | 
320  | 0  |         U_ASSERT(j < d2_);  | 
321  | 0  |         return data_[i * d2_ + j];  | 
322  | 0  |     }  | 
323  |  |  | 
324  | 0  |     inline Array1D row(int32_t i) const { | 
325  | 0  |         U_ASSERT(i < d1_);  | 
326  | 0  |         return Array1D(data_ + i * d2_, d2_);  | 
327  | 0  |     }  | 
328  |  |  | 
329  | 0  |     inline Array2D& clear() { | 
330  | 0  |         uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));  | 
331  | 0  |         return *this;  | 
332  | 0  |     }  | 
333  |  |  | 
334  |  | private:  | 
335  |  |     void* memory_;  | 
336  |  |     float* data_;  | 
337  |  |     int32_t d1_;  | 
338  |  |     int32_t d2_;  | 
339  |  | };  | 
340  |  |  | 
341  |  | Array2D::~Array2D()  | 
342  | 0  | { | 
343  | 0  |     uprv_free(memory_);  | 
344  | 0  | }  | 
345  |  |  | 
346  |  | typedef enum { | 
347  |  |     BEGIN,  | 
348  |  |     INSIDE,  | 
349  |  |     END,  | 
350  |  |     SINGLE  | 
351  |  | } LSTMClass;  | 
352  |  |  | 
353  |  | typedef enum { | 
354  |  |     UNKNOWN,  | 
355  |  |     CODE_POINTS,  | 
356  |  |     GRAPHEME_CLUSTER,  | 
357  |  | } EmbeddingType;  | 
358  |  |  | 
359  |  | struct LSTMData : public UMemory { | 
360  |  |     LSTMData(UResourceBundle* rb, UErrorCode &status);  | 
361  |  |     ~LSTMData();  | 
362  |  |     UHashtable* fDict;  | 
363  |  |     EmbeddingType fType;  | 
364  |  |     const UChar* fName;  | 
365  |  |     ConstArray2D fEmbedding;  | 
366  |  |     ConstArray2D fForwardW;  | 
367  |  |     ConstArray2D fForwardU;  | 
368  |  |     ConstArray1D fForwardB;  | 
369  |  |     ConstArray2D fBackwardW;  | 
370  |  |     ConstArray2D fBackwardU;  | 
371  |  |     ConstArray1D fBackwardB;  | 
372  |  |     ConstArray2D fOutputW;  | 
373  |  |     ConstArray1D fOutputB;  | 
374  |  |  | 
375  |  | private:  | 
376  |  |     UResourceBundle* fBundle;  | 
377  |  | };  | 
378  |  |  | 
379  |  | LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)  | 
380  | 0  |     : fDict(nullptr), fType(UNKNOWN), fName(nullptr),  | 
381  | 0  |       fBundle(rb)  | 
382  | 0  | { | 
383  | 0  |     if (U_FAILURE(status)) { | 
384  | 0  |         return;  | 
385  | 0  |     }  | 
386  | 0  |     if (IEEE_754 != 1) { | 
387  | 0  |         status = U_UNSUPPORTED_ERROR;  | 
388  | 0  |         return;  | 
389  | 0  |     }  | 
390  | 0  |     LocalUResourceBundlePointer embeddings_res(  | 
391  | 0  |         ures_getByKey(rb, "embeddings", nullptr, &status));  | 
392  | 0  |     int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);  | 
393  | 0  |     LocalUResourceBundlePointer hunits_res(  | 
394  | 0  |         ures_getByKey(rb, "hunits", nullptr, &status));  | 
395  | 0  |     if (U_FAILURE(status)) return;  | 
396  | 0  |     int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);  | 
397  | 0  |     const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);  | 
398  | 0  |     if (U_FAILURE(status)) return;  | 
399  | 0  |     if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) { | 
400  | 0  |         fType = CODE_POINTS;  | 
401  | 0  |     } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) { | 
402  | 0  |         fType = GRAPHEME_CLUSTER;  | 
403  | 0  |     }  | 
404  | 0  |     fName = ures_getStringByKey(rb, "model", nullptr, &status);  | 
405  | 0  |     LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));  | 
406  | 0  |     if (U_FAILURE(status)) return;  | 
407  | 0  |     int32_t data_len = 0;  | 
408  | 0  |     const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);  | 
409  | 0  |     fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);  | 
410  |  | 
  | 
411  | 0  |     StackUResourceBundle stackTempBundle;  | 
412  | 0  |     ResourceDataValue value;  | 
413  | 0  |     ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);  | 
414  | 0  |     ResourceArray stringArray = value.getArray(status);  | 
415  | 0  |     int32_t num_index = stringArray.getSize();  | 
416  | 0  |     if (U_FAILURE(status)) { return; } | 
417  |  |  | 
418  |  |     // put dict into hash  | 
419  | 0  |     int32_t stringLength;  | 
420  | 0  |     for (int32_t idx = 0; idx < num_index; idx++) { | 
421  | 0  |         stringArray.getValue(idx, value);  | 
422  | 0  |         const UChar* str = value.getString(stringLength, status);  | 
423  | 0  |         uhash_putiAllowZero(fDict, (void*)str, idx, &status);  | 
424  | 0  |         if (U_FAILURE(status)) return;  | 
425  |  | #ifdef LSTM_VECTORIZER_DEBUG  | 
426  |  |         printf("Assign ["); | 
427  |  |         while (*str != 0x0000) { | 
428  |  |             printf("U+%04x ", *str); | 
429  |  |             str++;  | 
430  |  |         }  | 
431  |  |         printf("] map to %d\n", idx-1); | 
432  |  | #endif  | 
433  | 0  |     }  | 
434  | 0  |     int32_t mat1_size = (num_index + 1) * embedding_size;  | 
435  | 0  |     int32_t mat2_size = embedding_size * 4 * hunits;  | 
436  | 0  |     int32_t mat3_size = hunits * 4 * hunits;  | 
437  | 0  |     int32_t mat4_size = 4 * hunits;  | 
438  | 0  |     int32_t mat5_size = mat2_size;  | 
439  | 0  |     int32_t mat6_size = mat3_size;  | 
440  | 0  |     int32_t mat7_size = mat4_size;  | 
441  | 0  |     int32_t mat8_size = 2 * hunits * 4;  | 
442  |  | #if U_DEBUG  | 
443  |  |     int32_t mat9_size = 4;  | 
444  |  |     U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +  | 
445  |  |         mat6_size + mat7_size + mat8_size + mat9_size);  | 
446  |  | #endif  | 
447  |  | 
  | 
448  | 0  |     fEmbedding.init(data, (num_index + 1), embedding_size);  | 
449  | 0  |     data += mat1_size;  | 
450  | 0  |     fForwardW.init(data, embedding_size, 4 * hunits);  | 
451  | 0  |     data += mat2_size;  | 
452  | 0  |     fForwardU.init(data, hunits, 4 * hunits);  | 
453  | 0  |     data += mat3_size;  | 
454  | 0  |     fForwardB.init(data, 4 * hunits);  | 
455  | 0  |     data += mat4_size;  | 
456  | 0  |     fBackwardW.init(data, embedding_size, 4 * hunits);  | 
457  | 0  |     data += mat5_size;  | 
458  | 0  |     fBackwardU.init(data, hunits, 4 * hunits);  | 
459  | 0  |     data += mat6_size;  | 
460  | 0  |     fBackwardB.init(data, 4 * hunits);  | 
461  | 0  |     data += mat7_size;  | 
462  | 0  |     fOutputW.init(data, 2 * hunits, 4);  | 
463  | 0  |     data += mat8_size;  | 
464  | 0  |     fOutputB.init(data, 4);  | 
465  | 0  | }  | 
466  |  |  | 
467  | 0  | LSTMData::~LSTMData() { | 
468  | 0  |     uhash_close(fDict);  | 
469  | 0  |     ures_close(fBundle);  | 
470  | 0  | }  | 
471  |  |  | 
472  |  | class Vectorizer : public UMemory { | 
473  |  | public:  | 
474  | 0  |     Vectorizer(UHashtable* dict) : fDict(dict) {} | 
475  |  |     virtual ~Vectorizer();  | 
476  |  |     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,  | 
477  |  |                            UVector32 &offsets, UVector32 &indices,  | 
478  |  |                            UErrorCode &status) const = 0;  | 
479  |  | protected:  | 
480  | 0  |     int32_t stringToIndex(const UChar* str) const { | 
481  | 0  |         UBool found = false;  | 
482  | 0  |         int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);  | 
483  | 0  |         if (!found) { | 
484  | 0  |             ret = fDict->count;  | 
485  | 0  |         }  | 
486  |  | #ifdef LSTM_VECTORIZER_DEBUG  | 
487  |  |         printf("["); | 
488  |  |         while (*str != 0x0000) { | 
489  |  |             printf("U+%04x ", *str); | 
490  |  |             str++;  | 
491  |  |         }  | 
492  |  |         printf("] map to %d\n", ret); | 
493  |  | #endif  | 
494  | 0  |         return ret;  | 
495  | 0  |     }  | 
496  |  |  | 
497  |  | private:  | 
498  |  |     UHashtable* fDict;  | 
499  |  | };  | 
500  |  |  | 
501  |  | Vectorizer::~Vectorizer()  | 
502  | 0  | { | 
503  | 0  | }  | 
504  |  |  | 
505  |  | class CodePointsVectorizer : public Vectorizer { | 
506  |  | public:  | 
507  | 0  |     CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {} | 
508  |  |     virtual ~CodePointsVectorizer();  | 
509  |  |     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,  | 
510  |  |                            UVector32 &offsets, UVector32 &indices,  | 
511  |  |                            UErrorCode &status) const;  | 
512  |  | };  | 
513  |  |  | 
514  |  | CodePointsVectorizer::~CodePointsVectorizer()  | 
515  |  | { | 
516  |  | }  | 
517  |  |  | 
518  |  | void CodePointsVectorizer::vectorize(  | 
519  |  |     UText *text, int32_t startPos, int32_t endPos,  | 
520  |  |     UVector32 &offsets, UVector32 &indices, UErrorCode &status) const  | 
521  | 0  | { | 
522  | 0  |     if (offsets.ensureCapacity(endPos - startPos, status) &&  | 
523  | 0  |             indices.ensureCapacity(endPos - startPos, status)) { | 
524  | 0  |         if (U_FAILURE(status)) return;  | 
525  | 0  |         utext_setNativeIndex(text, startPos);  | 
526  | 0  |         int32_t current;  | 
527  | 0  |         UChar str[2] = {0, 0}; | 
528  | 0  |         while (U_SUCCESS(status) &&  | 
529  | 0  |                (current = (int32_t)utext_getNativeIndex(text)) < endPos) { | 
530  |  |             // Since the LSTMBreakEngine is currently only accept chars in BMP,  | 
531  |  |             // we can ignore the possibility of hitting supplementary code  | 
532  |  |             // point.  | 
533  | 0  |             str[0] = (UChar) utext_next32(text);  | 
534  | 0  |             U_ASSERT(!U_IS_SURROGATE(str[0]));  | 
535  | 0  |             offsets.addElement(current, status);  | 
536  | 0  |             indices.addElement(stringToIndex(str), status);  | 
537  | 0  |         }  | 
538  | 0  |     }  | 
539  | 0  | }  | 
540  |  |  | 
541  |  | class GraphemeClusterVectorizer : public Vectorizer { | 
542  |  | public:  | 
543  |  |     GraphemeClusterVectorizer(UHashtable* dict)  | 
544  | 0  |         : Vectorizer(dict)  | 
545  | 0  |     { | 
546  | 0  |     }  | 
547  |  |     virtual ~GraphemeClusterVectorizer();  | 
548  |  |     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,  | 
549  |  |                            UVector32 &offsets, UVector32 &indices,  | 
550  |  |                            UErrorCode &status) const;  | 
551  |  | };  | 
552  |  |  | 
553  |  | GraphemeClusterVectorizer::~GraphemeClusterVectorizer()  | 
554  |  | { | 
555  |  | }  | 
556  |  |  | 
557  |  | constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;  | 
558  |  |  | 
559  |  | void GraphemeClusterVectorizer::vectorize(  | 
560  |  |     UText *text, int32_t startPos, int32_t endPos,  | 
561  |  |     UVector32 &offsets, UVector32 &indices, UErrorCode &status) const  | 
562  | 0  | { | 
563  | 0  |     if (U_FAILURE(status)) return;  | 
564  | 0  |     if (!offsets.ensureCapacity(endPos - startPos, status) ||  | 
565  | 0  |             !indices.ensureCapacity(endPos - startPos, status)) { | 
566  | 0  |         return;  | 
567  | 0  |     }  | 
568  | 0  |     if (U_FAILURE(status)) return;  | 
569  | 0  |     LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));  | 
570  | 0  |     if (U_FAILURE(status)) return;  | 
571  | 0  |     graphemeIter->setText(text, status);  | 
572  | 0  |     if (U_FAILURE(status)) return;  | 
573  |  |  | 
574  | 0  |     if (startPos != 0) { | 
575  | 0  |         graphemeIter->preceding(startPos);  | 
576  | 0  |     }  | 
577  | 0  |     int32_t last = startPos;  | 
578  | 0  |     int32_t current = startPos;  | 
579  | 0  |     UChar str[MAX_GRAPHEME_CLSTER_LENGTH];  | 
580  | 0  |     while ((current = graphemeIter->next()) != BreakIterator::DONE) { | 
581  | 0  |         if (current >= endPos) { | 
582  | 0  |             break;  | 
583  | 0  |         }  | 
584  | 0  |         if (current > startPos) { | 
585  | 0  |             utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);  | 
586  | 0  |             if (U_FAILURE(status)) return;  | 
587  | 0  |             offsets.addElement(last, status);  | 
588  | 0  |             indices.addElement(stringToIndex(str), status);  | 
589  | 0  |             if (U_FAILURE(status)) return;  | 
590  | 0  |         }  | 
591  | 0  |         last = current;  | 
592  | 0  |     }  | 
593  | 0  |     if (U_FAILURE(status) || last >= endPos) { | 
594  | 0  |         return;  | 
595  | 0  |     }  | 
596  | 0  |     utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);  | 
597  | 0  |     if (U_SUCCESS(status)) { | 
598  | 0  |         offsets.addElement(last, status);  | 
599  | 0  |         indices.addElement(stringToIndex(str), status);  | 
600  | 0  |     }  | 
601  | 0  | }  | 
602  |  |  | 
603  |  | // Computing LSTM as stated in  | 
604  |  | // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate  | 
605  |  | // ifco is temp array allocate outside which does not need to be  | 
606  |  | // input/output value but could avoid unnecessary memory alloc/free if passing  | 
607  |  | // in.  | 
608  |  | void compute(  | 
609  |  |     int32_t hunits,  | 
610  |  |     const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,  | 
611  |  |     const ReadArray1D& x, Array1D& h, Array1D& c,  | 
612  |  |     Array1D& ifco)  | 
613  | 0  | { | 
614  |  |     // ifco = x * W + h * U + b  | 
615  | 0  |     ifco.assign(b)  | 
616  | 0  |         .addDotProduct(x, W)  | 
617  | 0  |         .addDotProduct(h, U);  | 
618  |  | 
  | 
619  | 0  |     ifco.slice(0*hunits, hunits).sigmoid();  // i: sigmod  | 
620  | 0  |     ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid  | 
621  | 0  |     ifco.slice(2*hunits, hunits).tanh(); // c_: tanh  | 
622  | 0  |     ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod  | 
623  |  | 
  | 
624  | 0  |     c.hadamardProduct(ifco.slice(hunits, hunits))  | 
625  | 0  |         .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));  | 
626  |  | 
  | 
627  | 0  |     h.tanh(c)  | 
628  | 0  |         .hadamardProduct(ifco.slice(3*hunits, hunits));  | 
629  | 0  | }  | 
630  |  |  | 
631  |  | // Minimum word size  | 
632  |  | static const int32_t MIN_WORD = 2;  | 
633  |  |  | 
634  |  | // Minimum number of characters for two words  | 
635  |  | static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;  | 
636  |  |  | 
637  |  | int32_t  | 
638  |  | LSTMBreakEngine::divideUpDictionaryRange( UText *text,  | 
639  |  |                                                 int32_t startPos,  | 
640  |  |                                                 int32_t endPos,  | 
641  |  |                                                 UVector32 &foundBreaks,  | 
642  | 0  |                                                 UErrorCode& status) const { | 
643  | 0  |     if (U_FAILURE(status)) return 0;  | 
644  | 0  |     int32_t beginFoundBreakSize = foundBreaks.size();  | 
645  | 0  |     utext_setNativeIndex(text, startPos);  | 
646  | 0  |     utext_moveIndex32(text, MIN_WORD_SPAN);  | 
647  | 0  |     if (utext_getNativeIndex(text) >= endPos) { | 
648  | 0  |         return 0;       // Not enough characters for two words  | 
649  | 0  |     }  | 
650  | 0  |     utext_setNativeIndex(text, startPos);  | 
651  |  | 
  | 
652  | 0  |     UVector32 offsets(status);  | 
653  | 0  |     UVector32 indices(status);  | 
654  | 0  |     if (U_FAILURE(status)) return 0;  | 
655  | 0  |     fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);  | 
656  | 0  |     if (U_FAILURE(status)) return 0;  | 
657  | 0  |     int32_t* offsetsBuf = offsets.getBuffer();  | 
658  | 0  |     int32_t* indicesBuf = indices.getBuffer();  | 
659  |  | 
  | 
660  | 0  |     int32_t input_seq_len = indices.size();  | 
661  | 0  |     int32_t hunits = fData->fForwardU.d1();  | 
662  |  |  | 
663  |  |     // ----- Begin of all the Array memory allocation needed for this function  | 
664  |  |     // Allocate temp array used inside compute()  | 
665  | 0  |     Array1D ifco(4 * hunits, status);  | 
666  |  | 
  | 
667  | 0  |     Array1D c(hunits, status);  | 
668  | 0  |     Array1D logp(4, status);  | 
669  |  |  | 
670  |  |     // TODO: limit size of hBackward. If input_seq_len is too big, we could  | 
671  |  |     // run out of memory.  | 
672  |  |     // Backward LSTM  | 
673  | 0  |     Array2D hBackward(input_seq_len, hunits, status);  | 
674  |  |  | 
675  |  |     // Allocate fbRow and slice the internal array in two.  | 
676  | 0  |     Array1D fbRow(2 * hunits, status);  | 
677  |  |  | 
678  |  |     // ----- End of all the Array memory allocation needed for this function  | 
679  | 0  |     if (U_FAILURE(status)) return 0;  | 
680  |  |  | 
681  |  |     // To save the needed memory usage, the following is different from the  | 
682  |  |     // Python or ICU4X implementation. We first perform the Backward LSTM  | 
683  |  |     // and then merge the iteration of the forward LSTM and the output layer  | 
684  |  |     // together because we only neetdto remember the h[t-1] for Forward LSTM.  | 
685  | 0  |     for (int32_t i = input_seq_len - 1; i >= 0; i--) { | 
686  | 0  |         Array1D hRow = hBackward.row(i);  | 
687  | 0  |         if (i != input_seq_len - 1) { | 
688  | 0  |             hRow.assign(hBackward.row(i+1));  | 
689  | 0  |         }  | 
690  |  | #ifdef LSTM_DEBUG  | 
691  |  |         printf("hRow %d\n", i); | 
692  |  |         hRow.print();  | 
693  |  |         printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]); | 
694  |  |         printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i); | 
695  |  |         fData->fEmbedding.row(indicesBuf[i]).print();  | 
696  |  | #endif  // LSTM_DEBUG  | 
697  | 0  |         compute(hunits,  | 
698  | 0  |                 fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,  | 
699  | 0  |                 fData->fEmbedding.row(indicesBuf[i]),  | 
700  | 0  |                 hRow, c, ifco);  | 
701  | 0  |     }  | 
702  |  |  | 
703  |  | 
  | 
704  | 0  |     Array1D forwardRow = fbRow.slice(0, hunits);  // point to first half of data in fbRow.  | 
705  | 0  |     Array1D backwardRow = fbRow.slice(hunits, hunits);  // point to second half of data n fbRow.  | 
706  |  |  | 
707  |  |     // The following iteration merge the forward LSTM and the output layer  | 
708  |  |     // together.  | 
709  | 0  |     c.clear();  // reuse c since it is the same size.  | 
710  | 0  |     for (int32_t i = 0; i < input_seq_len; i++) { | 
711  |  | #ifdef LSTM_DEBUG  | 
712  |  |         printf("forwardRow %d\n", i); | 
713  |  |         forwardRow.print();  | 
714  |  | #endif  // LSTM_DEBUG  | 
715  |  |         // Forward LSTM  | 
716  |  |         // Calculate the result into forwardRow, which point to the data in the first half  | 
717  |  |         // of fbRow.  | 
718  | 0  |         compute(hunits,  | 
719  | 0  |                 fData->fForwardW, fData->fForwardU, fData->fForwardB,  | 
720  | 0  |                 fData->fEmbedding.row(indicesBuf[i]),  | 
721  | 0  |                 forwardRow, c, ifco);  | 
722  |  |  | 
723  |  |         // assign the data from hBackward.row(i) to second half of fbRowa.  | 
724  | 0  |         backwardRow.assign(hBackward.row(i));  | 
725  |  | 
  | 
726  | 0  |         logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);  | 
727  |  | #ifdef LSTM_DEBUG  | 
728  |  |         printf("backwardRow %d\n", i); | 
729  |  |         backwardRow.print();  | 
730  |  |         printf("logp %d\n", i); | 
731  |  |         logp.print();  | 
732  |  | #endif  // LSTM_DEBUG  | 
733  |  |  | 
734  |  |         // current = argmax(logp)  | 
735  | 0  |         LSTMClass current = (LSTMClass)logp.maxIndex();  | 
736  |  |         // BIES logic.  | 
737  | 0  |         if (current == BEGIN || current == SINGLE) { | 
738  | 0  |             if (i != 0) { | 
739  | 0  |                 foundBreaks.addElement(offsetsBuf[i], status);  | 
740  | 0  |                 if (U_FAILURE(status)) return 0;  | 
741  | 0  |             }  | 
742  | 0  |         }  | 
743  | 0  |     }  | 
744  | 0  |     return foundBreaks.size() - beginFoundBreakSize;  | 
745  | 0  | }  | 
746  |  |  | 
747  | 0  | Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) { | 
748  | 0  |     if (U_FAILURE(status)) { | 
749  | 0  |         return nullptr;  | 
750  | 0  |     }  | 
751  | 0  |     switch (data->fType) { | 
752  | 0  |         case CODE_POINTS:  | 
753  | 0  |             return new CodePointsVectorizer(data->fDict);  | 
754  | 0  |             break;  | 
755  | 0  |         case GRAPHEME_CLUSTER:  | 
756  | 0  |             return new GraphemeClusterVectorizer(data->fDict);  | 
757  | 0  |             break;  | 
758  | 0  |         default:  | 
759  | 0  |             break;  | 
760  | 0  |     }  | 
761  | 0  |     UPRV_UNREACHABLE;  | 
762  | 0  | }  | 
763  |  |  | 
764  |  | LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)  | 
765  | 0  |     : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))  | 
766  | 0  | { | 
767  | 0  |     if (U_FAILURE(status)) { | 
768  | 0  |       fData = nullptr;  // If failure, we should not delete fData in destructor because the caller will do so.  | 
769  | 0  |       return;  | 
770  | 0  |     }  | 
771  | 0  |     setCharacters(set);  | 
772  | 0  | }  | 
773  |  |  | 
774  | 0  | LSTMBreakEngine::~LSTMBreakEngine() { | 
775  | 0  |     delete fData;  | 
776  | 0  |     delete fVectorizer;  | 
777  | 0  | }  | 
778  |  |  | 
779  | 0  | const UChar* LSTMBreakEngine::name() const { | 
780  | 0  |     return fData->fName;  | 
781  | 0  | }  | 
782  |  |  | 
783  | 0  | UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) { | 
784  |  |     // open root from brkitr tree.  | 
785  | 0  |     UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);  | 
786  | 0  |     b = ures_getByKeyWithFallback(b, "lstm", b, &status);  | 
787  | 0  |     UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);  | 
788  | 0  |     ures_close(b);  | 
789  | 0  |     return result;  | 
790  | 0  | }  | 
791  |  |  | 
792  |  | U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)  | 
793  | 0  | { | 
794  | 0  |     if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) { | 
795  | 0  |         return nullptr;  | 
796  | 0  |     }  | 
797  | 0  |     UnicodeString name = defaultLSTM(script, status);  | 
798  | 0  |     if (U_FAILURE(status)) return nullptr;  | 
799  | 0  |     CharString namebuf;  | 
800  | 0  |     namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.')); | 
801  |  | 
  | 
802  | 0  |     LocalUResourceBundlePointer rb(  | 
803  | 0  |         ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));  | 
804  | 0  |     if (U_FAILURE(status)) return nullptr;  | 
805  |  |  | 
806  | 0  |     return CreateLSTMData(rb.orphan(), status);  | 
807  | 0  | }  | 
808  |  |  | 
809  |  | U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)  | 
810  | 0  | { | 
811  | 0  |     return new LSTMData(rb, status);  | 
812  | 0  | }  | 
813  |  |  | 
814  |  | U_CAPI const LanguageBreakEngine* U_EXPORT2  | 
815  |  | CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)  | 
816  | 0  | { | 
817  | 0  |     UnicodeString unicodeSetString;  | 
818  | 0  |     switch(script) { | 
819  | 0  |         case USCRIPT_THAI:  | 
820  | 0  |             unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");  | 
821  | 0  |             break;  | 
822  | 0  |         case USCRIPT_MYANMAR:  | 
823  | 0  |             unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");  | 
824  | 0  |             break;  | 
825  | 0  |         default:  | 
826  | 0  |             delete data;  | 
827  | 0  |             return nullptr;  | 
828  | 0  |     }  | 
829  | 0  |     UnicodeSet unicodeSet;  | 
830  | 0  |     unicodeSet.applyPattern(unicodeSetString, status);  | 
831  | 0  |     const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);  | 
832  | 0  |     if (U_FAILURE(status) || engine == nullptr) { | 
833  | 0  |         if (engine != nullptr) { | 
834  | 0  |             delete engine;  | 
835  | 0  |         } else { | 
836  | 0  |             status = U_MEMORY_ALLOCATION_ERROR;  | 
837  | 0  |         }  | 
838  | 0  |         return nullptr;  | 
839  | 0  |     }  | 
840  | 0  |     return engine;  | 
841  | 0  | }  | 
842  |  |  | 
843  |  | U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)  | 
844  | 0  | { | 
845  | 0  |     delete data;  | 
846  | 0  | }  | 
847  |  |  | 
848  |  | U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data)  | 
849  | 0  | { | 
850  | 0  |     return data->fName;  | 
851  | 0  | }  | 
852  |  |  | 
853  |  | U_NAMESPACE_END  | 
854  |  |  | 
855  |  | #endif /* #if !UCONFIG_NO_BREAK_ITERATION */  |