Coverage Report

Created: 2026-01-10 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-grammar.cpp
Line
Count
Source
1
#include "llama-grammar.h"
2
3
#include "llama-impl.h"
4
#include "llama-vocab.h"
5
#include "llama-sampling.h"
6
7
#include <cmath>
8
#include <algorithm>
9
#include <cstdint>
10
#include <stdexcept>
11
12
0
#define MAX_REPETITION_THRESHOLD 2000
13
//
14
// helpers
15
//
16
17
// NOTE: assumes valid utf8 (but checks for overrun)
18
0
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
19
0
    static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
20
0
    uint8_t  first_byte = static_cast<uint8_t>(*src);
21
0
    uint8_t  highbits   = first_byte >> 4;
22
0
    int      len        = lookup[highbits];
23
0
    uint8_t  mask       = (1 << (8 - len)) - 1;
24
0
    uint32_t value      = first_byte & mask;
25
0
    const char * end    = src + len; // may overrun!
26
0
    const char * pos    = src + 1;
27
0
    for ( ; pos < end && *pos; pos++) {
28
0
        value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
29
0
    }
30
0
    return std::make_pair(value, pos);
31
0
}
32
33
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
34
        const std::string & src,
35
0
        llama_partial_utf8 partial_start) {
36
0
    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
37
0
    const char          * pos      = src.c_str();
38
0
    std::vector<uint32_t> code_points;
39
40
    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
41
0
    code_points.reserve(src.size() + 1);
42
0
    uint32_t value    = partial_start.value;
43
0
    int      n_remain = partial_start.n_remain;
44
45
    // continue previous decode, if applicable
46
0
    while (*pos != 0 && n_remain > 0) {
47
0
        uint8_t next_byte = static_cast<uint8_t>(*pos);
48
0
        if ((next_byte >> 6) != 2) {
49
            // invalid sequence, abort
50
0
            code_points.push_back(0);
51
0
            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
52
0
        }
53
0
        value = (value << 6) + (next_byte & 0x3F);
54
0
        ++pos;
55
0
        --n_remain;
56
0
    }
57
58
0
    if (partial_start.n_remain > 0 && n_remain == 0) {
59
0
        code_points.push_back(value);
60
0
    }
61
62
    // decode any subsequent utf-8 sequences, which may end in an incomplete one
63
0
    while (*pos != 0) {
64
0
        uint8_t first_byte = static_cast<uint8_t>(*pos);
65
0
        uint8_t highbits   = first_byte >> 4;
66
0
        n_remain   = lookup[highbits] - 1;
67
68
0
        if (n_remain < 0) {
69
            // invalid sequence, abort
70
0
            code_points.clear();
71
0
            code_points.push_back(0);
72
0
            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
73
0
        }
74
75
0
        uint8_t mask  = (1 << (7 - n_remain)) - 1;
76
0
        value = first_byte & mask;
77
78
0
        ++pos;
79
0
        while (*pos != 0 && n_remain > 0) {
80
0
            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
81
0
            ++pos;
82
0
            --n_remain;
83
0
        }
84
0
        if (n_remain == 0) {
85
0
            code_points.push_back(value);
86
0
        }
87
0
    }
88
0
    code_points.push_back(0);
89
90
0
    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
91
0
}
92
93
0
static bool is_digit_char(char c) {
94
0
    return '0' <= c && c <= '9';
95
0
}
96
97
0
static bool is_word_char(char c) {
98
0
    return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
99
0
}
100
101
0
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
102
0
    const char * pos   = src;
103
0
    const char * end   = src + size;
104
0
    uint32_t     value = 0;
105
0
    for ( ; pos < end && *pos; pos++) {
106
0
        value <<= 4;
107
0
        char c = *pos;
108
0
        if ('a' <= c && c <= 'f') {
109
0
            value += c - 'a' + 10;
110
0
        } else if ('A' <= c && c <= 'F') {
111
0
            value += c - 'A' + 10;
112
0
        } else if ('0' <= c && c <= '9') {
113
0
            value += c - '0';
114
0
        } else {
115
0
            break;
116
0
        }
117
0
    }
118
0
    if (pos != end) {
119
0
        throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
120
0
    }
121
0
    return std::make_pair(value, pos);
122
0
}
123
124
0
static const char * parse_space(const char * src, bool newline_ok) {
125
0
    const char * pos = src;
126
0
    while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
127
0
            (newline_ok && (*pos == '\r' || *pos == '\n'))) {
128
0
        if (*pos == '#') {
129
0
            while (*pos && *pos != '\r' && *pos != '\n') {
130
0
                pos++;
131
0
            }
132
0
        } else {
133
0
            pos++;
134
0
        }
135
0
    }
136
0
    return pos;
137
0
}
138
139
0
static const char * parse_name(const char * src) {
140
0
    const char * pos = src;
141
0
    while (is_word_char(*pos)) {
142
0
        pos++;
143
0
    }
144
0
    if (pos == src) {
145
0
        throw std::runtime_error(std::string("expecting name at ") + src);
146
0
    }
147
0
    return pos;
148
0
}
149
150
0
static const char * parse_int(const char * src) {
151
0
    const char * pos = src;
152
0
    while (is_digit_char(*pos)) {
153
0
        pos++;
154
0
    }
155
0
    if (pos == src) {
156
0
        throw std::runtime_error(std::string("expecting integer at ") + src);
157
0
    }
158
0
    return pos;
159
0
}
160
161
0
static std::pair<uint32_t, const char *> parse_char(const char * src) {
162
0
    if (*src == '\\') {
163
0
        switch (src[1]) {
164
0
            case 'x': return parse_hex(src + 2, 2);
165
0
            case 'u': return parse_hex(src + 2, 4);
166
0
            case 'U': return parse_hex(src + 2, 8);
167
0
            case 't': return std::make_pair('\t', src + 2);
168
0
            case 'r': return std::make_pair('\r', src + 2);
169
0
            case 'n': return std::make_pair('\n', src + 2);
170
0
            case '\\':
171
0
            case '"':
172
0
            case '[':
173
0
            case ']':
174
0
                      return std::make_pair(src[1], src + 2);
175
0
            default:
176
0
                      throw std::runtime_error(std::string("unknown escape at ") + src);
177
0
        }
178
0
    } else if (*src) {
179
0
        return decode_utf8(src);
180
0
    }
181
0
    throw std::runtime_error("unexpected end of input");
182
0
}
183
184
0
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
185
0
    const char * pos = src;
186
0
    if (*pos != '<') {
187
0
        throw std::runtime_error(std::string("expecting '<' at ") + pos);
188
0
    }
189
0
    pos++;
