Coverage Report

Created: 2025-06-24 06:43

/src/icu/source/common/lstmbe.cpp
Line
Count
Source (jump to first uncovered line)
1
// © 2021 and later: Unicode, Inc. and others.
2
// License & terms of use: http://www.unicode.org/copyright.html
3
4
#include <utility>
5
#include <ctgmath>
6
7
#include "unicode/utypes.h"
8
9
#if !UCONFIG_NO_BREAK_ITERATION
10
11
#include "brkeng.h"
12
#include "charstr.h"
13
#include "cmemory.h"
14
#include "lstmbe.h"
15
#include "putilimp.h"
16
#include "uassert.h"
17
#include "ubrkimpl.h"
18
#include "uresimp.h"
19
#include "uvectr32.h"
20
#include "uvector.h"
21
22
#include "unicode/brkiter.h"
23
#include "unicode/resbund.h"
24
#include "unicode/ubrk.h"
25
#include "unicode/uniset.h"
26
#include "unicode/ustring.h"
27
#include "unicode/utf.h"
28
29
U_NAMESPACE_BEGIN
30
31
// Uncomment the following #define to debug.
32
// #define LSTM_DEBUG 1
33
// #define LSTM_VECTORIZER_DEBUG 1
34
35
/**
36
 * Interface for reading 1D array.
37
 */
38
class ReadArray1D {
39
public:
40
    virtual ~ReadArray1D();
41
    virtual int32_t d1() const = 0;
42
    virtual float get(int32_t i) const = 0;
43
44
#ifdef LSTM_DEBUG
45
    void print() const {
46
        printf("\n[");
47
        for (int32_t i = 0; i < d1(); i++) {
48
           printf("%0.8e ", get(i));
49
           if (i % 4 == 3) printf("\n");
50
        }
51
        printf("]\n");
52
    }
53
#endif
54
};
55
56
ReadArray1D::~ReadArray1D()
57
0
{
58
0
}
59
60
/**
61
 * Interface for reading 2D array.
62
 */
63
class ReadArray2D {
64
public:
65
    virtual ~ReadArray2D();
66
    virtual int32_t d1() const = 0;
67
    virtual int32_t d2() const = 0;
68
    virtual float get(int32_t i, int32_t j) const = 0;
69
};
70
71
ReadArray2D::~ReadArray2D()
72
0
{
73
0
}
74
75
/**
76
 * A class to index a float array as a 1D Array without owning the pointer or
77
 * copy the data.
78
 */
79
class ConstArray1D : public ReadArray1D {
80
public:
81
0
    ConstArray1D() : data_(nullptr), d1_(0) {}
82
83
0
    ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
84
85
    virtual ~ConstArray1D();
86
87
    // Init the object, the object does not own the data nor copy.
88
    // It is designed to directly use data from memory mapped resources.
89
0
    void init(const int32_t* data, int32_t d1) {
90
0
        U_ASSERT(IEEE_754 == 1);
91
0
        data_ = reinterpret_cast<const float*>(data);
92
0
        d1_ = d1;
93
0
    }
94
95
    // ReadArray1D methods.
96
0
    virtual int32_t d1() const { return d1_; }
97
0
    virtual float get(int32_t i) const {
98
0
        U_ASSERT(i < d1_);
99
0
        return data_[i];
100
0
    }
101
102
private:
103
    const float* data_;
104
    int32_t d1_;
105
};
106
107
ConstArray1D::~ConstArray1D()
108
{
109
}
110
111
/**
112
 * A class to index a float array as a 2D Array without owning the pointer or
113
 * copy the data.
114
 */
