/src/llama.cpp/common/sampling.cpp
Line | Count | Source |
1 | | #include "sampling.h" |
2 | | |
3 | | #include "common.h" |
4 | | #include "fit.h" |
5 | | #include "log.h" |
6 | | #include "reasoning-budget.h" |
7 | | |
8 | | #include "ggml.h" |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cctype> |
12 | | #include <climits> |
13 | | #include <cmath> |
14 | | #include <cstring> |
15 | | #include <unordered_map> |
16 | | #include <vector> |
17 | | |
18 | | // the ring buffer works similarly to std::deque, but with a fixed capacity |
19 | | // TODO: deduplicate with llama-impl.h |
20 | | template<typename T> |
21 | | struct ring_buffer { |
22 | 0 | ring_buffer(size_t cap) : capacity(cap), data(cap) {} |
23 | | |
24 | | T & front() { |
25 | | if (sz == 0) { |
26 | | throw std::runtime_error("ring buffer is empty"); |
27 | | } |
28 | | return data[first]; |
29 | | } |
30 | | |
31 | | const T & front() const { |
32 | | if (sz == 0) { |
33 | | throw std::runtime_error("ring buffer is empty"); |
34 | | } |
35 | | return data[first]; |
36 | | } |
37 | | |
38 | | T & back() { |
39 | | if (sz == 0) { |
40 | | throw std::runtime_error("ring buffer is empty"); |
41 | | } |
42 | | return data[pos]; |
43 | | } |
44 | | |
45 | | const T & back() const { |
46 | | if (sz == 0) { |
47 | | throw std::runtime_error("ring buffer is empty"); |
48 | | } |
49 | | return data[pos]; |
50 | | } |
51 | | |
52 | 0 | void push_back(const T & value) { |
53 | 0 | if (sz == capacity) { |
54 | | // advance the start when buffer is full |
55 | 0 | first = (first + 1) % capacity; |
56 | 0 | } else { |
57 | 0 | sz++; |
58 | 0 | } |
59 | 0 | data[pos] = value; |
60 | 0 | pos = (pos + 1) % capacity; |
61 | 0 | } |
62 | | |
63 | | T pop_front() { |
64 | | if (sz == 0) { |
65 | | throw std::runtime_error("ring buffer is empty"); |
66 | | } |
67 | | T value = data[first]; |
68 | | first = (first + 1) % capacity; |
69 | | sz--; |
70 | | return value; |
71 | | } |
72 | | |
73 | 0 | const T & rat(size_t i) const { |
74 | 0 | if (i >= sz) { |
75 | 0 | throw std::runtime_error("ring buffer: index out of bounds"); |
76 | 0 | } |
77 | 0 | return data[(first + sz - i - 1) % capacity]; |
78 | 0 | } |
79 | | |
80 | | std::vector<T> to_vector() const { |
81 | | std::vector<T> result; |
82 | | result.reserve(sz); |
83 | | for (size_t i = 0; i < sz; i++) { |
84 | | result.push_back(data[(first + i) % capacity]); |
85 | | } |
86 | | return result; |
87 | | } |
88 | | |
89 | 0 | void clear() { |
90 | | // here only reset the status of the buffer |
91 | 0 | sz = 0; |
92 | 0 | first = 0; |
93 | 0 | pos = 0; |
94 | 0 | } |
95 | | |
96 | | bool empty() const { |
97 | | return sz == 0; |
98 | | } |
99 | | |
100 | 0 | size_t size() const { |
101 | 0 | return sz; |
102 | 0 | } |
103 | | |
104 | | size_t capacity = 0; |
105 | | size_t sz = 0; |
106 | | size_t first = 0; |
107 | | size_t pos = 0; |
108 | | std::vector<T> data; |
109 | | }; |
110 | | |
111 | | struct common_sampler { |
112 | | common_params_sampling params; |
113 | | |
114 | | struct llama_sampler * grmr; |
115 | | struct llama_sampler * rbudget; |
116 | | struct llama_sampler * chain; |
117 | | |
118 | | ring_buffer<llama_token> prev; |
119 | | |
120 | | std::vector<llama_token_data> cur; |
121 | | |
122 | | llama_token_data_array cur_p; |
123 | | |
124 | 0 | void reset() { |
125 | 0 | prev.clear(); |
126 | |
|
127 | 0 | llama_sampler_reset(chain); |
128 | 0 | } |
129 | | |
130 | 0 | void set_logits(struct llama_context * ctx, int idx) { |
131 | 0 | const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); |
132 | 0 | const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); |
133 | 0 | const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); |
134 | |
|
135 | 0 | const llama_model * model = llama_get_model(ctx); |
136 | 0 | const llama_vocab * vocab = llama_model_get_vocab(model); |
137 | |
|
138 | 0 | const int n_vocab = llama_vocab_n_tokens(vocab); |
139 | |
|
140 | 0 | if (sampled_probs) { |
141 | 0 | const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); |
142 | 0 | cur.resize(sampled_probs_count); |
143 | 0 | for (uint32_t i = 0; i < sampled_probs_count; ++i) { |
144 | 0 | cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; |
145 | 0 | } |
146 | 0 | } else if (sampled_logits) { |
147 | 0 | const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); |
148 | 0 | cur.resize(sampled_logits_count); |
149 | 0 | for (uint32_t i = 0; i < sampled_logits_count; i++) { |
150 | 0 | cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; |
151 | 0 | } |
152 | 0 | } else { |
153 | 0 | const auto * logits = llama_get_logits_ith(ctx, idx); |
154 | 0 | GGML_ASSERT(logits != nullptr); |
155 | 0 | cur.resize(n_vocab); |
156 | 0 | for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
157 | 0 | cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; |
158 | 0 | } |
159 | 0 | } |
160 | |
|
161 | 0 | cur_p = { cur.data(), cur.size(), -1, false }; |
162 | 0 | } |
163 | | |
164 | 0 | common_time_meas tm() { |
165 | 0 | return common_time_meas(t_total_us, params.no_perf); |
166 | 0 | } |
167 | | |
168 | | mutable int64_t t_total_us = 0; |
169 | | }; |
170 | | |
171 | 0 | std::string common_params_sampling::print() const { |
172 | 0 | char result[1024]; |
173 | |
|
174 | 0 | snprintf(result, sizeof(result), |
175 | 0 | "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" |
176 | 0 | "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" |
177 | 0 | "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" |
178 | 0 | "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f", |
179 | 0 | penalty_last_n, penalty_repeat, penalty_freq, penalty_present, |
180 | 0 | dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, |
181 | 0 | top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, |
182 | 0 | mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay); |
183 | |
|
184 | 0 | return std::string(result); |
185 | 0 | } |
186 | | |
187 | 0 | struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { |
188 | 0 | const llama_vocab * vocab = llama_model_get_vocab(model); |
189 | |
|
190 | 0 | llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); |
191 | |
|
192 | 0 | lparams.no_perf = params.no_perf; |
193 | |
|
194 | 0 | llama_sampler * grmr = nullptr; |
195 | 0 | llama_sampler * rbudget = nullptr; |
196 | 0 | llama_sampler * chain = llama_sampler_chain_init(lparams); |
197 | |
|
198 | 0 | std::vector<llama_sampler *> samplers; |
199 | |
|
200 | 0 | const std::string & grammar_str = common_grammar_value(params.grammar); |
201 | 0 | if (grammar_str.compare(0, 11, "%llguidance") == 0) { |
202 | | #ifdef LLAMA_USE_LLGUIDANCE |
203 | | grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str()); |
204 | | #else |
205 | 0 | GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); |
206 | 0 | #endif // LLAMA_USE_LLGUIDANCE |
207 | 0 | } else { |
208 | 0 | std::vector<std::string> trigger_patterns; |
209 | 0 | std::vector<llama_token> trigger_tokens; |
210 | 0 | for (const auto & trigger : params.grammar_triggers) { |
211 | 0 | switch (trigger.type) { |
212 | 0 | case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: |
213 | 0 | { |
214 | 0 | const auto & word = trigger.value; |
215 | 0 | trigger_patterns.push_back(regex_escape(word)); |
216 | 0 | break; |
217 | 0 | } |
218 | 0 | case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: |
219 | 0 | { |
220 | 0 | trigger_patterns.push_back(trigger.value); |
221 | 0 | break; |
222 | 0 | } |
223 | 0 | case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: |
224 | 0 | { |
225 | 0 | const auto & pattern = trigger.value; |
226 | 0 | std::string anchored = "^$"; |
227 | 0 | if (!pattern.empty()) { |
228 | 0 | anchored = (pattern.front() != '^' ? "^" : "") |
229 | 0 | + pattern |
230 | 0 | + (pattern.back() != '$' ? "$" : ""); |
231 | 0 | } |
232 | 0 | trigger_patterns.push_back(anchored); |
233 | 0 | break; |
234 | 0 | } |
235 | 0 | case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: |
236 | 0 | { |
237 | 0 | const auto token = trigger.token; |
238 | 0 | trigger_tokens.push_back(token); |
239 | 0 | break; |
240 | 0 | } |
241 | 0 | default: |
242 | 0 | GGML_ASSERT(false && "unknown trigger type"); |
243 | 0 | } |
244 | 0 | } |
245 | | |
246 | 0 | std::vector<const char *> trigger_patterns_c; |
247 | 0 | trigger_patterns_c.reserve(trigger_patterns.size()); |
248 | 0 | for (const auto & regex : trigger_patterns) { |
249 | 0 | trigger_patterns_c.push_back(regex.c_str()); |
250 | 0 | } |
251 | |
|
252 | 0 | if (!grammar_str.empty()) { |
253 | 0 | if (params.grammar_lazy) { |
254 | 0 | grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root", |
255 | 0 | trigger_patterns_c.data(), trigger_patterns_c.size(), |
256 | 0 | trigger_tokens.data(), trigger_tokens.size()); |
257 | 0 | } else { |
258 | 0 | grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root"); |
259 | 0 | } |
260 | 0 | } |
261 | 0 | } |
262 | 0 | if (!grmr && !grammar_str.empty()) { |
263 | 0 | throw std::runtime_error("failed to parse grammar"); |
264 | 0 | } |
265 | | |
266 | | // Compute prefill tokens from the generation prompt |
267 | 0 | std::vector<llama_token> prefill_tokens; |
268 | 0 | if (!params.generation_prompt.empty()) { |
269 | 0 | GGML_ASSERT(vocab != nullptr); |
270 | 0 | auto tokens = common_tokenize(vocab, params.generation_prompt, false, true); |
271 | 0 | for (size_t i = 0; i < tokens.size(); i++) { |
272 | 0 | std::string piece = common_token_to_piece(vocab, tokens[i], true); |
273 | 0 | if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) { |
274 | | // Some tokenizers will add a space before the first special token, need to exclude |
275 | 0 | continue; |
276 | 0 | } |
277 | 0 | LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str()); |
278 | 0 | prefill_tokens.push_back(tokens[i]); |
279 | 0 | } |
280 | 0 | } |
281 | | |
282 | | // Feed generation prompt tokens to the grammar sampler so it advances past |
283 | | // tokens the template already placed in the prompt. |
284 | | // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. |
285 | 0 | if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) { |
286 | 0 | try { |
287 | 0 | for (const auto & token : prefill_tokens) { |
288 | 0 | llama_sampler_accept(grmr, token); |
289 | 0 | LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token); |
290 | 0 | } |
291 | 0 | } catch (std::exception &e) { |
292 | 0 | LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, |
293 | 0 | common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); |
294 | 0 | throw e; |
295 | 0 | } |
296 | 0 | } |
297 | | |
298 | | // reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression) |
299 | 0 | if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0 || params.reasoning_control)) { |
300 | 0 | rbudget = common_reasoning_budget_init( |
301 | 0 | vocab, |
302 | 0 | params.reasoning_budget_start, |
303 | 0 | params.reasoning_budget_end, |
304 | 0 | params.reasoning_budget_forced, |
305 | 0 | params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens); |
306 | |
|
307 | 0 | for (const auto & token : prefill_tokens) { |
308 | 0 | llama_sampler_accept(rbudget, token); |
309 | 0 | LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token); |
310 | 0 | } |
311 | 0 | } |
312 | |
|
313 | 0 | if (params.has_logit_bias()) { |
314 | 0 | samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); |
315 | 0 | } |
316 | |
|
317 | 0 | if (params.mirostat == 0) { |
318 | |
|
319 | 0 | bool use_adaptive_p = false; // see below |
320 | |
|
321 | 0 | for (const auto & cnstr : params.samplers) { |
322 | 0 | switch (cnstr) { |
323 | 0 | case COMMON_SAMPLER_TYPE_DRY: |
324 | 0 | { |
325 | 0 | std::vector<const char *> c_breakers; |
326 | 0 | c_breakers.reserve(params.dry_sequence_breakers.size()); |
327 | 0 | for (const auto & str : params.dry_sequence_breakers) { |
328 | 0 | c_breakers.push_back(str.c_str()); |
329 | 0 | } |
330 | 0 | samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); |
331 | 0 | } |
332 | 0 | break; |
333 | 0 | case COMMON_SAMPLER_TYPE_TOP_K: |
334 | 0 | samplers.push_back(llama_sampler_init_top_k(params.top_k)); |
335 | 0 | break; |
336 | 0 | case COMMON_SAMPLER_TYPE_TOP_P: |
337 | 0 | samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep)); |
338 | 0 | break; |
339 | 0 | case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: |
340 | 0 | samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); |
341 | 0 | break; |
342 | 0 | case COMMON_SAMPLER_TYPE_MIN_P: |
343 | 0 | samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep)); |
344 | 0 | break; |
345 | 0 | case COMMON_SAMPLER_TYPE_XTC: |
346 | 0 | samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); |
347 | 0 | break; |
348 | 0 | case COMMON_SAMPLER_TYPE_TYPICAL_P: |
349 | 0 | samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep)); |
350 | 0 | break; |
351 | 0 | case COMMON_SAMPLER_TYPE_TEMPERATURE: |
352 | 0 | samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); |
353 | 0 | break; |
354 | 0 | case COMMON_SAMPLER_TYPE_INFILL: |
355 | 0 | samplers.push_back(llama_sampler_init_infill(vocab)); |
356 | 0 | break; |
357 | 0 | case COMMON_SAMPLER_TYPE_PENALTIES: |
358 | 0 | samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); |
359 | 0 | break; |
360 | 0 | case COMMON_SAMPLER_TYPE_ADAPTIVE_P: |
361 | | // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects |
362 | | // a single token, so we will add `dist` at the end of the chain by default, |
363 | | // unless the user specifically included `adaptive-p`. we set this flag here |
364 | | // so we know to add the sampler at the very end. |
365 | 0 | use_adaptive_p = true; |
366 | 0 | break; |
367 | 0 | default: |
368 | 0 | GGML_ASSERT(false && "unknown sampler type"); |
369 | 0 | } |
370 | 0 | } |
371 | 0 | if (use_adaptive_p) { |
372 | | // only if user explicitly included adaptive-p sampler |
373 | 0 | samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); |
374 | 0 | } else { |
375 | | // default: sample from distribution |
376 | 0 | samplers.push_back(llama_sampler_init_dist(params.seed)); |
377 | 0 | } |
378 | 0 | } else if (params.mirostat == 1) { |
379 | 0 | samplers.push_back(llama_sampler_init_temp(params.temp)); |
380 | 0 | samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); |
381 | 0 | } else if (params.mirostat == 2) { |
382 | 0 | samplers.push_back(llama_sampler_init_temp(params.temp)); |
383 | 0 | samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); |
384 | 0 | } else { |
385 | 0 | GGML_ASSERT(false && "unknown mirostat version"); |
386 | 0 | } |
387 | | |
388 | 0 | for (auto * smpl : samplers) { |
389 | 0 | llama_sampler_chain_add(chain, smpl); |
390 | 0 | } |
391 | |
|
392 | 0 | if (grmr && params.backend_sampling) { |
393 | 0 | LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__); |
394 | |
|
395 | 0 | params.backend_sampling = false; |
396 | 0 | } |
397 | |
|
398 | 0 | if (rbudget && params.backend_sampling) { |
399 | 0 | LOG_WRN("%s: backend sampling is not compatible with reasoning budget, disabling\n", __func__); |
400 | |
|
401 | 0 | params.backend_sampling = false; |
402 | 0 | } |
403 | |
|
404 | 0 | auto * result = new common_sampler { |
405 | 0 | /* .params = */ params, |
406 | 0 | /* .grmr = */ grmr, |
407 | 0 | /* .rbudget = */ rbudget, |
408 | 0 | /* .chain = */ chain, |
409 | 0 | /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), |
410 | 0 | /* .cur = */ {}, |
411 | 0 | /* .cur_p = */ {}, |
412 | 0 | }; |
413 | |
|
414 | 0 | return result; |
415 | 0 | } |
416 | | |
417 | 0 | void common_sampler_free(struct common_sampler * gsmpl) { |
418 | 0 | if (!gsmpl) { |
419 | 0 | return; |
420 | 0 | } |
421 | | |
422 | 0 | llama_sampler_free(gsmpl->grmr); |
423 | 0 | llama_sampler_free(gsmpl->rbudget); |
424 | 0 | llama_sampler_free(gsmpl->chain); |
425 | |
|
426 | 0 | delete gsmpl; |
427 | 0 | } |
428 | | |
429 | 0 | static bool grammar_should_apply(struct common_sampler * gsmpl) { |
430 | 0 | if (!gsmpl->grmr) { |
431 | 0 | return false; |
432 | 0 | } |
433 | 0 | if (!gsmpl->rbudget) { |
434 | 0 | return true; |
435 | 0 | } |
436 | 0 | if (gsmpl->params.grammar_lazy) { |
437 | | // if grammar is lazy, only apply when reasoning budget is not active |
438 | 0 | const auto state = common_reasoning_budget_get_state(gsmpl->rbudget); |
439 | 0 | return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE; |
440 | 0 | } |
441 | 0 | return true; |
442 | 0 | } |
443 | | |
444 | 0 | void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) { |
445 | 0 | if (!gsmpl) { |
446 | 0 | return; |
447 | 0 | } |
448 | | |
449 | 0 | const auto tm = gsmpl->tm(); |
450 | | |
451 | | // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept |
452 | 0 | const auto accept_grammar = is_generated && grammar_should_apply(gsmpl); |
453 | |
|
454 | 0 | if (gsmpl->rbudget && is_generated) { |
455 | 0 | llama_sampler_accept(gsmpl->rbudget, token); |
456 | 0 | } |
457 | |
|
458 | 0 | if (gsmpl->grmr && accept_grammar) { |
459 | 0 | llama_sampler_accept(gsmpl->grmr, token); |
460 | 0 | } |
461 | |
|
462 | 0 | llama_sampler_accept(gsmpl->chain, token); |
463 | |
|
464 | 0 | gsmpl->prev.push_back(token); |
465 | 0 | } |
466 | | |
467 | 0 | void common_sampler_reset(struct common_sampler * gsmpl) { |
468 | 0 | if (!gsmpl) { |
469 | 0 | return; |
470 | 0 | } |
471 | | |
472 | 0 | gsmpl->reset(); |
473 | 0 | } |
474 | | |
475 | 0 | struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { |
476 | 0 | return new common_sampler { |
477 | 0 | /* .params = */ gsmpl->params, |
478 | 0 | /* .grmr = */ llama_sampler_clone(gsmpl->grmr), |
479 | 0 | /* .rbudget = */ llama_sampler_clone(gsmpl->rbudget), |
480 | 0 | /* .chain = */ llama_sampler_clone(gsmpl->chain), |
481 | 0 | /* .prev = */ gsmpl->prev, |
482 | 0 | /* .cur = */ gsmpl->cur, |
483 | 0 | /* .cur_p = */ gsmpl->cur_p, |
484 | 0 | }; |
485 | 0 | } |
486 | | |
487 | 0 | void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { |
488 | | // TODO: measure grammar performance |
489 | |
|
490 | 0 | const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0; |
491 | |
|
492 | 0 | llama_perf_sampler_data data_smpl; |
493 | 0 | llama_perf_context_data data_ctx; |
494 | |
|
495 | 0 | memset(&data_smpl, 0, sizeof(data_smpl)); |
496 | 0 | memset(&data_ctx, 0, sizeof(data_ctx)); |
497 | |
|
498 | 0 | if (gsmpl) { |
499 | 0 | auto & data = data_smpl; |
500 | |
|
501 | 0 | data = llama_perf_sampler(gsmpl->chain); |
502 | | |
503 | | // note: the sampling time includes the samplers time + extra time spent in common/sampling |
504 | 0 | LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms); |
505 | 0 | LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample); |
506 | 0 | } |
507 | |
|
508 | 0 | if (ctx) { |
509 | 0 | auto & data = data_ctx; |
510 | |
|
511 | 0 | data = llama_perf_context(ctx); |
512 | |
|
513 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
514 | |
|
515 | 0 | const double t_total_ms = t_end_ms - data.t_start_ms; |
516 | 0 | const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms); |
517 | 0 | const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms; |
518 | |
|
519 | 0 | LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
520 | 0 | LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
521 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
522 | 0 | LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
523 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
524 | 0 | LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
525 | 0 | LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc); |
526 | 0 | LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused); |
527 | |
|
528 | 0 | common_memory_breakdown_print(ctx); |
529 | 0 | } |
530 | 0 | } |
531 | | |
532 | 0 | struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { |
533 | 0 | if (!gsmpl) { |
534 | 0 | return nullptr; |
535 | 0 | } |
536 | | |
537 | 0 | return gsmpl->chain; |
538 | 0 | } |
539 | | |
540 | 0 | llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { |
541 | 0 | llama_synchronize(ctx); |
542 | | |
543 | | // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations |
544 | 0 | const auto tm = gsmpl->tm(); |
545 | |
|
546 | 0 | llama_token id = LLAMA_TOKEN_NULL; |
547 | |
|
548 | 0 | auto & grmr = gsmpl->grmr; |
549 | 0 | auto & rbudget = gsmpl->rbudget; |
550 | 0 | auto & chain = gsmpl->chain; |
551 | 0 | auto & cur_p = gsmpl->cur_p; // initialized by set_logits |
552 | |
|
553 | 0 | gsmpl->set_logits(ctx, idx); |
554 | | |
555 | | // Check if a backend sampler has already sampled a token in which case we |
556 | | // return that token id directly. |
557 | 0 | { |
558 | 0 | id = llama_get_sampled_token_ith(ctx, idx); |
559 | |
|
560 | 0 | if (id != LLAMA_TOKEN_NULL) { |
561 | 0 | LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); |
562 | |
|
563 | 0 | GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); |
564 | 0 | GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported"); |
565 | |
|
566 | 0 | for (size_t i = 0; i < cur_p.size; ++i) { |
567 | 0 | if (cur_p.data[i].id == id) { |
568 | 0 | cur_p.selected = i; |
569 | 0 | break; |
570 | 0 | } |
571 | 0 | } |
572 | |
|
573 | 0 | return id; |
574 | 0 | } |
575 | 0 | } |
576 | | |
577 | | // apply reasoning budget first |
578 | 0 | llama_sampler_apply(rbudget, &cur_p); |
579 | |
|
580 | 0 | if (grammar_first && grammar_should_apply(gsmpl)) { |
581 | 0 | llama_sampler_apply(grmr, &cur_p); |
582 | 0 | } |
583 | |
|
584 | 0 | llama_sampler_apply(chain, &cur_p); |
585 | |
|
586 | 0 | id = cur_p.data[cur_p.selected].id; |
587 | |
|
588 | 0 | if (grammar_first || !grammar_should_apply(gsmpl)) { |
589 | 0 | return id; |
590 | 0 | } |
591 | | |
592 | | // check if it the sampled token fits the grammar (grammar-based rejection sampling) |
593 | 0 | { |
594 | 0 | llama_token_data single_token_data = { id, 1.0f, 0.0f }; |
595 | 0 | llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; |
596 | |
|
597 | 0 | llama_sampler_apply(grmr, &single_token_data_array); |
598 | |
|
599 | 0 | const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; |
600 | 0 | if (is_valid) { |
601 | 0 | return id; |
602 | 0 | } |
603 | 0 | } |
604 | | |
605 | | // resampling: |
606 | | // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain |
607 | 0 | gsmpl->set_logits(ctx, idx); |
608 | |
|
609 | 0 | llama_sampler_apply(rbudget, &cur_p); |
610 | |
|
611 | 0 | if (grammar_should_apply(gsmpl)) { |
612 | 0 | llama_sampler_apply(grmr, &cur_p); |
613 | 0 | } |
614 | |
|
615 | 0 | llama_sampler_apply(chain, &cur_p); |
616 | |
|
617 | 0 | GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); |
618 | |
|
619 | 0 | id = cur_p.data[cur_p.selected].id; |
620 | |
|
621 | 0 | return id; |
622 | 0 | } |
623 | | |
624 | 0 | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) { |
625 | 0 | GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); |
626 | |
|
627 | 0 | std::vector<llama_token> result; |
628 | 0 | result.reserve(idxs.size()); |
629 | |
|
630 | 0 | size_t i = 0; |
631 | 0 | for (; i < draft.size(); i++) { |
632 | 0 | const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); |
633 | |
|
634 | 0 | common_sampler_accept(gsmpl, id, true); |
635 | |
|
636 | 0 | result.push_back(id); |
637 | |
|
638 | 0 | if (draft[i] != id) { |
639 | 0 | break; |
640 | 0 | } |
641 | 0 | } |
642 | |
|
643 | 0 | if (i == draft.size()) { |
644 | 0 | const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); |
645 | |
|
646 | 0 | common_sampler_accept(gsmpl, id, true); |
647 | |
|
648 | 0 | result.push_back(id); |
649 | 0 | } |
650 | |
|
651 | 0 | return result; |
652 | 0 | } |
653 | | |
654 | 0 | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { |
655 | 0 | std::vector<int> idxs(draft.size() + 1); |
656 | 0 | for (size_t i = 0; i < idxs.size(); ++i) { |
657 | 0 | idxs[i] = i; |
658 | 0 | } |
659 | |
|
660 | 0 | return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); |
661 | 0 | } |
662 | | |
663 | 0 | uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { |
664 | 0 | return llama_sampler_get_seed(gsmpl->chain); |
665 | 0 | } |
666 | | |
667 | 0 | bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl) { |
668 | 0 | if (!gsmpl) { |
669 | 0 | return false; |
670 | 0 | } |
671 | | |
672 | 0 | return common_reasoning_budget_force(gsmpl->rbudget); |
673 | 0 | } |
674 | | |
675 | | // helpers |
676 | | |
677 | 0 | llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { |
678 | 0 | const auto tm = gsmpl->tm(); |
679 | |
|
680 | 0 | auto * res = &gsmpl->cur_p; |
681 | |
|
682 | 0 | if (do_sort && !res->sorted) { |
683 | | // remember the selected token before sorting |
684 | 0 | const llama_token id = res->data[res->selected].id; |
685 | |
|
686 | 0 | std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) { |
687 | 0 | return a.p > b.p; |
688 | 0 | }); |
689 | | |
690 | | // restore the selected token after sorting |
691 | 0 | for (size_t i = 0; i < res->size; ++i) { |
692 | 0 | if (res->data[i].id == id) { |
693 | 0 | res->selected = i; |
694 | 0 | break; |
695 | 0 | } |
696 | 0 | } |
697 | |
|
698 | 0 | res->sorted = true; |
699 | 0 | } |
700 | |
|
701 | 0 | return res; |
702 | 0 | } |
703 | | |
704 | 0 | llama_token common_sampler_last(const struct common_sampler * gsmpl) { |
705 | 0 | return gsmpl->prev.rat(0); |
706 | 0 | } |
707 | | |
708 | 0 | std::string common_sampler_print(const struct common_sampler * gsmpl) { |
709 | 0 | std::string result = "logits "; |
710 | |
|
711 | 0 | for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { |
712 | 0 | const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); |
713 | 0 | result += std::string("-> "); |
714 | 0 | result += std::string(llama_sampler_name(smpl)) + " "; |
715 | 0 | } |
716 | |
|
717 | 0 | return result; |
718 | 0 | } |
719 | | |
720 | 0 | std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) { |
721 | 0 | n = std::min(n, (int) gsmpl->prev.size()); |
722 | |
|
723 | 0 | if (n <= 0) { |
724 | 0 | return ""; |
725 | 0 | } |
726 | | |
727 | 0 | std::string result; |
728 | 0 | result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab |
729 | |
|
730 | 0 | for (int i = n - 1; i >= 0; i--) { |
731 | 0 | const llama_token id = gsmpl->prev.rat(i); |
732 | |
|
733 | 0 | GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); |
734 | |
|
735 | 0 | result += common_token_to_piece(ctx_main, id); |
736 | 0 | } |
737 | |
|
738 | 0 | return result; |
739 | 0 | } |
740 | | |
741 | 0 | char common_sampler_type_to_chr(enum common_sampler_type cnstr) { |
742 | 0 | switch (cnstr) { |
743 | 0 | case COMMON_SAMPLER_TYPE_DRY: return 'd'; |
744 | 0 | case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; |
745 | 0 | case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; |
746 | 0 | case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; |
747 | 0 | case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; |
748 | 0 | case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; |
749 | 0 | case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; |
750 | 0 | case COMMON_SAMPLER_TYPE_XTC: return 'x'; |
751 | 0 | case COMMON_SAMPLER_TYPE_INFILL: return 'i'; |
752 | 0 | case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; |
753 | 0 | case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a'; |
754 | 0 | default : return '?'; |
755 | 0 | } |
756 | 0 | } |
757 | | |
758 | 0 | std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { |
759 | 0 | switch (cnstr) { |
760 | 0 | case COMMON_SAMPLER_TYPE_DRY: return "dry"; |
761 | 0 | case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; |
762 | 0 | case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; |
763 | 0 | case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; |
764 | 0 | case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; |
765 | 0 | case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; |
766 | 0 | case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; |
767 | 0 | case COMMON_SAMPLER_TYPE_XTC: return "xtc"; |
768 | 0 | case COMMON_SAMPLER_TYPE_INFILL: return "infill"; |
769 | 0 | case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; |
770 | 0 | case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p"; |
771 | 0 | default : return ""; |
772 | 0 | } |
773 | 0 | } |
774 | | |
775 | 0 | std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names) { |
776 | | // sampler names can be written multiple ways; generate aliases from canonical names |
777 | 0 | static const auto sampler_name_map = []{ |
778 | | // canonical sampler name mapping |
779 | 0 | std::unordered_map<std::string, common_sampler_type> canonical_name_map { |
780 | 0 | { "dry", COMMON_SAMPLER_TYPE_DRY }, |
781 | 0 | { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, |
782 | 0 | { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, |
783 | 0 | { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, |
784 | 0 | { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, |
785 | 0 | { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, |
786 | 0 | { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, |
787 | 0 | { "xtc", COMMON_SAMPLER_TYPE_XTC }, |
788 | 0 | { "infill", COMMON_SAMPLER_TYPE_INFILL }, |
789 | 0 | { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, |
790 | 0 | { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P } |
791 | 0 | }; |
792 | 0 | std::unordered_map<std::string, common_sampler_type> alias_name_map; |
793 | 0 | for (const auto & entry : canonical_name_map) { |
794 | 0 | const std::string & canonical = entry.first; |
795 | 0 | if (canonical.find('_') == std::string::npos) { |
796 | 0 | continue; |
797 | 0 | } |
798 | | // kebab-case: "top-k", "min-p", etc. |
799 | 0 | { |
800 | 0 | std::string kebab_case = canonical; |
801 | 0 | std::replace(kebab_case.begin(), kebab_case.end(), '_', '-'); |
802 | 0 | alias_name_map.insert({kebab_case, entry.second}); |
803 | 0 | } |
804 | | // no dash: "topk", "minp", etc. |
805 | 0 | { |
806 | 0 | std::string no_dash = canonical; |
807 | 0 | no_dash.erase(std::remove(no_dash.begin(), no_dash.end(), '_'), no_dash.end()); |
808 | 0 | alias_name_map.insert({no_dash, entry.second}); |
809 | 0 | } |
810 | 0 | } |
811 | | // misc. aliases |
812 | 0 | alias_name_map.insert({"nucleus", COMMON_SAMPLER_TYPE_TOP_P}); |
813 | 0 | alias_name_map.insert({"temp", COMMON_SAMPLER_TYPE_TEMPERATURE}); |
814 | 0 | alias_name_map.insert({"typ", COMMON_SAMPLER_TYPE_TYPICAL_P}); |
815 | | // include aliases + canonical names in the complete mapping |
816 | 0 | alias_name_map.merge(canonical_name_map); |
817 | 0 | return alias_name_map; |
818 | 0 | }(); |
819 | |
|
820 | 0 | std::vector<common_sampler_type> samplers; |
821 | 0 | samplers.reserve(names.size()); |
822 | |
|
823 | 0 | for (const auto & name : names) { |
824 | 0 | std::string name_lower = name; |
825 | 0 | std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower); |
826 | 0 | auto sampler = sampler_name_map.find(name_lower); |
827 | 0 | if (sampler != sampler_name_map.end()) { |
828 | 0 | samplers.push_back(sampler->second); |
829 | 0 | continue; |
830 | 0 | } |
831 | 0 | LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name_lower.c_str()); |
832 | 0 | } |
833 | |
|
834 | 0 | return samplers; |
835 | 0 | } |
836 | | |
837 | 0 | std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) { |
838 | 0 | std::unordered_map<char, common_sampler_type> sampler_name_map = { |
839 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, |
840 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, |
841 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, |
842 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, |
843 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, |
844 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, |
845 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, |
846 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, |
847 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, |
848 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, |
849 | 0 | { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P }, |
850 | 0 | }; |
851 | |
|
852 | 0 | std::vector<common_sampler_type> samplers; |
853 | 0 | samplers.reserve(chars.size()); |
854 | |
|
855 | 0 | for (const auto & c : chars) { |
856 | 0 | const auto sampler = sampler_name_map.find(c); |
857 | 0 | if (sampler != sampler_name_map.end()) { |
858 | 0 | samplers.push_back(sampler->second); |
859 | 0 | } else { |
860 | 0 | LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c); |
861 | 0 | } |
862 | 0 | } |
863 | |
|
864 | 0 | return samplers; |
865 | 0 | } |