190
191
    // Parse <[id]>
192
0
    if (*pos == '[') {
193
0
        pos++;
194
0
        const char * int_end = parse_int(pos);
195
0
        uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
196
0
        pos = int_end;
197
0
        if (*pos != ']') {
198
0
            throw std::runtime_error(std::string("expecting ']' at ") + pos);
199
0
        }
200
0
        pos++;
201
0
        if (*pos != '>') {
202
0
            throw std::runtime_error(std::string("expecting '>' at ") + pos);
203
0
        }
204
0
        pos++;
205
0
        return std::make_pair(token_id, pos);
206
0
    }
207
208
0
    if (vocab == nullptr) {
209
0
        throw std::runtime_error(std::string("no vocab to parse token at ") + src);
210
0
    }
211
212
    // Parse <token> and tokenize to obtain the token id
213
0
    while (*pos != 0 && *pos != '>') {
214
0
        pos++;
215
0
    }
216
0
    if (*pos != '>') {
217
0
        throw std::runtime_error(std::string("expecting '>' at ") + pos);
218
0
    }
219
0
    pos++;
220
221
0
    llama_token tokens[2];
222
0
    int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
223
0
    if (n_tokens != 1) {
224
        // must tokenize to exactly 1 token
225
0
        throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
226
0
    }
227
0
    return std::make_pair(tokens[0], pos);
228
0
}
229
230
0
static void print_grammar_char(FILE * file, uint32_t c) {
231
0
    if (0x20 <= c && c <= 0x7f) {
232
0
        fprintf(file, "%c", static_cast<char>(c));
233
0
    } else {
234
        // cop out of encoding UTF-8
235
0
        fprintf(file, "<U+%04X>", c);
236
0
    }
237
0
}
238
239
0
static bool is_char_element(llama_grammar_element elem) {
240
0
    switch (elem.type) {
241
0
        case LLAMA_GRETYPE_CHAR:           return true;
242
0
        case LLAMA_GRETYPE_CHAR_NOT:       return true;
243
0
        case LLAMA_GRETYPE_CHAR_ALT:       return true;
244
0
        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
245
0
        case LLAMA_GRETYPE_CHAR_ANY:       return true;
246
0
        default:                           return false;
247
0
    }
248
0
}
249
250
0
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
251
0
    for (auto elem : rule) {
252
0
        switch (elem.type) {
253
0
            case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
254
0
            case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
255
0
            case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
256
0
            case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
257
0
            case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
258
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
259
0
            case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
260
0
            case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
261
0
            case LLAMA_GRETYPE_TOKEN:          fprintf(file, "TOKEN");          break;
262
0
            case LLAMA_GRETYPE_TOKEN_NOT:      fprintf(file, "TOKEN_NOT");      break;
263
0
        }
264
0
        switch (elem.type) {
265
0
            case LLAMA_GRETYPE_END:
266
0
            case LLAMA_GRETYPE_ALT:
267
0
            case LLAMA_GRETYPE_RULE_REF:
268
0
                fprintf(file, "(%u) ", elem.value);
269
0
                break;
270
0
            case LLAMA_GRETYPE_CHAR:
271
0
            case LLAMA_GRETYPE_CHAR_NOT:
272
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
273
0
            case LLAMA_GRETYPE_CHAR_ALT:
274
0
            case LLAMA_GRETYPE_CHAR_ANY:
275
0
                fprintf(file, "(\"");
276
0
                print_grammar_char(file, elem.value);
277
0
                fprintf(file, "\") ");
278
0
                break;
279
0
            case LLAMA_GRETYPE_TOKEN:
280
0
                fprintf(file, "<[");
281
0
                fprintf(file, "%u", elem.value);
282
0
                fprintf(file, "]> ");
283
0
                break;
284
0
            case LLAMA_GRETYPE_TOKEN_NOT:
285
0
                fprintf(file, "!");
286
0
                fprintf(file, "<[");
287
0
                fprintf(file, "%u", elem.value);
288
0
                fprintf(file, "]> ");
289
0
                break;
290
0
        }
291
0
    }
292
0
    fprintf(file, "\n");
293
0
}
294
295
static void print_rule(
296
        FILE     * file,
297
        uint32_t   rule_id,
298
        const llama_grammar_rule & rule,
299
0
        const std::map<uint32_t, std::string> & symbol_id_names) {
300
0
    if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
301
0
        throw std::runtime_error(
302
0
            "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
303
0
    }
304
0
    fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
305
0
    for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
306
0
        llama_grammar_element elem = rule[i];
307
0
        switch (elem.type) {
308
0
            case LLAMA_GRETYPE_END:
309
0
                throw std::runtime_error(
310
0
                    "unexpected end of rule: " + std::to_string(rule_id) + "," +
311
0
                    std::to_string(i));
312
0
            case LLAMA_GRETYPE_ALT:
313
0
                fprintf(file, "| ");
314
0
                break;
315
0
            case LLAMA_GRETYPE_RULE_REF:
316
0
                fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
317
0
                break;
318
0
            case LLAMA_GRETYPE_CHAR:
319
0
                fprintf(file, "[");
320
0
                print_grammar_char(file, elem.value);
321
0
                break;
322
0
            case LLAMA_GRETYPE_CHAR_NOT:
323
0
                fprintf(file, "[^");
324
0
                print_grammar_char(file, elem.value);
325
0
                break;
326
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
327
0
                if (i == 0 || !is_char_element(rule[i - 1])) {
328
0
                    throw std::runtime_error(
329
0
                        "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
330
0
                        std::to_string(rule_id) + "," + std::to_string(i));
331
0
                }
332
0
                fprintf(file, "-");
333
0
                print_grammar_char(file, elem.value);
334
0
                break;
335
0
            case LLAMA_GRETYPE_CHAR_ALT:
336
0
                if (i == 0 || !is_char_element(rule[i - 1])) {
337
0
                    throw std::runtime_error(
338
0
                        "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
339
0
                        std::to_string(rule_id) + "," + std::to_string(i));
340
0
                }
341
0
                print_grammar_char(file, elem.value);
342
0
                break;
343
0
            case LLAMA_GRETYPE_CHAR_ANY:
344
0
                fprintf(file, ".");
345
0
                break;
346
0
            case LLAMA_GRETYPE_TOKEN:
347
0
                fprintf(file, "<[");
348
0
                fprintf(file, "%u", elem.value);
349
0
                fprintf(file, "]> ");
350
0
                break;
351
0
            case LLAMA_GRETYPE_TOKEN_NOT:
352
0
                fprintf(file, "!");
353
0
                fprintf(file, "<[");
354
0
                fprintf(file, "%u", elem.value);
355
0
                fprintf(file, "]> ");
356
0
                break;
357
0
        }
358
0
        if (is_char_element(elem)) {
359
0
            switch (rule[i + 1].type) {
360
0
                case LLAMA_GRETYPE_CHAR_ALT:
361
0
                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
362
0
                case LLAMA_GRETYPE_CHAR_ANY:
363
0
                    break;
364
0
                default:
365
0
                    fprintf(file, "] ");
366
0
            }
367
0
        }
368
0
    }