115
class ConstArray2D : public ReadArray2D {
116
public:
117
0
    ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
118
119
    ConstArray2D(const float* data, int32_t d1, int32_t d2)
120
0
        : data_(data), d1_(d1), d2_(d2) {}
121
122
    virtual ~ConstArray2D();
123
124
    // Init the object, the object does not own the data nor copy.
125
    // It is designed to directly use data from memory mapped resources.
126
0
    void init(const int32_t* data, int32_t d1, int32_t d2) {
127
0
        U_ASSERT(IEEE_754 == 1);
128
0
        data_ = reinterpret_cast<const float*>(data);
129
0
        d1_ = d1;
130
0
        d2_ = d2;
131
0
    }
132
133
    // ReadArray2D methods.
134
0
    inline int32_t d1() const { return d1_; }
135
0
    inline int32_t d2() const { return d2_; }
136
0
    float get(int32_t i, int32_t j) const {
137
0
        U_ASSERT(i < d1_);
138
0
        U_ASSERT(j < d2_);
139
0
        return data_[i * d2_ + j];
140
0
    }
141
142
    // Expose the ith row as a ConstArray1D
143
0
    inline ConstArray1D row(int32_t i) const {
144
0
        U_ASSERT(i < d1_);
145
0
        return ConstArray1D(data_ + i * d2_, d2_);
146
0
    }
147
148
private:
149
    const float* data_;
150
    int32_t d1_;
151
    int32_t d2_;
152
};
153
154
ConstArray2D::~ConstArray2D()
155
{
156
}
157
158
/**
159
 * A class to allocate data as a writable 1D array.
160
 * This is the main class implement matrix operation.
161
 */
162
class Array1D : public ReadArray1D {
163
public:
164
0
    Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
165
    Array1D(int32_t d1, UErrorCode &status)
166
0
        : memory_(uprv_malloc(d1 * sizeof(float))),
167
0
          data_((float*)memory_), d1_(d1) {
168
0
        if (U_SUCCESS(status)) {
169
0
            if (memory_ == nullptr) {
170
0
                status = U_MEMORY_ALLOCATION_ERROR;
171
0
                return;
172
0
            }
173
0
            clear();
174
0
        }
175
0
    }
176
177
    virtual ~Array1D();
178
179
    // A special constructor which does not own the memory but writeable
180
    // as a slice of an array.
181
    Array1D(float* data, int32_t d1)
182
0
        : memory_(nullptr), data_(data), d1_(d1) {}
183
184
    // ReadArray1D methods.
185
0
    virtual int32_t d1() const { return d1_; }
186
0
    virtual float get(int32_t i) const {
187
0
        U_ASSERT(i < d1_);
188
0
        return data_[i];
189
0
    }
190
191
    // Return the index which point to the max data in the array.
192
0
    inline int32_t maxIndex() const {
193
0
        int32_t index = 0;
194
0
        float max = data_[0];
195
0
        for (int32_t i = 1; i < d1_; i++) {
196
0
            if (data_[i] > max) {
197
0
                max = data_[i];
198
0
                index = i;
199
0
            }
200
0
        }
201
0
        return index;
202
0
    }
203
204
    // Slice part of the array to a new one.
205
0
    inline Array1D slice(int32_t from, int32_t size) const {
206
0
        U_ASSERT(from >= 0);
207
0
        U_ASSERT(from < d1_);
208
0
        U_ASSERT(from + size <= d1_);
209
0
        return Array1D(data_ + from, size);
210
0
    }
211
212
    // Add dot product of a 1D array and a 2D array into this one.
213
0
    inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
214
0
        U_ASSERT(a.d1() == b.d1());
215
0
        U_ASSERT(b.d2() == d1());
216
0
        for (int32_t i = 0; i < d1(); i++) {
217
0
            for (int32_t j = 0; j < a.d1(); j++) {
218
0
                data_[i] += a.get(j) * b.get(j, i);
219
0
            }
220
0
        }
221
0
        return *this;
222
0
    }
223
224
    // Hadamard Product the values of another array of the same size into this one.
225
0
    inline Array1D& hadamardProduct(const ReadArray1D& a) {
226
0
        U_ASSERT(a.d1() == d1());
227
0
        for (int32_t i = 0; i < d1(); i++) {
228
0
            data_[i] *= a.get(i);
229
0
        }
230
0
        return *this;
231
0
    }
232
233
    // Add the Hadamard Product of two arrays of the same size into this one.
234
0
    inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
235
0
        U_ASSERT(a.d1() == d1());
236
0
        U_ASSERT(b.d1() == d1());
237
0
        for (int32_t i = 0; i < d1(); i++) {
238
0
            data_[i] += a.get(i) * b.get(i);
239
0
        }
240
0
        return *this;
241
0
    }
242
243
    // Add the values of another array of the same size into this one.
244
0
    inline Array1D& add(const ReadArray1D& a) {
245
0
        U_ASSERT(a.d1() == d1());
246
0
        for (int32_t i = 0; i < d1(); i++) {
247
0
            data_[i] += a.get(i);
248
0
        }
249
0
        return *this;
250
0
    }
