/src/llama.cpp/src/llama-kv-cache-iswa.cpp
Line | Count | Source |
1 | | #include "llama-kv-cache-iswa.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-batch.h" |
5 | | #include "llama-model.h" |
6 | | |
7 | | #include <algorithm> |
8 | | #include <cassert> |
9 | | |
10 | | // |
11 | | // llama_kv_cache_iswa |
12 | | // |
13 | | |
14 | | llama_kv_cache_iswa::llama_kv_cache_iswa( |
15 | | const llama_model & model, |
16 | | ggml_type type_k, |
17 | | ggml_type type_v, |
18 | | bool v_trans, |
19 | | bool offload, |
20 | | bool swa_full, |
21 | | bool unified, |
22 | | uint32_t kv_size, |
23 | | uint32_t n_seq_max, |
24 | | uint32_t n_ubatch, |
25 | | uint32_t n_pad, |
26 | | const layer_filter_cb & filter, |
27 | 0 | const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { |
28 | | |
29 | | // chain filters |
30 | 0 | const layer_filter_cb filter_base = [&](int32_t il) { |
31 | 0 | if (filter && !filter(il)) { |
32 | 0 | return false; |
33 | 0 | } |
34 | | |
35 | 0 | return !model.hparams.is_swa(il); |
36 | 0 | }; |
37 | |
|
38 | 0 | const layer_filter_cb filter_swa = [&](int32_t il) { |
39 | 0 | if (filter && !filter(il)) { |
40 | 0 | return false; |
41 | 0 | } |
42 | | |
43 | 0 | return model.hparams.is_swa(il); |
44 | 0 | }; |
45 | |
|
46 | 0 | const uint32_t size_base = kv_size; |
47 | | |
48 | | // note: the SWA cache is always padded to 256 for performance |
49 | | // https://github.com/ggml-org/llama.cpp/issues/17037 |
50 | 0 | uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256); |
51 | | |
52 | | // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size |
53 | 0 | if (swa_full) { |
54 | 0 | LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", |
55 | 0 | __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); |
56 | |
|
57 | 0 | size_swa = size_base; |
58 | 0 | } |
59 | |
|
60 | 0 | LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); |
61 | |
|
62 | 0 | kv_base = std::make_unique<llama_kv_cache>( |
63 | 0 | model, type_k, type_v, |
64 | 0 | v_trans, offload, unified, size_base, n_seq_max, n_pad, |
65 | 0 | 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); |
66 | |
|
67 | 0 | LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); |
68 | |
|
69 | 0 | kv_swa = std::make_unique<llama_kv_cache>( |
70 | 0 | model, type_k, type_v, |
71 | 0 | v_trans, offload, unified, size_swa, n_seq_max, n_pad, |
72 | 0 | hparams.n_swa, hparams.swa_type, filter_swa, reuse); |
73 | 0 | } |
74 | | |
75 | 0 | void llama_kv_cache_iswa::clear(bool data) { |
76 | 0 | kv_base->clear(data); |
77 | 0 | kv_swa ->clear(data); |
78 | 0 | } |
79 | | |
80 | 0 | bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { |
81 | 0 | bool res = true; |
82 | |
|
83 | 0 | res = res & kv_base->seq_rm(seq_id, p0, p1); |
84 | 0 | res = res & kv_swa ->seq_rm(seq_id, p0, p1); |
85 | |
|
86 | 0 | return res; |
87 | 0 | } |
88 | | |
89 | 0 | void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { |
90 | 0 | kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
91 | 0 | kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
92 | 0 | } |
93 | | |
94 | 0 | void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) { |
95 | 0 | kv_base->seq_keep(seq_id); |
96 | 0 | kv_swa ->seq_keep(seq_id); |
97 | 0 | } |
98 | | |
99 | 0 | void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { |
100 | 0 | kv_base->seq_add(seq_id, p0, p1, shift); |
101 | 0 | kv_swa ->seq_add(seq_id, p0, p1, shift); |
102 | 0 | } |
103 | | |
104 | 0 | void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { |
105 | 0 | kv_base->seq_div(seq_id, p0, p1, d); |
106 | 0 | kv_swa ->seq_div(seq_id, p0, p1, d); |
107 | 0 | } |
108 | | |
109 | 0 | llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const { |
110 | | // the base cache is a superset of the SWA cache, so we can just check the SWA cache |
111 | 0 | return kv_swa->seq_pos_min(seq_id); |
112 | 0 | } |
113 | | |
114 | 0 | llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const { |
115 | 0 | return kv_swa->seq_pos_max(seq_id); |
116 | 0 | } |
117 | | |
118 | 0 | std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const { |
119 | 0 | std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown(); |
120 | 0 | for (const auto & buft_size : kv_swa->memory_breakdown()) { |
121 | 0 | mb[buft_size.first] += buft_size.second; |
122 | 0 | } |
123 | 0 | return mb; |
124 | 0 | } |
125 | | |
126 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { |
127 | 0 | GGML_UNUSED(embd_all); |
128 | | |
129 | | // first try simple split |
130 | 0 | do { |
131 | 0 | if (!unified) { |
132 | | // requires equal splits, so we skip the simple split |
133 | 0 | break; |
134 | 0 | } |
135 | | |
136 | 0 | balloc.split_reset(); |
137 | |
|
138 | 0 | std::vector<llama_ubatch> ubatches; |
139 | 0 | while (true) { |
140 | 0 | auto ubatch = balloc.split_simple(n_ubatch); |
141 | |
|
142 | 0 | if (ubatch.n_tokens == 0) { |
143 | 0 | break; |
144 | 0 | } |
145 | | |
146 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
147 | 0 | } |
148 | |
|
149 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
150 | | // failed to find a suitable split |
151 | 0 | break; |
152 | 0 | } |
153 | | |
154 | 0 | auto sinfos_base = kv_base->prepare(ubatches); |
155 | 0 | if (sinfos_base.empty()) { |
156 | 0 | break; |
157 | 0 | } |
158 | | |
159 | 0 | auto sinfos_swa = kv_swa->prepare(ubatches); |
160 | 0 | if (sinfos_swa.empty()) { |
161 | 0 | break; |
162 | 0 | } |
163 | | |
164 | 0 | assert(sinfos_base.size() == sinfos_swa.size()); |
165 | |
|
166 | 0 | return std::make_unique<llama_kv_cache_iswa_context>( |
167 | 0 | this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); |
168 | 0 | } while (false); |
169 | | |
170 | | // if it fails, try equal split |
171 | 0 | do { |
172 | 0 | balloc.split_reset(); |
173 | |
|
174 | 0 | std::vector<llama_ubatch> ubatches; |
175 | 0 | while (true) { |
176 | 0 | auto ubatch = balloc.split_equal(n_ubatch, !unified); |
177 | |
|
178 | 0 | if (ubatch.n_tokens == 0) { |
179 | 0 | break; |
180 | 0 | } |
181 | | |
182 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
183 | 0 | } |
184 | |
|
185 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
186 | | // failed to find a suitable split |
187 | 0 | break; |
188 | 0 | } |
189 | | |
190 | 0 | auto sinfos_base = kv_base->prepare(ubatches); |
191 | 0 | if (sinfos_base.empty()) { |
192 | 0 | break; |
193 | 0 | } |
194 | | |
195 | 0 | auto sinfos_swa = kv_swa->prepare(ubatches); |
196 | 0 | if (sinfos_swa.empty()) { |
197 | 0 | break; |
198 | 0 | } |
199 | | |
200 | 0 | assert(sinfos_base.size() == sinfos_swa.size()); |
201 | |
|
202 | 0 | return std::make_unique<llama_kv_cache_iswa_context>( |
203 | 0 | this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); |
204 | 0 | } while (false); |
205 | | |
206 | | // TODO: if we fail again, we should attempt different splitting strategies |
207 | | // but to do that properly, we first have to refactor the batches to be more flexible |
208 | | |
209 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
210 | 0 | } |
211 | | |
212 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_full() { |
213 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(this); |
214 | 0 | } |
215 | | |
216 | 0 | llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) { |
217 | 0 | return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize); |
218 | 0 | } |
219 | | |
220 | 0 | bool llama_kv_cache_iswa::get_can_shift() const { |
221 | 0 | return kv_base->get_size() == kv_swa->get_size(); |
222 | 0 | } |
223 | | |
224 | 0 | void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { |
225 | 0 | if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { |
226 | 0 | kv_base->state_write(io, seq_id, flags); |
227 | 0 | } |
228 | |
|
229 | 0 | kv_swa->state_write(io, seq_id, flags); |
230 | 0 | } |
231 | | |
232 | 0 | void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
233 | 0 | if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { |
234 | 0 | kv_base->state_read(io, seq_id, flags); |
235 | 0 | } |
236 | |
|
237 | 0 | kv_swa->state_read(io, seq_id, flags); |
238 | 0 | } |
239 | | |
240 | 0 | llama_kv_cache * llama_kv_cache_iswa::get_base() const { |
241 | 0 | return kv_base.get(); |
242 | 0 | } |
243 | | |
244 | 0 | llama_kv_cache * llama_kv_cache_iswa::get_swa() const { |
245 | 0 | return kv_swa.get(); |
246 | 0 | } |
247 | | |
248 | | // |
249 | | // llama_kv_cache_iswa_context |
250 | | // |
251 | | |
252 | 0 | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {} |
253 | | |
254 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
255 | | llama_kv_cache_iswa * kv) : |
256 | 0 | ctx_base(kv->get_base()->init_full()), |
257 | 0 | ctx_swa (kv->get_swa ()->init_full()), |
258 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
259 | 0 | } |
260 | | |
261 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
262 | | llama_kv_cache_iswa * kv, |
263 | | llama_context * lctx, |
264 | | bool optimize) : |
265 | 0 | ctx_base(kv->get_base()->init_update(lctx, optimize)), |
266 | 0 | ctx_swa (kv->get_swa ()->init_update(lctx, optimize)), |
267 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
268 | 0 | } |
269 | | |
270 | | llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( |
271 | | llama_kv_cache_iswa * kv, |
272 | | slot_info_vec_t sinfos_base, |
273 | | slot_info_vec_t sinfos_swa, |
274 | | std::vector<llama_ubatch> ubatches) : |
275 | 0 | ubatches(std::move(ubatches)), |
276 | | // note: here we copy the ubatches. not sure if this is ideal |
277 | 0 | ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), |
278 | 0 | ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), |
279 | 0 | status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { |
280 | 0 | } |
281 | | |
282 | 0 | llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default; |
283 | | |
284 | 0 | bool llama_kv_cache_iswa_context::next() { |
285 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
286 | |
|
287 | 0 | ctx_base->next(); |
288 | 0 | ctx_swa ->next(); |
289 | |
|
290 | 0 | if (++i_next >= ubatches.size()) { |
291 | 0 | return false; |
292 | 0 | } |
293 | | |
294 | 0 | return true; |
295 | 0 | } |
296 | | |
297 | 0 | bool llama_kv_cache_iswa_context::apply() { |
298 | 0 | assert(!llama_memory_status_is_fail(status)); |
299 | |
|
300 | 0 | bool res = true; |
301 | |
|
302 | 0 | res = res & ctx_base->apply(); |
303 | 0 | res = res & ctx_swa ->apply(); |
304 | |
|
305 | 0 | return res; |
306 | 0 | } |
307 | | |
308 | 0 | llama_memory_status llama_kv_cache_iswa_context::get_status() const { |
309 | 0 | return status; |
310 | 0 | } |
311 | | |
312 | 0 | const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const { |
313 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
314 | |
|
315 | 0 | return ubatches[i_next]; |
316 | 0 | } |
317 | | |
318 | 0 | const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const { |
319 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
320 | |
|
321 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_base.get()); |
322 | 0 | } |
323 | | |
324 | 0 | const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const { |
325 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
326 | |
|
327 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_swa.get()); |
328 | 0 | } |