369
0
    fprintf(file, "\n");
370
0
}
371
372
//
373
// Regex utilities
374
//
375
376
0
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
377
0
    auto find_start_pos = [](const std::smatch & match) {
378
        // get from the first matched capturing group to the end of the string
379
0
        size_t start = std::string::npos;
380
0
        for (auto i = 1u; i < match.size(); i++) {
381
0
            if (match.length(i) > 0) {
382
0
                start = match.position(i);
383
0
                break;
384
0
            }
385
0
        }
386
0
        if (start == std::string::npos) {
387
0
            start = match.position(0);
388
0
        }
389
0
        return start;
390
0
    };
391
392
0
    if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
393
        // match against the entire input
394
0
        std::smatch match;
395
0
        if (std::regex_match(input, match, regex)) {
396
0
            return find_start_pos(match);
397
0
        }
398
0
    }
399
400
    // search anywhere
401
0
    std::smatch match;
402
0
    if (std::regex_search(input, match, regex)) {
403
0
        return find_start_pos(match);
404
0
    }
405
406
0
    return std::string::npos;
407
0
}
408
409
410
//
411
// implementation
412
//
413
414
0
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
415
0
    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
416
0
    auto result = symbol_ids.emplace(std::string(src, len), next_id);
417
0
    return result.first->second;
418
0
}
419
420
0
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
421
0
    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
422
0
    symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
423
0
    return next_id;
424
0
}
425
426
0
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
427
0
    if (rules.size() <= rule_id) {
428
0
        rules.resize(rule_id + 1);
429
0
    }
430
0
    rules[rule_id] = rule;
431
0
}
432
433
const char * llama_grammar_parser::parse_alternates(
434
        const char        * src,
435
        const std::string & rule_name,
436
        uint32_t            rule_id,
437
0
        bool                is_nested) {
438
0
    llama_grammar_rule rule;
439
0
    const char * pos = parse_sequence(src, rule_name, rule, is_nested);
440
0
    while (*pos == '|') {
441
0
        rule.push_back({LLAMA_GRETYPE_ALT, 0});
442
0
        pos = parse_space(pos + 1, true);
443
0
        pos = parse_sequence(pos, rule_name, rule, is_nested);
444
0
    }
445
0
    rule.push_back({LLAMA_GRETYPE_END, 0});
446
0
    add_rule(rule_id, rule);
447
0
    return pos;
448
0
}
449
450
const char * llama_grammar_parser::parse_sequence(
451
        const char         * src,
452
        const std::string  & rule_name,
453
        llama_grammar_rule & rule,
454
0
        bool               is_nested) {
455
0
    size_t last_sym_start = rule.size();
456
0
    const char * pos = src;
457
458
    // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
459
    // (though it's technically the same as -1 now)
460
0
    auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
461
0
        bool no_max = max_times == UINT64_MAX;
462
0
        if (last_sym_start == rule.size()) {
463
0
            throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
464
0
        }
465
466
        // apply transformation to previous symbol (last_sym_start to end) according to
467
        // the following rewrite rules:
468
        // S{m,n} --> S S S (m times) S'(n-m)
469
        //            S'(x)   ::= S S'(x-1) |
470
        //            (... n-m definitions of these S' rules ...)
471
        //            S'(1)   ::= S |
472
        // S{m,} -->  S S S (m times) S'
473
        //            S'     ::= S S' |
474
        // S*     --> S{0,}
475
        //        --> S'     ::= S S' |
476
        // S+     --> S{1,}
477
        //        --> S S'
478
        //            S'     ::= S S' |
479
        // S?     --> S{0,1}
480
        //        --> S'
481
        //            S'     ::= S |
482
483
0
        llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
484
0
        if (min_times == 0) {
485
0
            rule.resize(last_sym_start);
486
0
        } else {
487
            // Repeat the previous elements (min_times - 1) times
488
0
            for (uint64_t i = 1; i < min_times; i++) {
489
0
                rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
490
0
            }
491
0
        }
492
493
0
        uint32_t last_rec_rule_id = 0;
494
0
        auto n_opt = no_max ? 1 : max_times - min_times;
495
496
0
        llama_grammar_rule rec_rule(prev_rule);
497
0
        for (uint64_t i = 0; i < n_opt; i++) {
498
0
            rec_rule.resize(prev_rule.size());
499
0
            uint32_t rec_rule_id = generate_symbol_id( rule_name);
500
0
            if (i > 0 || no_max) {
501
0
                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
502
0
            }
503
0
            rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
504
0
            rec_rule.push_back({LLAMA_GRETYPE_END, 0});
505
0
            add_rule( rec_rule_id, rec_rule);
506
0
            last_rec_rule_id = rec_rule_id;
507
0
        }
508
0
        if (n_opt > 0) {
509
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
510
0
        }
511
0
    };
