Coverage Report

Created: 2026-04-12 06:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-grammar.cpp
Line
Count
Source
1
#include "llama-grammar.h"
2
3
#include "llama-impl.h"
4
#include "llama-vocab.h"
5
#include "llama-sampler.h"
6
7
#include <cmath>
8
#include <algorithm>
9
#include <cstdint>
10
#include <set>
11
#include <stdexcept>
12
13
0
#define MAX_REPETITION_THRESHOLD 2000
14
//
15
// helpers
16
//
17
18
// NOTE: assumes valid utf8 (but checks for overrun)
19
0
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
20
0
    static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
21
0
    uint8_t  first_byte = static_cast<uint8_t>(*src);
22
0
    uint8_t  highbits   = first_byte >> 4;
23
0
    int      len        = lookup[highbits];
24
0
    uint8_t  mask       = (1 << (8 - len)) - 1;
25
0
    uint32_t value      = first_byte & mask;
26
0
    const char * end    = src + len; // may overrun!
27
0
    const char * pos    = src + 1;
28
0
    for ( ; pos < end && *pos; pos++) {
29
0
        value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
30
0
    }
31
0
    return std::make_pair(value, pos);
32
0
}
33
34
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
35
        const std::string & src,
36
0
        llama_partial_utf8 partial_start) {
37
0
    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
38
0
    const char          * pos      = src.c_str();
39
0
    std::vector<uint32_t> code_points;
40
41
    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
42
0
    code_points.reserve(src.size() + 1);
43
0
    uint32_t value    = partial_start.value;
44
0
    int      n_remain = partial_start.n_remain;
45
46
    // continue previous decode, if applicable
47
0
    while (*pos != 0 && n_remain > 0) {
48
0
        uint8_t next_byte = static_cast<uint8_t>(*pos);
49
0
        if ((next_byte >> 6) != 2) {
50
            // invalid sequence, abort
51
0
            code_points.push_back(0);
52
0
            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
53
0
        }
54
0
        value = (value << 6) + (next_byte & 0x3F);
55
0
        ++pos;
56
0
        --n_remain;
57
0
    }
58
59
0
    if (partial_start.n_remain > 0 && n_remain == 0) {
60
0
        code_points.push_back(value);
61
0
    }
62
63
    // decode any subsequent utf-8 sequences, which may end in an incomplete one
64
0
    while (*pos != 0) {
65
0
        uint8_t first_byte = static_cast<uint8_t>(*pos);
66
0
        uint8_t highbits   = first_byte >> 4;
67
0
        n_remain   = lookup[highbits] - 1;
68
69
0
        if (n_remain < 0) {
70
            // invalid sequence, abort
71
0
            code_points.clear();
72
0
            code_points.push_back(0);
73
0
            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
74
0
        }
75
76
0
        uint8_t mask  = (1 << (7 - n_remain)) - 1;
77
0
        value = first_byte & mask;
78
79
0
        ++pos;
80
0
        while (*pos != 0 && n_remain > 0) {
81
0
            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
82
0
            ++pos;
83
0
            --n_remain;
84
0
        }
85
0
        if (n_remain == 0) {
86
0
            code_points.push_back(value);
87
0
        }
88
0
    }
89
0
    code_points.push_back(0);
90
91
0
    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
92
0
}
93
94
0
static bool is_digit_char(char c) {
95
0
    return '0' <= c && c <= '9';
96
0
}
97
98
0
static bool is_word_char(char c) {
99
0
    return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
100
0
}
101
102
0
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
103
0
    const char * pos   = src;
104
0
    const char * end   = src + size;
105
0
    uint32_t     value = 0;
106
0
    for ( ; pos < end && *pos; pos++) {
107
0
        value <<= 4;
108
0
        char c = *pos;
109
0
        if ('a' <= c && c <= 'f') {
110
0
            value += c - 'a' + 10;
111
0
        } else if ('A' <= c && c <= 'F') {
112
0
            value += c - 'A' + 10;
113
0
        } else if ('0' <= c && c <= '9') {
114
0
            value += c - '0';
115
0
        } else {
116
0
            break;
117
0
        }
118
0
    }
119
0
    if (pos != end) {
120
0
        throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
121
0
    }
122
0
    return std::make_pair(value, pos);
123
0
}
124
125
0
static const char * parse_space(const char * src, bool newline_ok) {
126
0
    const char * pos = src;
127
0
    while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
128
0
            (newline_ok && (*pos == '\r' || *pos == '\n'))) {
129
0
        if (*pos == '#') {
130
0
            while (*pos && *pos != '\r' && *pos != '\n') {
131
0
                pos++;
132
0
            }
133
0
        } else {
134
0
            pos++;
135
0
        }
136
0
    }
137
0
    return pos;
138
0
}
139
140
0
static const char * parse_name(const char * src) {
141
0
    const char * pos = src;
142
0
    while (is_word_char(*pos)) {
143
0
        pos++;
144
0
    }
145
0
    if (pos == src) {
146
0
        throw std::runtime_error(std::string("expecting name at ") + src);
147
0
    }
148
0
    return pos;
149
0
}
150
151
0
static const char * parse_int(const char * src) {
152
0
    const char * pos = src;
153
0
    while (is_digit_char(*pos)) {
154
0
        pos++;
155
0
    }
156
0
    if (pos == src) {
157
0
        throw std::runtime_error(std::string("expecting integer at ") + src);
158
0
    }
159
0
    return pos;
160
0
}
161
162
0
static std::pair<uint32_t, const char *> parse_char(const char * src) {
163
0
    if (*src == '\\') {
164
0
        switch (src[1]) {
165
0
            case 'x': return parse_hex(src + 2, 2);
166
0
            case 'u': return parse_hex(src + 2, 4);
167
0
            case 'U': return parse_hex(src + 2, 8);
168
0
            case 't': return std::make_pair('\t', src + 2);
169
0
            case 'r': return std::make_pair('\r', src + 2);
170
0
            case 'n': return std::make_pair('\n', src + 2);
171
0
            case '\\':
172
0
            case '"':
173
0
            case '[':
174
0
            case ']':
175
0
                      return std::make_pair(src[1], src + 2);
176
0
            default:
177
0
                      throw std::runtime_error(std::string("unknown escape at ") + src);
178
0
        }
179
0
    } else if (*src) {
180
0
        return decode_utf8(src);
181
0
    }
182
0
    throw std::runtime_error("unexpected end of input");
183
0
}
184
185
0
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
186
0
    const char * pos = src;
187
0
    if (*pos != '<') {
188
0
        throw std::runtime_error(std::string("expecting '<' at ") + pos);
189
0
    }
190
0
    pos++;
191
192
    // Parse <[id]>
