Coverage Report

Created: 2025-12-28 06:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-memory-recurrent.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-batch.h"
4
#include "llama-graph.h"
5
#include "llama-memory.h"
6
7
#include <map>
8
#include <set>
9
#include <vector>
10
11
//
12
// llama_memory_recurrent
13
//
14
15
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
16
//       see the implementation of llama_kv_cache_context_i for an example how to do it
17
class llama_memory_recurrent : public llama_memory_i {
18
public:
19
    llama_memory_recurrent(
20
            const llama_model & model,
21
                    ggml_type   type_r,
22
                    ggml_type   type_s,
23
                         bool   offload,
24
                     uint32_t   mem_size,
25
                     uint32_t   n_seq_max,
26
        const layer_filter_cb & filter);
27
28
0
    ~llama_memory_recurrent() = default;
29
30
    //
31
    // llama_memory_i
32
    //
33
34
    llama_memory_context_ptr init_batch(
35
            llama_batch_allocr & balloc,
36
            uint32_t n_ubatch,
37
            bool embd_all) override;
38
39
    llama_memory_context_ptr init_full() override;
40
41
    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
42
43
    void clear(bool data) override;
44
45
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
46
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
47
    void seq_keep(llama_seq_id seq_id)                                                          override;
48
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
49
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
50
51
    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
52
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
53
54
    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
55
56
    bool prepare(const std::vector<llama_ubatch> & ubatches);
57
58
    // find a contiguous slot of memory cells and emplace the ubatch there
59
    bool find_slot(const llama_ubatch & ubatch);
60
61
    bool get_can_shift() const override;
62
63
    // state write/load
64
65
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
66
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
67
68
    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
69
    uint32_t size = 0; // total number of cells, shared across all sequences
70
    uint32_t used = 0; // used cells (i.e. at least one seq_id)
71
72
    // computed before each graph build
73
    uint32_t n = 0;
74
75
    // first zero-ed state
76
    int32_t rs_z = -1;
77
78
    // TODO: optimize for recurrent state needs
79
    struct mem_cell {
80
        llama_pos pos  = -1;
81
        int32_t   src  = -1; // used to know where states should be copied from
82
        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
83
        int32_t   tail = -1;
84
85
        std::set<llama_seq_id> seq_id;
86
87
0
        bool has_seq_id(const llama_seq_id & id) const {
88
0
            return seq_id.find(id) != seq_id.end();
89
0
        }
90
91
0
        bool is_empty() const {
92
0
            return seq_id.empty();
93
0
        }
94
95
0
        bool is_same_seq(const mem_cell & other) const {
96
0
            return seq_id == other.seq_id;
97
0
        }
98
    };
99
100
    std::vector<mem_cell> cells;
101
102
    // per layer
103
    std::vector<ggml_tensor *> r_l;
104
    std::vector<ggml_tensor *> s_l;
105
106
private:
107
    //const llama_model & model;
108
    const llama_hparams & hparams;
109
110
    const uint32_t n_seq_max = 1;
111
112
    // ggml contexts for the KV cache along with the allocated backend buffers:
113
    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114
115
    size_t total_size() const;
116
117
    size_t size_r_bytes() const;
118
    size_t size_s_bytes() const;
119
120
    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
121
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
122
123
    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
124
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
125
};
126
127
class llama_memory_recurrent_context : public llama_memory_context_i {
128
public:
129
    // used for errors
130
    llama_memory_recurrent_context(llama_memory_status status);
131
132
    // used to create a full-cache or update context
133
    llama_memory_recurrent_context(
134
            llama_memory_recurrent * mem);
135
136
    // used to create a batch processing context from a batch
137
    llama_memory_recurrent_context(
138
            llama_memory_recurrent * mem,
139
            std::vector<llama_ubatch> ubatches);
140
141
    virtual ~llama_memory_recurrent_context();
142
143
    //
144
    // llama_memory_context_i
145
    //
146
147
    bool next()  override;
148
    bool apply() override;
149
150
    llama_memory_status  get_status() const override;
151
    const llama_ubatch & get_ubatch() const override;
152
153
    //
154
    // llama_memory_recurrent_context specific API
155
    //
156
157
    uint32_t get_n_rs() const;
158
    uint32_t get_head() const;
159
    int32_t  get_rs_z() const;
160
    uint32_t get_size() const;
161
162
    ggml_tensor * get_r_l(int32_t il) const;
163
    ggml_tensor * get_s_l(int32_t il) const;
164
165
    int32_t s_copy(int i) const;
166
167
private:
168
    const llama_memory_status status;
169
170
    llama_memory_recurrent * mem;
171
172
    size_t i_next = 0;
173
174
    std::vector<llama_ubatch> ubatches;
175
176
    //
177
    // data needed for building the compute graph for the current ubatch:
178
    // TODO: extract all the state like `head` and `n` here
179
    //
180
181
    const bool is_full = false;
182
};