512
513
0
    while (*pos) {
514
0
        if (*pos == '"') { // literal string
515
0
            pos++;
516
0
            last_sym_start = rule.size();
517
0
            while (*pos != '"') {
518
0
                if (!*pos) {
519
0
                    throw std::runtime_error("unexpected end of input");
520
0
                }
521
0
                auto char_pair = parse_char(pos);
522
0
                     pos       = char_pair.second;
523
0
                rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
524
0
            }
525
0
            pos = parse_space(pos + 1, is_nested);
526
0
        } else if (*pos == '[') { // char range(s)
527
0
            pos++;
528
0
            enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
529
0
            if (*pos == '^') {
530
0
                pos++;
531
0
                start_type = LLAMA_GRETYPE_CHAR_NOT;
532
0
            }
533
0
            last_sym_start = rule.size();
534
0
            while (*pos != ']') {
535
0
                if (!*pos) {
536
0
                    throw std::runtime_error("unexpected end of input");
537
0
                }
538
0
                auto char_pair = parse_char(pos);
539
0
                     pos       = char_pair.second;
540
0
                enum llama_gretype type = last_sym_start < rule.size()
541
0
                    ? LLAMA_GRETYPE_CHAR_ALT
542
0
                    : start_type;
543
544
0
                rule.push_back({type, char_pair.first});
545
0
                if (pos[0] == '-' && pos[1] != ']') {
546
0
                    if (!pos[1]) {
547
0
                        throw std::runtime_error("unexpected end of input");
548
0
                    }
549
0
                    auto endchar_pair = parse_char(pos + 1);
550
0
                         pos          = endchar_pair.second;
551
0
                    rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
552
0
                }
553
0
            }
554
0
            pos = parse_space(pos + 1, is_nested);
555
0
        } else if (*pos == '<' || *pos == '!') { // token
556
0
            auto type = LLAMA_GRETYPE_TOKEN;
557
0
            if (*pos == '!') { // token inverse
558
0
                type = LLAMA_GRETYPE_TOKEN_NOT;
559
0
                pos++;
560
0
            }
561
0
            auto token_pair = parse_token(vocab, pos);
562
0
            const char * token_end  = token_pair.second;
563
0
            last_sym_start = rule.size();
564
0
            rule.push_back({type, token_pair.first});
565
0
            pos = parse_space(token_end, is_nested);
566
0
        } else if (is_word_char(*pos)) { // rule reference
567
0
            const char * name_end    = parse_name(pos);
568
0
            uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
569
0
            pos = parse_space(name_end, is_nested);
570
0
            last_sym_start = rule.size();
571
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
572
0
        } else if (*pos == '(') { // grouping
573
            // parse nested alternates into synthesized rule
574
0
            pos = parse_space(pos + 1, true);
575
0
            uint32_t sub_rule_id = generate_symbol_id(rule_name);
576
0
            pos = parse_alternates(pos, rule_name, sub_rule_id, true);
577
0
            last_sym_start = rule.size();
578
            // output reference to synthesized rule
579
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
580
0
            if (*pos != ')') {
581
0
                throw std::runtime_error(std::string("expecting ')' at ") + pos);
582
0
            }
583
0
            pos = parse_space(pos + 1, is_nested);
584
0
        } else if (*pos == '.') { // any char
585
0
            last_sym_start = rule.size();
586
0
            rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
587
0
            pos = parse_space(pos + 1, is_nested);
588
0
        } else if (*pos == '*') {
589
0
            pos = parse_space(pos + 1, is_nested);
590
0
            handle_repetitions(0, -1);
591
0
        } else if (*pos == '+') {
592
0
            pos = parse_space(pos + 1, is_nested);
593
0
            handle_repetitions(1, -1);
594
0
        } else if (*pos == '?') {
595
0
            pos = parse_space(pos + 1, is_nested);
596
0
            handle_repetitions(0, 1);
597
0
        } else if (*pos == '{') {
598
0
            pos = parse_space(pos + 1, is_nested);
599
600
0
            if (!is_digit_char(*pos)) {
601
0
                throw std::runtime_error(std::string("expecting an int at ") + pos);
602
0
            }
603
0
            const char * int_end = parse_int(pos);
604
0
            uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
605
0
            pos = parse_space(int_end, is_nested);
606
607
0
            uint64_t max_times = UINT64_MAX; // default: no max limit
608
609
0
            if (*pos == '}') {
610
0
                max_times = min_times;
611
0
                pos = parse_space(pos + 1, is_nested);
612
0
            } else if (*pos == ',') {
613
0
                pos = parse_space(pos + 1, is_nested);
614
615
0
                if (is_digit_char(*pos)) {
616
0
                    const char * int_end = parse_int(pos);
617
0
                    max_times = std::stoul(std::string(pos, int_end - pos));
618
0
                    pos = parse_space(int_end, is_nested);
619
0
                }
620
621
0
                if (*pos != '}') {
622
0
                    throw std::runtime_error(std::string("expecting '}' at ") + pos);
623
0
                }
624
0
                pos = parse_space(pos + 1, is_nested);
625
0
            } else {
626
0
                throw std::runtime_error(std::string("expecting ',' at ") + pos);
627
0
            }
628
0
            bool has_max = max_times != UINT64_MAX;
629
0
            if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
630
0
                throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
631
0
            }
632
0
            handle_repetitions(min_times, max_times);
633
0
        } else {
634
0
            break;
635
0
        }
636
0
    }
637
0
    return pos;
638
0
}
639
640
0
const char * llama_grammar_parser::parse_rule(const char * src) {
641
0
    const char * name_end = parse_name(src);
642
0
    const char * pos      = parse_space(name_end, false);
643
0
    size_t       name_len = name_end - src;
644
0
    uint32_t     rule_id  = get_symbol_id(src, name_len);
645
0
    const std::string name(src, name_len);
646
647
0
    if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
648
0
        throw std::runtime_error(std::string("expecting ::= at ") + pos);
649
0
    }
650
0
    pos = parse_space(pos + 3, true);
651
652
0
    pos = parse_alternates(pos, name, rule_id, false);
653
654
0
    if (*pos == '\r') {
655
0
        pos += pos[1] == '\n' ? 2 : 1;
656
0
    } else if (*pos == '\n') {
657
0
        pos++;
658
0
    } else if (*pos) {
659
0
        throw std::runtime_error(std::string("expecting newline or end at ") + pos);
660
0
    }
661
0
    return parse_space(pos, true);
662
0
}
663
664
0
bool llama_grammar_parser::parse(const char * src) {
665
0
    try {
666
0
        const char * pos = parse_space(src, true);
667
0
        while (*pos) {
668
0
            pos = parse_rule(pos);
669
0
        }
670
        // Validate the state to ensure that all rules are defined
671
0
        for (const auto & rule : rules) {
672
0
            if (rule.empty()) {
673
0
                throw std::runtime_error("Undefined rule");
674
0
            }
675
0
            for (const auto & elem : rule) {
676
0
                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
677
                    // Ensure that the rule at that location exists
678
0
                    if (elem.value >= rules.size() || rules[elem.value].empty()) {
679
                        // Get the name of the rule that is missing
680
0
                        for (const auto & kv : symbol_ids) {
681
0
                            if (kv.second == elem.value) {
682
0
                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
683
0
                            }
684
0
                        }
685
0
                    }
686
0
                }
687
0
            }
688
0
        }
689
0
    } catch (const std::exception & err) {
690
0
        fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
691
0
        rules.clear();
692
0
        return false;
693
0
    }
694
695
0
    return true;
