/src/llama.cpp/src/llama-kv-cache-iswa.cpp
Line | Count | Source |
1 | | #include "llama-kv-cache-iswa.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-batch.h" |
5 | | #include "llama-model.h" |
6 | | |
7 | | #include <algorithm> |
8 | | #include <cassert> |
9 | | |
10 | | // |
11 | | // llama_kv_cache_iswa |
12 | | // |
13 | | |
14 | | llama_kv_cache_iswa::llama_kv_cache_iswa( |
15 | | const llama_model & model, |
16 | | ggml_type type_k, |
17 | | ggml_type type_v, |
18 | | bool v_trans, |
19 | | bool offload, |
20 | | bool swa_full, |
21 | | bool unified, |
22 | | uint32_t kv_size, |
23 | | uint32_t n_seq_max, |
24 | | uint32_t n_ubatch, |
25 | | uint32_t n_pad, |
26 | | llama_memory_t mem_other, |
27 | | const layer_filter_cb & filter, |
28 | | const layer_reuse_cb & reuse, |
29 | 0 | const layer_share_cb & share) : hparams(model.hparams), unified(unified) { |
30 | | |
31 | | // chain filters |
32 | 0 | const layer_filter_cb filter_base = [&](int32_t il) { |
33 | 0 | if (filter && !filter(il)) { |
34 | 0 | return false; |
35 | 0 | } |
36 | | |
37 | 0 | return !model.hparams.is_swa(il); |
38 | 0 | }; |
39 | |
|
40 | 0 | const layer_filter_cb filter_swa = [&](int32_t il) { |
41 | 0 | if (filter && !filter(il)) { |
42 | 0 | return false; |
43 | 0 | } |
44 | | |
45 | 0 | return model.hparams.is_swa(il); |
46 | 0 | }; |
47 | |
|
48 | 0 | const uint32_t size_base = kv_size; |
49 | | |
50 | | // note: the SWA cache is always padded to 256 for performance |
51 | | // https://github.com/ggml-org/llama.cpp/issues/17037 |
52 | 0 | uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256); |
53 | | |
54 | | // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size |
55 | 0 | if (swa_full) { |
56 | 0 | LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", |
57 | 0 | __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); |
58 | |
|
59 | 0 | size_swa = size_base; |
60 | 0 | } |
61 | |
|
62 | 0 | LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); |
63 | |
|
64 | 0 | llama_memory_t mem_other_base = nullptr; |
65 | 0 | if (mem_other) { |
66 | 0 | mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base(); |
67 | 0 | } |
68 | |
|
69 | 0 | llama_memory_t mem_other_swa = nullptr; |
70 | 0 | if (mem_other) { |
71 | 0 | mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa(); |
72 | 0 | } |
73 | |
|
74 | 0 | kv_base = std::make_unique<llama_kv_cache>( |
75 | 0 | model, hparams, type_k, type_v, |
76 | 0 | v_trans, offload, unified, size_base, n_seq_max, n_pad, |
77 | 0 | 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); |
78 | |
|
79 | 0 | LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); |
80 | |
|
81 | 0 | kv_swa = std::make_unique<llama_kv_cache>( |
82 | 0 | model, hparams, type_k, type_v, |
83 | 0 | v_trans, offload, unified, size_swa, n_seq_max, n_pad, |
84 | 0 | hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); |
85 | 0 | } |
86 | | |
87 | 0 | void llama_kv_cache_iswa::clear(bool data) { |
88 | 0 | kv_base->clear(data); |
89 | 0 | kv_swa ->clear(data); |
90 | 0 | } |
91 | | |
92 | 0 | bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { |
93 | 0 | bool res = true; |
94 | |
|
95 | 0 | res = res & kv_base->seq_rm(seq_id, p0, p1); |
96 | 0 | res = res & kv_swa ->seq_rm(seq_id, p0, p1); |
97 | |
|
98 | 0 | return res; |
99 | 0 | } |
100 | | |
101 | 0 | void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { |
102 | 0 | kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
103 | 0 | kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
104 | 0 | } |
105 | | |
106 | 0 | void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) { |
107 | 0 | kv_base->seq_keep(seq_id); |
108 | 0 | kv_swa ->seq_keep(seq_id); |
109 | 0 | } |
110 | | |
111 | 0 | void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { |
112 | 0 | kv_base->seq_add(seq_id, p0, p1, shift); |
113 | 0 | kv_swa ->seq_add(seq_id, p0, p1, shift); |
114 | 0 | } |
115 | | |
116 | 0 | void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { |
117 | 0 | kv_base->seq_div(seq_id, p0, p1, d); |
118 | 0 | kv_swa ->seq_div(seq_id, p0, p1, d); |
119 | 0 | } |
120 | | |
121 | 0 | llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const { |
122 | | // the base cache is a superset of the SWA cache, so we can just check the SWA cache |
123 | 0 | return kv_swa->seq_pos_min(seq_id); |
124 | 0 | } |
125 | | |
126 | 0 | llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const { |
127 | 0 | return kv_swa->seq_pos_max(seq_id); |
128 | 0 | } |
129 | | |
130 | 0 | std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const { |
131 | 0 | std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown(); |
132 | 0 | for (const auto & buft_size : kv_swa->memory_breakdown()) { |
133 | 0 | mb[buft_size.first] += buft_size.second; |
134 | 0 | } |
135 | 0 | return mb; |
136 | 0 | } |
137 | | |
138 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { |
139 | 0 | GGML_UNUSED(embd_all); |
140 | | |
141 | | // first try simple split |
142 | 0 | do { |
143 | 0 | if (!unified) { |
144 | | // requires equal splits, so we skip the simple split |
145 | 0 | break; |
146 | 0 | } |
147 | | |
148 | 0 | balloc.split_reset(); |
149 | |
|
150 | 0 | std::vector<llama_ubatch> ubatches; |
151 | 0 | while (true) { |
152 | 0 | auto ubatch = balloc.split_simple(n_ubatch); |
153 | |
|
154 | 0 | if (ubatch.n_tokens == 0) { |
155 | 0 | break; |
156 | 0 | } |
157 | | |
158 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
159 | 0 | } |
160 | |
|
161 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
162 | | // failed to find a suitable split |
163 | 0 | break; |
164 | 0 | } |
165 | | |
166 | 0 | auto sinfos_base = kv_base->prepare(ubatches); |
167 | 0 | if (sinfos_base.empty()) { |
168 | 0 | break; |
169 | 0 | } |
170 | | |
171 | 0 | auto sinfos_swa = kv_swa->prepare(ubatches); |
172 | 0 | if (sinfos_swa.empty()) { |
173 | 0 | break; |
174 | 0 | } |
175 | | |
176 | 0 | assert(sinfos_base.size() == sinfos_swa.size()); |
177 | |
|
178 | 0 | return std::make_unique<llama_kv_cache_iswa_context>( |
179 | 0 | this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); |
180 | 0 | } while (false); |
181 | | |
182 | | // if it fails, try equal split |
183 | 0 | do { |
184 | 0 | balloc.split_reset(); |
185 | |
|
186 | 0 | std::vector<llama_ubatch> ubatches; |
187 | 0 | while (true) { |
188 | 0 | auto ubatch = balloc.split_equal(n_ubatch, !unified); |
189 | |
|
190 | 0 | if (ubatch.n_tokens == 0) { |
191 | 0 | break; |
192 | 0 | } |
193 | | |
194 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
195 | 0 | } |
196 | |
|
197 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
198 | | // failed to find a suitable split |
199 | 0 | break; |
200 | 0 | } |
201 | | |
202 | 0 | auto sinfos_base = kv_base->prepare(ubatches); |
203 | 0 | if (sinfos_base.empty()) { |
204 | 0 | break; |
205 | 0 | } |
206 | | |
207 | 0 | auto sinfos_swa = kv_swa->prepare(ubatches); |
208 | 0 | if (sinfos_swa.empty()) { |
209 | 0 | break; |
210 | 0 | } |
211 | | |
212 | 0 | assert(sinfos_base.size() == sinfos_swa.size()); |
213 | |
|
214 | 0 | return std::make_unique<llama_kv_cache_iswa_context>( |
215 | 0 | this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); |
216 | 0 | } while (false); |
217 | | |
218 | | // TODO: if we fail again, we should attempt different splitting strategies |
219 | | // but to do that properly, we first have to refactor the batches to be more flexible |
220 | | |
221 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
222 | 0 | } |
223 | | |
224 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_full() { |
225 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(this); |
226 | 0 | } |
227 | | |
228 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) { |
229 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize); |
230 | 0 | } |
231 | | |
232 | 0 | bool llama_kv_cache_iswa::get_can_shift() const { |
233 | 0 | return kv_base->get_can_shift() && |
234 | 0 | kv_swa->get_can_shift() && |
235 | 0 | kv_base->get_size() == kv_swa->get_size(); |
236 | 0 | } |
237 | | |
238 | 0 | void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { |
239 | 0 | if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { |
240 | 0 | kv_base->state_write(io, seq_id, flags); |
241 | 0 | } |
242 | |
|
243 | 0 | kv_swa->state_write(io, seq_id, flags); |
244 | 0 | } |
245 | | |
246 | 0 | void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
247 | 0 | if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { |
248 | 0 | kv_base->state_read(io, seq_id, flags); |
249 | 0 | } |
250 | |
|
251 | 0 | kv_swa->state_read(io, seq_id, flags); |
252 | 0 | } |
253 | | |
254 | 0 | llama_kv_cache * llama_kv_cache_iswa::get_base() const { |
255 | 0 | return kv_base.get(); |
256 | 0 | } |
257 | | |
258 | 0 | llama_kv_cache * llama_kv_cache_iswa::get_swa() const { |
259 | 0 | return kv_swa.get(); |
260 | 0 | } |
261 | | |
262 | | // |
263 | | // llama_kv_cache_iswa_context |
264 | | // |
265 | | |
266 | 0 | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {} |
267 | | |
268 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
269 | | llama_kv_cache_iswa * kv) : |
270 | 0 | ctx_base(kv->get_base()->init_full()), |
271 | 0 | ctx_swa (kv->get_swa ()->init_full()), |
272 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
273 | 0 | } |
274 | | |
275 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
276 | | llama_kv_cache_iswa * kv, |
277 | | llama_context * lctx, |
278 | | bool optimize) : |
279 | 0 | ctx_base(kv->get_base()->init_update(lctx, optimize)), |
280 | 0 | ctx_swa (kv->get_swa ()->init_update(lctx, optimize)), |
281 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
282 | 0 | } |
283 | | |
284 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
285 | | llama_kv_cache_iswa * kv, |
286 | | slot_info_vec_t sinfos_base, |
287 | | slot_info_vec_t sinfos_swa, |
288 | | std::vector<llama_ubatch> ubatches) : |
289 | 0 | ubatches(std::move(ubatches)), |
290 | | // note: here we copy the ubatches. not sure if this is ideal |
291 | 0 | ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), |
292 | 0 | ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), |
293 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
294 | 0 | } |
295 | | |
296 | 0 | llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default; |
297 | | |
298 | 0 | bool llama_kv_cache_iswa_context::next() { |
299 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
300 | |
|
301 | 0 | ctx_base->next(); |
302 | 0 | ctx_swa ->next(); |
303 | |
|
304 | 0 | if (++i_next >= ubatches.size()) { |
305 | 0 | return false; |
306 | 0 | } |
307 | | |
308 | 0 | return true; |
309 | 0 | } |
310 | | |
311 | 0 | bool llama_kv_cache_iswa_context::apply() { |
312 | 0 | assert(!llama_memory_status_is_fail(status)); |
313 | |
|
314 | 0 | bool res = true; |
315 | |
|
316 | 0 | res = res & ctx_base->apply(); |
317 | 0 | res = res & ctx_swa ->apply(); |
318 | |
|
319 | 0 | return res; |
320 | 0 | } |
321 | | |
322 | 0 | llama_memory_status llama_kv_cache_iswa_context::get_status() const { |
323 | 0 | return status; |
324 | 0 | } |
325 | | |
326 | 0 | const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const { |
327 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
328 | |
|
329 | 0 | return ubatches[i_next]; |
330 | 0 | } |
331 | | |
332 | 0 | const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const { |
333 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
334 | |
|
335 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_base.get()); |
336 | 0 | } |
337 | | |
338 | 0 | const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const { |
339 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
340 | |
|
341 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_swa.get()); |
342 | 0 | } |