Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-kv-cache-dsa.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-kv-cache.h"
4
5
#include <vector>
6
7
//
8
// llama_kv_cache_dsa
9
//
10
11
// utilizes two instances of llama_kv_cache:
12
// - the first instance is for caching key tensors of the model,
13
// - the second instance is for caching lightning indexer key tensors
14
15
class llama_kv_cache_dsa : public llama_memory_i {
16
public:
17
    llama_kv_cache_dsa(
18
            const llama_model & model,
19
                    ggml_type   type_k,
20
                    ggml_type   type_v,
21
                         bool   v_trans,
22
                         bool   offload,
23
                         bool   unified,
24
                     uint32_t   kv_size,
25
                     uint32_t   n_seq_max,
26
                     uint32_t   n_pad,
27
                     uint32_t   n_swa,
28
               llama_swa_type   swa_type,
29
        const layer_filter_cb & filter,
30
        const  layer_reuse_cb & reuse);
31
32
0
    ~llama_kv_cache_dsa() = default;
33
34
    //
35
    // llama_memory_i
36
    //
37
38
    llama_memory_context_ptr init_batch(
39
            llama_batch_allocr & balloc,
40
            uint32_t n_ubatch,
41
            bool embd_all) override;
42
43
    llama_memory_context_ptr init_full() override;
44
45
    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
46
47
    bool get_can_shift() const override;
48
49
    void clear(bool data) override;
50
51
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
52
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
53
    void seq_keep(llama_seq_id seq_id)                                                          override;
54
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
55
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
56
57
    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
58
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
59
60
    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
61
62
    // state write/load
63
64
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
65
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
66
67
    //
68
    // llama_kv_cache_dsa specific API
69
    //
70
71
    llama_kv_cache * get_mla() const;
72
    llama_kv_cache * get_lid() const;
73
74
private:
75
    // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it
76
    llama_hparams hparams_lid;
77
    const uint32_t n_stream  = 1;
78
79
    std::unique_ptr<llama_kv_cache> kv_mla;
80
    std::unique_ptr<llama_kv_cache> kv_lid;
81
};
82
83
class llama_kv_cache_dsa_context : public llama_memory_context_i {
84
public:
85
    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
86
87
    // used for errors
88
    llama_kv_cache_dsa_context(llama_memory_status status);
89
90
    // used to create a full-cache context
91
    llama_kv_cache_dsa_context(
92
            llama_kv_cache_dsa * kv);
93
94
    // used to create an update context
95
    llama_kv_cache_dsa_context(
96
            llama_kv_cache_dsa * kv,
97
            llama_context * lctx,
98
            bool optimize);
99
100
    // used to create a batch processing context from a batch
101
    llama_kv_cache_dsa_context(
102
            llama_kv_cache_dsa * kv,
103
            slot_info_vec_t sinfos_base,
104
            slot_info_vec_t sinfos_ik,
105
            std::vector<llama_ubatch> ubatches);
106
107
    virtual ~llama_kv_cache_dsa_context();
108
109
    //
110
    // llama_memory_context_i
111
    //
112
113
    bool next()  override;
114
    bool apply() override;
115
116
    llama_memory_status  get_status() const override;
117
    const llama_ubatch & get_ubatch() const override;
118
119
    //
120
    // llama_kv_cache_dsa_context specific API
121
    //
122
123
    const llama_kv_cache_context * get_mla() const;
124
    const llama_kv_cache_context * get_lid()  const;
125
126
private:
127
    //llama_kv_cache_dsa * kv;
128
129
    // the index of the next ubatch to process
130
    size_t i_next = 0;
131
132
    std::vector<llama_ubatch> ubatches;
133
134
    const llama_memory_context_ptr ctx_mla;
135
    const llama_memory_context_ptr ctx_lid;
136
137
    const llama_memory_status status;
138
};