/src/llama.cpp/common/reasoning-budget.cpp
Line | Count | Source |
1 | | #include "reasoning-budget.h" |
2 | | #include "common.h" |
3 | | #include "unicode.h" |
4 | | |
5 | | #include "log.h" |
6 | | |
7 | | #include <cmath> |
8 | | #include <cstdint> |
9 | | #include <string> |
10 | | #include <vector> |
11 | | |
12 | | struct token_matcher { |
13 | | std::vector<llama_token> tokens; |
14 | | size_t pos = 0; |
15 | | |
16 | 0 | bool advance(llama_token token) { |
17 | 0 | if (tokens.empty()) { |
18 | 0 | return false; |
19 | 0 | } |
20 | | |
21 | 0 | if (token == tokens[pos]) { |
22 | 0 | pos++; |
23 | 0 | if (pos >= tokens.size()) { |
24 | 0 | pos = 0; |
25 | 0 | return true; |
26 | 0 | } |
27 | 0 | } else { |
28 | 0 | pos = 0; |
29 | 0 | if (token == tokens[0]) { |
30 | 0 | pos = 1; |
31 | 0 | } |
32 | 0 | } |
33 | 0 | return false; |
34 | 0 | } |
35 | | |
36 | 0 | void reset() { pos = 0; } |
37 | | }; |
38 | | |
39 | | struct common_reasoning_budget_ctx { |
40 | | const llama_vocab * vocab; |
41 | | |
42 | | token_matcher start_matcher; |
43 | | token_matcher end_matcher; |
44 | | std::vector<llama_token> forced_tokens; |
45 | | |
46 | | int32_t budget; // maximum tokens in reasoning block |
47 | | int32_t remaining; // tokens remaining in budget |
48 | | |
49 | | common_reasoning_budget_state state; |
50 | | |
51 | | // for forcing |
52 | | size_t force_pos; // next position in forced_tokens to force |
53 | | }; |
54 | | |
55 | 0 | static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { |
56 | 0 | return "reasoning-budget"; |
57 | 0 | } |
58 | | |
59 | 0 | static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { |
60 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
61 | |
|
62 | 0 | switch (ctx->state) { |
63 | 0 | case REASONING_BUDGET_IDLE: |
64 | 0 | { |
65 | 0 | if (ctx->start_matcher.advance(token)) { |
66 | 0 | ctx->state = REASONING_BUDGET_COUNTING; |
67 | 0 | ctx->remaining = ctx->budget; |
68 | 0 | LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); |
69 | |
|
70 | 0 | if (ctx->remaining <= 0) { |
71 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
72 | 0 | ctx->force_pos = 0; |
73 | 0 | LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); |
74 | 0 | } |
75 | 0 | } |
76 | 0 | break; |
77 | 0 | } |
78 | 0 | case REASONING_BUDGET_COUNTING: |
79 | 0 | case REASONING_BUDGET_WAITING_UTF8: |
80 | 0 | { |
81 | 0 | if (ctx->end_matcher.advance(token)) { |
82 | 0 | ctx->state = REASONING_BUDGET_DONE; |
83 | 0 | LOG_INF("reasoning-budget: deactivated (natural end)\n"); |
84 | 0 | break; |
85 | 0 | } |
86 | | |
87 | 0 | bool utf8_complete = true; |
88 | 0 | if (ctx->vocab != nullptr) { |
89 | 0 | const std::string piece = common_token_to_piece(ctx->vocab, token, false); |
90 | 0 | utf8_complete = common_utf8_is_complete(piece); |
91 | 0 | } |
92 | |
|
93 | 0 | if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { |
94 | 0 | if (utf8_complete) { |
95 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
96 | 0 | ctx->force_pos = 0; |
97 | 0 | ctx->end_matcher.reset(); |
98 | 0 | LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); |
99 | 0 | } |
100 | 0 | } else if (ctx->state == REASONING_BUDGET_COUNTING) { |
101 | 0 | ctx->remaining--; |
102 | 0 | if (ctx->remaining <= 0) { |
103 | 0 | if (utf8_complete) { |
104 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
105 | 0 | ctx->force_pos = 0; |
106 | 0 | ctx->end_matcher.reset(); |
107 | 0 | LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); |
108 | 0 | } else { |
109 | 0 | ctx->state = REASONING_BUDGET_WAITING_UTF8; |
110 | 0 | ctx->end_matcher.reset(); |
111 | 0 | LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); |
112 | 0 | } |
113 | 0 | } |
114 | 0 | } |
115 | 0 | break; |
116 | 0 | } |
117 | 0 | case REASONING_BUDGET_FORCING: |
118 | 0 | ctx->force_pos++; |
119 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
120 | 0 | ctx->state = REASONING_BUDGET_DONE; |
121 | 0 | LOG_INF("reasoning-budget: forced sequence complete, done\n"); |
122 | 0 | } |
123 | 0 | break; |
124 | 0 | case REASONING_BUDGET_DONE: |
125 | | // Re-arm on a new start tag: some models emit multiple <think> blocks |
126 | | // per response, and each should get a fresh budget window. |
127 | 0 | if (ctx->start_matcher.advance(token)) { |
128 | 0 | ctx->state = REASONING_BUDGET_COUNTING; |
129 | 0 | ctx->remaining = ctx->budget; |
130 | 0 | ctx->end_matcher.reset(); |
131 | 0 | LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget); |
132 | |
|
133 | 0 | if (ctx->remaining <= 0) { |
134 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
135 | 0 | ctx->force_pos = 0; |
136 | 0 | LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); |
137 | 0 | } |
138 | 0 | } |
139 | 0 | break; |
140 | 0 | } |
141 | 0 | } |
142 | | |
143 | 0 | static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
144 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
145 | |
|
146 | 0 | if (ctx->state != REASONING_BUDGET_FORCING) { |
147 | | // passthrough — don't modify logits |
148 | 0 | return; |
149 | 0 | } |
150 | | |
151 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
152 | 0 | return; |
153 | 0 | } |
154 | | |
155 | 0 | const llama_token forced = ctx->forced_tokens[ctx->force_pos]; |
156 | | |
157 | | // set all logits to -inf except the forced token |
158 | 0 | for (size_t i = 0; i < cur_p->size; i++) { |
159 | 0 | if (cur_p->data[i].id != forced) { |
160 | 0 | cur_p->data[i].logit = -INFINITY; |
161 | 0 | } |
162 | 0 | } |
163 | 0 | } |
164 | | |
165 | 0 | static void common_reasoning_budget_reset(struct llama_sampler * smpl) { |
166 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
167 | 0 | ctx->state = REASONING_BUDGET_IDLE; |
168 | 0 | ctx->remaining = ctx->budget; |
169 | 0 | ctx->start_matcher.reset(); |
170 | 0 | ctx->end_matcher.reset(); |
171 | 0 | ctx->force_pos = 0; |
172 | 0 | } |
173 | | |
174 | | static struct llama_sampler * common_reasoning_budget_init_state( |
175 | | const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens, |
176 | | const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens, |
177 | | int32_t budget, common_reasoning_budget_state initial_state); |
178 | | |
179 | | static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl); |
180 | | |
181 | 0 | static void common_reasoning_budget_free(struct llama_sampler * smpl) { |
182 | 0 | delete (common_reasoning_budget_ctx *) smpl->ctx; |
183 | 0 | } |
184 | | |
185 | | static struct llama_sampler_i common_reasoning_budget_i = { |
186 | | /* .name = */ common_reasoning_budget_name, |
187 | | /* .accept = */ common_reasoning_budget_accept, |
188 | | /* .apply = */ common_reasoning_budget_apply, |
189 | | /* .reset = */ common_reasoning_budget_reset, |
190 | | /* .clone = */ common_reasoning_budget_clone, |
191 | | /* .free = */ common_reasoning_budget_free, |
192 | | /* .backend_init = */ nullptr, |
193 | | /* .backend_accept = */ nullptr, |
194 | | /* .backend_apply = */ nullptr, |
195 | | /* .backend_set_input = */ nullptr, |
196 | | }; |
197 | | |
198 | 0 | static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { |
199 | 0 | const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; |
200 | |
|
201 | 0 | return llama_sampler_init( |
202 | 0 | /* .iface = */ &common_reasoning_budget_i, |
203 | 0 | /* .ctx = */ new common_reasoning_budget_ctx(*ctx) |
204 | 0 | ); |
205 | 0 | } |
206 | | |
207 | | static struct llama_sampler * common_reasoning_budget_init_state( |
208 | | const struct llama_vocab * vocab, |
209 | | const std::vector<llama_token> & start_tokens, |
210 | | const std::vector<llama_token> & end_tokens, |
211 | | const std::vector<llama_token> & forced_tokens, |
212 | | int32_t budget, |
213 | 0 | common_reasoning_budget_state initial_state) { |
214 | | // promote COUNTING with budget <= 0 to FORCING |
215 | 0 | if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { |
216 | 0 | initial_state = REASONING_BUDGET_FORCING; |
217 | 0 | } |
218 | |
|
219 | 0 | return llama_sampler_init( |
220 | 0 | /* .iface = */ &common_reasoning_budget_i, |
221 | 0 | /* .ctx = */ new common_reasoning_budget_ctx { |
222 | 0 | /* .vocab = */ vocab, |
223 | 0 | /* .start_matcher = */ { start_tokens, 0 }, |
224 | 0 | /* .end_matcher = */ { end_tokens, 0 }, |
225 | 0 | /* .forced_tokens = */ forced_tokens, |
226 | 0 | /* .budget = */ budget, |
227 | 0 | /* .remaining = */ budget, |
228 | 0 | /* .state = */ initial_state, |
229 | 0 | /* .force_pos = */ 0, |
230 | 0 | } |
231 | 0 | ); |
232 | 0 | } |
233 | | |
234 | | struct llama_sampler * common_reasoning_budget_init( |
235 | | const struct llama_vocab * vocab, |
236 | | const std::vector<llama_token> & start_tokens, |
237 | | const std::vector<llama_token> & end_tokens, |
238 | | const std::vector<llama_token> & forced_tokens, |
239 | | int32_t budget, |
240 | 0 | common_reasoning_budget_state initial_state) { |
241 | 0 | return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); |
242 | 0 | } |
243 | | |
244 | 0 | common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) { |
245 | 0 | if (!smpl) { |
246 | 0 | return REASONING_BUDGET_IDLE; |
247 | 0 | } |
248 | 0 | return ((const common_reasoning_budget_ctx *)smpl->ctx)->state; |
249 | 0 | } |
250 | | |
251 | 0 | bool common_reasoning_budget_force(struct llama_sampler * smpl) { |
252 | 0 | if (!smpl) { |
253 | 0 | return false; |
254 | 0 | } |
255 | | |
256 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
257 | | |
258 | | // only a sampler that is actively counting down the budget may be forced; |
259 | | // any other state (idle, already forcing/waiting, or done) is left untouched |
260 | 0 | if (ctx->state != REASONING_BUDGET_COUNTING) { |
261 | 0 | return false; |
262 | 0 | } |
263 | | |
264 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
265 | 0 | ctx->force_pos = 0; |
266 | 0 | ctx->end_matcher.reset(); |
267 | 0 | LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n"); |
268 | |
|
269 | 0 | return true; |
270 | 0 | } |