251
252
    // Assign the values of another array of the same size into this one.
253
0
    inline Array1D& assign(const ReadArray1D& a) {
254
0
        U_ASSERT(a.d1() == d1());
255
0
        for (int32_t i = 0; i < d1(); i++) {
256
0
            data_[i] = a.get(i);
257
0
        }
258
0
        return *this;
259
0
    }
260
261
    // Apply tanh to all the elements in the array.
262
0
    inline Array1D& tanh() {
263
0
        return tanh(*this);
264
0
    }
265
266
    // Apply tanh of a and store into this array.
267
0
    inline Array1D& tanh(const Array1D& a) {
268
0
        U_ASSERT(a.d1() == d1());
269
0
        for (int32_t i = 0; i < d1_; i++) {
270
0
            data_[i] = std::tanh(a.get(i));
271
0
        }
272
0
        return *this;
273
0
    }
274
275
    // Apply sigmoid to all the elements in the array.
276
0
    inline Array1D& sigmoid() {
277
0
        for (int32_t i = 0; i < d1_; i++) {
278
0
            data_[i] = 1.0f/(1.0f + expf(-data_[i]));
279
0
        }
280
0
        return *this;
281
0
    }
282
283
0
    inline Array1D& clear() {
284
0
        uprv_memset(data_, 0, d1_ * sizeof(float));
285
0
        return *this;
286
0
    }
287
288
private:
289
    void* memory_;
290
    float* data_;
291
    int32_t d1_;
292
};
293
294
Array1D::~Array1D()
295
0
{
296
0
    uprv_free(memory_);
297
0
}
298
299
class Array2D : public ReadArray2D {
300
public:
301
0
    Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
302
    Array2D(int32_t d1, int32_t d2, UErrorCode &status)
303
0
        : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
304
0
          data_((float*)memory_), d1_(d1), d2_(d2) {
305
0
        if (U_SUCCESS(status)) {
306
0
            if (memory_ == nullptr) {
307
0
                status = U_MEMORY_ALLOCATION_ERROR;
308
0
                return;
309
0
            }
310
0
            clear();
311
0
        }
312
0
    }
313
    virtual ~Array2D();
314
315
    // ReadArray2D methods.
316
0
    virtual int32_t d1() const { return d1_; }
317
0
    virtual int32_t d2() const { return d2_; }
318
0
    virtual float get(int32_t i, int32_t j) const {
319
0
        U_ASSERT(i < d1_);
320
0
        U_ASSERT(j < d2_);
321
0
        return data_[i * d2_ + j];
322
0
    }
323
324
0
    inline Array1D row(int32_t i) const {
325
0
        U_ASSERT(i < d1_);
326
0
        return Array1D(data_ + i * d2_, d2_);
327
0
    }
328
329
0
    inline Array2D& clear() {
330
0
        uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
331
0
        return *this;
332
0
    }
333
334
private:
335
    void* memory_;
336
    float* data_;
337
    int32_t d1_;
338
    int32_t d2_;
339
};
340
341
Array2D::~Array2D()
342
0
{
343
0
    uprv_free(memory_);
344
0
}
345
346
typedef enum {
347
    BEGIN,
348
    INSIDE,
349
    END,
350
    SINGLE
351
} LSTMClass;
352
353
typedef enum {
354
    UNKNOWN,
355
    CODE_POINTS,
356
    GRAPHEME_CLUSTER,
357
} EmbeddingType;
358
359
struct LSTMData : public UMemory {
360
    LSTMData(UResourceBundle* rb, UErrorCode &status);
361
    ~LSTMData();
362
    UHashtable* fDict;
363
    EmbeddingType fType;
364
    const UChar* fName;
365
    ConstArray2D fEmbedding;
366
    ConstArray2D fForwardW;
367
    ConstArray2D fForwardU;
368
    ConstArray1D fForwardB;
369
    ConstArray2D fBackwardW;
370
    ConstArray2D fBackwardU;
371
    ConstArray1D fBackwardB;
372
    ConstArray2D fOutputW;
373
    ConstArray1D fOutputB;
374
375
private:
376
    UResourceBundle* fBundle;
377
};
378
379
LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
380
0
    : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
381
0
      fBundle(rb)
