Coverage Report

Created: 2025-11-24 06:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-kv-cache.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-batch.h"
4
#include "llama-graph.h"
5
#include "llama-kv-cells.h"
6
#include "llama-memory.h"
7
8
#include <unordered_map>
9
#include <vector>
10
11
struct llama_cparams;
12
struct llama_hparams;
13
struct llama_model;
14
struct llama_context;
15
16
//
17
// llama_kv_cache
18
//
19
20
class llama_kv_cache : public llama_memory_i {
21
public:
22
    struct stream_copy_info {
23
0
        bool empty() const {
24
0
            assert(ssrc.size() == sdst.size());
25
0
            return ssrc.empty();
26
0
        }
27
28
        std::vector<uint32_t> ssrc;
29
        std::vector<uint32_t> sdst;
30
    };
31
32
    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
33
    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
34
    struct slot_info {
35
        // data for ggml_set_rows
36
        using idx_vec_t = std::vector<uint32_t>;
37
38
        // number of streams: ns = s1 - s0 + 1
39
        uint32_t s0;
40
        uint32_t s1;
41
42
        std::vector<llama_seq_id> strm; // [ns]
43
        std::vector<idx_vec_t>    idxs; // [ns]
44
45
0
        uint32_t head() const {
46
0
            GGML_ASSERT(idxs.size() == 1);
47
0
            GGML_ASSERT(!idxs[0].empty());
48
49
0
            return idxs[0][0];
50
0
        }
51
52
0
        void resize(size_t n) {
53
0
            strm.resize(n);
54
0
            idxs.resize(n);
55
0
        }
56
57
0
        size_t size() const {
58
0
            GGML_ASSERT(idxs.size() == strm.size());
59
0
            GGML_ASSERT(!idxs.empty());
60
61
0
            return idxs[0].size();
62
0
        }
63
64
0
        size_t n_stream() const {
65
0
            return strm.size();
66
0
        }
67
68
0
        bool empty() const {
69
0
            return idxs.empty();
70
0
        }
71
72
0
        void clear() {
73
0
            idxs.clear();
74
0
        }
75
    };
76
77
    using slot_info_vec_t = std::vector<slot_info>;
78
79
    llama_kv_cache(
80
            const llama_model & model,
81
                    ggml_type   type_k,
82
                    ggml_type   type_v,
83
                         bool   v_trans,
84
                         bool   offload,
85
                         bool   unified,
86
                     uint32_t   kv_size,
87
                     uint32_t   n_seq_max,
88
                     uint32_t   n_pad,
89
                     uint32_t   n_swa,
90
               llama_swa_type   swa_type,
91
        const layer_filter_cb & filter,
92
        const  layer_reuse_cb & reuse);
93
94
0
    ~llama_kv_cache() = default;
95
96
    //
97
    // llama_memory_i
98
    //
99
100
    llama_memory_context_ptr init_batch(
101
            llama_batch_allocr & balloc,
102
            uint32_t n_ubatch,
103
            bool embd_all) override;
104
105
    llama_memory_context_ptr init_full() override;
106
107
    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
108
109
    bool get_can_shift() const override;
110
111
    void clear(bool data) override;
112
113
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
114
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
115
    void seq_keep(llama_seq_id seq_id)                                                          override;
116
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
117
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
118
119
    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
120
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
121
122
    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
123
124
    // state write/load
125
126
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
127
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
128
129
    //
130
    // llama_kv_cache specific API
131
    //
132
133
    uint32_t get_size()     const;
134
    uint32_t get_n_stream() const;
135
136
    bool get_has_shift() const;
137
138
    //
139
    // graph_build API
140
    //
141
142
    uint32_t get_n_kv(const slot_info & sinfo) const;
143
144
    // get views of the current state of the cache
145
    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
146
    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
147
148
    // store k_cur and v_cur in the cache based on the provided head location
149
    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
150
    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
151
152
    //
153
    // preparation API
154
    //
155
156
    // find places for the provided ubatches in the cache, returns the slot infos
157
    // return empty vector on failure
158
    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
159
160
    bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
161
162
    // find a slot of kv cells that can hold the ubatch
163
    // if cont == true, then the slot must be continuous
164
    // return empty slot_info on failure
165
    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
166
167
    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
168
    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
169
170
    //
171
    // input API
172
    //
173
174
    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
175
    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
176
177
    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
178
    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
179
180
    void set_input_k_shift(ggml_tensor * dst) const;
181
182
    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
183
    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
184
185
private:
186
    const llama_model & model;
187
    const llama_hparams & hparams;
188
189
    struct kv_layer {
190
        // layer index in the model
191
        // note: can be different from the layer index in the KV cache
192
        uint32_t il;
193
194
        ggml_tensor * k;
195
        ggml_tensor * v;
196
197
        std::vector<ggml_tensor *> k_stream;
198
        std::vector<ggml_tensor *> v_stream;
199
    };
200
201
    bool v_trans = true;  // the value tensor is transposed
202
203
    const uint32_t n_seq_max = 1;
204
    const uint32_t n_stream  = 1;
205
206
    // required padding
207
    const uint32_t n_pad = 1;
208
209
    // SWA
210
    const uint32_t n_swa = 0;
211
212
    // env: LLAMA_KV_CACHE_DEBUG
213
    int debug = 0;
214
215
    // this is the SWA type of the cache - not to be confused with the model SWA type
216
    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
217
218
    // ggml contexts for the KV cache along with the allocated backend buffers:
219
    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
220
221
    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
222
    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
223
    std::vector<uint32_t> v_heads;
224
225
    std::vector<llama_kv_cells> v_cells;
226
227
    // maps from a sequence id to a stream id
228
    std::vector<uint32_t> seq_to_stream;
229
230
    // pending stream copies that will be applied during the next update
231
    stream_copy_info sc_info;
232
233
    std::vector<kv_layer> layers;
234
235
    // model layer id -> KV cache layer id
236
    std::unordered_map<int32_t, int32_t> map_layer_ids;
237
238
    size_t total_size() const;
239
240
    size_t size_k_bytes() const;
241
    size_t size_v_bytes() const;
242
243
    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
244
245
    ggml_tensor * build_rope_shift(
246
            const llama_cparams & cparams,
247
                   ggml_context * ctx,
248
                    ggml_tensor * cur,
249
                    ggml_tensor * shift,
250
                    ggml_tensor * factors,
251
                          float   freq_base,
252
                          float   freq_scale) const;
253
254
    ggml_cgraph * build_graph_shift(
255
               llm_graph_result * res,
256
                  llama_context * lctx) const;
257
258
    struct cell_ranges_t {
259
        uint32_t strm;
260
261
        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
262
    };
263
264
    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
265
    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
266
267
    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
268
    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
269
};
270
271
class llama_kv_cache_context : public llama_memory_context_i {
272
public:
273
    // some shorthands
274
    using slot_info_vec_t  = llama_kv_cache::slot_info_vec_t;
275
    using stream_copy_info = llama_kv_cache::stream_copy_info;
276
277
    // used for errors
278
    llama_kv_cache_context(llama_memory_status status);
279
280
    // used to create a full-cache context
281
    llama_kv_cache_context(
282
            llama_kv_cache * kv);
283
284
    // used to create an update context
285
    llama_kv_cache_context(
286
            llama_kv_cache * kv,
287
            llama_context * lctx,
288
            bool do_shift,
289
            stream_copy_info sc_info);
290
291
    // used to create a batch procesing context from a batch
292
    llama_kv_cache_context(
293
            llama_kv_cache * kv,
294
            slot_info_vec_t sinfos,
295
            std::vector<llama_ubatch> ubatches);
296
297
    virtual ~llama_kv_cache_context();
298
299
    //
300
    // llama_memory_context_i
301
    //
302
303
    bool next()  override;
304
    bool apply() override;
305
306
    llama_memory_status  get_status() const override;
307
    const llama_ubatch & get_ubatch() const override;
308
309
    //
310
    // llama_kv_cache_context specific API
311
    //
312
313
    uint32_t get_n_kv() const;
314
315
    // get views of the current state of the cache
316
    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
317
    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
318
319
    // store k_cur and v_cur in the cache based on the provided head location
320
    // note: the heads in k_cur and v_cur should be layed out contiguously in memory
321
    //   - k_cur  [n_embd_head_k, n_head_k, n_tokens]
322
    //   - k_idxs [n_tokens]
323
    //   - v_cur  [n_embd_head_v, n_head_v, n_tokens]
324
    //   - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
325
    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
326
    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
327
328
    // create destination indices for each head of the current batch for where it would be written in the KV cache
329
    // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
330
    //   helps understand the implementation logic of cpy_k and cpy_v
331
    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
332
    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
333
334
    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
335
    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
336
337
    void set_input_k_shift   (ggml_tensor * dst) const;
338
    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
339
    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
340
341
private:
342
    llama_memory_status status;
343
344
    llama_kv_cache * kv;
345
    llama_context * lctx;
346
347
    //
348
    // update context
349
    //
350
351
    bool do_shift = false;
352
353
    stream_copy_info sc_info;
354
355
    //
356
    // batch processing context
357
    //
358
359
    // the index of the cur ubatch to process
360
    size_t i_cur = 0;
361
362
    slot_info_vec_t sinfos;
363
364
    std::vector<llama_ubatch> ubatches;
365
366
    //
367
    // data needed for building the compute graph for the current ubatch:
368
    //
369
370
    // a heuristic, to avoid attending the full cache if it is not yet utilized
371
    // as the cache gets filled, the benefit from this heuristic disappears
372
    int32_t n_kv;
373
};