696
0
}
697
698
0
void llama_grammar_parser::print(FILE * file) {
699
0
    try {
700
0
        std::map<uint32_t, std::string> symbol_id_names;
701
0
        for (const auto & kv : symbol_ids) {
702
0
            symbol_id_names[kv.second] = kv.first;
703
0
        }
704
0
        for (size_t i = 0, end = rules.size(); i < end; i++) {
705
            // fprintf(file, "%zu: ", i);
706
            // print_rule_binary(file, rules[i]);
707
0
            print_rule(file, uint32_t(i), rules[i], symbol_id_names);
708
            // fprintf(file, "\n");
709
0
        }
710
0
    } catch (const std::exception & err) {
711
0
        fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
712
0
    }
713
0
}
714
715
0
llama_grammar_stack llama_grammar_parser::c_rules() const {
716
0
    llama_grammar_stack ret;
717
0
    ret.reserve(rules.size());
718
0
    for (const auto & rule : rules) {
719
0
        ret.push_back(rule.data());
720
0
    }
721
0
    return ret;
722
0
}
723
724
// returns true iff pos points to the end of one of the definitions of a rule
725
0
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
726
0
    switch (pos->type) {
727
0
        case LLAMA_GRETYPE_END: return true;  // NOLINT
728
0
        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
729
0
        default:                return false;
730
0
    }
731
0
}
732
733
// returns true iff chr satisfies the char range at pos (regular or inverse range)
734
// asserts that pos is pointing to a char range element
735
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
736
        const llama_grammar_element * pos,
737
0
        const uint32_t                chr) {
738
0
    bool found            = false;
739
0
    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
740
741
0
    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
742
743
0
    do {
744
0
        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
745
            // inclusive range, e.g. [a-z]
746
0
            found = found || (pos->value <= chr && chr <= pos[1].value);
747
0
            pos += 2;
748
0
        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
749
            // Any character matches "."
750
0
            found = true;
751
0
            pos += 1;
752
0
        } else {
753
            // exact char match, e.g. [a] or "a"
754
0
            found = found || pos->value == chr;
755
0
            pos += 1;
756
0
        }
757
0
    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
758
759
0
    return std::make_pair(found == is_positive_char, pos);
760
0
}
761
762
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
763
// range at pos (regular or inverse range)
764
// asserts that pos is pointing to a char range element
765
static bool llama_grammar_match_partial_char(
766
        const llama_grammar_element * pos,
767
0
        const llama_partial_utf8      partial_utf8) {
768
0
    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
769
0
    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
770
771
0
    uint32_t partial_value = partial_utf8.value;
772
0
    int      n_remain      = partial_utf8.n_remain;
773
774
    // invalid sequence or 7-bit char split across 2 bytes (overlong)
775
0
    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
776
0
        return false;
777
0
    }
778
779
    // range of possible code points this partial UTF-8 sequence could complete to
780
0
    uint32_t low  = partial_value << (n_remain * 6);
781
0
    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
782
783
0
    if (low == 0) {
784
0
        if (n_remain == 2) {
785
0
            low = 1 << 11;
786
0
        } else if (n_remain == 3) {
787
0
            low = 1 << 16;
788
0
        }
789
0
    }
790
791
0
    do {
792
0
        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
793
            // inclusive range, e.g. [a-z]
794
0
            if (pos->value <= high && low <= pos[1].value) {
795
0
                return is_positive_char;
796
0
            }
797
0
            pos += 2;
798
0
        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
799
            // Any character matches "."
800
0
            return true;
801
0
        } else {
802
            // exact char match, e.g. [a] or "a"
803
0
            if (low <= pos->value && pos->value <= high) {
804
0
                return is_positive_char;
805
0
            }
806
0
            pos += 1;
807
0
        }
808
0
    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
809
810
0
    return !is_positive_char;
811
0
}
812
813
// returns true iff token matches the rule at pos (regular or inverse)
814
// asserts that pos is pointing to a token element
815
static bool llama_grammar_match_token(
816
    const llama_grammar_element * pos,
817
0
    const llama_token             token) {
818
0
    GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
819
0
    if (pos->type == LLAMA_GRETYPE_TOKEN) {
820
0
        return pos->value == static_cast<uint32_t>(token);
821
0
    }
822
0
    if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
823
0
        return pos->value != static_cast<uint32_t>(token);
824
0
    }
825
0
    return false;
826
0
}
827
828
// transforms a grammar pushdown stack into N possible stacks, all ending
829
// at a character range (terminal element)
830
static void llama_grammar_advance_stack(
831
        const llama_grammar_rules  & rules,
832
        const llama_grammar_stack  & stack,
833
0
              llama_grammar_stacks & new_stacks) {
834
0
    if (stack.empty()) {
835
0
        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
836
0
            new_stacks.emplace_back(stack);
837
0
        }
838
0
        return;
839
0
    }
840
841
0
    const llama_grammar_element * pos = stack.back();
842
843
0
    switch (pos->type) {
844
0
        case LLAMA_GRETYPE_RULE_REF: {
845
0
            const size_t                  rule_id = static_cast<size_t>(pos->value);
846
0
            const llama_grammar_element * subpos  = rules[rule_id].data();
847
0
            do {
848
                // init new stack without the top (pos)
849
0
                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
850
0
                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
851
                    // if this rule ref is followed by another element, add that to stack
852
0
                    new_stack.push_back(pos + 1);
853
0
                }
854
0
                if (!llama_grammar_is_end_of_sequence(subpos)) {
855
                    // if alternate is nonempty, add to stack
856
0
                    new_stack.push_back(subpos);
857
0
                }
858
0
                llama_grammar_advance_stack(rules, new_stack, new_stacks);
859
0
                while (!llama_grammar_is_end_of_sequence(subpos)) {
860
                    // scan to end of alternate def
861
0
                    subpos++;
862
0
                }
863
0
                if (subpos->type == LLAMA_GRETYPE_ALT) {
864
                    // there's another alternate def of this rule to process
865
0
                    subpos++;
866
0
                } else {
867
0
                    break;
868
0
                }
869
0
            } while (true);
870
0
            break;
871
0
        }
872
0
        case LLAMA_GRETYPE_CHAR:
873
0
        case LLAMA_GRETYPE_CHAR_NOT:
874
0
        case LLAMA_GRETYPE_CHAR_ANY:
875
0
        case LLAMA_GRETYPE_TOKEN:
876
0
        case LLAMA_GRETYPE_TOKEN_NOT:
877
0
            if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
878
                // only add the stack if it's not a duplicate of one we already have
879
0
                new_stacks.emplace_back(stack);
880
0
            }
881
0
            break;
882
0
        default:
883
            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
884
            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
885
            // those
886
0
            GGML_ABORT("fatal error");
887
0
    }
888
0
}
889
890
static llama_grammar_candidates llama_grammar_reject_candidates(
891
        const llama_grammar_rules      & rules,
892
        const llama_grammar_stacks     & stacks,
893
0
        const llama_grammar_candidates & candidates) {
894
0
    GGML_ASSERT(!stacks.empty()); // REVIEW
895
896
0
    if (candidates.empty()) {
897
0
        return {};
898
0
    }
899
900
0
    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
901
902
0
    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
903
0
        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
904
0
    }
