/src/llama.cpp/common/reasoning-budget.cpp
Line | Count | Source |
1 | | #include "reasoning-budget.h" |
2 | | #include "common.h" |
3 | | #include "unicode.h" |
4 | | |
5 | | #include "log.h" |
6 | | |
7 | | #include <cmath> |
8 | | #include <cstdint> |
9 | | #include <string> |
10 | | #include <vector> |
11 | | |
12 | | struct token_matcher { |
13 | | std::vector<llama_token> tokens; |
14 | | size_t pos = 0; |
15 | | |
16 | 0 | bool advance(llama_token token) { |
17 | 0 | if (tokens.empty()) { |
18 | 0 | return false; |
19 | 0 | } |
20 | | |
21 | 0 | if (token == tokens[pos]) { |
22 | 0 | pos++; |
23 | 0 | if (pos >= tokens.size()) { |
24 | 0 | pos = 0; |
25 | 0 | return true; |
26 | 0 | } |
27 | 0 | } else { |
28 | 0 | pos = 0; |
29 | 0 | if (token == tokens[0]) { |
30 | 0 | pos = 1; |
31 | 0 | } |
32 | 0 | } |
33 | 0 | return false; |
34 | 0 | } |
35 | | |
36 | 0 | void reset() { pos = 0; } |
37 | | }; |
38 | | |
39 | | struct common_reasoning_budget_ctx { |
40 | | const llama_vocab * vocab; |
41 | | |
42 | | token_matcher start_matcher; |
43 | | token_matcher end_matcher; |
44 | | std::vector<llama_token> forced_tokens; |
45 | | |
46 | | int32_t budget; // maximum tokens in reasoning block |
47 | | int32_t remaining; // tokens remaining in budget |
48 | | |
49 | | common_reasoning_budget_state state; |
50 | | |
51 | | // for forcing |
52 | | size_t force_pos; // next position in forced_tokens to force |
53 | | }; |
54 | | |
55 | 0 | static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { |
56 | 0 | return "reasoning-budget"; |
57 | 0 | } |
58 | | |
59 | 0 | static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { |
60 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
61 | |
|
62 | 0 | switch (ctx->state) { |
63 | 0 | case REASONING_BUDGET_IDLE: |
64 | 0 | { |
65 | 0 | if (ctx->start_matcher.advance(token)) { |
66 | 0 | ctx->state = REASONING_BUDGET_COUNTING; |
67 | 0 | ctx->remaining = ctx->budget; |
68 | 0 | LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); |
69 | |
|
70 | 0 | if (ctx->remaining <= 0) { |
71 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
72 | 0 | ctx->force_pos = 0; |
73 | 0 | LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); |
74 | 0 | } |
75 | 0 | } |
76 | 0 | break; |
77 | 0 | } |
78 | 0 | case REASONING_BUDGET_COUNTING: |
79 | 0 | case REASONING_BUDGET_WAITING_UTF8: |
80 | 0 | { |
81 | 0 | if (ctx->end_matcher.advance(token)) { |
82 | 0 | ctx->state = REASONING_BUDGET_DONE; |
83 | 0 | LOG_INF("reasoning-budget: deactivated (natural end)\n"); |
84 | 0 | break; |
85 | 0 | } |
86 | | |
87 | 0 | bool utf8_complete = true; |
88 | 0 | if (ctx->vocab != nullptr) { |
89 | 0 | const std::string piece = common_token_to_piece(ctx->vocab, token, false); |
90 | 0 | utf8_complete = common_utf8_is_complete(piece); |
91 | 0 | } |
92 | |
|
93 | 0 | if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { |
94 | 0 | if (utf8_complete) { |
95 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
96 | 0 | ctx->force_pos = 0; |
97 | 0 | ctx->end_matcher.reset(); |
98 | 0 | LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); |
99 | 0 | } |
100 | 0 | } else if (ctx->state == REASONING_BUDGET_COUNTING) { |
101 | 0 | ctx->remaining--; |
102 | 0 | if (ctx->remaining <= 0) { |
103 | 0 | if (utf8_complete) { |
104 | 0 | ctx->state = REASONING_BUDGET_FORCING; |
105 | 0 | ctx->force_pos = 0; |
106 | 0 | ctx->end_matcher.reset(); |
107 | 0 | LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); |
108 | 0 | } else { |
109 | 0 | ctx->state = REASONING_BUDGET_WAITING_UTF8; |
110 | 0 | ctx->end_matcher.reset(); |
111 | 0 | LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); |
112 | 0 | } |
113 | 0 | } |
114 | 0 | } |
115 | 0 | break; |
116 | 0 | } |
117 | 0 | case REASONING_BUDGET_FORCING: |
118 | | // force_pos is advanced in apply(), not here. |
119 | | // This ensures the first forced token isn't skipped when the sampler |
120 | | // is initialized directly in FORCING state (e.g. COUNTING + budget=0) |
121 | 0 | break; |
122 | 0 | case REASONING_BUDGET_DONE: |
123 | 0 | break; |
124 | 0 | } |
125 | 0 | } |
126 | | |
127 | 0 | static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
128 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
129 | |
|
130 | 0 | if (ctx->state != REASONING_BUDGET_FORCING) { |
131 | | // passthrough — don't modify logits |
132 | 0 | return; |
133 | 0 | } |
134 | | |
135 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
136 | 0 | return; |
137 | 0 | } |
138 | | |
139 | 0 | const llama_token forced = ctx->forced_tokens[ctx->force_pos]; |
140 | | |
141 | | // set all logits to -inf except the forced token |
142 | 0 | for (size_t i = 0; i < cur_p->size; i++) { |
143 | 0 | if (cur_p->data[i].id != forced) { |
144 | 0 | cur_p->data[i].logit = -INFINITY; |
145 | 0 | } |
146 | 0 | } |
147 | | |
148 | | // advance to next forced token (done here rather than in accept so that |
149 | | // the first forced token isn't skipped when starting in FORCING state) |
150 | 0 | ctx->force_pos++; |
151 | 0 | if (ctx->force_pos >= ctx->forced_tokens.size()) { |
152 | 0 | ctx->state = REASONING_BUDGET_DONE; |
153 | 0 | LOG_INF("reasoning-budget: forced sequence complete, done\n"); |
154 | 0 | } |
155 | 0 | } |
156 | | |
157 | 0 | static void common_reasoning_budget_reset(struct llama_sampler * smpl) { |
158 | 0 | auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; |
159 | 0 | ctx->state = REASONING_BUDGET_IDLE; |
160 | 0 | ctx->remaining = ctx->budget; |
161 | 0 | ctx->start_matcher.reset(); |
162 | 0 | ctx->end_matcher.reset(); |
163 | 0 | ctx->force_pos = 0; |
164 | 0 | } |
165 | | |
166 | | // forward declaration for use in clone |
167 | | static struct llama_sampler * common_reasoning_budget_init_state( |
168 | | const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens, |
169 | | const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens, |
170 | | int32_t budget, common_reasoning_budget_state initial_state); |
171 | | |
172 | 0 | static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { |
173 | 0 | const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; |
174 | 0 | return common_reasoning_budget_init_state( |
175 | 0 | ctx->vocab, |
176 | 0 | ctx->start_matcher.tokens, |
177 | 0 | ctx->end_matcher.tokens, |
178 | 0 | ctx->forced_tokens, |
179 | 0 | ctx->budget, |
180 | 0 | ctx->state); |
181 | 0 | } |
182 | | |
183 | 0 | static void common_reasoning_budget_free(struct llama_sampler * smpl) { |
184 | 0 | delete (common_reasoning_budget_ctx *) smpl->ctx; |
185 | 0 | } |
186 | | |
187 | | static struct llama_sampler_i common_reasoning_budget_i = { |
188 | | /* .name = */ common_reasoning_budget_name, |
189 | | /* .accept = */ common_reasoning_budget_accept, |
190 | | /* .apply = */ common_reasoning_budget_apply, |
191 | | /* .reset = */ common_reasoning_budget_reset, |
192 | | /* .clone = */ common_reasoning_budget_clone, |
193 | | /* .free = */ common_reasoning_budget_free, |
194 | | /* .backend_init = */ nullptr, |
195 | | /* .backend_accept = */ nullptr, |
196 | | /* .backend_apply = */ nullptr, |
197 | | /* .backend_set_input = */ nullptr, |
198 | | }; |
199 | | |
200 | | static struct llama_sampler * common_reasoning_budget_init_state( |
201 | | const struct llama_vocab * vocab, |
202 | | const std::vector<llama_token> & start_tokens, |
203 | | const std::vector<llama_token> & end_tokens, |
204 | | const std::vector<llama_token> & forced_tokens, |
205 | | int32_t budget, |
206 | 0 | common_reasoning_budget_state initial_state) { |
207 | | // promote COUNTING with budget <= 0 to FORCING |
208 | 0 | if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { |
209 | 0 | initial_state = REASONING_BUDGET_FORCING; |
210 | 0 | } |
211 | |
|
212 | 0 | return llama_sampler_init( |
213 | 0 | /* .iface = */ &common_reasoning_budget_i, |
214 | 0 | /* .ctx = */ new common_reasoning_budget_ctx { |
215 | 0 | /* .vocab = */ vocab, |
216 | 0 | /* .start_matcher = */ { start_tokens, 0 }, |
217 | 0 | /* .end_matcher = */ { end_tokens, 0 }, |
218 | 0 | /* .forced_tokens = */ forced_tokens, |
219 | 0 | /* .budget = */ budget, |
220 | 0 | /* .remaining = */ budget, |
221 | 0 | /* .state = */ initial_state, |
222 | 0 | /* .force_pos = */ 0, |
223 | 0 | } |
224 | 0 | ); |
225 | 0 | } |
226 | | |
227 | | struct llama_sampler * common_reasoning_budget_init( |
228 | | const struct llama_vocab * vocab, |
229 | | const std::vector<llama_token> & start_tokens, |
230 | | const std::vector<llama_token> & end_tokens, |
231 | | const std::vector<llama_token> & forced_tokens, |
232 | | int32_t budget, |
233 | 0 | const std::vector<llama_token> & prefill_tokens) { |
234 | | // Determine initial state from prefill: COUNTING if the prefill begins with |
235 | | // the start sequence but does not also contain the end sequence after it. |
236 | 0 | common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; |
237 | 0 | if (!prefill_tokens.empty() && !start_tokens.empty() && |
238 | 0 | prefill_tokens.size() >= start_tokens.size() && |
239 | 0 | std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { |
240 | 0 | initial_state = REASONING_BUDGET_COUNTING; |
241 | | // If the end sequence also follows the start in the prefill, reasoning |
242 | | // was opened and immediately closed — stay IDLE. |
243 | 0 | if (!end_tokens.empty() && |
244 | 0 | prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { |
245 | 0 | auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size(); |
246 | 0 | if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() && |
247 | 0 | std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { |
248 | 0 | initial_state = REASONING_BUDGET_IDLE; |
249 | 0 | } |
250 | 0 | } |
251 | 0 | } |
252 | 0 | return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); |
253 | 0 | } |
254 | | |
255 | | struct llama_sampler * common_reasoning_budget_init( |
256 | | const struct llama_vocab * vocab, |
257 | | const std::vector<llama_token> & start_tokens, |
258 | | const std::vector<llama_token> & end_tokens, |
259 | | const std::vector<llama_token> & forced_tokens, |
260 | | int32_t budget, |
261 | 0 | common_reasoning_budget_state initial_state) { |
262 | 0 | return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); |
263 | 0 | } |