/src/llama.cpp/common/reasoning-budget.cpp
Line | Count | Source |
1 | | #include "reasoning-budget.h" |
2 | | #include "common.h" |
3 | | #include "unicode.h" |
4 | | |
5 | | #include "log.h" |
6 | | |
7 | | #include <cmath> |
8 | | #include <cstdint> |
9 | | #include <string> |
10 | | #include <vector> |
11 | | |
12 | | struct token_matcher { |
13 | | std::vector<llama_token> tokens; |
14 | | size_t pos = 0; |
15 | | |
16 | 0 | bool advance(llama_token token) { |
17 | 0 | if (tokens.empty()) { |
18 | 0 | return false; |
19 | 0 | } |
20 | | |
21 | 0 | if (token == tokens[pos]) { |
22 | 0 | pos++; |
23 | 0 | if (pos >= tokens.size()) { |
24 | 0 | pos = 0; |
25 | 0 | return true; |
26 | 0 | } |
27 | 0 | } else { |
28 | 0 | pos = 0; |
29 | 0 | if (token == tokens[0]) { |
30 | 0 | pos = 1; |
31 | 0 | } |
32 | 0 | } |
33 | 0 | return false; |
34 | 0 | } |
35 | | |
36 | 0 | void reset() { pos = 0; } |
37 | | }; |
38 | | |
39 | | struct common_reasoning_budget_ctx { |
40 | | const llama_vocab * vocab; |
41 | | |
42 | | token_matcher start_matcher; |
43 | | token_matcher end_matcher; |
44 | | std::vector<llama_token> forced_tokens; |
45 | | |
46 | | int32_t budget; // maximum tokens in reasoning block |
47 | | int32_t remaining; // tokens remaining in budget |
48 | | |
49 | | common_reasoning_budget_state state; |
50 | | |
51 | | // for forcing |
52 | | size_t force_pos; // next position in forced_tokens to force |
53 | | }; |
54 | | |
55 | 0 | static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { |
56 | 0 | return "reasoning-budget"; |
57 | 0 | } |
58 | | |
59 | 0 | static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { |
60 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
61 | |
|
62 | 0 | switch (ctx->state) { |
63 | 0 | case REASONING_BUDGET_IDLE: |
64 | 0 | { |
65 | 0 | if (ctx->start_matcher.advance(token)) { |
66 | 0 | ctx->state = REASONING_BUDGET_COUNTING; |
67 | 0 | ctx->remaining = ctx->budget; |
68 | 0 | LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); |
69 | |
|
70 | 0 | if (ctx->remaining <= 0) { |
71 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
72 | 0 | ctx->force_pos = 0; |
73 | 0 | LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); |
74 | 0 | } |
75 | 0 | } |
76 | 0 | break; |
77 | 0 | } |
78 | 0 | case REASONING_BUDGET_COUNTING: |
79 | 0 | case REASONING_BUDGET_WAITING_UTF8: |
80 | 0 | { |
81 | 0 | if (ctx->end_matcher.advance(token)) { |
82 | 0 | ctx->state = REASONING_BUDGET_DONE; |
83 | 0 | LOG_INF("reasoning-budget: deactivated (natural end)\n"); |
84 | 0 | break; |
85 | 0 | } |
86 | | |
87 | 0 | bool utf8_complete = true; |
88 | 0 | if (ctx->vocab != nullptr) { |
89 | 0 | const std::string piece = common_token_to_piece(ctx->vocab, token, false); |
90 | 0 | utf8_complete = common_utf8_is_complete(piece); |
91 | 0 | } |
92 | |
|
93 | 0 | if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { |
94 | 0 | if (utf8_complete) { |
95 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
96 | 0 | ctx->force_pos = 0; |
97 | 0 | ctx->end_matcher.reset(); |
98 | 0 | LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); |
99 | 0 | } |
100 | 0 | } else if (ctx->state == REASONING_BUDGET_COUNTING) { |
101 | 0 | ctx->remaining--; |
102 | 0 | if (ctx->remaining <= 0) { |
103 | 0 | if (utf8_complete) { |
104 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
105 | 0 | ctx->force_pos = 0; |
106 | 0 | ctx->end_matcher.reset(); |
107 | 0 | LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); |
108 | 0 | } else { |
109 | 0 | ctx->state = REASONING_BUDGET_WAITING_UTF8; |
110 | 0 | ctx->end_matcher.reset(); |
111 | 0 | LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); |
112 | 0 | } |
113 | 0 | } |
114 | 0 | } |
115 | 0 | break; |
116 | 0 | } |
117 | 0 | case REASONING_BUDGET_FORCING: |
118 | 0 | ctx->force_pos++; |
119 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
120 | 0 | ctx->state = REASONING_BUDGET_DONE; |
121 | 0 | LOG_INF("reasoning-budget: forced sequence complete, done\n"); |
122 | 0 | } |
123 | 0 | break; |
124 | 0 | case REASONING_BUDGET_DONE: |
125 | 0 | break; |
126 | 0 | } |
127 | 0 | } |
128 | | |
129 | 0 | static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
130 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
131 | |
|
132 | 0 | if (ctx->state != REASONING_BUDGET_FORCING) { |
133 | | // passthrough — don't modify logits |
134 | 0 | return; |
135 | 0 | } |
136 | | |
137 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
138 | 0 | return; |
139 | 0 | } |
140 | | |
141 | 0 | const llama_token forced = ctx->forced_tokens[ctx->force_pos]; |
142 | | |
143 | | // set all logits to -inf except the forced token |
144 | 0 | for (size_t i = 0; i < cur_p->size; i++) { |
145 | 0 | if (cur_p->data[i].id != forced) { |
146 | 0 | cur_p->data[i].logit = -INFINITY; |
147 | 0 | } |
148 | 0 | } |
149 | 0 | } |
150 | | |
151 | 0 | static void common_reasoning_budget_reset(struct llama_sampler * smpl) { |
152 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
153 | 0 | ctx->state = REASONING_BUDGET_IDLE; |
154 | 0 | ctx->remaining = ctx->budget; |
155 | 0 | ctx->start_matcher.reset(); |
156 | 0 | ctx->end_matcher.reset(); |
157 | 0 | ctx->force_pos = 0; |
158 | 0 | } |
159 | | |
160 | | // forward declaration for use in clone |
161 | | static struct llama_sampler * common_reasoning_budget_init_state( |
162 | | const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens, |
163 | | const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens, |
164 | | int32_t budget, common_reasoning_budget_state initial_state); |
165 | | |
166 | 0 | static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { |
167 | 0 | const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; |
168 | 0 | return common_reasoning_budget_init_state( |
169 | 0 | ctx->vocab, |
170 | 0 | ctx->start_matcher.tokens, |
171 | 0 | ctx->end_matcher.tokens, |
172 | 0 | ctx->forced_tokens, |
173 | 0 | ctx->budget, |
174 | 0 | ctx->state); |
175 | 0 | } |
176 | | |
177 | 0 | static void common_reasoning_budget_free(struct llama_sampler * smpl) { |
178 | 0 | delete (common_reasoning_budget_ctx *) smpl->ctx; |
179 | 0 | } |
180 | | |
181 | | static struct llama_sampler_i common_reasoning_budget_i = { |
182 | | /* .name = */ common_reasoning_budget_name, |
183 | | /* .accept = */ common_reasoning_budget_accept, |
184 | | /* .apply = */ common_reasoning_budget_apply, |
185 | | /* .reset = */ common_reasoning_budget_reset, |
186 | | /* .clone = */ common_reasoning_budget_clone, |
187 | | /* .free = */ common_reasoning_budget_free, |
188 | | /* .backend_init = */ nullptr, |
189 | | /* .backend_accept = */ nullptr, |
190 | | /* .backend_apply = */ nullptr, |
191 | | /* .backend_set_input = */ nullptr, |
192 | | }; |
193 | | |
194 | | static struct llama_sampler * common_reasoning_budget_init_state( |
195 | | const struct llama_vocab * vocab, |
196 | | const std::vector<llama_token> & start_tokens, |
197 | | const std::vector<llama_token> & end_tokens, |
198 | | const std::vector<llama_token> & forced_tokens, |
199 | | int32_t budget, |
200 | 0 | common_reasoning_budget_state initial_state) { |
201 | | // promote COUNTING with budget <= 0 to FORCING |
202 | 0 | if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { |
203 | 0 | initial_state = REASONING_BUDGET_FORCING; |
204 | 0 | } |
205 | |
|
206 | 0 | return llama_sampler_init( |
207 | 0 | /* .iface = */ &common_reasoning_budget_i, |
208 | 0 | /* .ctx = */ new common_reasoning_budget_ctx { |
209 | 0 | /* .vocab = */ vocab, |
210 | 0 | /* .start_matcher = */ { start_tokens, 0 }, |
211 | 0 | /* .end_matcher = */ { end_tokens, 0 }, |
212 | 0 | /* .forced_tokens = */ forced_tokens, |
213 | 0 | /* .budget = */ budget, |
214 | 0 | /* .remaining = */ budget, |
215 | 0 | /* .state = */ initial_state, |
216 | 0 | /* .force_pos = */ 0, |
217 | 0 | } |
218 | 0 | ); |
219 | 0 | } |
220 | | |
221 | | struct llama_sampler * common_reasoning_budget_init( |
222 | | const struct llama_vocab * vocab, |
223 | | const std::vector<llama_token> & start_tokens, |
224 | | const std::vector<llama_token> & end_tokens, |
225 | | const std::vector<llama_token> & forced_tokens, |
226 | | int32_t budget, |
227 | 0 | const std::vector<llama_token> & prefill_tokens) { |
228 | | // Determine initial state from prefill: COUNTING if the prefill begins with |
229 | | // the start sequence but does not also contain the end sequence after it. |
230 | 0 | common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; |
231 | 0 | if (!prefill_tokens.empty() && !start_tokens.empty() && |
232 | 0 | prefill_tokens.size() >= start_tokens.size() && |
233 | 0 | std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { |
234 | 0 | initial_state = REASONING_BUDGET_COUNTING; |
235 | | // If the end sequence also follows the start in the prefill, reasoning |
236 | | // was opened and immediately closed — stay IDLE. |
237 | 0 | if (!end_tokens.empty() && |
238 | 0 | prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { |
239 | 0 | auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size(); |
240 | 0 | if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() && |
241 | 0 | std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { |
242 | 0 | initial_state = REASONING_BUDGET_IDLE; |
243 | 0 | } |
244 | 0 | } |
245 | 0 | } |
246 | 0 | return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); |
247 | 0 | } |
248 | | |
249 | | struct llama_sampler * common_reasoning_budget_init( |
250 | | const struct llama_vocab * vocab, |
251 | | const std::vector<llama_token> & start_tokens, |
252 | | const std::vector<llama_token> & end_tokens, |
253 | | const std::vector<llama_token> & forced_tokens, |
254 | | int32_t budget, |
255 | 0 | common_reasoning_budget_state initial_state) { |
256 | 0 | return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); |
257 | 0 | } |
258 | | |
259 | 0 | common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) { |
260 | 0 | if (!smpl) { |
261 | 0 | return REASONING_BUDGET_IDLE; |
262 | 0 | } |
263 | 0 | return ((const common_reasoning_budget_ctx *)smpl->ctx)->state; |
264 | 0 | } |