382
0
{
383
0
    if (U_FAILURE(status)) {
384
0
        return;
385
0
    }
386
0
    if (IEEE_754 != 1) {
387
0
        status = U_UNSUPPORTED_ERROR;
388
0
        return;
389
0
    }
390
0
    LocalUResourceBundlePointer embeddings_res(
391
0
        ures_getByKey(rb, "embeddings", nullptr, &status));
392
0
    int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
393
0
    LocalUResourceBundlePointer hunits_res(
394
0
        ures_getByKey(rb, "hunits", nullptr, &status));
395
0
    if (U_FAILURE(status)) return;
396
0
    int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
397
0
    const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
398
0
    if (U_FAILURE(status)) return;
399
0
    if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
400
0
        fType = CODE_POINTS;
401
0
    } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
402
0
        fType = GRAPHEME_CLUSTER;
403
0
    }
404
0
    fName = ures_getStringByKey(rb, "model", nullptr, &status);
405
0
    LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
406
0
    if (U_FAILURE(status)) return;
407
0
    int32_t data_len = 0;
408
0
    const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
409
0
    fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
410
411
0
    StackUResourceBundle stackTempBundle;
412
0
    ResourceDataValue value;
413
0
    ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
414
0
    ResourceArray stringArray = value.getArray(status);
415
0
    int32_t num_index = stringArray.getSize();
416
0
    if (U_FAILURE(status)) { return; }
417
418
    // put dict into hash
419
0
    int32_t stringLength;
420
0
    for (int32_t idx = 0; idx < num_index; idx++) {
421
0
        stringArray.getValue(idx, value);
422
0
        const UChar* str = value.getString(stringLength, status);
423
0
        uhash_putiAllowZero(fDict, (void*)str, idx, &status);
424
0
        if (U_FAILURE(status)) return;
425
#ifdef LSTM_VECTORIZER_DEBUG
426
        printf("Assign [");
427
        while (*str != 0x0000) {
428
            printf("U+%04x ", *str);
429
            str++;
430
        }
431
        printf("] map to %d\n", idx-1);
432
#endif
433
0
    }
434
0
    int32_t mat1_size = (num_index + 1) * embedding_size;
435
0
    int32_t mat2_size = embedding_size * 4 * hunits;
436
0
    int32_t mat3_size = hunits * 4 * hunits;
437
0
    int32_t mat4_size = 4 * hunits;
438
0
    int32_t mat5_size = mat2_size;
439
0
    int32_t mat6_size = mat3_size;
440
0
    int32_t mat7_size = mat4_size;
441
0
    int32_t mat8_size = 2 * hunits * 4;
442
#if U_DEBUG
443
    int32_t mat9_size = 4;
444
    U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
445
        mat6_size + mat7_size + mat8_size + mat9_size);
446
#endif
447
448
0
    fEmbedding.init(data, (num_index + 1), embedding_size);
449
0
    data += mat1_size;
450
0
    fForwardW.init(data, embedding_size, 4 * hunits);
451
0
    data += mat2_size;
452
0
    fForwardU.init(data, hunits, 4 * hunits);
453
0
    data += mat3_size;
454
0
    fForwardB.init(data, 4 * hunits);
455
0
    data += mat4_size;
456
0
    fBackwardW.init(data, embedding_size, 4 * hunits);
457
0
    data += mat5_size;
458
0
    fBackwardU.init(data, hunits, 4 * hunits);
459
0
    data += mat6_size;
460
0
    fBackwardB.init(data, 4 * hunits);
461
0
    data += mat7_size;
462
0
    fOutputW.init(data, 2 * hunits, 4);
463
0
    data += mat8_size;
464
0
    fOutputB.init(data, 4);
465
0
}
466
467
0
LSTMData::~LSTMData() {
468
0
    uhash_close(fDict);
469
0
    ures_close(fBundle);
470
0
}
471
472
class Vectorizer : public UMemory {
473
public:
474
0
    Vectorizer(UHashtable* dict) : fDict(dict) {}
475
    virtual ~Vectorizer();
476
    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
477
                           UVector32 &offsets, UVector32 &indices,
478
                           UErrorCode &status) const = 0;
479
protected:
480
0
    int32_t stringToIndex(const UChar* str) const {
481
0
        UBool found = false;
482
0
        int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
483
0
        if (!found) {
484
0
            ret = fDict->count;
485
0
        }
486
#ifdef LSTM_VECTORIZER_DEBUG
487
        printf("[");
488
        while (*str != 0x0000) {
489
            printf("U+%04x ", *str);
490
            str++;
491
        }
492
        printf("] map to %d\n", ret);
493
#endif
494
0
        return ret;
495
0
    }
