/src/llama.cpp/src/llama-memory-hybrid-iswa.cpp
Line | Count | Source |
1 | | #include "llama-memory-hybrid-iswa.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-model.h" |
5 | | #include "llama-context.h" |
6 | | |
7 | | // |
8 | | // llama_memory_hybrid_iswa |
9 | | // |
10 | | |
11 | | llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( |
12 | | const llama_model & model, |
13 | | /* attn */ |
14 | | ggml_type type_k, |
15 | | ggml_type type_v, |
16 | | bool v_trans, |
17 | | bool swa_full, |
18 | | uint32_t kv_size, |
19 | | uint32_t n_ubatch, |
20 | | uint32_t n_pad, |
21 | | /* recurrent */ |
22 | | ggml_type type_r, |
23 | | ggml_type type_s, |
24 | | uint32_t rs_size, |
25 | | /* common */ |
26 | | uint32_t n_seq_max, |
27 | | uint32_t n_rs_seq, |
28 | | bool offload, |
29 | | bool unified, |
30 | | /* layer filters */ |
31 | | const layer_filter_cb & filter_attn, |
32 | | const layer_filter_cb & filter_recr) : |
33 | 0 | hparams(model.hparams), |
34 | 0 | mem_attn(new llama_kv_cache_iswa( |
35 | 0 | model, |
36 | 0 | type_k, |
37 | 0 | type_v, |
38 | 0 | v_trans, |
39 | 0 | offload, |
40 | 0 | swa_full, |
41 | 0 | unified, |
42 | 0 | kv_size, |
43 | 0 | n_seq_max, |
44 | 0 | n_ubatch, |
45 | 0 | n_pad, |
46 | 0 | nullptr, |
47 | 0 | filter_attn == nullptr ? |
48 | 0 | [&](int32_t il) { return !hparams.is_recr(il); } |
49 | 0 | : filter_attn, |
50 | 0 | nullptr, |
51 | 0 | nullptr |
52 | 0 | )), |
53 | 0 | mem_recr(new llama_memory_recurrent( |
54 | 0 | model, |
55 | 0 | type_r, |
56 | 0 | type_s, |
57 | 0 | offload, |
58 | 0 | rs_size, |
59 | 0 | n_seq_max, |
60 | 0 | n_rs_seq, |
61 | 0 | filter_recr == nullptr ? |
62 | 0 | [&](int32_t il) { return hparams.is_recr(il); } |
63 | 0 | : filter_recr |
64 | 0 | )) {} |
65 | | |
66 | 0 | llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { |
67 | 0 | do { |
68 | 0 | balloc.split_reset(); |
69 | | |
70 | | // follow the recurrent pattern for creating the ubatch splits |
71 | 0 | std::vector<llama_ubatch> ubatches; |
72 | |
|
73 | 0 | while (true) { |
74 | 0 | llama_ubatch ubatch; |
75 | |
|
76 | 0 | if (embd_all) { |
77 | | // if all tokens are output, split by sequence |
78 | 0 | ubatch = balloc.split_seq(n_ubatch); |
79 | 0 | } else { |
80 | 0 | if (mem_recr->n_rs_seq > 0) { |
81 | | // [TAG_RECURRENT_ROLLBACK_SPLITS] |
82 | | // TODO: recurrent state rollback does not support equal splits |
83 | 0 | ubatch = balloc.split_seq(n_ubatch); |
84 | 0 | } else { |
85 | | // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) |
86 | 0 | const bool unified = (mem_attn->get_base()->get_n_stream() == 1); |
87 | 0 | ubatch = balloc.split_equal(n_ubatch, !unified); |
88 | 0 | } |
89 | 0 | } |
90 | |
|
91 | 0 | if (ubatch.n_tokens == 0) { |
92 | 0 | break; |
93 | 0 | } |
94 | | |
95 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
96 | 0 | } |
97 | |
|
98 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
99 | | // failed to find a suitable split |
100 | 0 | break; |
101 | 0 | } |
102 | | |
103 | | // prepare the recurrent batches first |
104 | 0 | if (!mem_recr->prepare(ubatches)) { |
105 | | // TODO: will the recurrent cache be in an undefined context at this point? |
106 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); |
107 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
108 | 0 | } |
109 | | |
110 | | // prepare the attention cache (iswa version returns both base and swa slot infos) |
111 | 0 | auto sinfos_base = mem_attn->get_base()->prepare(ubatches); |
112 | 0 | if (sinfos_base.empty()) { |
113 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); |
114 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
115 | 0 | } |
116 | | |
117 | 0 | auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); |
118 | 0 | if (sinfos_swa.empty()) { |
119 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); |
120 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
121 | 0 | } |
122 | | |
123 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>( |
124 | 0 | this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); |
125 | 0 | } while(false); |
126 | | |
127 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
128 | 0 | } |
129 | | |
130 | 0 | llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { |
131 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(this); |
132 | 0 | } |
133 | | |
134 | 0 | llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { |
135 | 0 | return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize); |
136 | 0 | } |
137 | | |
138 | 0 | bool llama_memory_hybrid_iswa::get_can_shift() const { |
139 | | // Shifting is trivially supported for recurrent |
140 | 0 | return mem_attn->get_can_shift(); |
141 | 0 | } |
142 | | |
143 | 0 | void llama_memory_hybrid_iswa::clear(bool data) { |
144 | 0 | mem_attn->clear(data); |
145 | 0 | mem_recr->clear(data); |
146 | 0 | } |
147 | | |
148 | 0 | bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { |
149 | | // Try removing from the recurrent cache first since it may fail. If it does |
150 | | // fail, the cache will not have been mutated. |
151 | 0 | if (!mem_recr->seq_rm(seq_id, p0, p1)) { |
152 | 0 | return false; |
153 | 0 | } |
154 | 0 | return mem_attn->seq_rm(seq_id, p0, p1); |
155 | 0 | } |
156 | | |
157 | 0 | void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { |
158 | 0 | mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
159 | 0 | mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
160 | 0 | } |
161 | | |
162 | 0 | void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { |
163 | 0 | mem_attn->seq_keep(seq_id); |
164 | 0 | mem_recr->seq_keep(seq_id); |
165 | 0 | } |
166 | | |
167 | 0 | void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { |
168 | 0 | mem_attn->seq_add(seq_id, p0, p1, shift); |
169 | 0 | mem_recr->seq_add(seq_id, p0, p1, shift); |
170 | 0 | } |
171 | | |
172 | 0 | void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { |
173 | 0 | mem_attn->seq_div(seq_id, p0, p1, d); |
174 | 0 | mem_recr->seq_div(seq_id, p0, p1, d); |
175 | 0 | } |
176 | | |
177 | 0 | llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { |
178 | | // the min of the total cache is the max of the two caches' min values |
179 | 0 | return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); |
180 | 0 | } |
181 | | |
182 | 0 | llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { |
183 | | // the max of the total cache is the min of the two caches' max values |
184 | 0 | return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); |
185 | 0 | } |
186 | | |
187 | 0 | std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const { |
188 | 0 | std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown(); |
189 | 0 | for (const auto & buft_size : mem_recr->memory_breakdown()) { |
190 | 0 | mb[buft_size.first] += buft_size.second; |
191 | 0 | } |
192 | 0 | return mb; |
193 | 0 | } |
194 | | |
195 | 0 | void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { |
196 | 0 | mem_attn->state_write(io, seq_id, flags); |
197 | 0 | mem_recr->state_write(io, seq_id, flags); |
198 | 0 | } |
199 | | |
200 | 0 | void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
201 | 0 | mem_attn->state_read(io, seq_id, flags); |
202 | 0 | mem_recr->state_read(io, seq_id, flags); |
203 | 0 | } |
204 | | |
205 | 0 | llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { |
206 | 0 | return mem_attn.get(); |
207 | 0 | } |
208 | | |
209 | 0 | llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { |
210 | 0 | return mem_recr.get(); |
211 | 0 | } |
212 | | |
213 | | // |
214 | | // llama_memory_hybrid_iswa_context |
215 | | // |
216 | | |
217 | 0 | llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} |
218 | | |
219 | | llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : |
220 | 0 | ctx_attn(mem->get_mem_attn()->init_full()), |
221 | 0 | ctx_recr(mem->get_mem_recr()->init_full()), |
222 | 0 | status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { |
223 | 0 | } |
224 | | |
225 | | llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( |
226 | | llama_memory_hybrid_iswa * mem, |
227 | | llama_context * lctx, |
228 | | bool optimize) : |
229 | 0 | ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), |
230 | 0 | ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), |
231 | 0 | status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { |
232 | 0 | } |
233 | | |
234 | | llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( |
235 | | llama_memory_hybrid_iswa * mem, |
236 | | slot_info_vec_t sinfos_base, |
237 | | slot_info_vec_t sinfos_swa, |
238 | | std::vector<llama_ubatch> ubatches) : |
239 | 0 | ubatches(std::move(ubatches)), |
240 | | // note: here we copy the ubatches. not sure if this is ideal |
241 | 0 | ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), |
242 | 0 | ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), |
243 | 0 | status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { |
244 | 0 | } |
245 | | |
246 | 0 | bool llama_memory_hybrid_iswa_context::next() { |
247 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
248 | |
|
249 | 0 | ctx_attn->next(); |
250 | 0 | ctx_recr->next(); |
251 | |
|
252 | 0 | if (++i_next >= ubatches.size()) { |
253 | 0 | return false; |
254 | 0 | } |
255 | | |
256 | 0 | return true; |
257 | 0 | } |
258 | | |
259 | 0 | bool llama_memory_hybrid_iswa_context::apply() { |
260 | 0 | assert(!llama_memory_status_is_fail(status)); |
261 | |
|
262 | 0 | bool res = true; |
263 | |
|
264 | 0 | res = res & ctx_attn->apply(); |
265 | 0 | res = res & ctx_recr->apply(); |
266 | |
|
267 | 0 | return res; |
268 | 0 | } |
269 | | |
270 | 0 | llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { |
271 | 0 | return status; |
272 | 0 | } |
273 | | |
274 | 0 | const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { |
275 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
276 | 0 | return ubatches[i_next]; |
277 | 0 | } |
278 | | |
279 | 0 | const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { |
280 | 0 | return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get()); |
281 | 0 | } |
282 | | |
283 | 0 | const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { |
284 | 0 | return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get()); |
285 | 0 | } |