905
906
0
    return rejects;
907
0
}
908
909
static bool llama_grammar_detect_left_recursion(
910
        const llama_grammar_rules & rules,
911
        size_t rule_index,
912
        std::vector<bool> * rules_visited,
913
        std::vector<bool> * rules_in_progress,
914
0
        std::vector<bool> * rules_may_be_empty) {
915
0
    if ((*rules_in_progress)[rule_index]) {
916
0
        return true;
917
0
    }
918
919
0
    (*rules_in_progress)[rule_index] = true;
920
921
0
    const llama_grammar_rule & rule = rules[rule_index];
922
923
    // First check if the rule might produce the empty string. This could be done combined with the second
924
    // step but it's more readable as two steps.
925
0
    bool at_rule_start = true;
926
0
    for (size_t i = 0; i < rule.size(); i++) {
927
0
        if (llama_grammar_is_end_of_sequence(&rule[i])) {
928
0
            if (at_rule_start) {
929
0
                (*rules_may_be_empty)[rule_index] = true;
930
0
                break;
931
0
            }
932
0
            at_rule_start = true;
933
0
        } else {
934
0
            at_rule_start = false;
935
0
        }
936
0
    }
937
938
    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
939
    // be empty)
940
0
    bool recurse_into_nonterminal = true;
941
0
    for (size_t i = 0; i < rule.size(); i++) {
942
0
        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
943
0
            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
944
0
                return true;
945
0
            }
946
0
            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
947
0
                recurse_into_nonterminal = false;
948
0
            }
949
0
        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
950
0
            recurse_into_nonterminal = true;
951
0
        } else {
952
0
            recurse_into_nonterminal = false;
953
0
        }
954
0
    }
955
956
0
    (*rules_in_progress)[rule_index] = false;
957
0
    (*rules_visited)[rule_index] = true;
958
959
0
    return false;
960
0
}
961
962
0
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
963
0
    return grammar->rules;
964
0
}
965
966
0
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
967
0
    return grammar->stacks;
968
0
}
969
970
static void llama_grammar_accept_chr(
971
        struct llama_grammar       & grammar,
972
        const llama_grammar_stack  & stack,
973
              uint32_t               chr,
974
0
              llama_grammar_stacks & new_stacks) {
975
0
    if (stack.empty()) {
976
0
        return;
977
0
    }
978
979
0
    const llama_grammar_element * pos = stack.back();
980
981
    // ignore if this turns into a token
982
0
    if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
983
0
        return;
984
0
    }
985
986
0
    auto match = llama_grammar_match_char(pos, chr);
987
0
    if (match.first) {
988
0
        llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
989
0
        if (!llama_grammar_is_end_of_sequence(match.second)) {
990
0
            new_stack.push_back(match.second);
991
0
        }
992
0
        llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
993
0
    }
994
0
}
995
996
0
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
997
0
    llama_grammar_stacks stacks_new;
998
0
    stacks_new.reserve(grammar->stacks.size());
999
1000
0
    for (const auto & stack : grammar->stacks) {
1001
0
        llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
1002
0
    }
1003
1004
0
    grammar->stacks = std::move(stacks_new);
1005
0
}
1006
1007
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
1008
        const llama_grammar_rules      & rules,
1009
        const llama_grammar_stack      & stack,
1010
0
        const llama_grammar_candidates & candidates) {
1011
1012
0
    llama_grammar_candidates rejects;
1013
0
    rejects.reserve(candidates.size());
1014
1015
0
    if (stack.empty()) {
1016
0
        for (const auto & tok : candidates) {
1017
0
            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
1018
0
                rejects.push_back(tok);
1019
0
            }
1020
0
        }
1021
0
        return rejects;
1022
0
    }
1023
1024
0
    const llama_grammar_element * stack_pos = stack.back();
1025
1026
    // if the top of the stack is a token rule, then we only need to check the token id
1027
0
    if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1028
0
        for (const auto & tok : candidates) {
1029
0
            if (*tok.code_points == 0) {
1030
                // reached the end of a token consumed by char rules, reject iff it ended
1031
                // in a partial response
1032
0
                if (tok.partial_utf8.n_remain != 0) {
1033
0
                    rejects.push_back(tok);
1034
0
                }
1035
0
            } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
1036
0
                rejects.push_back(tok);
1037
0
            }
1038
0
        }
1039
0
        return rejects;
1040
0
    }
1041
1042
0
    llama_grammar_candidates next_candidates;
1043
0
    next_candidates.reserve(candidates.size());
1044
1045
0
    for (const auto & tok : candidates) {
1046
0
        if (*tok.code_points == 0) {
1047
            // reached end of full codepoints in token, reject iff it ended in a partial sequence
1048
            // that cannot satisfy this position in grammar
1049
0
            if (tok.partial_utf8.n_remain != 0 &&
1050
0
                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
1051
0
                rejects.push_back(tok);
1052
0
            }
1053
0
        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
1054
0
            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
1055
0
        } else {
1056
0
            rejects.push_back(tok);
1057
0
        }
1058
0
    }
1059
1060
0
    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
1061
1062
    // update top of stack to next element, if any
1063
0
    llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
1064
0
    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
1065
0
        stack_after.push_back(stack_pos_after);
1066
0
    }
1067
0
    llama_grammar_stacks next_stacks;
1068
0
    llama_grammar_advance_stack(rules, stack_after, next_stacks);
1069
1070
0
    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
1071
0
    for (const auto & tok : next_rejects) {
1072
0
        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
1073
0
    }
1074
1075
0
    return rejects;
1076
0
}
1077
1078
////////////////////
1079
1080
struct llama_grammar * llama_grammar_init_impl(
1081
        const struct llama_vocab * vocab,
1082
        const llama_grammar_element ** rules,
1083
        size_t n_rules,
1084
0
        size_t start_rule_index) {
1085
0
    const llama_grammar_element * pos;
1086
1087
    // copy rule definitions into vectors
1088
0
    llama_grammar_rules vec_rules(n_rules);
1089
0
    for (size_t i = 0; i < n_rules; i++) {
1090
0
        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1091
0
            vec_rules[i].push_back(*pos);
1092
0
        }
1093
0
        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1094
0
    }
1095
1096
    // Check for left recursion
1097
0
    std::vector<bool> rules_visited(n_rules);
1098
0
    std::vector<bool> rules_in_progress(n_rules);
1099
0
    std::vector<bool> rules_may_be_empty(n_rules);
