/src/llama.cpp/common/speculative.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "common.h" |
5 | | |
6 | | struct common_speculative; |
7 | | |
8 | | // comma separated list the provided types |
9 | | std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types); |
10 | | |
11 | | // comma separated list of all types |
12 | | const char * common_speculative_all_types_str(); |
13 | | |
14 | | // parse user provided types |
15 | | std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names); |
16 | | |
17 | | // convert string to type |
18 | | enum common_speculative_type common_speculative_type_from_name(const std::string & name); |
19 | | |
20 | | // convert type to string |
21 | | std::string common_speculative_type_to_str(enum common_speculative_type type); |
22 | | |
23 | | // return the max number of draft tokens based on the speculative parameters |
24 | | int32_t common_speculative_n_max(const common_params_speculative * spec); |
25 | | |
26 | | common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); |
27 | | |
28 | | void common_speculative_free(common_speculative * spec); |
29 | | |
30 | | struct common_speculative_draft_params { |
31 | | // this flag is used to chain the drafts through all the available implementations |
32 | | // after the first successful draft from an implementation, we set it |
33 | | // to false to prevent further drafts for that sequence |
34 | | // at the end of the draft() call, all drafting flags will be reset to false |
35 | | bool drafting = false; |
36 | | |
37 | | // overrides individual configurations (-1 disabled) |
38 | | // can be used to constraint the max draft based on the remaining context size |
39 | | int32_t n_max = -1; |
40 | | |
41 | | llama_pos n_past; |
42 | | llama_token id_last; |
43 | | |
44 | | // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls |
45 | | const llama_tokens * prompt; |
46 | | |
47 | | // the generated draft from the last _draft() call |
48 | | llama_tokens * result; |
49 | | }; |
50 | | |
51 | | common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); |
52 | | |
53 | | // optionally call once at the beginning of a new generation |
54 | | void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); |
55 | | |
56 | | // process the batch and update the internal state of the speculative context |
57 | | bool common_speculative_process(common_speculative * spec, const llama_batch & batch); |
58 | | |
59 | | // true if any implementation requires target post-norm embeddings to be extracted |
60 | | bool common_speculative_need_embd(common_speculative * spec); |
61 | | |
62 | | // true if any implementation requires target nextn embeddings to be extracted |
63 | | bool common_speculative_need_embd_nextn(common_speculative * spec); |
64 | | |
65 | | // generate drafts for the sequences specified with `common_speculative_get_draft_params` |
66 | | void common_speculative_draft(common_speculative * spec); |
67 | | |
68 | | // informs the speculative context that n_accepted tokens were accepted by the target model |
69 | | void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); |
70 | | |
71 | | // (optional) get/set internal state |
72 | | bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data); |
73 | | void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data); |
74 | | |
75 | | // print statistics about the speculative decoding |
76 | | void common_speculative_print_stats(const common_speculative * spec); |
77 | | |
78 | | struct common_speculative_deleter { |
79 | 0 | void operator()(common_speculative * s) { common_speculative_free(s); } |
80 | | }; |
81 | | |
82 | | typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr; |