/src/llama.cpp/src/llama-memory-recurrent.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama-batch.h" |
4 | | #include "llama-graph.h" |
5 | | #include "llama-memory.h" |
6 | | |
7 | | #include <map> |
8 | | #include <set> |
9 | | #include <vector> |
10 | | |
11 | | // |
12 | | // llama_memory_recurrent |
13 | | // |
14 | | |
15 | | // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i |
16 | | // see the implementation of llama_kv_cache_context_i for an example how to do it |
17 | | class llama_memory_recurrent : public llama_memory_i { |
18 | | public: |
19 | | llama_memory_recurrent( |
20 | | const llama_model & model, |
21 | | ggml_type type_r, |
22 | | ggml_type type_s, |
23 | | bool offload, |
24 | | uint32_t mem_size, |
25 | | uint32_t n_seq_max, |
26 | | uint32_t n_rs_seq, |
27 | | const layer_filter_cb & filter); |
28 | | |
29 | 0 | ~llama_memory_recurrent() = default; |
30 | | |
31 | | // |
32 | | // llama_memory_i |
33 | | // |
34 | | |
35 | | llama_memory_context_ptr init_batch( |
36 | | llama_batch_allocr & balloc, |
37 | | uint32_t n_ubatch, |
38 | | bool embd_all) override; |
39 | | |
40 | | llama_memory_context_ptr init_full() override; |
41 | | |
42 | | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
43 | | |
44 | | void clear(bool data) override; |
45 | | |
46 | | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
47 | | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
48 | | void seq_keep(llama_seq_id seq_id) override; |
49 | | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
50 | | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
51 | | |
52 | | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
53 | | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
54 | | |
55 | | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
56 | | |
57 | | bool prepare(const std::vector<llama_ubatch> & ubatches); |
58 | | |
59 | | // find a contiguous slot of memory cells and emplace the ubatch there |
60 | | bool find_slot(const llama_ubatch & ubatch); |
61 | | |
62 | | bool get_can_shift() const override; |
63 | | |
64 | | // state write/load |
65 | | |
66 | | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
67 | | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
68 | | |
69 | | uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) |
70 | | uint32_t size = 0; // total number of cells, shared across all sequences |
71 | | uint32_t used = 0; // used cells (i.e. at least one seq_id) |
72 | | |
73 | | // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups |
74 | | uint32_t n_rs_seq = 0; |
75 | | |
76 | | // per-seq rollback index |
77 | | std::vector<uint32_t> rs_idx; |
78 | | |
79 | | void set_rs_idx(llama_seq_id seq_id, uint32_t idx); |
80 | | |
81 | | // computed before each graph build |
82 | | uint32_t n = 0; |
83 | | |
84 | | // first zero-ed state |
85 | | int32_t rs_z = -1; |
86 | | |
87 | | // TODO: optimize for recurrent state needs |
88 | | struct mem_cell { |
89 | | llama_pos pos = -1; |
90 | | int32_t src = -1; // used to know where states should be copied from |
91 | | int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) |
92 | | int32_t tail = -1; |
93 | | |
94 | | std::set<llama_seq_id> seq_id; |
95 | | |
96 | 0 | bool has_seq_id(const llama_seq_id & id) const { |
97 | 0 | return seq_id.find(id) != seq_id.end(); |
98 | 0 | } |
99 | | |
100 | 0 | bool is_empty() const { |
101 | 0 | return seq_id.empty(); |
102 | 0 | } |
103 | | |
104 | 0 | bool is_same_seq(const mem_cell & other) const { |
105 | 0 | return seq_id == other.seq_id; |
106 | 0 | } |
107 | | }; |
108 | | |
109 | | std::vector<mem_cell> cells; |
110 | | |
111 | | // per layer |
112 | | std::vector<ggml_tensor *> r_l; |
113 | | std::vector<ggml_tensor *> s_l; |
114 | | |
115 | | private: |
116 | | //const llama_model & model; |
117 | | const llama_hparams & hparams; |
118 | | |
119 | | const uint32_t n_seq_max = 1; |
120 | | |
121 | | // ggml contexts for the KV cache along with the allocated backend buffers: |
122 | | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
123 | | |
124 | | size_t total_size() const; |
125 | | |
126 | | size_t size_r_bytes() const; |
127 | | size_t size_s_bytes() const; |
128 | | |
129 | | void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; |
130 | | void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; |
131 | | |
132 | | bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
133 | | bool state_read_data(llama_io_read_i & io, uint32_t cell_count); |
134 | | }; |
135 | | |
136 | | class llama_memory_recurrent_context : public llama_memory_context_i { |
137 | | public: |
138 | | // used for errors |
139 | | llama_memory_recurrent_context(llama_memory_status status); |
140 | | |
141 | | // used to create a full-cache or update context |
142 | | llama_memory_recurrent_context( |
143 | | llama_memory_recurrent * mem); |
144 | | |
145 | | // used to create a batch processing context from a batch |
146 | | llama_memory_recurrent_context( |
147 | | llama_memory_recurrent * mem, |
148 | | std::vector<llama_ubatch> ubatches); |
149 | | |
150 | | virtual ~llama_memory_recurrent_context(); |
151 | | |
152 | | // |
153 | | // llama_memory_context_i |
154 | | // |
155 | | |
156 | | bool next() override; |
157 | | bool apply() override; |
158 | | |
159 | | llama_memory_status get_status() const override; |
160 | | const llama_ubatch & get_ubatch() const override; |
161 | | |
162 | | // |
163 | | // llama_memory_recurrent_context specific API |
164 | | // |
165 | | |
166 | | uint32_t get_n_rs() const; |
167 | | uint32_t get_head() const; |
168 | | int32_t get_rs_z() const; |
169 | | uint32_t get_size() const; |
170 | | |
171 | | ggml_tensor * get_r_l(int32_t il) const; |
172 | | ggml_tensor * get_s_l(int32_t il) const; |
173 | | |
174 | | int32_t s_copy(int i) const; |
175 | | |
176 | | private: |
177 | | const llama_memory_status status; |
178 | | |
179 | | llama_memory_recurrent * mem; |
180 | | |
181 | | size_t i_next = 0; |
182 | | |
183 | | std::vector<llama_ubatch> ubatches; |
184 | | |
185 | | // |
186 | | // data needed for building the compute graph for the current ubatch: |
187 | | // TODO: extract all the state like `head` and `n` here |
188 | | // |
189 | | |
190 | | const bool is_full = false; |
191 | | }; |