/src/llama.cpp/src/llama-kv-cache.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama-batch.h" |
4 | | #include "llama-graph.h" |
5 | | #include "llama-kv-cells.h" |
6 | | #include "llama-memory.h" |
7 | | |
8 | | #include <unordered_map> |
9 | | #include <vector> |
10 | | |
11 | | struct llama_cparams; |
12 | | struct llama_hparams; |
13 | | struct llama_model; |
14 | | struct llama_context; |
15 | | |
16 | | // |
17 | | // llama_kv_cache |
18 | | // |
19 | | |
20 | | class llama_kv_cache : public llama_memory_i { |
21 | | public: |
22 | | struct stream_copy_info { |
23 | 0 | bool empty() const { |
24 | 0 | assert(ssrc.size() == sdst.size()); |
25 | 0 | return ssrc.empty(); |
26 | 0 | } |
27 | | |
28 | | std::vector<uint32_t> ssrc; |
29 | | std::vector<uint32_t> sdst; |
30 | | }; |
31 | | |
32 | | // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the |
33 | | // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] |
34 | | struct slot_info { |
35 | | // data for ggml_set_rows |
36 | | using idx_vec_t = std::vector<uint32_t>; |
37 | | |
38 | | // number of streams: ns = s1 - s0 + 1 |
39 | | uint32_t s0; |
40 | | uint32_t s1; |
41 | | |
42 | | std::vector<llama_seq_id> strm; // [ns] |
43 | | std::vector<idx_vec_t> idxs; // [ns] |
44 | | |
45 | 0 | uint32_t head() const { |
46 | 0 | GGML_ASSERT(idxs.size() == 1); |
47 | 0 | GGML_ASSERT(!idxs[0].empty()); |
48 | |
|
49 | 0 | return idxs[0][0]; |
50 | 0 | } |
51 | | |
52 | 0 | void resize(size_t n) { |
53 | 0 | strm.resize(n); |
54 | 0 | idxs.resize(n); |
55 | 0 | } |
56 | | |
57 | 0 | size_t size() const { |
58 | 0 | GGML_ASSERT(idxs.size() == strm.size()); |
59 | 0 | GGML_ASSERT(!idxs.empty()); |
60 | |
|
61 | 0 | return idxs[0].size(); |
62 | 0 | } |
63 | | |
64 | 0 | size_t n_stream() const { |
65 | 0 | return strm.size(); |
66 | 0 | } |
67 | | |
68 | 0 | bool empty() const { |
69 | 0 | return idxs.empty(); |
70 | 0 | } |
71 | | |
72 | 0 | void clear() { |
73 | 0 | idxs.clear(); |
74 | 0 | } |
75 | | |
76 | | // check if indices are contiguous starting from head() |
77 | 0 | bool is_contiguous() const { |
78 | 0 | if (idxs.empty() || idxs[0].empty()) { |
79 | 0 | return true; |
80 | 0 | } |
81 | 0 | if (idxs.size() > 1) { |
82 | 0 | return false; |
83 | 0 | } |
84 | 0 | const uint32_t h = idxs[0][0]; |
85 | 0 | for (size_t i = 0; i < idxs[0].size(); ++i) { |
86 | 0 | if (idxs[0][i] != h + i) { |
87 | 0 | return false; |
88 | 0 | } |
89 | 0 | } |
90 | 0 | return true; |
91 | 0 | } |
92 | | }; |
93 | | |
94 | | using slot_info_vec_t = std::vector<slot_info>; |
95 | | |
96 | | // TODO: refactor the memory instances to not depend on `llama_model` |
97 | | // instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly |
98 | | // likely through `struct llama_memory_params` |
99 | | llama_kv_cache( |
100 | | const llama_model & model, |
101 | | const llama_hparams & hparams, |
102 | | ggml_type type_k, |
103 | | ggml_type type_v, |
104 | | bool v_trans, |
105 | | bool offload, |
106 | | bool unified, |
107 | | uint32_t kv_size, |
108 | | uint32_t n_seq_max, |
109 | | uint32_t n_pad, |
110 | | uint32_t n_swa, |
111 | | llama_swa_type swa_type, |
112 | | llama_memory_t mem_other, |
113 | | const layer_filter_cb & filter, |
114 | | const layer_reuse_cb & reuse, |
115 | | const layer_share_cb & share); |
116 | | |
117 | 0 | ~llama_kv_cache() = default; |
118 | | |
119 | | // |
120 | | // llama_memory_i |
121 | | // |
122 | | |
123 | | llama_memory_context_ptr init_batch( |
124 | | llama_batch_allocr & balloc, |
125 | | uint32_t n_ubatch, |
126 | | bool embd_all) override; |
127 | | |
128 | | llama_memory_context_ptr init_full() override; |
129 | | |
130 | | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
131 | | |
132 | | bool get_can_shift() const override; |
133 | | |
134 | | void clear(bool data) override; |
135 | | |
136 | | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
137 | | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
138 | | void seq_keep(llama_seq_id seq_id) override; |
139 | | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
140 | | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
141 | | |
142 | | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
143 | | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
144 | | |
145 | | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
146 | | |
147 | | // state write/load |
148 | | |
149 | | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
150 | | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
151 | | |
152 | | // |
153 | | // llama_kv_cache specific API |
154 | | // |
155 | | |
156 | | uint32_t get_size() const; |
157 | | uint32_t get_n_stream() const; |
158 | | |
159 | | bool get_has_shift() const; |
160 | | |
161 | | ggml_type type_k() const; |
162 | | ggml_type type_v() const; |
163 | | |
164 | | // |
165 | | // graph_build API |
166 | | // |
167 | | |
168 | | uint32_t get_n_kv(const slot_info & sinfo) const; |
169 | | |
170 | | // get views of the current state of the cache |
171 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
172 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
173 | | |
174 | | // store k_cur and v_cur in the cache based on the provided head location |
175 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; |
176 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; |
177 | | |
178 | | // |
179 | | // preparation API |
180 | | // |
181 | | |
182 | | // find places for the provided ubatches in the cache, returns the slot infos |
183 | | // return empty vector on failure |
184 | | slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); |
185 | | |
186 | | bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); |
187 | | |
188 | | // find a slot of kv cells that can hold the ubatch |
189 | | // if cont == true, then the slot must be continuous |
190 | | // return empty slot_info on failure |
191 | | slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; |
192 | | |
193 | | // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] |
194 | | void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); |
195 | | |
196 | | // |
197 | | // input API |
198 | | // |
199 | | |
200 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
201 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
202 | | |
203 | | ggml_tensor * build_input_k_rot(ggml_context * ctx) const; |
204 | | ggml_tensor * build_input_v_rot(ggml_context * ctx) const; |
205 | | |
206 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
207 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
208 | | |
209 | | void set_input_k_shift(ggml_tensor * dst) const; |
210 | | |
211 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
212 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
213 | | |
214 | | void set_input_k_rot(ggml_tensor * dst) const; |
215 | | void set_input_v_rot(ggml_tensor * dst) const; |
216 | | |
217 | | private: |
218 | | const llama_model & model; |
219 | | const llama_hparams & hparams; |
220 | | |
221 | | struct kv_layer { |
222 | | // layer index in the model |
223 | | // note: can be different from the layer index in the KV cache |
224 | | uint32_t il; |
225 | | |
226 | | ggml_tensor * k; |
227 | | ggml_tensor * v; |
228 | | |
229 | | std::vector<ggml_tensor *> k_stream; |
230 | | std::vector<ggml_tensor *> v_stream; |
231 | | }; |
232 | | |
233 | | bool v_trans = true; // the value tensor is transposed |
234 | | |
235 | | const uint32_t n_seq_max = 1; |
236 | | const uint32_t n_stream = 1; |
237 | | |
238 | | // required padding |
239 | | const uint32_t n_pad = 1; |
240 | | |
241 | | // SWA |
242 | | const uint32_t n_swa = 0; |
243 | | |
244 | | // env: LLAMA_ATTN_ROT_DISABLE |
245 | | bool attn_rot_k = false; |
246 | | bool attn_rot_v = false; |
247 | | |
248 | | // if all layers participating in the cache have constant head size, the value is stored here |
249 | | // otherwise the value is -1 |
250 | | int32_t n_embd_head_k_all = 0; |
251 | | int32_t n_embd_head_v_all = 0; |
252 | | |
253 | | // pre-computed hadamard martrices |
254 | | std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard; |
255 | | |
256 | | // env: LLAMA_KV_CACHE_DEBUG |
257 | | int debug = 0; |
258 | | |
259 | | // this is the SWA type of the cache - not to be confused with the model SWA type |
260 | | const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; |
261 | | |
262 | | // ggml contexts for the KV cache along with the allocated backend buffers: |
263 | | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
264 | | |
265 | | // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) |
266 | | // note: this is not part of the KV state and it's only used to speed-up the find_slot() method |
267 | | std::vector<uint32_t> v_heads; |
268 | | |
269 | | // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] |
270 | | llama_kv_cache * other; |
271 | | |
272 | | std::shared_ptr<llama_kv_cells_vec> v_cells_impl; |
273 | | |
274 | | llama_kv_cells_vec & v_cells; |
275 | | |
276 | | // maps from a sequence id to a stream id |
277 | | std::vector<uint32_t> seq_to_stream; |
278 | | |
279 | | // pending stream copies that will be applied during the next update |
280 | | stream_copy_info sc_info; |
281 | | |
282 | | std::vector<kv_layer> layers; |
283 | | |
284 | | // model layer id -> KV cache layer id |
285 | | std::unordered_map<int32_t, int32_t> map_layer_ids; |
286 | | |
287 | | size_t total_size() const; |
288 | | |
289 | | size_t size_k_bytes() const; |
290 | | size_t size_v_bytes() const; |
291 | | |
292 | | ggml_tensor * build_rope_shift( |
293 | | const llama_cparams & cparams, |
294 | | ggml_context * ctx, |
295 | | ggml_tensor * cur, |
296 | | ggml_tensor * shift, |
297 | | ggml_tensor * rot, |
298 | | ggml_tensor * factors, |
299 | | float freq_base, |
300 | | float freq_scale, |
301 | | uint32_t il) const; |
302 | | |
303 | | ggml_cgraph * build_graph_shift( |
304 | | llm_graph_result * res, |
305 | | llama_context * lctx) const; |
306 | | |
307 | | struct cell_ranges_t { |
308 | | uint32_t strm; |
309 | | |
310 | | std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive |
311 | | }; |
312 | | |
313 | | void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; |
314 | | void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; |
315 | | |
316 | | bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); |
317 | | bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); |
318 | | }; |
319 | | |
320 | | class llama_kv_cache_context : public llama_memory_context_i { |
321 | | public: |
322 | | // some shorthands |
323 | | using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; |
324 | | using stream_copy_info = llama_kv_cache::stream_copy_info; |
325 | | |
326 | | // used for errors |
327 | | llama_kv_cache_context(llama_memory_status status); |
328 | | |
329 | | // used to create a full-cache context |
330 | | llama_kv_cache_context( |
331 | | llama_kv_cache * kv); |
332 | | |
333 | | // used to create an update context |
334 | | llama_kv_cache_context( |
335 | | llama_kv_cache * kv, |
336 | | llama_context * lctx, |
337 | | bool do_shift, |
338 | | stream_copy_info sc_info); |
339 | | |
340 | | // used to create a batch processing context from a batch |
341 | | llama_kv_cache_context( |
342 | | llama_kv_cache * kv, |
343 | | slot_info_vec_t sinfos, |
344 | | std::vector<llama_ubatch> ubatches); |
345 | | |
346 | | virtual ~llama_kv_cache_context(); |
347 | | |
348 | | // |
349 | | // llama_memory_context_i |
350 | | // |
351 | | |
352 | | bool next() override; |
353 | | bool apply() override; |
354 | | |
355 | | llama_memory_status get_status() const override; |
356 | | const llama_ubatch & get_ubatch() const override; |
357 | | |
358 | | // |
359 | | // llama_kv_cache_context specific API |
360 | | // |
361 | | |
362 | | uint32_t get_n_kv() const; |
363 | | |
364 | | ggml_type type_k() const; |
365 | | ggml_type type_v() const; |
366 | | |
367 | | // get views of the current state of the cache |
368 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; |
369 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; |
370 | | |
371 | | // store k_cur and v_cur in the cache based on the provided head location |
372 | | // note: the heads in k_cur and v_cur should be laid out contiguously in memory |
373 | | // - k_cur [n_embd_head_k, n_head_k, n_tokens] |
374 | | // - k_idxs [n_tokens] |
375 | | // - v_cur [n_embd_head_v, n_head_v, n_tokens] |
376 | | // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed |
377 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; |
378 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; |
379 | | |
380 | | // create destination indices for each head of the current batch for where it would be written in the KV cache |
381 | | // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but |
382 | | // helps understand the implementation logic of cpy_k and cpy_v |
383 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
384 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
385 | | |
386 | | ggml_tensor * build_input_k_rot(ggml_context * ctx) const; |
387 | | ggml_tensor * build_input_v_rot(ggml_context * ctx) const; |
388 | | |
389 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
390 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
391 | | |
392 | | void set_input_k_shift (ggml_tensor * dst) const; |
393 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
394 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
395 | | |
396 | | void set_input_k_rot(ggml_tensor * dst) const; |
397 | | void set_input_v_rot(ggml_tensor * dst) const; |
398 | | |
399 | | private: |
400 | | llama_memory_status status; |
401 | | |
402 | | llama_kv_cache * kv; |
403 | | llama_context * lctx; |
404 | | |
405 | | // |
406 | | // update context |
407 | | // |
408 | | |
409 | | bool do_shift = false; |
410 | | |
411 | | stream_copy_info sc_info; |
412 | | |
413 | | // |
414 | | // batch processing context |
415 | | // |
416 | | |
417 | | // the index of the cur ubatch to process |
418 | | size_t i_cur = 0; |
419 | | |
420 | | slot_info_vec_t sinfos; |
421 | | |
422 | | std::vector<llama_ubatch> ubatches; |
423 | | |
424 | | // |
425 | | // data needed for building the compute graph for the current ubatch: |
426 | | // |
427 | | |
428 | | // a heuristic, to avoid attending the full cache if it is not yet utilized |
429 | | // as the cache gets filled, the benefit from this heuristic disappears |
430 | | int32_t n_kv; |
431 | | }; |