496
497
private:
498
    UHashtable* fDict;
499
};
500
501
Vectorizer::~Vectorizer()
502
0
{
503
0
}
504
505
class CodePointsVectorizer : public Vectorizer {
506
public:
507
0
    CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
508
    virtual ~CodePointsVectorizer();
509
    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
510
                           UVector32 &offsets, UVector32 &indices,
511
                           UErrorCode &status) const;
512
};
513
514
CodePointsVectorizer::~CodePointsVectorizer()
515
{
516
}
517
518
void CodePointsVectorizer::vectorize(
519
    UText *text, int32_t startPos, int32_t endPos,
520
    UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
521
0
{
522
0
    if (offsets.ensureCapacity(endPos - startPos, status) &&
523
0
            indices.ensureCapacity(endPos - startPos, status)) {
524
0
        if (U_FAILURE(status)) return;
525
0
        utext_setNativeIndex(text, startPos);
526
0
        int32_t current;
527
0
        UChar str[2] = {0, 0};
528
0
        while (U_SUCCESS(status) &&
529
0
               (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
530
            // Since the LSTMBreakEngine is currently only accept chars in BMP,
531
            // we can ignore the possibility of hitting supplementary code
532
            // point.
533
0
            str[0] = (UChar) utext_next32(text);
534
0
            U_ASSERT(!U_IS_SURROGATE(str[0]));
535
0
            offsets.addElement(current, status);
536
0
            indices.addElement(stringToIndex(str), status);
537
0
        }
538
0
    }
539
0
}
540
541
class GraphemeClusterVectorizer : public Vectorizer {
542
public:
543
    GraphemeClusterVectorizer(UHashtable* dict)
544
0
        : Vectorizer(dict)
545
0
    {
546
0
    }
547
    virtual ~GraphemeClusterVectorizer();
548
    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
549
                           UVector32 &offsets, UVector32 &indices,
550
                           UErrorCode &status) const;
551
};
552
553
GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
554
{
555
}
556
557
constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
558
559
void GraphemeClusterVectorizer::vectorize(
560
    UText *text, int32_t startPos, int32_t endPos,
561
    UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
562
0
{
563
0
    if (U_FAILURE(status)) return;
564
0
    if (!offsets.ensureCapacity(endPos - startPos, status) ||
565
0
            !indices.ensureCapacity(endPos - startPos, status)) {
566
0
        return;
567
0
    }
568
0
    if (U_FAILURE(status)) return;
569
0
    LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
570
0
    if (U_FAILURE(status)) return;
571
0
    graphemeIter->setText(text, status);
572
0
    if (U_FAILURE(status)) return;
573
574
0
    if (startPos != 0) {
575
0
        graphemeIter->preceding(startPos);
576
0
    }
577
0
    int32_t last = startPos;
578
0
    int32_t current = startPos;
579
0
    UChar str[MAX_GRAPHEME_CLSTER_LENGTH];
580
0
    while ((current = graphemeIter->next()) != BreakIterator::DONE) {
581
0
        if (current >= endPos) {
582
0
            break;
583
0
        }
584
0
        if (current > startPos) {
585
0
            utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
586
0
            if (U_FAILURE(status)) return;
587
0
            offsets.addElement(last, status);
588
0
            indices.addElement(stringToIndex(str), status);
589
0
            if (U_FAILURE(status)) return;
590
0
        }
591
0
        last = current;
592
0
    }
593
0
    if (U_FAILURE(status) || last >= endPos) {
594
0
        return;
595
0
    }
596
0
    utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
597
0
    if (U_SUCCESS(status)) {
598
0
        offsets.addElement(last, status);
599
0
        indices.addElement(stringToIndex(str), status);
600
0
    }
601
0
}
602
603
// Computing LSTM as stated in
604
// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
605
// ifco is temp array allocate outside which does not need to be
606
// input/output value but could avoid unnecessary memory alloc/free if passing
607
// in.
608
void compute(
609
    int32_t hunits,
610
    const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
611
    const ReadArray1D& x, Array1D& h, Array1D& c,
612
    Array1D& ifco)
613
0
{
614
    // ifco = x * W + h * U + b
615
0
    ifco.assign(b)
616
0
        .addDotProduct(x, W)
617
0
        .addDotProduct(h, U);
618
619
0
    ifco.slice(0*hunits, hunits).sigmoid();  // i: sigmod
620
0
    ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
621
0
    ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
622
0
    ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
623
624
0
    c.hadamardProduct(ifco.slice(hunits, hunits))
625
0
        .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
626
627
0
    h.tanh(c)
628
0
        .hadamardProduct(ifco.slice(3*hunits, hunits));
629
0
}
630
631
// Minimum word size
632
static const int32_t MIN_WORD = 2;
633
634
// Minimum number of characters for two words
635
static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
636
637
int32_t
638
LSTMBreakEngine::divideUpDictionaryRange( UText *text,
639
                                                int32_t startPos,
640
                                                int32_t endPos,
641
                                                UVector32 &foundBreaks,
642
0
                                                UErrorCode& status) const {
643
0
    if (U_FAILURE(status)) return 0;
644
0
    int32_t beginFoundBreakSize = foundBreaks.size();
645
0
    utext_setNativeIndex(text, startPos);
646
0
    utext_moveIndex32(text, MIN_WORD_SPAN);
647
0
    if (utext_getNativeIndex(text) >= endPos) {
648
0
        return 0;       // Not enough characters for two words
649
0
    }
650
0
    utext_setNativeIndex(text, startPos);
651
652
0
    UVector32 offsets(status);
653
0
    UVector32 indices(status);
654
0
    if (U_FAILURE(status)) return 0;
655
0
    fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
656
0
    if (U_FAILURE(status)) return 0;
657
0
    int32_t* offsetsBuf = offsets.getBuffer();
658
0
    int32_t* indicesBuf = indices.getBuffer();
659
660
0
    int32_t input_seq_len = indices.size();
661
0
    int32_t hunits = fData->fForwardU.d1();
662
663
    // ----- Begin of all the Array memory allocation needed for this function
664
    // Allocate temp array used inside compute()
665
0
    Array1D ifco(4 * hunits, status);
666
667
0
    Array1D c(hunits, status);
668
0
    Array1D logp(4, status);
669
670
    // TODO: limit size of hBackward. If input_seq_len is too big, we could
671
    // run out of memory.
672
    // Backward LSTM
673
0
    Array2D hBackward(input_seq_len, hunits, status);
674
675
    // Allocate fbRow and slice the internal array in two.
676
0
    Array1D fbRow(2 * hunits, status);
677
678
    // ----- End of all the Array memory allocation needed for this function
679
0
    if (U_FAILURE(status)) return 0;
680
681
    // To save the needed memory usage, the following is different from the
682
    // Python or ICU4X implementation. We first perform the Backward LSTM
683
    // and then merge the iteration of the forward LSTM and the output layer
684
    // together because we only neetdto remember the h[t-1] for Forward LSTM.
685
0
    for (int32_t i = input_seq_len - 1; i >= 0; i--) {
686
0
        Array1D hRow = hBackward.row(i);
687
0
        if (i != input_seq_len - 1) {
688
0
            hRow.assign(hBackward.row(i+1));
689
0
        }
690
#ifdef LSTM_DEBUG
691
        printf("hRow %d\n", i);
692
        hRow.print();
693
        printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
694
        printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
695
        fData->fEmbedding.row(indicesBuf[i]).print();
696
#endif  // LSTM_DEBUG
697
0
        compute(hunits,
698
0
                fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
699
0
                fData->fEmbedding.row(indicesBuf[i]),
700
0
                hRow, c, ifco);
701
0
    }
702
703
704
0
    Array1D forwardRow = fbRow.slice(0, hunits);  // point to first half of data in fbRow.
705
0
    Array1D backwardRow = fbRow.slice(hunits, hunits);  // point to second half of data n fbRow.
706
707
    // The following iteration merge the forward LSTM and the output layer
708
    // together.
709
0
    c.clear();  // reuse c since it is the same size.
710
0
    for (int32_t i = 0; i < input_seq_len; i++) {
711
#ifdef LSTM_DEBUG
712
        printf("forwardRow %d\n", i);
713
        forwardRow.print();
714
#endif  // LSTM_DEBUG
715
        // Forward LSTM
716
        // Calculate the result into forwardRow, which point to the data in the first half
717
        // of fbRow.
718
0
        compute(hunits,
719
0
                fData->fForwardW, fData->fForwardU, fData->fForwardB,
720
0
                fData->fEmbedding.row(indicesBuf[i]),
721
0
                forwardRow, c, ifco);
722
723
        // assign the data from hBackward.row(i) to second half of fbRowa.
724
0
        backwardRow.assign(hBackward.row(i));
725
726
0
        logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
727
#ifdef LSTM_DEBUG
728
        printf("backwardRow %d\n", i);
729
        backwardRow.print();
730
        printf("logp %d\n", i);
731
        logp.print();
732
#endif  // LSTM_DEBUG
733
734
        // current = argmax(logp)
735
0
        LSTMClass current = (LSTMClass)logp.maxIndex();
736
        // BIES logic.
737
0
        if (current == BEGIN || current == SINGLE) {
738
0
            if (i != 0) {
739
0
                foundBreaks.addElement(offsetsBuf[i], status);
740
0
                if (U_FAILURE(status)) return 0;
741
0
            }
742
0
        }
743
0
    }
744
0
    return foundBreaks.size() - beginFoundBreakSize;
745
0
}
746
747
0
Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
748
0
    if (U_FAILURE(status)) {
749
0
        return nullptr;
750
0
    }
751
0
    switch (data->fType) {
752
0
        case CODE_POINTS:
753
0
            return new CodePointsVectorizer(data->fDict);
754
0
            break;
755
0
        case GRAPHEME_CLUSTER:
756
0
            return new GraphemeClusterVectorizer(data->fDict);
757
0
            break;
758
0
        default:
759
0
            break;
760
0
    }
761
0
    UPRV_UNREACHABLE;
762
0
}
763
764
LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
765
0
    : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
766
0
{
767
0
    if (U_FAILURE(status)) {
768
0
      fData = nullptr;  // If failure, we should not delete fData in destructor because the caller will do so.
769
0
      return;
770
0
    }
771
0
    setCharacters(set);
772
0
}
773
774
0
LSTMBreakEngine::~LSTMBreakEngine() {
775
0
    delete fData;
776
0
    delete fVectorizer;
777
0
}
778
779
0
const UChar* LSTMBreakEngine::name() const {
780
0
    return fData->fName;
781
0
}
782
783
0
UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
784
    // open root from brkitr tree.
785
0
    UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
786
0
    b = ures_getByKeyWithFallback(b, "lstm", b, &status);
787
0
    UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
788
0
    ures_close(b);
789
0
    return result;
790
0
}
791
792
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
793
0
{
794
0
    if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
795
0
        return nullptr;
796
0
    }