1100
0
    for (size_t i = 0; i < n_rules; i++) {
1101
0
        if (rules_visited[i]) {
1102
0
            continue;
1103
0
        }
1104
0
        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1105
0
            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1106
0
            return nullptr;
1107
0
        }
1108
0
    }
1109
1110
    // loop over alternates of start rule to build initial stacks
1111
0
    llama_grammar_stacks stacks;
1112
0
    pos = vec_rules[start_rule_index].data();
1113
0
    do {
1114
0
        llama_grammar_stack stack;
1115
0
        if (!llama_grammar_is_end_of_sequence(pos)) {
1116
            // if alternate is nonempty, add to stack
1117
0
            stack.push_back(pos);
1118
0
        }
1119
0
        llama_grammar_advance_stack(vec_rules, stack, stacks);
1120
0
        while (!llama_grammar_is_end_of_sequence(pos)) {
1121
            // scan to end of alternate def
1122
0
            pos++;
1123
0
        }
1124
0
        if (pos->type == LLAMA_GRETYPE_ALT) {
1125
            // there's another alternate def of this rule to process
1126
0
            pos++;
1127
0
        } else {
1128
0
            break;
1129
0
        }
1130
0
    } while (true);
1131
1132
    // Important: vec_rules has to be moved here, not copied, because stacks contains
1133
    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1134
    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1135
0
    return new llama_grammar {
1136
0
        vocab,
1137
0
        std::move(vec_rules),
1138
0
        std::move(stacks),
1139
0
        /* .partial_utf8 = */             {},
1140
0
        /* .lazy = */                     false,
1141
0
        /* .awaiting_trigger = */         false,
1142
0
        /* .trigger_buffer = */           "",
1143
0
        /* .trigger_buffer_positions = */ {},
1144
0
        /* .trigger_tokens = */           {},
1145
0
        /* .trigger_patterns = */         {},
1146
0
    };
1147
0
}
1148
1149
struct llama_grammar * llama_grammar_init_impl(
1150
        const struct llama_vocab * vocab,
1151
                      const char * grammar_str,
1152
                      const char * grammar_root,
1153
                              bool lazy,
1154
                     const char ** trigger_patterns,
1155
                            size_t num_trigger_patterns,
1156
               const llama_token * trigger_tokens,
1157
0
                            size_t num_trigger_tokens) {
1158
0
    llama_grammar_parser parser(vocab);
1159
1160
    // if there is a grammar, parse it
1161
    // rules will be empty (default) if there are parse errors
1162
0
    if (!parser.parse(grammar_str) || parser.rules.empty()) {
1163
0
        fprintf(stderr, "%s: failed to parse grammar\n", __func__);
1164
0
        return nullptr;
1165
0
    }
1166
1167
    // Ensure that there is a "root" node.
1168
0
    if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
1169
0
        fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
1170
0
        return nullptr;
1171
0
    }
1172
1173
0
    std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
1174
1175
0
    const size_t n_rules = grammar_rules.size();
1176
0
    const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
1177
1178
0
    const llama_grammar_element * pos;
1179
1180
    // copy rule definitions into vectors
1181
0
    llama_grammar_rules vec_rules(n_rules);
1182
0
    for (size_t i = 0; i < n_rules; i++) {
1183
0
        for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1184
0
            vec_rules[i].push_back(*pos);
1185
0
        }
1186
0
        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1187
0
    }
1188
1189
    // Check for left recursion
1190
0
    std::vector<bool> rules_visited(n_rules);
1191
0
    std::vector<bool> rules_in_progress(n_rules);
1192
0
    std::vector<bool> rules_may_be_empty(n_rules);
1193
0
    for (size_t i = 0; i < n_rules; i++) {
1194
0
        if (rules_visited[i]) {
1195
0
            continue;
1196
0
        }
1197
0
        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1198
0
            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1199
0
            return nullptr;
1200
0
        }
1201
0
    }
1202
1203
    // loop over alternates of start rule to build initial stacks
1204
0
    llama_grammar_stacks stacks;
1205
0
    pos = vec_rules[start_rule_index].data();
1206
0
    do {
1207
0
        llama_grammar_stack stack;
1208
0
        if (!llama_grammar_is_end_of_sequence(pos)) {
1209
            // if alternate is nonempty, add to stack
1210
0
            stack.push_back(pos);
1211
0
        }
1212
0
        llama_grammar_advance_stack(vec_rules, stack, stacks);
1213
0
        while (!llama_grammar_is_end_of_sequence(pos)) {
1214
            // scan to end of alternate def
1215
0
            pos++;
1216
0
        }
1217
0
        if (pos->type == LLAMA_GRETYPE_ALT) {
1218
            // there's another alternate def of this rule to process
1219
0
            pos++;
1220
0
        } else {
1221
0
            break;
1222
0
        }
1223
0
    } while (true);
1224
1225
0
    std::vector<llama_token>    vec_trigger_tokens;
1226
0
    std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
1227
0
    for (size_t i = 0; i < num_trigger_tokens; i++) {
1228
0
        GGML_ASSERT(trigger_tokens != nullptr);
1229
0
        vec_trigger_tokens.push_back(trigger_tokens[i]);
1230
0
    }
1231
0
    for (size_t i = 0; i < num_trigger_patterns; i++) {
1232
0
        GGML_ASSERT(trigger_patterns != nullptr);
1233
0
        auto & trigger = vec_trigger_patterns.emplace_back();
1234
0
        trigger.pattern = trigger_patterns[i];
1235
0
        trigger.regex = std::regex(trigger.pattern);
1236
0
    }
1237
1238
    // Important: vec_rules has to be moved here, not copied, because stacks contains
1239
    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1240
    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1241
0
    return new llama_grammar {
1242
0
        vocab,
1243
0
        std::move(vec_rules),
1244
0
        std::move(stacks),
1245
0
        /* .partial_utf8 = */             {},
1246
0
        /* .lazy = */                     lazy,
1247
0
        /* .awaiting_trigger = */         lazy,
1248
0
        /* .trigger_buffer = */           "",
1249
0
        /* .trigger_buffer_positions = */ {},
1250
0
        std::move(vec_trigger_tokens),
1251
0
        std::move(vec_trigger_patterns),
1252
0
    };
1253
0
}
1254
1255
0
void llama_grammar_free_impl(struct llama_grammar * grammar) {
1256
0
    if (grammar == nullptr) {
1257
0
        return;
1258
0
    }
1259
1260
0
    delete grammar;
1261
0
}
1262
1263
0
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1264
0
    auto * result = new llama_grammar {
1265
0
        grammar.vocab,
1266
0
        grammar.rules,
1267
0
        grammar.stacks,
1268
0
        grammar.partial_utf8,
1269
0
        grammar.lazy,
1270
0
        grammar.awaiting_trigger,
1271
0
        grammar.trigger_buffer,
1272
0
        grammar.trigger_buffer_positions,
1273
0
        grammar.trigger_tokens,
1274
0
        grammar.trigger_patterns,
1275
0
    };
