/src/llama.cpp/src/llama-batch.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | |
5 | | #include "llama-cparams.h" |
6 | | |
7 | | #include <array> |
8 | | #include <vector> |
9 | | #include <set> |
10 | | #include <bitset> |
11 | | #include <memory> |
12 | | #include <unordered_map> |
13 | | |
14 | | // keep this struct lightweight |
15 | | struct llama_ubatch { |
16 | 0 | bool equal_seqs() const { |
17 | 0 | return b_equal_seqs != 0; |
18 | 0 | } |
19 | | |
20 | | // typical for M-RoPE cases: |
21 | | // 0 - sequantial position of the tokens/embeddings in the sequence |
22 | | // 1 - y position in the image |
23 | | // 2 - x position in the image |
24 | | // 3 - other |
25 | 0 | bool is_pos_2d() const { |
26 | | // TODO @ngxson : we may need to check for model arch when more models use >1 positions |
27 | 0 | return n_pos >= 3; |
28 | 0 | } |
29 | | |
30 | | uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment |
31 | | // otherwise address sanitizer complains |
32 | | // TODO: whole_seqs for embeddings? |
33 | | |
34 | | uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) |
35 | | uint32_t n_seq_tokens; // tokens per sequence set |
36 | | uint32_t n_seqs; // sequence sets in the ubatch |
37 | | uint32_t n_seqs_unq; // unique sequence ids in the ubatch |
38 | | uint32_t n_pos; // number of position inputs for each token/embedding |
39 | | |
40 | | // seq_id_unq: unique sequence ids in the ubatch |
41 | | // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq) |
42 | | // used for extracting sequence pooled embeddings |
43 | | |
44 | | // // size | idx | val |
45 | | llama_token * token; // [n_tokens] | i | id, token |
46 | | float * embd; // [n_embd, n_tokens] | i | embd |
47 | | llama_pos * pos; // [n_tokens*n_pos] | i | pos |
48 | | int32_t * n_seq_id; // [n_tokens] | i | - |
49 | | llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id |
50 | | llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id |
51 | | int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx |
52 | | int8_t * output; // [n_tokens] | i | - |
53 | | |
54 | | struct data_t { |
55 | | std::vector<llama_token> token; |
56 | | std::vector<float> embd; |
57 | | std::vector<llama_pos> pos; |
58 | | std::vector<int32_t> n_seq_id; |
59 | | std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below |
60 | | std::vector<llama_seq_id> seq_id_unq; |
61 | | std::vector<int32_t> seq_idx; |
62 | | std::vector<int8_t> output; |
63 | | |
64 | | std::vector<llama_seq_id> seq_id_data; |
65 | | }; |
66 | | |
67 | | // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data |
68 | | std::shared_ptr<data_t> data; |
69 | | }; |
70 | | |
71 | | // a helper for sanitizing, fulfilling and splitting a batch |
72 | | class llama_batch_allocr { |
73 | | public: |
74 | | llama_batch_allocr(uint32_t n_pos_per_embd); |
75 | | |
76 | | // sanitize and auto-gen missing data in the input batch |
77 | | // memory is optional. if provided will be used to check for sequence continuity and to determine the positions |
78 | | bool init( |
79 | | const llama_batch & batch_inp, |
80 | | const llama_vocab & vocab, |
81 | | const llama_memory_i * memory, |
82 | | uint32_t n_embd, |
83 | | uint32_t n_seq_max, |
84 | | bool output_all); |
85 | | |
86 | | const llama_batch & get_batch() const; |
87 | | |
88 | | uint32_t get_n_tokens() const; |
89 | | uint32_t get_n_outputs() const; |
90 | | uint32_t get_n_used() const; |
91 | | |
92 | | // the array of output indices in the order they were encountered during the ubatch splitting |
93 | | std::vector<int32_t> & get_out_ids(); |
94 | | |
95 | | // min/max positions of each sequence in the current ubatch |
96 | | llama_pos seq_pos_min(llama_seq_id seq_id) const; |
97 | | llama_pos seq_pos_max(llama_seq_id seq_id) const; |
98 | | |
99 | | // call once before splitting the batch to reset the internal state |
100 | | void split_reset(); |
101 | | |
102 | | // simple split, unknown number of sequence sets of unequal lengths |
103 | | llama_ubatch split_simple(uint32_t n_ubatch); |
104 | | |
105 | | // make ubatches of equal-length sequences sets |
106 | | // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids |
107 | | llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
108 | | |
109 | | // sequence-set-wise split - each ubatch contains a single sequence-set |
110 | | llama_ubatch split_seq(uint32_t n_ubatch); |
111 | | |
112 | | // a helper method for creating a well-defined ubatch of tokens |
113 | | // TODO: support embeddings if needed in the future |
114 | | llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
115 | | |
116 | | private: |
117 | | void clear(); |
118 | | |
119 | | // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs) |
120 | | // return llama_ubatch.n_tokens == 0 if the entire batch was consumed |
121 | | llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
122 | | |
123 | | // for debugging, start with LLAMA_BATCH_DEBUG=2 |
124 | | void ubatch_print(const llama_ubatch & ubatch, int debug); |
125 | | |
126 | | llama_batch batch; |
127 | | |
128 | | // only for debugging purposes |
129 | | const llama_vocab * vocab; |
130 | | |
131 | | // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd |
132 | | // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 |
133 | | const uint32_t n_pos_per_embd; |
134 | | |
135 | | uint32_t n_embd; |
136 | | uint32_t n_seq_max; |
137 | | uint32_t n_outputs; |
138 | | |
139 | | std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id |
140 | | |
141 | | std::vector<llama_pos> pos; |
142 | | std::vector<int32_t> n_seq_id; |
143 | | std::vector<llama_seq_id *> seq_id; |
144 | | std::vector<llama_seq_id> seq_id_unq; |
145 | | std::vector<int32_t> seq_idx; |
146 | | std::vector<int8_t> output; |
147 | | |
148 | | using pos_set_t = std::set<llama_pos>; |
149 | | using seq_cpl_t = std::vector<bool>; |
150 | | |
151 | | // helper flag to quickly determine if there are any coupled sequences in the batch |
152 | | bool has_cpl = false; |
153 | | |
154 | | std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s |
155 | | std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |
156 | | |
157 | | using idx_vec_t = std::vector<int32_t>; |
158 | | using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
159 | | |
160 | | std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i |
161 | | |
162 | | std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears |
163 | | |
164 | | // batch indices of the output |
165 | | std::vector<int32_t> out_ids; |
166 | | |
167 | | uint32_t n_used; |
168 | | |
169 | | // used[i] indicates if token i has already been used in a previous ubatch |
170 | | std::vector<bool> used; |
171 | | |
172 | | int debug; |
173 | | }; |