193
0
    if (*pos == '[') {
194
0
        pos++;
195
0
        const char * int_end = parse_int(pos);
196
0
        uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
197
0
        pos = int_end;
198
0
        if (*pos != ']') {
199
0
            throw std::runtime_error(std::string("expecting ']' at ") + pos);
200
0
        }
201
0
        pos++;
202
0
        if (*pos != '>') {
203
0
            throw std::runtime_error(std::string("expecting '>' at ") + pos);
204
0
        }
205
0
        pos++;
206
0
        return std::make_pair(token_id, pos);
207
0
    }
208
209
0
    if (vocab == nullptr) {
210
0
        throw std::runtime_error(std::string("no vocab to parse token at ") + src);
211
0
    }
212
213
    // Parse <token> and tokenize to obtain the token id
214
0
    while (*pos != 0 && *pos != '>') {
215
0
        pos++;
216
0
    }
217
0
    if (*pos != '>') {
218
0
        throw std::runtime_error(std::string("expecting '>' at ") + pos);
219
0
    }
220
0
    pos++;
221
222
0
    llama_token tokens[2];
223
0
    int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
224
0
    if (n_tokens != 1) {
225
        // must tokenize to exactly 1 token
226
0
        throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
227
0
    }
228
0
    return std::make_pair(tokens[0], pos);
229
0
}
230
231
0
static void print_grammar_char(FILE * file, uint32_t c) {
232
0
    if (0x20 <= c && c <= 0x7f) {
233
0
        fprintf(file, "%c", static_cast<char>(c));
234
0
    } else {
235
        // cop out of encoding UTF-8
236
0
        fprintf(file, "<U+%04X>", c);
237
0
    }
238
0
}
239
240
0
static bool is_char_element(llama_grammar_element elem) {
241
0
    switch (elem.type) {
242
0
        case LLAMA_GRETYPE_CHAR:           return true;
243
0
        case LLAMA_GRETYPE_CHAR_NOT:       return true;
244
0
        case LLAMA_GRETYPE_CHAR_ALT:       return true;
245
0
        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
246
0
        case LLAMA_GRETYPE_CHAR_ANY:       return true;
247
0
        default:                           return false;
248
0
    }
249
0
}
250
251
0
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
252
0
    for (auto elem : rule) {
253
0
        switch (elem.type) {
254
0
            case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
255
0
            case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
256
0
            case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
257
0
            case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
258
0
            case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
259
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
260
0
            case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
261
0
            case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
262
0
            case LLAMA_GRETYPE_TOKEN:          fprintf(file, "TOKEN");          break;
263
0
            case LLAMA_GRETYPE_TOKEN_NOT:      fprintf(file, "TOKEN_NOT");      break;
264
0
        }
265
0
        switch (elem.type) {
266
0
            case LLAMA_GRETYPE_END:
267
0
            case LLAMA_GRETYPE_ALT:
268
0
            case LLAMA_GRETYPE_RULE_REF:
269
0
                fprintf(file, "(%u) ", elem.value);
270
0
                break;
271
0
            case LLAMA_GRETYPE_CHAR:
272
0
            case LLAMA_GRETYPE_CHAR_NOT:
273
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
274
0
            case LLAMA_GRETYPE_CHAR_ALT:
275
0
            case LLAMA_GRETYPE_CHAR_ANY:
276
0
                fprintf(file, "(\"");
277
0
                print_grammar_char(file, elem.value);
278
0
                fprintf(file, "\") ");
279
0
                break;
280
0
            case LLAMA_GRETYPE_TOKEN:
281
0
                fprintf(file, "<[");
282
0
                fprintf(file, "%u", elem.value);
283
0
                fprintf(file, "]> ");
284
0
                break;
285
0
            case LLAMA_GRETYPE_TOKEN_NOT:
286
0
                fprintf(file, "!");
287
0
                fprintf(file, "<[");
288
0
                fprintf(file, "%u", elem.value);
289
0
                fprintf(file, "]> ");
290
0
                break;
291
0
        }
292
0
    }
293
0
    fprintf(file, "\n");
294
0
}
295
296
static void print_rule(
297
        FILE     * file,
298
        uint32_t   rule_id,
299
        const llama_grammar_rule & rule,
300
0
        const std::map<uint32_t, std::string> & symbol_id_names) {
301
0
    if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
302
0
        throw std::runtime_error(
303
0
            "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
304
0
    }
305
0
    fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
306
0
    for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
307
0
        llama_grammar_element elem = rule[i];
308
0
        switch (elem.type) {
309
0
            case LLAMA_GRETYPE_END:
310
0
                throw std::runtime_error(
311
0
                    "unexpected end of rule: " + std::to_string(rule_id) + "," +
312
0
                    std::to_string(i));
313
0
            case LLAMA_GRETYPE_ALT:
314
0
                fprintf(file, "| ");
315
0
                break;
316
0
            case LLAMA_GRETYPE_RULE_REF:
317
0
                fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
318
0
                break;
319
0
            case LLAMA_GRETYPE_CHAR:
320
0
                fprintf(file, "[");
321
0
                print_grammar_char(file, elem.value);
322
0
                break;
323
0
            case LLAMA_GRETYPE_CHAR_NOT:
324
0
                fprintf(file, "[^");
325
0
                print_grammar_char(file, elem.value);
326
0
                break;
327
0
            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
328
0
                if (i == 0 || !is_char_element(rule[i - 1])) {
329
0
                    throw std::runtime_error(
330
0
                        "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
331
0
                        std::to_string(rule_id) + "," + std::to_string(i));
332
0
                }
333
0
                fprintf(file, "-");
334
0
                print_grammar_char(file, elem.value);
335
0
                break;
336
0
            case LLAMA_GRETYPE_CHAR_ALT:
337
0
                if (i == 0 || !is_char_element(rule[i - 1])) {
338
0
                    throw std::runtime_error(
339
0
                        "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
340
0
                        std::to_string(rule_id) + "," + std::to_string(i));
341
0
                }
342
0
                print_grammar_char(file, elem.value);
343
0
                break;
344
0
            case LLAMA_GRETYPE_CHAR_ANY:
345
0
                fprintf(file, ".");
346
0
                break;
347
0
            case LLAMA_GRETYPE_TOKEN:
348
0
                fprintf(file, "<[");
349
0
                fprintf(file, "%u", elem.value);
350
0
                fprintf(file, "]> ");
351
0
                break;
352
0
            case LLAMA_GRETYPE_TOKEN_NOT:
353
0
                fprintf(file, "!");
354
0
                fprintf(file, "<[");
355
0
                fprintf(file, "%u", elem.value);
356
0
                fprintf(file, "]> ");
357
0
                break;
358
0
        }
359
0
        if (is_char_element(elem)) {
360
0
            switch (rule[i + 1].type) {
361
0
                case LLAMA_GRETYPE_CHAR_ALT:
362
0
                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
363
0
                case LLAMA_GRETYPE_CHAR_ANY:
364
0
                    break;
365
0
                default:
366
0
                    fprintf(file, "] ");
367
0
            }
368
0
        }
369
0
    }
