/src/llama.cpp/src/llama-kv-cache.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama-batch.h" |
4 | | #include "llama-graph.h" |
5 | | #include "llama-kv-cells.h" |
6 | | #include "llama-memory.h" |
7 | | |
8 | | #include <unordered_map> |
9 | | #include <vector> |
10 | | |
11 | | struct llama_cparams; |
12 | | struct llama_hparams; |
13 | | struct llama_model; |
14 | | struct llama_context; |
15 | | |
16 | | // |
17 | | // llama_kv_cache |
18 | | // |
19 | | |
20 | | class llama_kv_cache : public llama_memory_i { |
21 | | public: |
22 | | struct stream_copy_info { |
23 | 0 | bool empty() const { |
24 | 0 | assert(ssrc.size() == sdst.size()); |
25 | 0 | return ssrc.empty(); |
26 | 0 | } |
27 | | |
28 | | std::vector<uint32_t> ssrc; |
29 | | std::vector<uint32_t> sdst; |
30 | | }; |
31 | | |
32 | | // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the |
33 | | // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] |
34 | | struct slot_info { |
35 | | // data for ggml_set_rows |
36 | | using idx_vec_t = std::vector<uint32_t>; |
37 | | |
38 | | // number of streams: ns = s1 - s0 + 1 |
39 | | uint32_t s0; |
40 | | uint32_t s1; |
41 | | |
42 | | std::vector<llama_seq_id> strm; // [ns] |
43 | | std::vector<idx_vec_t> idxs; // [ns] |
44 | | |
45 | 0 | uint32_t head() const { |
46 | 0 | GGML_ASSERT(idxs.size() == 1); |
47 | 0 | GGML_ASSERT(!idxs[0].empty()); |
48 | |
|
49 | 0 | return idxs[0][0]; |
50 | 0 | } |
51 | | |
52 | 0 | void resize(size_t n) { |
53 | 0 | strm.resize(n); |
54 | 0 | idxs.resize(n); |
55 | 0 | } |
56 | | |
57 | 0 | size_t size() const { |
58 | 0 | GGML_ASSERT(idxs.size() == strm.size()); |
59 | 0 | GGML_ASSERT(!idxs.empty()); |
60 | |
|
61 | 0 | return idxs[0].size(); |
62 | 0 | } |
63 | | |
64 | 0 | size_t n_stream() const { |
65 | 0 | return strm.size(); |
66 | 0 | } |
67 | | |
68 | 0 | bool empty() const { |
69 | 0 | return idxs.empty(); |
70 | 0 | } |
71 | | |
72 | 0 | void clear() { |
73 | 0 | idxs.clear(); |
74 | 0 | } |
75 | | |
76 | | // check if indices are contiguous starting from head() |
77 | 0 | bool is_contiguous() const { |
78 | 0 | if (idxs.empty() || idxs[0].empty()) { |
79 | 0 | return true; |
80 | 0 | } |
81 | 0 | if (idxs.size() > 1) { |
82 | 0 | return false; |
83 | 0 | } |
84 | 0 | const uint32_t h = idxs[0][0]; |
85 | 0 | for (size_t i = 0; i < idxs[0].size(); ++i) { |
86 | 0 | if (idxs[0][i] != h + i) { |
87 | 0 | return false; |
88 | 0 | } |
89 | 0 | } |
90 | 0 | return true; |
91 | 0 | } |
92 | | }; |
93 | | |
94 | | using slot_info_vec_t = std::vector<slot_info>; |
95 | | |
96 | | llama_kv_cache( |
97 | | const llama_model & model, |
98 | | ggml_type type_k, |
99 | | ggml_type type_v, |
100 | | bool v_trans, |
101 | | bool offload, |
102 | | bool unified, |
103 | | uint32_t kv_size, |
104 | | uint32_t n_seq_max, |
105 | | uint32_t n_pad, |
106 | | uint32_t n_swa, |
107 | | llama_swa_type swa_type, |
108 | | const layer_filter_cb & filter, |
109 | | const layer_reuse_cb & reuse); |
110 | | |
111 | 0 | ~llama_kv_cache() = default; |
112 | | |
113 | | // |
114 | | // llama_memory_i |
115 | | // |
116 | | |
117 | | llama_memory_context_ptr init_batch( |
118 | | llama_batch_allocr & balloc, |
119 | | uint32_t n_ubatch, |
120 | | bool embd_all) override; |
121 | | |
122 | | llama_memory_context_ptr init_full() override; |
123 | | |
124 | | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
125 | | |
126 | | bool get_can_shift() const override; |
127 | | |
128 | | void clear(bool data) override; |
129 | | |
130 | | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
131 | | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
132 | | void seq_keep(llama_seq_id seq_id) override; |
133 | | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
134 | | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
135 | | |
136 | | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
137 | | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
138 | | |
139 | | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
140 | | |
141 | | // state write/load |
142 | | |
143 | | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
144 | | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
145 | | |
146 | | // |
147 | | // llama_kv_cache specific API |
148 | | // |
149 | | |
150 | | uint32_t get_size() const; |
151 | | uint32_t get_n_stream() const; |
152 | | |
153 | | bool get_has_shift() const; |
154 | | |
155 | | // |
156 | | // graph_build API |
157 | | // |
158 | | |
159 | | uint32_t get_n_kv(const slot_info & sinfo) const; |
160 | | |
161 | | // get views of the current state of the cache |
162 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
163 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
164 | | |
165 | | // store k_cur and v_cur in the cache based on the provided head location |
166 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; |
167 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; |
168 | | |
169 | | // |
170 | | // preparation API |
171 | | // |
172 | | |
173 | | // find places for the provided ubatches in the cache, returns the slot infos |
174 | | // return empty vector on failure |
175 | | slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); |
176 | | |
177 | | bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); |
178 | | |
179 | | // find a slot of kv cells that can hold the ubatch |
180 | | // if cont == true, then the slot must be continuous |
181 | | // return empty slot_info on failure |
182 | | slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; |
183 | | |
184 | | // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] |
185 | | void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); |
186 | | |
187 | | // |
188 | | // input API |
189 | | // |
190 | | |
191 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
192 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
193 | | |
194 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
195 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
196 | | |
197 | | void set_input_k_shift(ggml_tensor * dst) const; |
198 | | |
199 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
200 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
201 | | |
202 | | private: |
203 | | const llama_model & model; |
204 | | const llama_hparams & hparams; |
205 | | |
206 | | struct kv_layer { |
207 | | // layer index in the model |
208 | | // note: can be different from the layer index in the KV cache |
209 | | uint32_t il; |
210 | | |
211 | | ggml_tensor * k; |
212 | | ggml_tensor * v; |
213 | | |
214 | | std::vector<ggml_tensor *> k_stream; |
215 | | std::vector<ggml_tensor *> v_stream; |
216 | | }; |
217 | | |
218 | | bool v_trans = true; // the value tensor is transposed |
219 | | |
220 | | const uint32_t n_seq_max = 1; |
221 | | const uint32_t n_stream = 1; |
222 | | |
223 | | // required padding |
224 | | const uint32_t n_pad = 1; |
225 | | |
226 | | // SWA |
227 | | const uint32_t n_swa = 0; |
228 | | |
229 | | // env: LLAMA_KV_CACHE_DEBUG |
230 | | int debug = 0; |
231 | | |
232 | | // this is the SWA type of the cache - not to be confused with the model SWA type |
233 | | const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; |
234 | | |
235 | | // ggml contexts for the KV cache along with the allocated backend buffers: |
236 | | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
237 | | |
238 | | // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) |
239 | | // note: this is not part of the KV state and it's only used to speed-up the find_slot() method |
240 | | std::vector<uint32_t> v_heads; |
241 | | |
242 | | std::vector<llama_kv_cells> v_cells; |
243 | | |
244 | | // maps from a sequence id to a stream id |
245 | | std::vector<uint32_t> seq_to_stream; |
246 | | |
247 | | // pending stream copies that will be applied during the next update |
248 | | stream_copy_info sc_info; |
249 | | |
250 | | std::vector<kv_layer> layers; |
251 | | |
252 | | // model layer id -> KV cache layer id |
253 | | std::unordered_map<int32_t, int32_t> map_layer_ids; |
254 | | |
255 | | size_t total_size() const; |
256 | | |
257 | | size_t size_k_bytes() const; |
258 | | size_t size_v_bytes() const; |
259 | | |
260 | | bool is_masked_swa(llama_pos p0, llama_pos p1) const; |
261 | | |
262 | | ggml_tensor * build_rope_shift( |
263 | | const llama_cparams & cparams, |
264 | | ggml_context * ctx, |
265 | | ggml_tensor * cur, |
266 | | ggml_tensor * shift, |
267 | | ggml_tensor * factors, |
268 | | float freq_base, |
269 | | float freq_scale) const; |
270 | | |
271 | | ggml_cgraph * build_graph_shift( |
272 | | llm_graph_result * res, |
273 | | llama_context * lctx) const; |
274 | | |
275 | | struct cell_ranges_t { |
276 | | uint32_t strm; |
277 | | |
278 | | std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive |
279 | | }; |
280 | | |
281 | | void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; |
282 | | void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; |
283 | | |
284 | | bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); |
285 | | bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); |
286 | | }; |
287 | | |
288 | | class llama_kv_cache_context : public llama_memory_context_i { |
289 | | public: |
290 | | // some shorthands |
291 | | using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; |
292 | | using stream_copy_info = llama_kv_cache::stream_copy_info; |
293 | | |
294 | | // used for errors |
295 | | llama_kv_cache_context(llama_memory_status status); |
296 | | |
297 | | // used to create a full-cache context |
298 | | llama_kv_cache_context( |
299 | | llama_kv_cache * kv); |
300 | | |
301 | | // used to create an update context |
302 | | llama_kv_cache_context( |
303 | | llama_kv_cache * kv, |
304 | | llama_context * lctx, |
305 | | bool do_shift, |
306 | | stream_copy_info sc_info); |
307 | | |
308 | | // used to create a batch procesing context from a batch |
309 | | llama_kv_cache_context( |
310 | | llama_kv_cache * kv, |
311 | | slot_info_vec_t sinfos, |
312 | | std::vector<llama_ubatch> ubatches); |
313 | | |
314 | | virtual ~llama_kv_cache_context(); |
315 | | |
316 | | // |
317 | | // llama_memory_context_i |
318 | | // |
319 | | |
320 | | bool next() override; |
321 | | bool apply() override; |
322 | | |
323 | | llama_memory_status get_status() const override; |
324 | | const llama_ubatch & get_ubatch() const override; |
325 | | |
326 | | // |
327 | | // llama_kv_cache_context specific API |
328 | | // |
329 | | |
330 | | uint32_t get_n_kv() const; |
331 | | |
332 | | // get views of the current state of the cache |
333 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; |
334 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; |
335 | | |
336 | | // store k_cur and v_cur in the cache based on the provided head location |
337 | | // note: the heads in k_cur and v_cur should be layed out contiguously in memory |
338 | | // - k_cur [n_embd_head_k, n_head_k, n_tokens] |
339 | | // - k_idxs [n_tokens] |
340 | | // - v_cur [n_embd_head_v, n_head_v, n_tokens] |
341 | | // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed |
342 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; |
343 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; |
344 | | |
345 | | // create destination indices for each head of the current batch for where it would be written in the KV cache |
346 | | // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but |
347 | | // helps understand the implementation logic of cpy_k and cpy_v |
348 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
349 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
350 | | |
351 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
352 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
353 | | |
354 | | void set_input_k_shift (ggml_tensor * dst) const; |
355 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
356 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
357 | | |
358 | | private: |
359 | | llama_memory_status status; |
360 | | |
361 | | llama_kv_cache * kv; |
362 | | llama_context * lctx; |
363 | | |
364 | | // |
365 | | // update context |
366 | | // |
367 | | |
368 | | bool do_shift = false; |
369 | | |
370 | | stream_copy_info sc_info; |
371 | | |
372 | | // |
373 | | // batch processing context |
374 | | // |
375 | | |
376 | | // the index of the cur ubatch to process |
377 | | size_t i_cur = 0; |
378 | | |
379 | | slot_info_vec_t sinfos; |
380 | | |
381 | | std::vector<llama_ubatch> ubatches; |
382 | | |
383 | | // |
384 | | // data needed for building the compute graph for the current ubatch: |
385 | | // |
386 | | |
387 | | // a heuristic, to avoid attending the full cache if it is not yet utilized |
388 | | // as the cache gets filled, the benefit from this heuristic disappears |
389 | | int32_t n_kv; |
390 | | }; |