Coverage Report

Created: 2025-12-28 06:26

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/common/sampling.h
Line
Count
Source
1
#pragma once
2
3
#include "llama.h"
4
5
#include "common.h"
6
7
#include <string>
8
#include <vector>
9
10
// common_sampler extends llama_sampler with additional functionality:
11
//
12
//  - grammar support
13
//  - custom sampler logic based on the parameters
14
//  - history of the last accepted tokens
15
//  - performance metrics
16
//
17
// This goal is to have a common implementation of the sampling logic shared across the examples.
18
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
19
// complex (top-k, top-p, etc).
20
//
21
// Another example is related to the grammar. In general, the grammar constraints applied on the full
22
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
23
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
24
// grammar constraints are applied to the full vocabulary and the token is resampled.
25
//
26
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
27
// be moved into the core llama library.
28
//
29
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
30
// This can be used to access the probabilities of the rest of the non-sampled tokens.
31
//
32
// TODO: measure grammar performance
33
//
34
35
struct common_sampler;
36
37
// llama_sampler API overloads
38
39
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
40
41
void common_sampler_free(struct common_sampler * gsmpl);
42
43
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
44
void                    common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
45
void                    common_sampler_reset (struct common_sampler * gsmpl);
46
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
47
48
// arguments can be nullptr to skip printing
49
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
50
51
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
52
53
// extended sampling implementation:
54
//
55
// - set logits
56
// - apply the configured sampler chain
57
// - check if the token fits the grammar (if any)
58
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
59
//
60
// if grammar_first is true, the grammar is applied before the samplers (slower)
61
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
62
//
63
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
64
65
// generalized version of common_sampler_sample
66
//
67
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
68
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
69
//
70
//      common_sampler_sample_n(gsmpl, ctx, { idx }, {});
71
//
72
// is equivalent to
73
//
74
//      common_sampler_sample(gsmpl, ctx, idx);
75
//      common_sampler_accept(gsmpl, token, true);
76
//
77
// requires: idxs.size() == draft.size() + 1
78
//
79
// returns at least 1 token, up to idxs.size()
80
//
81
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
82
83
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
84
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
85
86
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
87
88
// helpers
89
90
// access the internal list of current candidate tokens
91
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
92
// the .sorted flag of the result indicates whether the returned candidates are sorted
93
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
94
95
// get the last accepted token
96
llama_token common_sampler_last(const struct common_sampler * gsmpl);
97
98
// print the sampler chain into a string
99
std::string common_sampler_print(const struct common_sampler * gsmpl);
100
101
// get a string representation of the last accepted tokens
102
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
103
104
char        common_sampler_type_to_chr(enum common_sampler_type cnstr);
105
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
106
107
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
108
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
109
110
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
111
                const char * grammar_kind, const char * grammar_data);
112
113
struct common_sampler_deleter {
114
0
    void operator()(common_sampler * s) { common_sampler_free(s); }
115
};
116
117
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;