370
0
    fprintf(file, "\n");
371
0
}
372
373
//
374
// Regex utilities
375
//
376
377
0
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
378
0
    auto find_start_pos = [](const std::smatch & match) {
379
        // get from the first matched capturing group to the end of the string
380
0
        size_t start = std::string::npos;
381
0
        for (auto i = 1u; i < match.size(); i++) {
382
0
            if (match.length(i) > 0) {
383
0
                start = match.position(i);
384
0
                break;
385
0
            }
386
0
        }
387
0
        if (start == std::string::npos) {
388
0
            start = match.position(0);
389
0
        }
390
0
        return start;
391
0
    };
392
393
0
    if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
394
        // match against the entire input
395
0
        std::smatch match;
396
0
        if (std::regex_match(input, match, regex)) {
397
0
            return find_start_pos(match);
398
0
        }
399
0
    }
400
401
    // search anywhere
402
0
    std::smatch match;
403
0
    if (std::regex_search(input, match, regex)) {
404
0
        return find_start_pos(match);
405
0
    }
406
407
0
    return std::string::npos;
408
0
}
409
410
411
//
412
// implementation
413
//
414
415
0
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
416
0
    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
417
0
    auto result = symbol_ids.emplace(std::string(src, len), next_id);
418
0
    return result.first->second;
419
0
}
420
421
0
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
422
0
    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
423
0
    symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
424
0
    return next_id;
425
0
}
426
427
0
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
428
0
    if (rules.size() <= rule_id) {
429
0
        rules.resize(rule_id + 1);
430
0
    }
431
0
    rules[rule_id] = rule;
432
0
}
433
434
const char * llama_grammar_parser::parse_alternates(
435
        const char        * src,
436
        const std::string & rule_name,
437
        uint32_t            rule_id,
438
0
        bool                is_nested) {
439
0
    llama_grammar_rule rule;
440
0
    const char * pos = parse_sequence(src, rule_name, rule, is_nested);
441
0
    while (*pos == '|') {
442
0
        rule.push_back({LLAMA_GRETYPE_ALT, 0});
443
0
        pos = parse_space(pos + 1, true);
444
0
        pos = parse_sequence(pos, rule_name, rule, is_nested);
445
0
    }
446
0
    rule.push_back({LLAMA_GRETYPE_END, 0});
447
0
    add_rule(rule_id, rule);
448
0
    return pos;
449
0
}
450
451
const char * llama_grammar_parser::parse_sequence(
452
        const char         * src,
453
        const std::string  & rule_name,
454
        llama_grammar_rule & rule,
455
0
        bool               is_nested) {
456
0
    size_t last_sym_start = rule.size();
457
0
    const char * pos = src;
458
0
    uint64_t n_prev_rules = 1;
459
460
    // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
461
    // (though it's technically the same as -1 now)
462
0
    auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
463
0
        bool no_max = max_times == UINT64_MAX;
464
0
        if (last_sym_start == rule.size()) {
465
0
            throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
466
0
        }
467
468
        // apply transformation to previous symbol (last_sym_start to end) according to
469
        // the following rewrite rules:
470
        // S{m,n} --> S S S (m times) S'(n-m)
471
        //            S'(x)   ::= S S'(x-1) |
472
        //            (... n-m definitions of these S' rules ...)
473
        //            S'(1)   ::= S |
474
        // S{m,} -->  S S S (m times) S'
475
        //            S'     ::= S S' |
476
        // S*     --> S{0,}
477
        //        --> S'     ::= S S' |
478
        // S+     --> S{1,}
479
        //        --> S S'
480
        //            S'     ::= S S' |
481
        // S?     --> S{0,1}
482
        //        --> S'
483
        //            S'     ::= S |
484
485
0
        llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
486
        // Calculate the total number of rules that will be generated by this repetition
487
0
        uint64_t total_rules = 1; // Start with 1 for the original rule
488
0
        if (!no_max && max_times > 0) {
489
0
            total_rules = max_times;
490
0
        } else if (min_times > 0) {
491
0
            total_rules = min_times;
492
0
        }
493
494
0
        if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) {
495
0
            throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity");
496
0
        }
497
498
0
        if (min_times == 0) {
499
0
            rule.resize(last_sym_start);
500
0
        } else {
501
            // Repeat the previous elements (min_times - 1) times
502
0
            for (uint64_t i = 1; i < min_times; i++) {
503
0
                rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
504
0
            }
505
0
        }
506
507
0
        uint32_t last_rec_rule_id = 0;
508
0
        auto n_opt = no_max ? 1 : max_times - min_times;
509
510
0
        llama_grammar_rule rec_rule(prev_rule);
511
0
        for (uint64_t i = 0; i < n_opt; i++) {
512
0
            rec_rule.resize(prev_rule.size());
513
0
            uint32_t rec_rule_id = generate_symbol_id( rule_name);
514
0
            if (i > 0 || no_max) {
515
0
                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
516
0
            }
517
0
            rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
518
0
            rec_rule.push_back({LLAMA_GRETYPE_END, 0});
519
0
            add_rule( rec_rule_id, rec_rule);
520
0
            last_rec_rule_id = rec_rule_id;
521
0
        }
522
0
        if (n_opt > 0) {
523
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
524
0
        }
525
0
        n_prev_rules *= total_rules;
526
0
        GGML_ASSERT(n_prev_rules >= 1);
527
0
    };
