Coverage Report

Created: 2026-06-13 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-memory-recurrent.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-batch.h"
4
#include "llama-graph.h"
5
#include "llama-memory.h"
6
7
#include <map>
8
#include <set>
9
#include <vector>
10
11
//
12
// llama_memory_recurrent
13
//
14
15
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
16
//       see the implementation of llama_kv_cache_context_i for an example how to do it
17
class llama_memory_recurrent : public llama_memory_i {
18
public:
19
    llama_memory_recurrent(
20
            const llama_model & model,
21
                    ggml_type   type_r,
22
                    ggml_type   type_s,
23
                         bool   offload,
24
                     uint32_t   mem_size,
25
                     uint32_t   n_seq_max,
26
                     uint32_t   n_rs_seq,
27
        const layer_filter_cb & filter);
28
29
0
    ~llama_memory_recurrent() = default;
30
31
    //
32
    // llama_memory_i
33
    //
34
35
    llama_memory_context_ptr init_batch(
36
            llama_batch_allocr & balloc,
37
            uint32_t n_ubatch,
38
            bool embd_all) override;
39
40
    llama_memory_context_ptr init_full() override;
41
42
    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
43
44
    void clear(bool data) override;
45
46
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
47
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
48
    void seq_keep(llama_seq_id seq_id)                                                          override;
49
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
50
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
51
52
    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
53
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
54
55
    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
56
57
    bool prepare(const std::vector<llama_ubatch> & ubatches);
58
59
    // find a contiguous slot of memory cells and emplace the ubatch there
60
    bool find_slot(const llama_ubatch & ubatch);
61
62
    bool get_can_shift() const override;
63
64
    // state write/load
65
66
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
67
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
68
69
    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
70
    uint32_t size = 0; // total number of cells, shared across all sequences
71
    uint32_t used = 0; // used cells (i.e. at least one seq_id)
72
73
    // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
74
    uint32_t n_rs_seq = 0;
75
76
    // per-seq rollback index
77
    std::vector<uint32_t> rs_idx;
78
79
    void set_rs_idx(llama_seq_id seq_id, uint32_t idx);
80
81
    // computed before each graph build
82
    uint32_t n = 0;
83
84
    // first zero-ed state
85
    int32_t rs_z = -1;
86
87
    // TODO: optimize for recurrent state needs
88
    struct mem_cell {
89
        llama_pos pos  = -1;
90
        int32_t   src  = -1; // used to know where states should be copied from
91
        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
92
        int32_t   tail = -1;
93
94
        std::set<llama_seq_id> seq_id;
95
96
0
        bool has_seq_id(const llama_seq_id & id) const {
97
0
            return seq_id.find(id) != seq_id.end();
98
0
        }
99
100
0
        bool is_empty() const {
101
0
            return seq_id.empty();
102
0
        }
103
104
0
        bool is_same_seq(const mem_cell & other) const {
105
0
            return seq_id == other.seq_id;
106
0
        }
107
    };
108
109
    std::vector<mem_cell> cells;
110
111
    // per layer
112
    std::vector<ggml_tensor *> r_l;
113
    std::vector<ggml_tensor *> s_l;
114
115
private:
116
    //const llama_model & model;
117
    const llama_hparams & hparams;
118
119
    const uint32_t n_seq_max = 1;
120
121
    // ggml contexts for the KV cache along with the allocated backend buffers:
122
    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
123
124
    size_t total_size() const;
125
126
    size_t size_r_bytes() const;
127
    size_t size_s_bytes() const;
128
129
    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
130
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
131
132
    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
133
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
134
};
135
136
class llama_memory_recurrent_context : public llama_memory_context_i {
137
public:
138
    // used for errors
139
    llama_memory_recurrent_context(llama_memory_status status);
140
141
    // used to create a full-cache or update context
142
    llama_memory_recurrent_context(
143
            llama_memory_recurrent * mem);
144
145
    // used to create a batch processing context from a batch
146
    llama_memory_recurrent_context(
147
            llama_memory_recurrent * mem,
148
            std::vector<llama_ubatch> ubatches);
149
150
    virtual ~llama_memory_recurrent_context();
151
152
    //
153
    // llama_memory_context_i
154
    //
155
156
    bool next()  override;
157
    bool apply() override;
158
159
    llama_memory_status  get_status() const override;
160
    const llama_ubatch & get_ubatch() const override;
161
162
    //
163
    // llama_memory_recurrent_context specific API
164
    //
165
166
    uint32_t get_n_rs() const;
167
    uint32_t get_head() const;
168
    int32_t  get_rs_z() const;
169
    uint32_t get_size() const;
170
171
    ggml_tensor * get_r_l(int32_t il) const;
172
    ggml_tensor * get_s_l(int32_t il) const;
173
174
    int32_t s_copy(int i) const;
175
176
private:
177
    const llama_memory_status status;
178
179
    llama_memory_recurrent * mem;
180
181
    size_t i_next = 0;
182
183
    std::vector<llama_ubatch> ubatches;
184
185
    //
186
    // data needed for building the compute graph for the current ubatch:
187
    // TODO: extract all the state like `head` and `n` here
188
    //
189
190
    const bool is_full = false;
191
};