1276
1277
    // redirect elements in stacks to point to new rules
1278
0
    for (size_t is = 0; is < result->stacks.size(); is++) {
1279
0
        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
1280
0
            for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1281
0
                for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1282
0
                    if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1283
0
                        result->stacks[is][ie] =  &result->rules[ir0][ir1];
1284
0
                    }
1285
0
                }
1286
0
            }
1287
0
        }
1288
0
    }
1289
1290
0
    return result;
1291
0
}
1292
1293
0
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1294
0
    GGML_ASSERT(grammar.vocab != nullptr);
1295
1296
0
    if (grammar.awaiting_trigger) {
1297
0
        return;
1298
0
    }
1299
1300
0
    bool allow_eog = false;
1301
0
    for (const auto & stack : grammar.stacks) {
1302
0
        if (stack.empty()) {
1303
0
            allow_eog = true;
1304
0
            break;
1305
0
        }
1306
0
    }
1307
1308
0
    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1309
0
    candidates_decoded.reserve(cur_p->size);
1310
1311
0
    llama_grammar_candidates candidates_grammar;
1312
0
    candidates_grammar.reserve(cur_p->size);
1313
1314
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1315
0
        const llama_token id      = cur_p->data[i].id;
1316
0
        const std::string & piece = grammar.vocab->token_to_piece(id);
1317
1318
0
        if (grammar.vocab->is_eog(id)) {
1319
0
            if (!allow_eog) {
1320
0
                cur_p->data[i].logit = -INFINITY;
1321
0
            }
1322
0
        } else if (piece.empty() || piece[0] == 0) {
1323
0
            cur_p->data[i].logit = -INFINITY;
1324
0
        } else {
1325
0
            candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1326
0
            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
1327
0
        }
1328
0
    }
1329
1330
0
    const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
1331
0
    for (const auto & reject : rejects) {
1332
0
        cur_p->data[reject.index].logit = -INFINITY;
1333
0
    }
1334
0
}
1335
1336
0
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1337
0
    GGML_ASSERT(grammar.vocab != nullptr);
1338
1339
0
    const auto & piece = grammar.vocab->token_to_piece(token);
1340
1341
0
    if (grammar.awaiting_trigger) {
1342
0
        if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1343
0
            grammar.awaiting_trigger = false;
1344
0
            grammar.trigger_buffer.clear();
1345
0
            llama_grammar_accept_token(grammar, token, piece);
1346
0
            LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1347
0
            return;
1348
0
        } else {
1349
0
            auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
1350
0
            grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
1351
0
            grammar.trigger_buffer += piece;
1352
1353
0
            for (const auto & trigger_pattern : grammar.trigger_patterns) {
1354
0
                auto start = trigger_pattern.find(grammar.trigger_buffer);
1355
0
                if (start != std::string::npos) {
1356
0
                    grammar.awaiting_trigger = false;
1357
1358
                    // replay tokens that overlap with [start, end)
1359
0
                    for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
1360
0
                        auto [tok_start, tok_end] = tok_pos;
1361
0
                        if (tok_end <= start) {
1362
0
                            continue;
1363
0
                        }
1364
1365
0
                        size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
1366
0
                        size_t piece_len = tok_end - piece_start;
1367
0
                        auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
1368
0
                        llama_grammar_accept_token(grammar, tok, tok_piece);
1369
0
                    }
1370
1371
0
                    auto constrained_str = grammar.trigger_buffer.substr(start);
1372
0
                    grammar.trigger_buffer.clear();
1373
0
                    grammar.trigger_buffer_positions.clear();
1374
0
                    LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1375
0
                    return;
1376
0
                }
1377
0
            }
1378
0
            LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
1379
0
            return;
1380
0
        }
1381
0
    }
1382
1383
0
    if (grammar.vocab->is_eog(token)) {
1384
0
        for (const auto & stack : grammar.stacks) {
1385
0
            if (stack.empty()) {
1386
0
                return;
1387
0
            }
1388
0
        }
1389
0
        GGML_ABORT("fatal error");
1390
0
    }
1391
1392
0
    llama_grammar_accept_token(grammar, token, piece);
1393
0
}
1394
1395
0
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
1396
    // Note terminating 0 in decoded string
1397
0
    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1398
0
    const auto & code_points = decoded.first;
1399
1400
0
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1401
0
        llama_grammar_accept(&grammar, *it);
1402
0
    }
1403
1404
0
    grammar.partial_utf8 = decoded.second;
1405
0
    if (grammar.stacks.empty()) {
1406
0
        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1407
0
    }
1408
0
}
1409
1410
0
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
1411
    // Note terminating 0 in decoded string
1412
0
    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1413
0
    const auto & code_points = decoded.first;
1414
1415
0
    llama_grammar_stacks stacks_new;
1416
0
    stacks_new.reserve(grammar.stacks.size());
1417
1418
0
    for (const auto & stack : grammar.stacks) {
1419
0
        if (stack.empty()) {
1420
0
            continue;
1421
0
        }
1422
1423
0
        const llama_grammar_element * pos = stack.back();
1424
1425
0
        if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1426
0
            if (llama_grammar_match_token(pos, token)) {
1427
0
                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1428
0
                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
1429
0
                    new_stack.push_back(pos + 1);
1430
0
                }
1431
0
                llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
1432
0
            }
1433
0
        } else {
1434
0
            llama_grammar_stacks current_stacks = {stack};
1435
1436
0
            for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1437
0
                llama_grammar_stacks next_stacks;
1438
1439
0
                for (const auto & cur_stack : current_stacks) {
1440
0
                    llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
1441
0
                }
1442
1443
0
                current_stacks = std::move(next_stacks);
1444
0
                if (current_stacks.empty()) {
1445
0
                    break;
1446
0
                }
1447
0
            }
1448
1449
0
            for (auto & surviving_stack : current_stacks) {
1450
0
                if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
1451
0
                    stacks_new.emplace_back(surviving_stack);
1452
0
                }
1453
0
            }
1454
0
        }
1455
0
    }
1456
1457
0
    grammar.stacks = std::move(stacks_new);
1458
0
    grammar.partial_utf8 = decoded.second;
1459
1460
0
    if (grammar.stacks.empty()) {
1461
0
        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
1462
0
    }
1463
0
}
1464