/src/llama.cpp/common/sampling.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | |
5 | | #include "common.h" |
6 | | |
7 | | #include <string> |
8 | | #include <vector> |
9 | | |
10 | | // common_sampler extends llama_sampler with additional functionality: |
11 | | // |
12 | | // - grammar support |
13 | | // - custom sampler logic based on the parameters |
14 | | // - history of the last accepted tokens |
15 | | // - performance metrics |
16 | | // |
17 | | // This goal is to have a common implementation of the sampling logic shared across the examples. |
18 | | // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more |
19 | | // complex (top-k, top-p, etc). |
20 | | // |
21 | | // Another example is related to the grammar. In general, the grammar constraints applied on the full |
22 | | // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled |
23 | | // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the |
24 | | // grammar constraints are applied to the full vocabulary and the token is resampled. |
25 | | // |
26 | | // The common_sampler also maintains a container with the last accepted tokens. In the future, this can |
27 | | // be moved into the core llama library. |
28 | | // |
29 | | // For convenience, the common_sampler also maintains a container with the current candidate tokens. |
30 | | // This can be used to access the probabilities of the rest of the non-sampled tokens. |
31 | | // |
32 | | // TODO: measure grammar performance |
33 | | // |
34 | | |
35 | | struct common_sampler; |
36 | | |
37 | | // llama_sampler API overloads |
38 | | |
39 | | struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); |
40 | | |
41 | | void common_sampler_free(struct common_sampler * gsmpl); |
42 | | |
43 | | // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar |
44 | | void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); |
45 | | void common_sampler_reset (struct common_sampler * gsmpl); |
46 | | struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); |
47 | | |
48 | | // arguments can be nullptr to skip printing |
49 | | void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); |
50 | | |
51 | | struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); |
52 | | |
53 | | // extended sampling implementation: |
54 | | // |
55 | | // - set logits |
56 | | // - apply the configured sampler chain |
57 | | // - check if the token fits the grammar (if any) |
58 | | // - if not: resample by first applying the grammar constraints and then sampling again (slower path) |
59 | | // |
60 | | // if grammar_first is true, the grammar is applied before the samplers (slower) |
61 | | // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar |
62 | | // |
63 | | llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); |
64 | | |
65 | | // generalized version of common_sampler_sample |
66 | | // |
67 | | // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match |
68 | | // if the sampler disagrees at some point, we stop and return the accepted tokens up to now |
69 | | // |
70 | | // common_sampler_sample_n(gsmpl, ctx, { idx }, {}); |
71 | | // |
72 | | // is equivalent to |
73 | | // |
74 | | // common_sampler_sample(gsmpl, ctx, idx); |
75 | | // common_sampler_accept(gsmpl, token, true); |
76 | | // |
77 | | // requires: idxs.size() == draft.size() + 1 |
78 | | // |
79 | | // returns at least 1 token, up to idxs.size() |
80 | | // |
81 | | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false); |
82 | | |
83 | | // assume idxs == [ 0, 1, 2, ..., draft.size() ] |
84 | | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); |
85 | | |
86 | | uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); |
87 | | |
88 | | // helpers |
89 | | |
90 | | // access the internal list of current candidate tokens |
91 | | // if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability) |
92 | | // the .sorted flag of the result indicates whether the returned candidates are sorted |
93 | | llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort); |
94 | | |
95 | | // get the last accepted token |
96 | | llama_token common_sampler_last(const struct common_sampler * gsmpl); |
97 | | |
98 | | // print the sampler chain into a string |
99 | | std::string common_sampler_print(const struct common_sampler * gsmpl); |
100 | | |
101 | | // get a string representation of the last accepted tokens |
102 | | std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n); |
103 | | |
104 | | char common_sampler_type_to_chr(enum common_sampler_type cnstr); |
105 | | std::string common_sampler_type_to_str(enum common_sampler_type cnstr); |
106 | | |
107 | | std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); |
108 | | std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars); |
109 | | |
110 | | llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, |
111 | | const char * grammar_kind, const char * grammar_data); |
112 | | |
113 | | struct common_sampler_deleter { |
114 | 0 | void operator()(common_sampler * s) { common_sampler_free(s); } |
115 | | }; |
116 | | |
117 | | typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr; |