/src/llama.cpp/src/llama-kv-cache.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama-batch.h" |
4 | | #include "llama-graph.h" |
5 | | #include "llama-kv-cells.h" |
6 | | #include "llama-memory.h" |
7 | | |
8 | | #include <unordered_map> |
9 | | #include <vector> |
10 | | |
11 | | struct llama_cparams; |
12 | | struct llama_hparams; |
13 | | struct llama_model; |
14 | | struct llama_context; |
15 | | |
16 | | // |
17 | | // llama_kv_cache |
18 | | // |
19 | | |
20 | | class llama_kv_cache : public llama_memory_i { |
21 | | public: |
22 | | struct stream_copy_info { |
23 | 0 | bool empty() const { |
24 | 0 | assert(ssrc.size() == sdst.size()); |
25 | 0 | return ssrc.empty(); |
26 | 0 | } |
27 | | |
28 | | std::vector<uint32_t> ssrc; |
29 | | std::vector<uint32_t> sdst; |
30 | | }; |
31 | | |
32 | | // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the |
33 | | // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] |
34 | | struct slot_info { |
35 | | // data for ggml_set_rows |
36 | | using idx_vec_t = std::vector<uint32_t>; |
37 | | |
38 | | // number of streams: ns = s1 - s0 + 1 |
39 | | uint32_t s0; |
40 | | uint32_t s1; |
41 | | |
42 | | std::vector<llama_seq_id> strm; // [ns] |
43 | | std::vector<idx_vec_t> idxs; // [ns] |
44 | | |
45 | 0 | uint32_t head() const { |
46 | 0 | GGML_ASSERT(idxs.size() == 1); |
47 | 0 | GGML_ASSERT(!idxs[0].empty()); |
48 | |
|
49 | 0 | return idxs[0][0]; |
50 | 0 | } |
51 | | |
52 | 0 | void resize(size_t n) { |
53 | 0 | strm.resize(n); |
54 | 0 | idxs.resize(n); |
55 | 0 | } |
56 | | |
57 | 0 | size_t size() const { |
58 | 0 | GGML_ASSERT(idxs.size() == strm.size()); |
59 | 0 | GGML_ASSERT(!idxs.empty()); |
60 | |
|
61 | 0 | return idxs[0].size(); |
62 | 0 | } |
63 | | |
64 | 0 | size_t n_stream() const { |
65 | 0 | return strm.size(); |
66 | 0 | } |
67 | | |
68 | 0 | bool empty() const { |
69 | 0 | return idxs.empty(); |
70 | 0 | } |
71 | | |
72 | 0 | void clear() { |
73 | 0 | idxs.clear(); |
74 | 0 | } |
75 | | }; |
76 | | |
77 | | using slot_info_vec_t = std::vector<slot_info>; |
78 | | |
79 | | llama_kv_cache( |
80 | | const llama_model & model, |
81 | | ggml_type type_k, |
82 | | ggml_type type_v, |
83 | | bool v_trans, |
84 | | bool offload, |
85 | | bool unified, |
86 | | uint32_t kv_size, |
87 | | uint32_t n_seq_max, |
88 | | uint32_t n_pad, |
89 | | uint32_t n_swa, |
90 | | llama_swa_type swa_type, |
91 | | const layer_filter_cb & filter, |
92 | | const layer_reuse_cb & reuse); |
93 | | |
94 | 0 | ~llama_kv_cache() = default; |
95 | | |
96 | | // |
97 | | // llama_memory_i |
98 | | // |
99 | | |
100 | | llama_memory_context_ptr init_batch( |
101 | | llama_batch_allocr & balloc, |
102 | | uint32_t n_ubatch, |
103 | | bool embd_all) override; |
104 | | |
105 | | llama_memory_context_ptr init_full() override; |
106 | | |
107 | | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
108 | | |
109 | | bool get_can_shift() const override; |
110 | | |
111 | | void clear(bool data) override; |
112 | | |
113 | | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
114 | | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
115 | | void seq_keep(llama_seq_id seq_id) override; |
116 | | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
117 | | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
118 | | |
119 | | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
120 | | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
121 | | |
122 | | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
123 | | |
124 | | // state write/load |
125 | | |
126 | | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
127 | | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
128 | | |
129 | | // |
130 | | // llama_kv_cache specific API |
131 | | // |
132 | | |
133 | | uint32_t get_size() const; |
134 | | uint32_t get_n_stream() const; |
135 | | |
136 | | bool get_has_shift() const; |
137 | | |
138 | | // |
139 | | // graph_build API |
140 | | // |
141 | | |
142 | | uint32_t get_n_kv(const slot_info & sinfo) const; |
143 | | |
144 | | // get views of the current state of the cache |
145 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
146 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
147 | | |
148 | | // store k_cur and v_cur in the cache based on the provided head location |
149 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; |
150 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; |
151 | | |
152 | | // |
153 | | // preparation API |
154 | | // |
155 | | |
156 | | // find places for the provided ubatches in the cache, returns the slot infos |
157 | | // return empty vector on failure |
158 | | slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); |
159 | | |
160 | | bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); |
161 | | |
162 | | // find a slot of kv cells that can hold the ubatch |
163 | | // if cont == true, then the slot must be continuous |
164 | | // return empty slot_info on failure |
165 | | slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; |
166 | | |
167 | | // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] |
168 | | void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); |
169 | | |
170 | | // |
171 | | // input API |
172 | | // |
173 | | |
174 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
175 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
176 | | |
177 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
178 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
179 | | |
180 | | void set_input_k_shift(ggml_tensor * dst) const; |
181 | | |
182 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
183 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
184 | | |
185 | | private: |
186 | | const llama_model & model; |
187 | | const llama_hparams & hparams; |
188 | | |
189 | | struct kv_layer { |
190 | | // layer index in the model |
191 | | // note: can be different from the layer index in the KV cache |
192 | | uint32_t il; |
193 | | |
194 | | ggml_tensor * k; |
195 | | ggml_tensor * v; |
196 | | |
197 | | std::vector<ggml_tensor *> k_stream; |
198 | | std::vector<ggml_tensor *> v_stream; |
199 | | }; |
200 | | |
201 | | bool v_trans = true; // the value tensor is transposed |
202 | | |
203 | | const uint32_t n_seq_max = 1; |
204 | | const uint32_t n_stream = 1; |
205 | | |
206 | | // required padding |
207 | | const uint32_t n_pad = 1; |
208 | | |
209 | | // SWA |
210 | | const uint32_t n_swa = 0; |
211 | | |
212 | | // env: LLAMA_KV_CACHE_DEBUG |
213 | | int debug = 0; |
214 | | |
215 | | // this is the SWA type of the cache - not to be confused with the model SWA type |
216 | | const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; |
217 | | |
218 | | // ggml contexts for the KV cache along with the allocated backend buffers: |
219 | | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
220 | | |
221 | | // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) |
222 | | // note: this is not part of the KV state and it's only used to speed-up the find_slot() method |
223 | | std::vector<uint32_t> v_heads; |
224 | | |
225 | | std::vector<llama_kv_cells> v_cells; |
226 | | |
227 | | // maps from a sequence id to a stream id |
228 | | std::vector<uint32_t> seq_to_stream; |
229 | | |
230 | | // pending stream copies that will be applied during the next update |
231 | | stream_copy_info sc_info; |
232 | | |
233 | | std::vector<kv_layer> layers; |
234 | | |
235 | | // model layer id -> KV cache layer id |
236 | | std::unordered_map<int32_t, int32_t> map_layer_ids; |
237 | | |
238 | | size_t total_size() const; |
239 | | |
240 | | size_t size_k_bytes() const; |
241 | | size_t size_v_bytes() const; |
242 | | |
243 | | bool is_masked_swa(llama_pos p0, llama_pos p1) const; |
244 | | |
245 | | ggml_tensor * build_rope_shift( |
246 | | const llama_cparams & cparams, |
247 | | ggml_context * ctx, |
248 | | ggml_tensor * cur, |
249 | | ggml_tensor * shift, |
250 | | ggml_tensor * factors, |
251 | | float freq_base, |
252 | | float freq_scale) const; |
253 | | |
254 | | ggml_cgraph * build_graph_shift( |
255 | | llm_graph_result * res, |
256 | | llama_context * lctx) const; |
257 | | |
258 | | struct cell_ranges_t { |
259 | | uint32_t strm; |
260 | | |
261 | | std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive |
262 | | }; |
263 | | |
264 | | void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; |
265 | | void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; |
266 | | |
267 | | bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
268 | | bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); |
269 | | }; |
270 | | |
271 | | class llama_kv_cache_context : public llama_memory_context_i { |
272 | | public: |
273 | | // some shorthands |
274 | | using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; |
275 | | using stream_copy_info = llama_kv_cache::stream_copy_info; |
276 | | |
277 | | // used for errors |
278 | | llama_kv_cache_context(llama_memory_status status); |
279 | | |
280 | | // used to create a full-cache context |
281 | | llama_kv_cache_context( |
282 | | llama_kv_cache * kv); |
283 | | |
284 | | // used to create an update context |
285 | | llama_kv_cache_context( |
286 | | llama_kv_cache * kv, |
287 | | llama_context * lctx, |
288 | | bool do_shift, |
289 | | stream_copy_info sc_info); |
290 | | |
291 | | // used to create a batch procesing context from a batch |
292 | | llama_kv_cache_context( |
293 | | llama_kv_cache * kv, |
294 | | slot_info_vec_t sinfos, |
295 | | std::vector<llama_ubatch> ubatches); |
296 | | |
297 | | virtual ~llama_kv_cache_context(); |
298 | | |
299 | | // |
300 | | // llama_memory_context_i |
301 | | // |
302 | | |
303 | | bool next() override; |
304 | | bool apply() override; |
305 | | |
306 | | llama_memory_status get_status() const override; |
307 | | const llama_ubatch & get_ubatch() const override; |
308 | | |
309 | | // |
310 | | // llama_kv_cache_context specific API |
311 | | // |
312 | | |
313 | | uint32_t get_n_kv() const; |
314 | | |
315 | | // get views of the current state of the cache |
316 | | ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; |
317 | | ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; |
318 | | |
319 | | // store k_cur and v_cur in the cache based on the provided head location |
320 | | // note: the heads in k_cur and v_cur should be layed out contiguously in memory |
321 | | // - k_cur [n_embd_head_k, n_head_k, n_tokens] |
322 | | // - k_idxs [n_tokens] |
323 | | // - v_cur [n_embd_head_v, n_head_v, n_tokens] |
324 | | // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed |
325 | | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; |
326 | | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; |
327 | | |
328 | | // create destination indices for each head of the current batch for where it would be written in the KV cache |
329 | | // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but |
330 | | // helps understand the implementation logic of cpy_k and cpy_v |
331 | | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
332 | | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
333 | | |
334 | | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
335 | | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
336 | | |
337 | | void set_input_k_shift (ggml_tensor * dst) const; |
338 | | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
339 | | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
340 | | |
341 | | private: |
342 | | llama_memory_status status; |
343 | | |
344 | | llama_kv_cache * kv; |
345 | | llama_context * lctx; |
346 | | |
347 | | // |
348 | | // update context |
349 | | // |
350 | | |
351 | | bool do_shift = false; |
352 | | |
353 | | stream_copy_info sc_info; |
354 | | |
355 | | // |
356 | | // batch processing context |
357 | | // |
358 | | |
359 | | // the index of the cur ubatch to process |
360 | | size_t i_cur = 0; |
361 | | |
362 | | slot_info_vec_t sinfos; |
363 | | |
364 | | std::vector<llama_ubatch> ubatches; |
365 | | |
366 | | // |
367 | | // data needed for building the compute graph for the current ubatch: |
368 | | // |
369 | | |
370 | | // a heuristic, to avoid attending the full cache if it is not yet utilized |
371 | | // as the cache gets filled, the benefit from this heuristic disappears |
372 | | int32_t n_kv; |
373 | | }; |