528
529
0
    while (*pos) {
530
0
        if (*pos == '"') { // literal string
531
0
            pos++;
532
0
            last_sym_start = rule.size();
533
0
            n_prev_rules = 1;
534
0
            while (*pos != '"') {
535
0
                if (!*pos) {
536
0
                    throw std::runtime_error("unexpected end of input");
537
0
                }
538
0
                auto char_pair = parse_char(pos);
539
0
                     pos       = char_pair.second;
540
0
                rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
541
0
            }
542
0
            pos = parse_space(pos + 1, is_nested);
543
0
        } else if (*pos == '[') { // char range(s)
544
0
            pos++;
545
0
            enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
546
0
            if (*pos == '^') {
547
0
                pos++;
548
0
                start_type = LLAMA_GRETYPE_CHAR_NOT;
549
0
            }
550
0
            last_sym_start = rule.size();
551
0
            n_prev_rules = 1;
552
0
            while (*pos != ']') {
553
0
                if (!*pos) {
554
0
                    throw std::runtime_error("unexpected end of input");
555
0
                }
556
0
                auto char_pair = parse_char(pos);
557
0
                     pos       = char_pair.second;
558
0
                enum llama_gretype type = last_sym_start < rule.size()
559
0
                    ? LLAMA_GRETYPE_CHAR_ALT
560
0
                    : start_type;
561
562
0
                rule.push_back({type, char_pair.first});
563
0
                if (pos[0] == '-' && pos[1] != ']') {
564
0
                    if (!pos[1]) {
565
0
                        throw std::runtime_error("unexpected end of input");
566
0
                    }
567
0
                    auto endchar_pair = parse_char(pos + 1);
568
0
                         pos          = endchar_pair.second;
569
0
                    rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
570
0
                }
571
0
            }
572
0
            pos = parse_space(pos + 1, is_nested);
573
0
        } else if (*pos == '<' || *pos == '!') { // token
574
0
            auto type = LLAMA_GRETYPE_TOKEN;
575
0
            if (*pos == '!') { // token inverse
576
0
                type = LLAMA_GRETYPE_TOKEN_NOT;
577
0
                pos++;
578
0
            }
579
0
            auto token_pair = parse_token(vocab, pos);
580
0
            const char * token_end  = token_pair.second;
581
0
            last_sym_start = rule.size();
582
0
            n_prev_rules = 1;
583
0
            rule.push_back({type, token_pair.first});
584
0
            pos = parse_space(token_end, is_nested);
585
0
        } else if (is_word_char(*pos)) { // rule reference
586
0
            const char * name_end    = parse_name(pos);
587
0
            uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
588
0
            pos = parse_space(name_end, is_nested);
589
0
            last_sym_start = rule.size();
590
0
            n_prev_rules = 1;
591
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
592
0
        } else if (*pos == '(') { // grouping
593
            // parse nested alternates into synthesized rule
594
0
            pos = parse_space(pos + 1, true);
595
0
            uint32_t n_rules_before = symbol_ids.size();
596
0
            uint32_t sub_rule_id = generate_symbol_id(rule_name);
597
0
            pos = parse_alternates(pos, rule_name, sub_rule_id, true);
598
0
            n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before);
599
0
            last_sym_start = rule.size();
600
            // output reference to synthesized rule
601
0
            rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
602
0
            if (*pos != ')') {
603
0
                throw std::runtime_error(std::string("expecting ')' at ") + pos);
604
0
            }
605
0
            pos = parse_space(pos + 1, is_nested);
606
0
        } else if (*pos == '.') { // any char
607
0
            last_sym_start = rule.size();
608
0
            n_prev_rules = 1;
609
0
            rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
610
0
            pos = parse_space(pos + 1, is_nested);
611
0
        } else if (*pos == '*') {
612
0
            pos = parse_space(pos + 1, is_nested);
613
0
            handle_repetitions(0, -1);
614
0
        } else if (*pos == '+') {
615
0
            pos = parse_space(pos + 1, is_nested);
616
0
            handle_repetitions(1, -1);
617
0
        } else if (*pos == '?') {
618
0
            pos = parse_space(pos + 1, is_nested);
619
0
            handle_repetitions(0, 1);
620
0
        } else if (*pos == '{') {
621
0
            pos = parse_space(pos + 1, is_nested);
622
623
0
            if (!is_digit_char(*pos)) {
624
0
                throw std::runtime_error(std::string("expecting an int at ") + pos);
625
0
            }
626
0
            const char * int_end = parse_int(pos);
627
0
            uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
628
0
            pos = parse_space(int_end, is_nested);
629
630
0
            uint64_t max_times = UINT64_MAX; // default: no max limit
631
632
0
            if (*pos == '}') {
633
0
                max_times = min_times;
634
0
                pos = parse_space(pos + 1, is_nested);
635
0
            } else if (*pos == ',') {
636
0
                pos = parse_space(pos + 1, is_nested);
637
638
0
                if (is_digit_char(*pos)) {
639
0
                    const char * int_end = parse_int(pos);
640
0
                    max_times = std::stoull(std::string(pos, int_end - pos));
641
0
                    pos = parse_space(int_end, is_nested);
642
0
                }
643
644
0
                if (*pos != '}') {
645
0
                    throw std::runtime_error(std::string("expecting '}' at ") + pos);
646
0
                }
647
0
                pos = parse_space(pos + 1, is_nested);
648
0
            } else {
649
0
                throw std::runtime_error(std::string("expecting ',' at ") + pos);
650
0
            }
651
0
            bool has_max = max_times != UINT64_MAX;
652
0
            if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
653
0
                throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
654
0
            }
655
0
            handle_repetitions(min_times, max_times);
656
0
        } else {
657
0
            break;
658
0
        }
659
0
    }
660
0
    return pos;
661
0
}
662
663
0
const char * llama_grammar_parser::parse_rule(const char * src) {
664
0
    const char * name_end = parse_name(src);
665
0
    const char * pos      = parse_space(name_end, false);
666
0
    size_t       name_len = name_end - src;
667
0
    uint32_t     rule_id  = get_symbol_id(src, name_len);
668
0
    const std::string name(src, name_len);
669
670
0
    if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
671
0
        throw std::runtime_error(std::string("expecting ::= at ") + pos);
672
0
    }
673
0
    pos = parse_space(pos + 3, true);
674
675
0
    pos = parse_alternates(pos, name, rule_id, false);
676
677
0
    if (*pos == '\r') {
678
0
        pos += pos[1] == '\n' ? 2 : 1;
679
0
    } else if (*pos == '\n') {
680
0
        pos++;
681
0
    } else if (*pos) {
682
0
        throw std::runtime_error(std::string("expecting newline or end at ") + pos);
683
0
    }
684
0
    return parse_space(pos, true);
685
0
}
686
687
0
bool llama_grammar_parser::parse(const char * src) {
688
0
    try {
689
0
        const char * pos = parse_space(src, true);
690
0
        while (*pos) {
691
0
            pos = parse_rule(pos);
692
0
        }
693
        // Validate the state to ensure that all rules are defined
694
0
        for (const auto & rule : rules) {
695
0
            if (rule.empty()) {
696
0
                throw std::runtime_error("Undefined rule");
697
0
            }
698
0
            for (const auto & elem : rule) {
699
0
                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
700
                    // Ensure that the rule at that location exists
701
0
                    if (elem.value >= rules.size() || rules[elem.value].empty()) {
702
                        // Get the name of the rule that is missing
703
0
                        for (const auto & kv : symbol_ids) {
704
0
                            if (kv.second == elem.value) {
705
0
                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
706
0
                            }
707
0
                        }
708
0
                    }
709
0
                }
710
0
            }
711
0
        }
712
0
    } catch (const std::exception & err) {
713
0
        fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
714
0
        rules.clear();
715
0
        return false;
716
0
    }
717
718
0
    return true;
