/src/llama.cpp/src/llama-memory-recurrent.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama-batch.h" |
4 | | #include "llama-graph.h" |
5 | | #include "llama-memory.h" |
6 | | |
7 | | #include <map> |
8 | | #include <set> |
9 | | #include <vector> |
10 | | |
11 | | // |
12 | | // llama_memory_recurrent |
13 | | // |
14 | | |
15 | | // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i |
16 | | // see the implementation of llama_kv_cache_context_i for an example how to do it |
17 | | class llama_memory_recurrent : public llama_memory_i { |
18 | | public: |
19 | | llama_memory_recurrent( |
20 | | const llama_model & model, |
21 | | ggml_type type_r, |
22 | | ggml_type type_s, |
23 | | bool offload, |
24 | | uint32_t mem_size, |
25 | | uint32_t n_seq_max, |
26 | | const layer_filter_cb & filter); |
27 | | |
28 | 0 | ~llama_memory_recurrent() = default; |
29 | | |
30 | | // |
31 | | // llama_memory_i |
32 | | // |
33 | | |
34 | | llama_memory_context_ptr init_batch( |
35 | | llama_batch_allocr & balloc, |
36 | | uint32_t n_ubatch, |
37 | | bool embd_all) override; |
38 | | |
39 | | llama_memory_context_ptr init_full() override; |
40 | | |
41 | | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
42 | | |
43 | | void clear(bool data) override; |
44 | | |
45 | | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
46 | | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
47 | | void seq_keep(llama_seq_id seq_id) override; |
48 | | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
49 | | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
50 | | |
51 | | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
52 | | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
53 | | |
54 | | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
55 | | |
56 | | bool prepare(const std::vector<llama_ubatch> & ubatches); |
57 | | |
58 | | // find a contiguous slot of memory cells and emplace the ubatch there |
59 | | bool find_slot(const llama_ubatch & ubatch); |
60 | | |
61 | | bool get_can_shift() const override; |
62 | | |
63 | | // state write/load |
64 | | |
65 | | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
66 | | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
67 | | |
68 | | uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) |
69 | | uint32_t size = 0; // total number of cells, shared across all sequences |
70 | | uint32_t used = 0; // used cells (i.e. at least one seq_id) |
71 | | |
72 | | // computed before each graph build |
73 | | uint32_t n = 0; |
74 | | |
75 | | // first zero-ed state |
76 | | int32_t rs_z = -1; |
77 | | |
78 | | // TODO: optimize for recurrent state needs |
79 | | struct mem_cell { |
80 | | llama_pos pos = -1; |
81 | | int32_t src = -1; // used to know where states should be copied from |
82 | | int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) |
83 | | int32_t tail = -1; |
84 | | |
85 | | std::set<llama_seq_id> seq_id; |
86 | | |
87 | 0 | bool has_seq_id(const llama_seq_id & id) const { |
88 | 0 | return seq_id.find(id) != seq_id.end(); |
89 | 0 | } |
90 | | |
91 | 0 | bool is_empty() const { |
92 | 0 | return seq_id.empty(); |
93 | 0 | } |
94 | | |
95 | 0 | bool is_same_seq(const mem_cell & other) const { |
96 | 0 | return seq_id == other.seq_id; |
97 | 0 | } |
98 | | }; |
99 | | |
100 | | std::vector<mem_cell> cells; |
101 | | |
102 | | // per layer |
103 | | std::vector<ggml_tensor *> r_l; |
104 | | std::vector<ggml_tensor *> s_l; |
105 | | |
106 | | private: |
107 | | //const llama_model & model; |
108 | | const llama_hparams & hparams; |
109 | | |
110 | | const uint32_t n_seq_max = 1; |
111 | | |
112 | | // ggml contexts for the KV cache along with the allocated backend buffers: |
113 | | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
114 | | |
115 | | size_t total_size() const; |
116 | | |
117 | | size_t size_r_bytes() const; |
118 | | size_t size_s_bytes() const; |
119 | | |
120 | | void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; |
121 | | void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; |
122 | | |
123 | | bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
124 | | bool state_read_data(llama_io_read_i & io, uint32_t cell_count); |
125 | | }; |
126 | | |
127 | | class llama_memory_recurrent_context : public llama_memory_context_i { |
128 | | public: |
129 | | // used for errors |
130 | | llama_memory_recurrent_context(llama_memory_status status); |
131 | | |
132 | | // used to create a full-cache or update context |
133 | | llama_memory_recurrent_context( |
134 | | llama_memory_recurrent * mem); |
135 | | |
136 | | // used to create a batch processing context from a batch |
137 | | llama_memory_recurrent_context( |
138 | | llama_memory_recurrent * mem, |
139 | | std::vector<llama_ubatch> ubatches); |
140 | | |
141 | | virtual ~llama_memory_recurrent_context(); |
142 | | |
143 | | // |
144 | | // llama_memory_context_i |
145 | | // |
146 | | |
147 | | bool next() override; |
148 | | bool apply() override; |
149 | | |
150 | | llama_memory_status get_status() const override; |
151 | | const llama_ubatch & get_ubatch() const override; |
152 | | |
153 | | // |
154 | | // llama_memory_recurrent_context specific API |
155 | | // |
156 | | |
157 | | uint32_t get_n_rs() const; |
158 | | uint32_t get_head() const; |
159 | | int32_t get_rs_z() const; |
160 | | uint32_t get_size() const; |
161 | | |
162 | | ggml_tensor * get_r_l(int32_t il) const; |
163 | | ggml_tensor * get_s_l(int32_t il) const; |
164 | | |
165 | | int32_t s_copy(int i) const; |
166 | | |
167 | | private: |
168 | | const llama_memory_status status; |
169 | | |
170 | | llama_memory_recurrent * mem; |
171 | | |
172 | | size_t i_next = 0; |
173 | | |
174 | | std::vector<llama_ubatch> ubatches; |
175 | | |
176 | | // |
177 | | // data needed for building the compute graph for the current ubatch: |
178 | | // TODO: extract all the state like `head` and `n` here |
179 | | // |
180 | | |
181 | | const bool is_full = false; |
182 | | }; |