Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/common/reasoning-budget.cpp
Line
Count
Source
1
#include "reasoning-budget.h"
2
#include "common.h"
3
#include "unicode.h"
4
5
#include "log.h"
6
7
#include <cmath>
8
#include <cstdint>
9
#include <string>
10
#include <vector>
11
12
struct token_matcher {
13
    std::vector<llama_token> tokens;
14
    size_t pos = 0;
15
16
0
    bool advance(llama_token token) {
17
0
        if (tokens.empty()) {
18
0
            return false;
19
0
        }
20
21
0
        if (token == tokens[pos]) {
22
0
            pos++;
23
0
            if (pos >= tokens.size()) {
24
0
                pos = 0;
25
0
                return true;
26
0
            }
27
0
        } else {
28
0
            pos = 0;
29
0
            if (token == tokens[0]) {
30
0
                pos = 1;
31
0
            }
32
0
        }
33
0
        return false;
34
0
    }
35
36
0
    void reset() { pos = 0; }
37
};
38
39
struct common_reasoning_budget_ctx {
40
    const llama_vocab * vocab;
41
42
    token_matcher start_matcher;
43
    token_matcher end_matcher;
44
    std::vector<llama_token> forced_tokens;
45
46
    int32_t budget;           // maximum tokens in reasoning block
47
    int32_t remaining;        // tokens remaining in budget
48
49
    common_reasoning_budget_state state;
50
51
    // for forcing
52
    size_t force_pos;         // next position in forced_tokens to force
53
};
54
55
0
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
56
0
    return "reasoning-budget";
57
0
}
58
59
0
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
60
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
61
62
0
    switch (ctx->state) {
63
0
        case REASONING_BUDGET_IDLE:
64
0
        {
65
0
            if (ctx->start_matcher.advance(token)) {
66
0
                ctx->state = REASONING_BUDGET_COUNTING;
67
0
                ctx->remaining = ctx->budget;
68
0
                LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
69
70
0
                if (ctx->remaining <= 0) {
71
0
                    ctx->state = REASONING_BUDGET_FORCING;
72
0
                    ctx->force_pos = 0;
73
0
                    LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
74
0
                }
75
0
            }
76
0
            break;
77
0
        }
78
0
        case REASONING_BUDGET_COUNTING:
79
0
        case REASONING_BUDGET_WAITING_UTF8:
80
0
        {
81
0
            if (ctx->end_matcher.advance(token)) {
82
0
                ctx->state = REASONING_BUDGET_DONE;
83
0
                LOG_INF("reasoning-budget: deactivated (natural end)\n");
84
0
                break;
85
0
            }
86
87
0
            bool utf8_complete = true;
88
0
            if (ctx->vocab != nullptr) {
89
0
                const std::string piece = common_token_to_piece(ctx->vocab, token, false);
90
0
                utf8_complete = common_utf8_is_complete(piece);
91
0
            }
92
93
0
            if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
94
0
                if (utf8_complete) {
95
0
                    ctx->state = REASONING_BUDGET_FORCING;
96
0
                    ctx->force_pos = 0;
97
0
                    ctx->end_matcher.reset();
98
0
                    LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
99
0
                }
100
0
            } else if (ctx->state == REASONING_BUDGET_COUNTING) {
101
0
                ctx->remaining--;
102
0
                if (ctx->remaining <= 0) {
103
0
                    if (utf8_complete) {
104
0
                        ctx->state = REASONING_BUDGET_FORCING;
105
0
                        ctx->force_pos = 0;
106
0
                        ctx->end_matcher.reset();
107
0
                        LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
108
0
                    } else {
109
0
                        ctx->state = REASONING_BUDGET_WAITING_UTF8;
110
0
                        ctx->end_matcher.reset();
111
0
                        LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
112
0
                    }
113
0
                }
114
0
            }
115
0
            break;
116
0
        }
117
0
        case REASONING_BUDGET_FORCING:
118
0
            ctx->force_pos++;
119
0
            if (ctx->force_pos >= ctx->forced_tokens.size()) {
120
0
                ctx->state = REASONING_BUDGET_DONE;
121
0
                LOG_INF("reasoning-budget: forced sequence complete, done\n");
122
0
            }
123
0
            break;
124
0
        case REASONING_BUDGET_DONE:
125
            // Re-arm on a new start tag: some models emit multiple <think> blocks
126
            // per response, and each should get a fresh budget window.
127
0
            if (ctx->start_matcher.advance(token)) {
128
0
                ctx->state = REASONING_BUDGET_COUNTING;
129
0
                ctx->remaining = ctx->budget;
130
0
                ctx->end_matcher.reset();
131
0
                LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
132
133
0
                if (ctx->remaining <= 0) {
134
0
                    ctx->state = REASONING_BUDGET_FORCING;
135
0
                    ctx->force_pos = 0;
136
0
                    LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
137
0
                }
138
0
            }