719
0
}
720
721
0
void llama_grammar_parser::print(FILE * file) {
722
0
    try {
723
0
        std::map<uint32_t, std::string> symbol_id_names;
724
0
        for (const auto & kv : symbol_ids) {
725
0
            symbol_id_names[kv.second] = kv.first;
726
0
        }
727
0
        for (size_t i = 0, end = rules.size(); i < end; i++) {
728
            // fprintf(file, "%zu: ", i);
729
            // print_rule_binary(file, rules[i]);
730
0
            print_rule(file, uint32_t(i), rules[i], symbol_id_names);
731
            // fprintf(file, "\n");
732
0
        }
733
0
    } catch (const std::exception & err) {
734
0
        fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
735
0
    }
736
0
}
737
738
0
llama_grammar_stack llama_grammar_parser::c_rules() const {
739
0
    llama_grammar_stack ret;
740
0
    ret.reserve(rules.size());
741
0
    for (const auto & rule : rules) {
742
0
        ret.push_back(rule.data());
743
0
    }
744
0
    return ret;
745
0
}
746
747
// returns true iff pos points to the end of one of the definitions of a rule
748
0
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
749
0
    switch (pos->type) {
750
0
        case LLAMA_GRETYPE_END: return true;  // NOLINT
751
0
        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
752
0
        default:                return false;
753
0
    }
754
0
}
755
756
// returns true iff chr satisfies the char range at pos (regular or inverse range)
757
// asserts that pos is pointing to a char range element
758
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
759
        const llama_grammar_element * pos,
760
0
        const uint32_t                chr) {
761
0
    bool found            = false;
762
0
    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
763
764
0
    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
765
766
0
    do {
767
0
        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
768
            // inclusive range, e.g. [a-z]
769
0
            found = found || (pos->value <= chr && chr <= pos[1].value);
770
0
            pos += 2;
771
0
        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
772
            // Any character matches "."
773
0
            found = true;
774
0
            pos += 1;
775
0
        } else {
776
            // exact char match, e.g. [a] or "a"
777
0
            found = found || pos->value == chr;
778
0
            pos += 1;
779
0
        }
780
0
    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
781
782
0
    return std::make_pair(found == is_positive_char, pos);
783
0
}
784
785
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
786
// range at pos (regular or inverse range)
787
// asserts that pos is pointing to a char range element
788
static bool llama_grammar_match_partial_char(
789
        const llama_grammar_element * pos,
790
0
        const llama_partial_utf8      partial_utf8) {
791
0
    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
792
0
    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
793
794
0
    uint32_t partial_value = partial_utf8.value;
795
0
    int      n_remain      = partial_utf8.n_remain;
796
797
    // invalid sequence or 7-bit char split across 2 bytes (overlong)
798
0
    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
799
0
        return false;
800
0
    }
801
802
    // range of possible code points this partial UTF-8 sequence could complete to
803
0
    uint32_t low  = partial_value << (n_remain * 6);
804
0
    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
805
806
0
    if (low == 0) {
807
0
        if (n_remain == 2) {
808
0
            low = 1 << 11;
809
0
        } else if (n_remain == 3) {
810
0
            low = 1 << 16;
811
0
        }
812
0
    }
813
814
0
    do {
815
0
        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
816
            // inclusive range, e.g. [a-z]
817
0
            if (pos->value <= high && low <= pos[1].value) {
818
0
                return is_positive_char;
819
0
            }
820
0
            pos += 2;
821
0
        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
822
            // Any character matches "."
823
0
            return true;
824
0
        } else {
825
            // exact char match, e.g. [a] or "a"
826
0
            if (low <= pos->value && pos->value <= high) {
827
0
                return is_positive_char;
828
0
            }
829
0
            pos += 1;
830
0
        }
831
0
    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
832
833
0
    return !is_positive_char;
834
0
}
835
836
// returns true iff token matches the rule at pos (regular or inverse)
837
// asserts that pos is pointing to a token element
838
static bool llama_grammar_match_token(
839
    const llama_grammar_element * pos,
840
0
    const llama_token             token) {
841
0
    GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
842
0
    if (pos->type == LLAMA_GRETYPE_TOKEN) {
843
0
        return pos->value == static_cast<uint32_t>(token);
844
0
    }
845
0
    if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
846
0
        return pos->value != static_cast<uint32_t>(token);
847
0
    }
848
0
    return false;
849
0
}
850
851
// transforms a grammar pushdown stack into N possible stacks, all ending
852
// at a character range (terminal element)
853
static void llama_grammar_advance_stack(
854
        const llama_grammar_rules  & rules,
855
        const llama_grammar_stack  & stack,
856
0
        llama_grammar_stacks & new_stacks) {
857
0
    std::vector<llama_grammar_stack> todo;
858
0
    todo.push_back(stack);
859
860
0
    auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) {
861
0
        return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(),
862
0
            [](const llama_grammar_element * pa, const llama_grammar_element * pb) {
863
0
                return pa < pb;  // Compare pointer addresses
864
0
            }
865
0
        );
866
0
    };
867
868
0
    std::set<llama_grammar_stack, decltype(stack_cmp)> seen(stack_cmp);
869
870
0
    while (!todo.empty()) {
871
0
        llama_grammar_stack curr_stack = std::move(todo.back());
872
0
        todo.pop_back();
873
874
0
        if (seen.find( curr_stack) != seen.end()) {
875
0
            continue;
876
0
        }
877
0
        seen.insert(curr_stack);
878
879
0
        if (curr_stack.empty()) {
880
0
            if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
881
0
                new_stacks.emplace_back(std::move(curr_stack));
882
0
            }
883
0
            continue;
884
0
        }
885
886
0
        const llama_grammar_element * pos = curr_stack.back();
887
888
0
        switch (pos->type) {
889
0
        case LLAMA_GRETYPE_RULE_REF: {
890
0
            const size_t                  rule_id = static_cast<size_t>(pos->value);
891
0
            const llama_grammar_element * subpos  = rules[rule_id].data();
892
0
            do {
893
                // init new stack without the top (pos)
894
0
                llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
895
0
                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
896
                    // if this rule ref is followed by another element, add that to stack
897
0
                    next_stack.push_back(pos + 1);
898
0
                }
899
0
                if (!llama_grammar_is_end_of_sequence(subpos)) {
900
                    // if alternate is nonempty, add to stack
901
0
                    next_stack.push_back(subpos);
902
0
                }
903
0
                todo.push_back(std::move(next_stack));
904
0
                while (!llama_grammar_is_end_of_sequence(subpos)) {
905
                    // scan to end of alternate def
906
0
                    subpos++;
907
0
                }
908
0
                if (subpos->type == LLAMA_GRETYPE_ALT) {
909
                    // there's another alternate def of this rule to process
910
0
                    subpos++;
911
0
                } else {
912
0
                    break;
913
0
                }
914
0
            } while (true);
915
0
            break;
916
0
        }
917
0
        case LLAMA_GRETYPE_CHAR:
918
0
        case LLAMA_GRETYPE_CHAR_NOT:
919
0
        case LLAMA_GRETYPE_CHAR_ANY:
920
0
        case LLAMA_GRETYPE_TOKEN:
921
0
        case LLAMA_GRETYPE_TOKEN_NOT:
922
0
            if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
923
                // only add the stack if it's not a duplicate of one we already have
