Coverage Report

Created: 2026-03-21 06:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/common/reasoning-budget.cpp
Line
Count
Source
1
#include "reasoning-budget.h"
2
#include "common.h"
3
#include "unicode.h"
4
5
#include "log.h"
6
7
#include <cmath>
8
#include <cstdint>
9
#include <string>
10
#include <vector>
11
12
struct token_matcher {
13
    std::vector<llama_token> tokens;
14
    size_t pos = 0;
15
16
0
    bool advance(llama_token token) {
17
0
        if (tokens.empty()) {
18
0
            return false;
19
0
        }
20
21
0
        if (token == tokens[pos]) {
22
0
            pos++;
23
0
            if (pos >= tokens.size()) {
24
0
                pos = 0;
25
0
                return true;
26
0
            }
27
0
        } else {
28
0
            pos = 0;
29
0
            if (token == tokens[0]) {
30
0
                pos = 1;
31
0
            }
32
0
        }
33
0
        return false;
34
0
    }
35
36
0
    void reset() { pos = 0; }
37
};
38
39
struct common_reasoning_budget_ctx {
40
    const llama_vocab * vocab;
41
42
    token_matcher start_matcher;
43
    token_matcher end_matcher;
44
    std::vector<llama_token> forced_tokens;
45
46
    int32_t budget;           // maximum tokens in reasoning block
47
    int32_t remaining;        // tokens remaining in budget
48
49
    common_reasoning_budget_state state;
50
51
    // for forcing
52
    size_t force_pos;         // next position in forced_tokens to force
53
};
54
55
0
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
56
0
    return "reasoning-budget";
57
0
}
58
59
0
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
60
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
61
62
0
    switch (ctx->state) {
63
0
        case REASONING_BUDGET_IDLE:
64
0
        {
65
0
            if (ctx->start_matcher.advance(token)) {
66
0
                ctx->state = REASONING_BUDGET_COUNTING;
67
0
                ctx->remaining = ctx->budget;
68
0
                LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
69
70
0
                if (ctx->remaining <= 0) {
71
0
                    ctx->state = REASONING_BUDGET_FORCING;
72
0
                    ctx->force_pos = 0;
73
0
                    LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
74
0
                }
75
0
            }
76
0
            break;
77
0
        }
78
0
        case REASONING_BUDGET_COUNTING:
79
0
        case REASONING_BUDGET_WAITING_UTF8:
80
0
        {
81
0
            if (ctx->end_matcher.advance(token)) {
82
0
                ctx->state = REASONING_BUDGET_DONE;
83
0
                LOG_INF("reasoning-budget: deactivated (natural end)\n");
84
0
                break;
85
0
            }
86
87
0
            bool utf8_complete = true;
88
0
            if (ctx->vocab != nullptr) {
89
0
                const std::string piece = common_token_to_piece(ctx->vocab, token, false);
90
0
                utf8_complete = common_utf8_is_complete(piece);
91
0
            }
92
93
0
            if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
94
0
                if (utf8_complete) {
95
0
                    ctx->state = REASONING_BUDGET_FORCING;
96
0
                    ctx->force_pos = 0;
97
0
                    ctx->end_matcher.reset();
98
0
                    LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
99
0
                }
100
0
            } else if (ctx->state == REASONING_BUDGET_COUNTING) {
101
0
                ctx->remaining--;
102
0
                if (ctx->remaining <= 0) {
103
0
                    if (utf8_complete) {
104
0
                        ctx->state = REASONING_BUDGET_FORCING;
105
0
                        ctx->force_pos = 0;
106
0
                        ctx->end_matcher.reset();
107
0
                        LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
108
0
                    } else {
109
0
                        ctx->state = REASONING_BUDGET_WAITING_UTF8;
110
0
                        ctx->end_matcher.reset();
111
0
                        LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
112
0
                    }
113
0
                }
114
0
            }
115
0
            break;
116
0
        }
117
0
        case REASONING_BUDGET_FORCING:
118
            // force_pos is advanced in apply(), not here.
119
            // This ensures the first forced token isn't skipped when the sampler
120
            // is initialized directly in FORCING state (e.g. COUNTING + budget=0)
121
0
            break;
122
0
        case REASONING_BUDGET_DONE:
123
0
            break;
124
0
    }
125
0
}
126
127
0
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
128
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
129
130
0
    if (ctx->state != REASONING_BUDGET_FORCING) {
131
        // passthrough — don't modify logits
132
0
        return;
133
0
    }
134
135
0
    if (ctx->force_pos >= ctx->forced_tokens.size()) {
136
0
        return;
137
0
    }
138
139
0
    const llama_token forced = ctx->forced_tokens[ctx->force_pos];
140
141
    // set all logits to -inf except the forced token
142
0
    for (size_t i = 0; i < cur_p->size; i++) {
143
0
        if (cur_p->data[i].id != forced) {
144
0
            cur_p->data[i].logit = -INFINITY;
145
0
        }
146
0
    }
