/src/llama.cpp/src/llama-grammar.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | |
5 | | #include <map> |
6 | | #include <regex> |
7 | | #include <string> |
8 | | #include <vector> |
9 | | |
10 | | struct llama_vocab; |
11 | | |
12 | | // grammar element type |
13 | | enum llama_gretype { |
14 | | // end of rule definition |
15 | | LLAMA_GRETYPE_END = 0, |
16 | | |
17 | | // start of alternate definition for rule |
18 | | LLAMA_GRETYPE_ALT = 1, |
19 | | |
20 | | // non-terminal element: reference to rule |
21 | | LLAMA_GRETYPE_RULE_REF = 2, |
22 | | |
23 | | // terminal element: character (code point) |
24 | | LLAMA_GRETYPE_CHAR = 3, |
25 | | |
26 | | // inverse char(s) ([^a], [^a-b] [^abc]) |
27 | | LLAMA_GRETYPE_CHAR_NOT = 4, |
28 | | |
29 | | // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to |
30 | | // be an inclusive range ([a-z]) |
31 | | LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, |
32 | | |
33 | | // modifies a preceding LLAMA_GRETYPE_CHAR or |
34 | | // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) |
35 | | LLAMA_GRETYPE_CHAR_ALT = 6, |
36 | | |
37 | | // any character (.) |
38 | | LLAMA_GRETYPE_CHAR_ANY = 7, |
39 | | |
40 | | // terminal element: token (<[token-id]>) |
41 | | LLAMA_GRETYPE_TOKEN = 8, |
42 | | |
43 | | // inverse token (!<[token-id]>) |
44 | | LLAMA_GRETYPE_TOKEN_NOT = 9, |
45 | | }; |
46 | | |
47 | | typedef struct llama_grammar_element { |
48 | | enum llama_gretype type; |
49 | | uint32_t value; // Unicode code point, rule ID, or token ID |
50 | | } llama_grammar_element; |
51 | | |
52 | | struct llama_partial_utf8 { |
53 | | uint32_t value; // bit value so far (unshifted) |
54 | | int n_remain; // num bytes remaining; -1 indicates invalid sequence |
55 | | }; |
56 | | |
57 | | struct llama_grammar_candidate { |
58 | | size_t index; |
59 | | const uint32_t * code_points; |
60 | | llama_partial_utf8 partial_utf8; |
61 | | llama_token id; |
62 | | }; |
63 | | |
64 | | using llama_grammar_rule = std::vector< llama_grammar_element>; |
65 | | using llama_grammar_stack = std::vector<const llama_grammar_element *>; |
66 | | |
67 | | using llama_grammar_rules = std::vector<llama_grammar_rule>; |
68 | | using llama_grammar_stacks = std::vector<llama_grammar_stack>; |
69 | | using llama_grammar_candidates = std::vector<llama_grammar_candidate>; |
70 | | |
71 | | // TODO: remove, needed for tests atm |
72 | | const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); |
73 | | llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); |
74 | | |
75 | | // takes a set of possible pushdown stacks on a grammar, which are required to |
76 | | // be positioned at a character range (see `llama_grammar_advance_stack`), and |
77 | | // produces the N possible stacks if the given char is accepted at those |
78 | | // positions |
79 | | void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); |
80 | | |
81 | | std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( |
82 | | const llama_grammar_rules & rules, |
83 | | const llama_grammar_stack & stack, |
84 | | const llama_grammar_candidates & candidates); |
85 | | |
86 | | struct llama_grammar_parser { |
87 | | const llama_vocab * vocab; |
88 | | std::map<std::string, uint32_t> symbol_ids; |
89 | | |
90 | | llama_grammar_rules rules; |
91 | | |
92 | 0 | llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} |
93 | | |
94 | | llama_grammar_stack c_rules() const; |
95 | | |
96 | | uint32_t get_symbol_id(const char * src, size_t len); |
97 | | uint32_t generate_symbol_id(const std::string & base_name); |
98 | | |
99 | | void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); |
100 | | |
101 | | const char * parse_alternates( |
102 | | const char * src, |
103 | | const std::string & rule_name, |
104 | | uint32_t rule_id, |
105 | | bool is_nested); |
106 | | |
107 | | const char * parse_sequence( |
108 | | const char * src, |
109 | | const std::string & rule_name, |
110 | | llama_grammar_rule & rule, |
111 | | bool is_nested); |
112 | | |
113 | | const char * parse_rule(const char * src); |
114 | | |
115 | | bool parse(const char * src); |
116 | | void print(FILE * file); |
117 | | }; |
118 | | |
119 | | struct llama_grammar_trigger_pattern { |
120 | | std::string pattern; |
121 | | std::regex regex; |
122 | | |
123 | | size_t find(const std::string & input) const; |
124 | | }; |
125 | | |
126 | | struct llama_grammar { |
127 | | // maintain a list of llama_tokens and their positions in the trigger_buffer |
128 | | using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>; |
129 | | |
130 | | // note: allow null vocab for testing (not great) |
131 | | const llama_vocab * vocab; |
132 | | |
133 | | const llama_grammar_rules rules; // TODO: shared ptr |
134 | | llama_grammar_stacks stacks; |
135 | | |
136 | | // buffer for partially generated UTF-8 sequence from accepted tokens |
137 | | llama_partial_utf8 partial_utf8; |
138 | | |
139 | | // lazy grammars wait for trigger words or tokens before constraining the sampling. |
140 | | // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. |
141 | | // (useful e.g. for tool_choice=required) |
142 | | bool lazy = false; |
143 | | bool awaiting_trigger = false; // Initialized to true for lazy grammars only |
144 | | std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. |
145 | | std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. |
146 | | std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). |
147 | | std::vector<llama_grammar_trigger_pattern> |
148 | | trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated |
149 | | // string, and the grammar will be given the string from the first match group onwards. |
150 | | |
151 | | }; |
152 | | |
153 | | // |
154 | | // internal API |
155 | | // |
156 | | |
157 | | // note: needed for tests (not great) |
158 | | struct llama_grammar * llama_grammar_init_impl( |
159 | | const struct llama_vocab * vocab, |
160 | | const llama_grammar_element ** rules, |
161 | | size_t n_rules, |
162 | | size_t start_rule_index); |
163 | | |
164 | | struct llama_grammar * llama_grammar_init_impl( |
165 | | const struct llama_vocab * vocab, |
166 | | const char * grammar_str, |
167 | | const char * grammar_root, |
168 | | bool lazy, |
169 | | const char ** trigger_patterns, |
170 | | size_t num_trigger_patterns, |
171 | | const llama_token * trigger_tokens, |
172 | | size_t num_trigger_tokens); |
173 | | |
174 | | void llama_grammar_free_impl(struct llama_grammar * grammar); |
175 | | |
176 | | struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); |
177 | | |
178 | | // TODO: move the API below as member functions of llama_grammar |
179 | | void llama_grammar_apply_impl( |
180 | | const struct llama_grammar & grammar, |
181 | | llama_token_data_array * cur_p); |
182 | | |
183 | | void llama_grammar_accept_impl( |
184 | | struct llama_grammar & grammar, |
185 | | llama_token token); |
186 | | |
187 | | void llama_grammar_accept_str( |
188 | | struct llama_grammar & grammar, |
189 | | const std::string & piece); |
190 | | |
191 | | void llama_grammar_accept_token( |
192 | | struct llama_grammar & grammar, |
193 | | llama_token token, |
194 | | const std::string & piece); |