924
0
                new_stacks.emplace_back(std::move(curr_stack));
925
0
            }
926
0
            break;
927
0
        default:
928
            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
929
            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
930
            // those
931
0
            GGML_ABORT("fatal error");
932
0
        }
933
0
    }
934
0
}
935
936
static llama_grammar_candidates llama_grammar_reject_candidates(
937
        const llama_grammar_rules      & rules,
938
        const llama_grammar_stacks     & stacks,
939
0
        const llama_grammar_candidates & candidates) {
940
0
    GGML_ASSERT(!stacks.empty()); // REVIEW
941
942
0
    if (candidates.empty()) {
943
0
        return {};
944
0
    }
945
946
0
    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
947
948
0
    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
949
0
        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
950
0
    }
951
952
0
    return rejects;
953
0
}
954
955
static bool llama_grammar_detect_left_recursion(
956
        const llama_grammar_rules & rules,
957
        size_t rule_index,
958
        std::vector<bool> * rules_visited,
959
        std::vector<bool> * rules_in_progress,
960
0
        std::vector<bool> * rules_may_be_empty) {
961
0
    if ((*rules_in_progress)[rule_index]) {
962
0
        return true;
963
0
    }
964
965
0
    (*rules_in_progress)[rule_index] = true;
966
967
0
    const llama_grammar_rule & rule = rules[rule_index];
968
969
    // First check if the rule might produce the empty string. This could be done combined with the second
970
    // step but it's more readable as two steps.
971
0
    bool at_rule_start = true;
972
0
    for (size_t i = 0; i < rule.size(); i++) {
973
0
        if (llama_grammar_is_end_of_sequence(&rule[i])) {
974
0
            if (at_rule_start) {
975
0
                (*rules_may_be_empty)[rule_index] = true;
976
0
                break;
977
0
            }
978
0
            at_rule_start = true;
979
0
        } else {
980
0
            at_rule_start = false;
981
0
        }
982
0
    }
983
984
    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
985
    // be empty)
986
0
    bool recurse_into_nonterminal = true;
987
0
    for (size_t i = 0; i < rule.size(); i++) {
988
0
        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
989
0
            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
990
0
                return true;
991
0
            }
992
0
            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
993
0
                recurse_into_nonterminal = false;
994
0
            }
995
0
        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
996
0
            recurse_into_nonterminal = true;
997
0
        } else {
998
0
            recurse_into_nonterminal = false;
999
0
        }
1000
0
    }
1001
1002
0
    (*rules_in_progress)[rule_index] = false;
1003
0
    (*rules_visited)[rule_index] = true;
1004
1005
0
    return false;
1006
0
}
1007
1008
0
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
1009
0
    return grammar->rules;
1010
0
}
1011
1012
0
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
1013
0
    return grammar->stacks;
1014
0
}
1015
1016
static void llama_grammar_accept_chr(
1017
        struct llama_grammar       & grammar,
1018
        const llama_grammar_stack  & stack,
1019
              uint32_t               chr,
1020
0
              llama_grammar_stacks & new_stacks) {
1021
0
    if (stack.empty()) {
1022
0
        return;
1023
0
    }
1024
1025
0
    const llama_grammar_element * pos = stack.back();
1026
1027
    // ignore if this turns into a token
1028
0
    if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1029
0
        return;
1030
0
    }
1031
1032
0
    auto match = llama_grammar_match_char(pos, chr);
1033
0
    if (match.first) {
1034
0
        llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1035
0
        if (!llama_grammar_is_end_of_sequence(match.second)) {
1036
0
            new_stack.push_back(match.second);
1037
0
        }
1038
0
        llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
1039
0
    }
1040
0
}
1041
1042
0
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
1043
0
    llama_grammar_stacks stacks_new;
1044
0
    stacks_new.reserve(grammar->stacks.size());
1045
1046
0
    for (const auto & stack : grammar->stacks) {
1047
0
        llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
1048
0
    }
1049
1050
0
    grammar->stacks = std::move(stacks_new);
1051
0
}
1052
1053
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
1054
        const llama_grammar_rules      & rules,
1055
        const llama_grammar_stack      & stack,
1056
0
        const llama_grammar_candidates & candidates) {
1057
1058
0
    llama_grammar_candidates rejects;
1059
0
    rejects.reserve(candidates.size());
1060
1061
0
    if (stack.empty()) {
1062
0
        for (const auto & tok : candidates) {
1063
0
            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
1064
0
                rejects.push_back(tok);
1065
0
            }
1066
0
        }
1067
0
        return rejects;
1068
0
    }
1069
1070
0
    const llama_grammar_element * stack_pos = stack.back();
1071
1072
    // if the top of the stack is a token rule, then we only need to check the token id
1073
0
    if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1074
0
        for (const auto & tok : candidates) {
1075
0
            if (*tok.code_points == 0) {
1076
                // reached the end of a token consumed by char rules, reject iff it ended
1077
                // in a partial response
1078
0
                if (tok.partial_utf8.n_remain != 0) {
1079
0
                    rejects.push_back(tok);
1080
0
                }
1081
0
            } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
1082
0
                rejects.push_back(tok);
1083
0
            }
1084
0
        }
1085
0
        return rejects;
1086
0
    }
1087
1088
0
    llama_grammar_candidates next_candidates;
1089
0
    next_candidates.reserve(candidates.size());
1090
1091
0
    for (const auto & tok : candidates) {
1092
0
        if (*tok.code_points == 0) {
1093
            // reached end of full codepoints in token, reject iff it ended in a partial sequence
1094
            // that cannot satisfy this position in grammar
1095
0
            if (tok.partial_utf8.n_remain != 0 &&
1096
0
                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
1097
0
                rejects.push_back(tok);
1098
0
            }
1099
0
        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
1100
0
            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
1101
0
        } else {
1102
0
            rejects.push_back(tok);
1103
0
        }
1104
0
    }
1105
1106
0
    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
1107
1108
    // update top of stack to next element, if any
1109
0
    llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
1110
0
    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
1111
0
        stack_after.push_back(stack_pos_after);
1112
0
    }
1113
0
    llama_grammar_stacks next_stacks;
1114
0
    llama_grammar_advance_stack(rules, stack_after, next_stacks);
1115
1116
0
    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
1117
0
    for (const auto & tok : next_rejects) {
1118
0
        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
1119
0
    }
1120
1121
0
    return rejects;
1122
0
}
1123
1124
////////////////////
1125
1126
struct llama_grammar * llama_grammar_init_impl(
1127
        const struct llama_vocab * vocab,
1128
        const llama_grammar_element ** rules,
1129
        size_t n_rules,
1130
0
        size_t start_rule_index) {
1131
0
    const llama_grammar_element * pos;
1132
1133
    // copy rule definitions into vectors
1134
0
    llama_grammar_rules vec_rules(n_rules);
1135
0
    for (size_t i = 0; i < n_rules; i++) {
1136
0
        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1137
0
            vec_rules[i].push_back(*pos);
1138
0
        }
1139
0
        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1140
0
    }