797
0
    UnicodeString name = defaultLSTM(script, status);
798
0
    if (U_FAILURE(status)) return nullptr;
799
0
    CharString namebuf;
800
0
    namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
801
802
0
    LocalUResourceBundlePointer rb(
803
0
        ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
804
0
    if (U_FAILURE(status)) return nullptr;
805
806
0
    return CreateLSTMData(rb.orphan(), status);
807
0
}
808
809
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
810
0
{
811
0
    return new LSTMData(rb, status);
812
0
}
813
814
U_CAPI const LanguageBreakEngine* U_EXPORT2
815
CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
816
0
{
817
0
    UnicodeString unicodeSetString;
818
0
    switch(script) {
819
0
        case USCRIPT_THAI:
820
0
            unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
821
0
            break;
822
0
        case USCRIPT_MYANMAR:
823
0
            unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
824
0
            break;
825
0
        default:
826
0
            delete data;
827
0
            return nullptr;
828
0
    }
829
0
    UnicodeSet unicodeSet;
830
0
    unicodeSet.applyPattern(unicodeSetString, status);
831
0
    const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
832
0
    if (U_FAILURE(status) || engine == nullptr) {
833
0
        if (engine != nullptr) {
834
0
            delete engine;
835
0
        } else {
836
0
            status = U_MEMORY_ALLOCATION_ERROR;
837
0
        }
838
0
        return nullptr;
839
0
    }
840
0
    return engine;
841
0
}
842
843
U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
844
0
{
845
0
    delete data;
846
0
}
847
848
U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data)
849
0
{
850
0
    return data->fName;
851
0
}
852
853
U_NAMESPACE_END
854
855
#endif /* #if !UCONFIG_NO_BREAK_ITERATION */