139
0
            break;
140
0
    }
141
0
}
142
143
0
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
144
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
145
146
0
    if (ctx->state != REASONING_BUDGET_FORCING) {
147
        // passthrough — don't modify logits
148
0
        return;
149
0
    }
150
151
0
    if (ctx->force_pos >= ctx->forced_tokens.size()) {
152
0
        return;
153
0
    }
154
155
0
    const llama_token forced = ctx->forced_tokens[ctx->force_pos];
156
157
    // set all logits to -inf except the forced token
158
0
    for (size_t i = 0; i < cur_p->size; i++) {
159
0
        if (cur_p->data[i].id != forced) {
160
0
            cur_p->data[i].logit = -INFINITY;
161
0
        }
162
0
    }
163
0
}
164
165
0
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
166
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
167
0
    ctx->state = REASONING_BUDGET_IDLE;
168
0
    ctx->remaining = ctx->budget;
169
0
    ctx->start_matcher.reset();
170
0
    ctx->end_matcher.reset();
171
0
    ctx->force_pos = 0;
172
0
}
173
174
static struct llama_sampler * common_reasoning_budget_init_state(
175
        const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
176
        const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
177
        int32_t budget, common_reasoning_budget_state initial_state);
178
179
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl);
180
181
0
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
182
0
    delete (common_reasoning_budget_ctx *) smpl->ctx;
183
0
}
184
185
static struct llama_sampler_i common_reasoning_budget_i = {
186
    /* .name              = */ common_reasoning_budget_name,
187
    /* .accept            = */ common_reasoning_budget_accept,
188
    /* .apply             = */ common_reasoning_budget_apply,
189
    /* .reset             = */ common_reasoning_budget_reset,
190
    /* .clone             = */ common_reasoning_budget_clone,
191
    /* .free              = */ common_reasoning_budget_free,
192
    /* .backend_init      = */ nullptr,
193
    /* .backend_accept    = */ nullptr,
194
    /* .backend_apply     = */ nullptr,
195
    /* .backend_set_input = */ nullptr,
196
};
197
198
0
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
199
0
    const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
200
201
0
    return llama_sampler_init(
202
0
        /* .iface = */ &common_reasoning_budget_i,
203
0
        /* .ctx   = */ new common_reasoning_budget_ctx(*ctx)
204
0
    );
205
0
}
206
207
static struct llama_sampler * common_reasoning_budget_init_state(
208
        const struct llama_vocab             * vocab,
209
        const std::vector<llama_token>       & start_tokens,
210
        const std::vector<llama_token>       & end_tokens,
211
        const std::vector<llama_token>       & forced_tokens,
212
        int32_t                                budget,
213
0
        common_reasoning_budget_state          initial_state) {
214
    // promote COUNTING with budget <= 0 to FORCING
215
0
    if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
216
0
        initial_state = REASONING_BUDGET_FORCING;
217
0
    }
218
219
0
    return llama_sampler_init(
220
0
        /* .iface = */ &common_reasoning_budget_i,
221
0
        /* .ctx   = */ new common_reasoning_budget_ctx {
222
0
            /* .vocab         = */ vocab,
223
0
            /* .start_matcher = */ { start_tokens, 0 },
224
0
            /* .end_matcher   = */ { end_tokens, 0 },
225
0
            /* .forced_tokens = */ forced_tokens,
226
0
            /* .budget        = */ budget,
227
0
            /* .remaining     = */ budget,
228
0
            /* .state         = */ initial_state,
229
0
            /* .force_pos     = */ 0,
230
0
        }
231
0
    );
232
0
}
233
234
struct llama_sampler * common_reasoning_budget_init(
235
        const struct llama_vocab       * vocab,
236
        const std::vector<llama_token> & start_tokens,
237
        const std::vector<llama_token> & end_tokens,
238
        const std::vector<llama_token> & forced_tokens,
239
        int32_t                          budget,
240
0
        common_reasoning_budget_state    initial_state) {
241
0
    return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
242
0
}
243
244
0
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
245
0
    if (!smpl) {
246
0
        return REASONING_BUDGET_IDLE;
247
0
    }
248
0
    return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
249
0
}
250
251
0
bool common_reasoning_budget_force(struct llama_sampler * smpl) {
252
0
    if (!smpl) {
253
0
        return false;
254
0
    }
255
256
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
257
258
    // only a sampler that is actively counting down the budget may be forced;
259
    // any other state (idle, already forcing/waiting, or done) is left untouched
260
0
    if (ctx->state != REASONING_BUDGET_COUNTING) {
261
0
        return false;
262
0
    }
263
264
0
    ctx->state = REASONING_BUDGET_FORCING;
265
0
    ctx->force_pos = 0;
266
0
    ctx->end_matcher.reset();
267
0
    LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
268
269
0
    return true;
270
0
}