1141
1142
    // Check for left recursion
1143
0
    std::vector<bool> rules_visited(n_rules);
1144
0
    std::vector<bool> rules_in_progress(n_rules);
1145
0
    std::vector<bool> rules_may_be_empty(n_rules);
1146
0
    for (size_t i = 0; i < n_rules; i++) {
1147
0
        if (rules_visited[i]) {
1148
0
            continue;
1149
0
        }
1150
0
        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1151
0
            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1152
0
            return nullptr;
1153
0
        }
1154
0
    }
1155
1156
    // loop over alternates of start rule to build initial stacks
1157
0
    llama_grammar_stacks stacks;
1158
0
    pos = vec_rules[start_rule_index].data();
1159
0
    do {
1160
0
        llama_grammar_stack stack;
1161
0
        if (!llama_grammar_is_end_of_sequence(pos)) {
1162
            // if alternate is nonempty, add to stack
1163
0
            stack.push_back(pos);
1164
0
        }
1165
0
        llama_grammar_advance_stack(vec_rules, stack, stacks);
1166
0
        while (!llama_grammar_is_end_of_sequence(pos)) {
1167
            // scan to end of alternate def
1168
0
            pos++;
1169
0
        }
1170
0
        if (pos->type == LLAMA_GRETYPE_ALT) {
1171
            // there's another alternate def of this rule to process
1172
0
            pos++;
1173
0
        } else {
1174
0
            break;
1175
0
        }
1176
0
    } while (true);
1177
1178
    // Important: vec_rules has to be moved here, not copied, because stacks contains
1179
    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1180
    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1181
0
    return new llama_grammar {
1182
0
        vocab,
1183
0
        std::move(vec_rules),
1184
0
        std::move(stacks),
1185
0
        /* .partial_utf8 = */             {},
1186
0
        /* .lazy = */                     false,
1187
0
        /* .awaiting_trigger = */         false,
1188
0
        /* .trigger_buffer = */           "",
1189
0
        /* .trigger_buffer_positions = */ {},
1190
0
        /* .trigger_tokens = */           {},
1191
0
        /* .trigger_patterns = */         {},
1192
0
    };
1193
0
}
1194
1195
struct llama_grammar * llama_grammar_init_impl(
1196
        const struct llama_vocab * vocab,
1197
                      const char * grammar_str,
1198
                      const char * grammar_root,
1199
                              bool lazy,
1200
                     const char ** trigger_patterns,
1201
                            size_t num_trigger_patterns,
1202
               const llama_token * trigger_tokens,
1203
0
                            size_t num_trigger_tokens) {
1204
0
    llama_grammar_parser parser(vocab);
1205
1206
    // if there is a grammar, parse it
1207
    // rules will be empty (default) if there are parse errors
1208
0
    if (!parser.parse(grammar_str) || parser.rules.empty()) {
1209
0
        LLAMA_LOG_ERROR("failed to parse grammar\n");
1210
0
        return nullptr;
1211
0
    }
1212
1213
    // Ensure that the grammar contains the start symbol
1214
0
    if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) {
1215
0
        LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root);
1216
0
        return nullptr;
1217
0
    }
1218
1219
0
    std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
1220
1221
0
    const size_t n_rules = grammar_rules.size();
1222
0
    const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
1223
1224
0
    const llama_grammar_element * pos;
1225
1226
    // copy rule definitions into vectors
1227
0
    llama_grammar_rules vec_rules(n_rules);
1228
0
    for (size_t i = 0; i < n_rules; i++) {
1229
0
        for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1230
0
            vec_rules[i].push_back(*pos);
1231
0
        }
1232
0
        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1233
0
    }
1234
1235
    // Check for left recursion
1236
0
    std::vector<bool> rules_visited(n_rules);
1237
0
    std::vector<bool> rules_in_progress(n_rules);
1238
0
    std::vector<bool> rules_may_be_empty(n_rules);
1239
0
    for (size_t i = 0; i < n_rules; i++) {
1240
0
        if (rules_visited[i]) {
1241
0
            continue;
1242
0
        }
1243
0
        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1244
0
            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i);
1245
0
            return nullptr;
1246
0
        }
1247
0
    }
1248
1249
    // loop over alternates of start rule to build initial stacks
1250
0
    llama_grammar_stacks stacks;
1251
0
    pos = vec_rules[start_rule_index].data();
1252
0
    do {
1253
0
        llama_grammar_stack stack;
1254
0
        if (!llama_grammar_is_end_of_sequence(pos)) {
1255
            // if alternate is nonempty, add to stack
1256
0
            stack.push_back(pos);
1257
0
        }
1258
0
        llama_grammar_advance_stack(vec_rules, stack, stacks);
1259
0
        while (!llama_grammar_is_end_of_sequence(pos)) {
1260
            // scan to end of alternate def
1261
0
            pos++;
1262
0
        }
1263
0
        if (pos->type == LLAMA_GRETYPE_ALT) {
1264
            // there's another alternate def of this rule to process
1265
0
            pos++;
1266
0
        } else {
1267
0
            break;
1268
0
        }
1269
0
    } while (true);
1270
1271
0
    std::vector<llama_token>    vec_trigger_tokens;
1272
0
    std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
1273
0
    for (size_t i = 0; i < num_trigger_tokens; i++) {
1274
0
        GGML_ASSERT(trigger_tokens != nullptr);
1275
0
        vec_trigger_tokens.push_back(trigger_tokens[i]);
1276
0
    }
1277
0
    for (size_t i = 0; i < num_trigger_patterns; i++) {
1278
0
        GGML_ASSERT(trigger_patterns != nullptr);
1279
0
        auto & trigger = vec_trigger_patterns.emplace_back();
1280
0
        trigger.pattern = trigger_patterns[i];
1281
0
        trigger.regex = std::regex(trigger.pattern);
1282
0
    }
1283
1284
    // Important: vec_rules has to be moved here, not copied, because stacks contains
1285
    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1286
    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1287
0
    return new llama_grammar {
1288
0
        vocab,
1289
0
        std::move(vec_rules),
1290
0
        std::move(stacks),
1291
0
        /* .partial_utf8 = */             {},
1292
0
        /* .lazy = */                     lazy,
1293
0
        /* .awaiting_trigger = */         lazy,
1294
0
        /* .trigger_buffer = */           "",
1295
0
        /* .trigger_buffer_positions = */ {},
1296
0
        std::move(vec_trigger_tokens),
1297
0
        std::move(vec_trigger_patterns),
1298
0
    };