147
148
    // advance to next forced token (done here rather than in accept so that
149
    // the first forced token isn't skipped when starting in FORCING state)
150
0
    ctx->force_pos++;
151
0
    if (ctx->force_pos >= ctx->forced_tokens.size()) {
152
0
        ctx->state = REASONING_BUDGET_DONE;
153
0
        LOG_INF("reasoning-budget: forced sequence complete, done\n");
154
0
    }
155
0
}
156
157
0
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
158
0
    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
159
0
    ctx->state = REASONING_BUDGET_IDLE;
160
0
    ctx->remaining = ctx->budget;
161
0
    ctx->start_matcher.reset();
162
0
    ctx->end_matcher.reset();
163
0
    ctx->force_pos = 0;
164
0
}
165
166
// forward declaration for use in clone
167
static struct llama_sampler * common_reasoning_budget_init_state(
168
        const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
169
        const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
170
        int32_t budget, common_reasoning_budget_state initial_state);
171
172
0
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
173
0
    const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
174
0
    return common_reasoning_budget_init_state(
175
0
        ctx->vocab,
176
0
        ctx->start_matcher.tokens,
177
0
        ctx->end_matcher.tokens,
178
0
        ctx->forced_tokens,
179
0
        ctx->budget,
180
0
        ctx->state);
181
0
}
182
183
0
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
184
0
    delete (common_reasoning_budget_ctx *) smpl->ctx;
185
0
}
186
187
static struct llama_sampler_i common_reasoning_budget_i = {
188
    /* .name              = */ common_reasoning_budget_name,
189
    /* .accept            = */ common_reasoning_budget_accept,
190
    /* .apply             = */ common_reasoning_budget_apply,
191
    /* .reset             = */ common_reasoning_budget_reset,
192
    /* .clone             = */ common_reasoning_budget_clone,
193
    /* .free              = */ common_reasoning_budget_free,
194
    /* .backend_init      = */ nullptr,
195
    /* .backend_accept    = */ nullptr,
196
    /* .backend_apply     = */ nullptr,
197
    /* .backend_set_input = */ nullptr,
198
};
199
200
static struct llama_sampler * common_reasoning_budget_init_state(
201
        const struct llama_vocab             * vocab,
202
        const std::vector<llama_token>       & start_tokens,
203
        const std::vector<llama_token>       & end_tokens,
204
        const std::vector<llama_token>       & forced_tokens,
205
        int32_t                                budget,
206
0
        common_reasoning_budget_state          initial_state) {
207
    // promote COUNTING with budget <= 0 to FORCING
208
0
    if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
209
0
        initial_state = REASONING_BUDGET_FORCING;
210
0
    }
211
212
0
    return llama_sampler_init(
213
0
        /* .iface = */ &common_reasoning_budget_i,
214
0
        /* .ctx   = */ new common_reasoning_budget_ctx {
215
0
            /* .vocab         = */ vocab,
216
0
            /* .start_matcher = */ { start_tokens, 0 },
217
0
            /* .end_matcher   = */ { end_tokens, 0 },
218
0
            /* .forced_tokens = */ forced_tokens,
219
0
            /* .budget        = */ budget,
220
0
            /* .remaining     = */ budget,
221
0
            /* .state         = */ initial_state,
222
0
            /* .force_pos     = */ 0,
223
0
        }
224
0
    );
225
0
}
226
227
struct llama_sampler * common_reasoning_budget_init(
228
        const struct llama_vocab       * vocab,
229
        const std::vector<llama_token> & start_tokens,
230
        const std::vector<llama_token> & end_tokens,
231
        const std::vector<llama_token> & forced_tokens,
232
        int32_t                          budget,
233
0
        const std::vector<llama_token> & prefill_tokens) {
234
    // Determine initial state from prefill: COUNTING if the prefill begins with
235
    // the start sequence but does not also contain the end sequence after it.
236
0
    common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
237
0
    if (!prefill_tokens.empty() && !start_tokens.empty() &&
238
0
            prefill_tokens.size() >= start_tokens.size() &&
239
0
            std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
240
0
        initial_state = REASONING_BUDGET_COUNTING;
241
        // If the end sequence also follows the start in the prefill, reasoning
242
        // was opened and immediately closed — stay IDLE.
243
0
        if (!end_tokens.empty() &&
244
0
                prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
245
0
            auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
246
0
            if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
247
0
                    std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
248
0
                initial_state = REASONING_BUDGET_IDLE;
249
0
            }
250
0
        }
251
0
    }
252
0
    return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
253
0
}
254
255
struct llama_sampler * common_reasoning_budget_init(
256
        const struct llama_vocab       * vocab,
257
        const std::vector<llama_token> & start_tokens,
258
        const std::vector<llama_token> & end_tokens,
259
        const std::vector<llama_token> & forced_tokens,
260
        int32_t                          budget,
261
0
        common_reasoning_budget_state    initial_state) {
262
0
    return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
263
0
}