1299
0
}
1300
1301
0
void llama_grammar_free_impl(struct llama_grammar * grammar) {
1302
0
    if (grammar == nullptr) {
1303
0
        return;
1304
0
    }
1305
1306
0
    delete grammar;
1307
0
}
1308
1309
0
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1310
0
    auto * result = new llama_grammar {
1311
0
        grammar.vocab,
1312
0
        grammar.rules,
1313
0
        grammar.stacks,
1314
0
        grammar.partial_utf8,
1315
0
        grammar.lazy,
1316
0
        grammar.awaiting_trigger,
1317
0
        grammar.trigger_buffer,
1318
0
        grammar.trigger_buffer_positions,
1319
0
        grammar.trigger_tokens,
1320
0
        grammar.trigger_patterns,
1321
0
    };
1322
1323
    // redirect elements in stacks to point to new rules
1324
0
    for (size_t is = 0; is < result->stacks.size(); is++) {
1325
0
        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
1326
0
            for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1327
0
                for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1328
0
                    if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1329
0
                        result->stacks[is][ie] =  &result->rules[ir0][ir1];
1330
0
                    }
1331
0
                }
1332
0
            }
1333
0
        }
1334
0
    }
1335
1336
0
    return result;
1337
0
}
1338
1339
0
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1340
0
    GGML_ASSERT(grammar.vocab != nullptr);
1341
1342
0
    if (grammar.awaiting_trigger) {
1343
0
        return;
1344
0
    }
1345
1346
0
    bool allow_eog = false;
1347
0
    for (const auto & stack : grammar.stacks) {
1348
0
        if (stack.empty()) {
1349
0
            allow_eog = true;
1350
0
            break;
1351
0
        }
1352
0
    }
1353
1354
0
    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1355
0
    candidates_decoded.reserve(cur_p->size);
1356
1357
0
    llama_grammar_candidates candidates_grammar;
1358
0
    candidates_grammar.reserve(cur_p->size);
1359
1360
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1361
0
        const llama_token id      = cur_p->data[i].id;
1362
0
        const std::string & piece = grammar.vocab->token_to_piece(id);
1363
1364
0
        if (grammar.vocab->is_eog(id)) {
1365
0
            if (!allow_eog) {
1366
0
                cur_p->data[i].logit = -INFINITY;
1367
0
            }
1368
0
        } else if (piece.empty() || piece[0] == 0) {
1369
0
            cur_p->data[i].logit = -INFINITY;
1370
0
        } else {
1371
0
            candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1372
0
            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
1373
0
        }
1374
0
    }
1375
1376
0
    const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
1377
0
    for (const auto & reject : rejects) {
1378
0
        cur_p->data[reject.index].logit = -INFINITY;
1379
0
    }
1380
0
}
1381
1382
0
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1383
0
    GGML_ASSERT(grammar.vocab != nullptr);
1384
1385
0
    const auto & piece = grammar.vocab->token_to_piece(token);
1386
1387
0
    if (grammar.awaiting_trigger) {
1388
0
        if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1389
0
            grammar.awaiting_trigger = false;
1390
0
            grammar.trigger_buffer.clear();
1391
0
            llama_grammar_accept_token(grammar, token, piece);
1392
0
            LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1393
0
            return;
1394
0
        } else {
1395
0
            auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
1396
0
            grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
1397
0
            grammar.trigger_buffer += piece;
1398
1399
0
            for (const auto & trigger_pattern : grammar.trigger_patterns) {
1400
0
                auto start = trigger_pattern.find(grammar.trigger_buffer);
1401
0
                if (start != std::string::npos) {
1402
0
                    grammar.awaiting_trigger = false;
1403
1404
                    // replay tokens that overlap with [start, end)
1405
0
                    for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
1406
0
                        auto [tok_start, tok_end] = tok_pos;
1407
0
                        if (tok_end <= start) {
1408
0
                            continue;
1409
0
                        }
1410
1411
0
                        size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
1412
0
                        size_t piece_len = tok_end - piece_start;
1413
0
                        auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
1414
0
                        llama_grammar_accept_token(grammar, tok, tok_piece);
1415
0
                    }
1416
1417
0
                    auto constrained_str = grammar.trigger_buffer.substr(start);
1418
0
                    grammar.trigger_buffer.clear();
1419
0
                    grammar.trigger_buffer_positions.clear();
1420
0
                    LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1421
0
                    return;
1422
0
                }
1423
0
            }
1424
0
            LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
1425
0
            return;
1426
0
        }
1427
0
    }
1428
1429
0
    if (grammar.vocab->is_eog(token)) {
1430
0
        for (const auto & stack : grammar.stacks) {
1431
0
            if (stack.empty()) {
1432
0
                return;
1433
0
            }
1434
0
        }
1435
0
        GGML_ABORT("fatal error");
1436
0
    }
1437
1438
0
    llama_grammar_accept_token(grammar, token, piece);
1439
0
}
1440
1441
0
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
1442
    // Note terminating 0 in decoded string
1443
0
    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1444
0
    const auto & code_points = decoded.first;
1445
1446
0
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1447
0
        llama_grammar_accept(&grammar, *it);
1448
0
    }
1449
1450
0
    grammar.partial_utf8 = decoded.second;
1451
0
    if (grammar.stacks.empty()) {
1452
0
        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1453
0
    }
1454
0
}
1455
1456
0
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
1457
    // Note terminating 0 in decoded string
1458
0
    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1459
0
    const auto & code_points = decoded.first;
1460
1461
0
    llama_grammar_stacks stacks_new;
1462
0
    stacks_new.reserve(grammar.stacks.size());
1463
1464
0
    for (const auto & stack : grammar.stacks) {
1465
0
        if (stack.empty()) {
1466
0
            continue;
1467
0
        }
1468
1469
0
        const llama_grammar_element * pos = stack.back();
1470
1471
0
        if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1472
0
            if (llama_grammar_match_token(pos, token)) {
1473
0
                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1474
0
                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
1475
0
                    new_stack.push_back(pos + 1);
1476
0
                }
1477
0
                llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
1478
0
            }
1479
0
        } else {
1480
0
            llama_grammar_stacks current_stacks = {stack};
1481
1482
0
            for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1483
0
                llama_grammar_stacks next_stacks;
1484
1485
0
                for (const auto & cur_stack : current_stacks) {
1486
0
                    llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
1487
0
                }
1488
1489
0
                current_stacks = std::move(next_stacks);
1490
0
                if (current_stacks.empty()) {
1491
0
                    break;
1492
0
                }
1493
0
            }
1494
1495
0
            for (auto & surviving_stack : current_stacks) {
1496
0
                if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
1497
0
                    stacks_new.emplace_back(surviving_stack);
1498
0
                }
1499
0
            }
1500
0
        }
1501
0
    }
1502
1503
0
    grammar.stacks = std::move(stacks_new);
1504
0
    grammar.partial_utf8 = decoded.second;
1505
1506
0
    if (grammar.stacks.empty()) {
1507
0
        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
1508
0
    }
1509
0
}
1510