/src/llama.cpp/src/llama-kv-cache-dsa.cpp
Line | Count | Source |
1 | | #include "llama-kv-cache-dsa.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-batch.h" |
5 | | #include "llama-model.h" |
6 | | |
7 | | #include <algorithm> |
8 | | #include <cassert> |
9 | | |
10 | | // |
11 | | // llama_kv_cache_dsa |
12 | | // |
13 | | |
14 | | llama_kv_cache_dsa::llama_kv_cache_dsa( |
15 | | const llama_model & model, |
16 | | ggml_type type_k, |
17 | | ggml_type type_v, |
18 | | bool v_trans, |
19 | | bool offload, |
20 | | bool unified, |
21 | | uint32_t kv_size, |
22 | | uint32_t n_seq_max, |
23 | | uint32_t n_pad, |
24 | | uint32_t n_swa, |
25 | | llama_swa_type swa_type, |
26 | | const layer_filter_cb & filter, |
27 | | const layer_reuse_cb & reuse) : |
28 | 0 | hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { |
29 | |
|
30 | 0 | LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); |
31 | |
|
32 | 0 | kv_mla = std::make_unique<llama_kv_cache>( |
33 | 0 | model, model.hparams, type_k, type_v, |
34 | 0 | v_trans, offload, unified, kv_size, n_seq_max, n_pad, |
35 | 0 | n_swa, swa_type, nullptr, filter, reuse, nullptr); |
36 | | |
37 | | // we use llama_kv_cache for caching indexer keys |
38 | | // by hand-tweaking some hparams we fool it to create |
39 | | // indexer key cache tensors with correct dimensions |
40 | | // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 |
41 | | |
42 | | // DSA lightning indexer uses MQA with single key head |
43 | 0 | std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); |
44 | 0 | hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; |
45 | 0 | hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; |
46 | |
|
47 | 0 | LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); |
48 | |
|
49 | 0 | kv_lid = std::make_unique<llama_kv_cache>( |
50 | 0 | model, hparams_lid, type_k, type_v, |
51 | 0 | v_trans, offload, unified, kv_size, n_seq_max, n_pad, |
52 | 0 | n_swa, swa_type, nullptr, filter, reuse, nullptr); |
53 | 0 | } |
54 | | |
55 | 0 | void llama_kv_cache_dsa::clear(bool data) { |
56 | 0 | kv_mla->clear(data); |
57 | 0 | kv_lid->clear(data); |
58 | 0 | } |
59 | | |
60 | 0 | bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { |
61 | 0 | bool res = true; |
62 | |
|
63 | 0 | res = res & kv_mla->seq_rm(seq_id, p0, p1); |
64 | 0 | res = res & kv_lid->seq_rm(seq_id, p0, p1); |
65 | |
|
66 | 0 | return res; |
67 | 0 | } |
68 | | |
69 | 0 | void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { |
70 | 0 | kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
71 | 0 | kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
72 | 0 | } |
73 | | |
74 | 0 | void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { |
75 | 0 | kv_mla->seq_keep(seq_id); |
76 | 0 | kv_lid->seq_keep(seq_id); |
77 | 0 | } |
78 | | |
79 | 0 | void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { |
80 | 0 | kv_mla->seq_add(seq_id, p0, p1, shift); |
81 | 0 | kv_lid->seq_add(seq_id, p0, p1, shift); |
82 | 0 | } |
83 | | |
84 | 0 | void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { |
85 | 0 | kv_mla->seq_div(seq_id, p0, p1, d); |
86 | 0 | kv_lid->seq_div(seq_id, p0, p1, d); |
87 | 0 | } |
88 | | |
89 | 0 | llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { |
90 | 0 | return kv_mla->seq_pos_min(seq_id); |
91 | 0 | } |
92 | | |
93 | 0 | llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { |
94 | 0 | return kv_mla->seq_pos_max(seq_id); |
95 | 0 | } |
96 | | |
97 | 0 | std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const { |
98 | 0 | std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown(); |
99 | 0 | for (const auto & buft_size : kv_lid->memory_breakdown()) { |
100 | 0 | mb[buft_size.first] += buft_size.second; |
101 | 0 | } |
102 | 0 | return mb; |
103 | 0 | } |
104 | | |
105 | | llama_memory_context_ptr llama_kv_cache_dsa::init_batch( |
106 | | llama_batch_allocr & balloc, |
107 | | uint32_t n_ubatch, |
108 | 0 | bool embd_all) { |
109 | 0 | GGML_UNUSED(embd_all); |
110 | |
|
111 | 0 | do { |
112 | 0 | balloc.split_reset(); |
113 | |
|
114 | 0 | std::vector<llama_ubatch> ubatches; |
115 | 0 | while (true) { |
116 | 0 | auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); |
117 | |
|
118 | 0 | if (ubatch.n_tokens == 0) { |
119 | 0 | break; |
120 | 0 | } |
121 | | |
122 | 0 | ubatches.push_back(std::move(ubatch)); // NOLINT |
123 | 0 | } |
124 | |
|
125 | 0 | if (balloc.get_n_used() < balloc.get_n_tokens()) { |
126 | | // failed to find a suitable split |
127 | 0 | break; |
128 | 0 | } |
129 | | |
130 | 0 | auto sinfos_mla = kv_mla->prepare(ubatches); |
131 | 0 | if (sinfos_mla.empty()) { |
132 | 0 | break; |
133 | 0 | } |
134 | | |
135 | 0 | auto sinfos_lid = kv_lid->prepare(ubatches); |
136 | 0 | if (sinfos_lid.empty()) { |
137 | 0 | break; |
138 | 0 | } |
139 | | |
140 | 0 | assert(sinfos_mla.size() == sinfos_lid.size()); |
141 | |
|
142 | 0 | return std::make_unique<llama_kv_cache_dsa_context>( |
143 | 0 | this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); |
144 | 0 | } while (false); |
145 | | |
146 | 0 | return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |
147 | 0 | } |
148 | | |
149 | 0 | llama_memory_context_ptr llama_kv_cache_dsa::init_full() { |
150 | 0 | return std::make_unique<llama_kv_cache_dsa_context>(this); |
151 | 0 | } |
152 | | |
153 | 0 | llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { |
154 | 0 | return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize); |
155 | 0 | } |
156 | | |
157 | 0 | bool llama_kv_cache_dsa::get_can_shift() const { |
158 | 0 | return kv_mla->get_can_shift() && |
159 | 0 | kv_lid->get_can_shift() && |
160 | 0 | kv_mla->get_size() == kv_lid->get_size(); |
161 | 0 | } |
162 | | |
163 | 0 | void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { |
164 | 0 | kv_mla->state_write(io, seq_id, flags); |
165 | 0 | kv_lid->state_write(io, seq_id, flags); |
166 | 0 | } |
167 | | |
168 | 0 | void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
169 | 0 | kv_mla->state_read(io, seq_id, flags); |
170 | 0 | kv_lid->state_read(io, seq_id, flags); |
171 | 0 | } |
172 | | |
173 | 0 | llama_kv_cache * llama_kv_cache_dsa::get_mla() const { |
174 | 0 | return kv_mla.get(); |
175 | 0 | } |
176 | | |
177 | 0 | llama_kv_cache * llama_kv_cache_dsa::get_lid() const { |
178 | 0 | return kv_lid.get(); |
179 | 0 | } |
180 | | |
181 | | // |
182 | | // llama_kv_cache_dsa_context |
183 | | // |
184 | | |
185 | 0 | llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} |
186 | | |
187 | | llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( |
188 | | llama_kv_cache_dsa * kv) : |
189 | 0 | ctx_mla(kv->get_mla()->init_full()), |
190 | 0 | ctx_lid(kv->get_lid()->init_full()), |
191 | 0 | status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { |
192 | 0 | } |
193 | | |
194 | | llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( |
195 | | llama_kv_cache_dsa * kv, |
196 | | llama_context * lctx, |
197 | | bool optimize) : |
198 | 0 | ctx_mla(kv->get_mla()->init_update(lctx, optimize)), |
199 | 0 | ctx_lid(kv->get_lid()->init_update(lctx, optimize)), |
200 | 0 | status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { |
201 | 0 | } |
202 | | |
203 | | llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( |
204 | | llama_kv_cache_dsa * kv, |
205 | | slot_info_vec_t sinfos_mla, |
206 | | slot_info_vec_t sinfos_lid, |
207 | | std::vector<llama_ubatch> ubatches) : |
208 | 0 | ubatches(std::move(ubatches)), |
209 | | // note: here we copy the ubatches. not sure if this is ideal |
210 | 0 | ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), |
211 | 0 | ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), |
212 | 0 | status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { |
213 | 0 | } |
214 | | |
215 | 0 | llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; |
216 | | |
217 | 0 | bool llama_kv_cache_dsa_context::next() { |
218 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
219 | |
|
220 | 0 | ctx_mla->next(); |
221 | 0 | ctx_lid->next(); |
222 | |
|
223 | 0 | if (++i_next >= ubatches.size()) { |
224 | 0 | return false; |
225 | 0 | } |
226 | | |
227 | 0 | return true; |
228 | 0 | } |
229 | | |
230 | 0 | bool llama_kv_cache_dsa_context::apply() { |
231 | 0 | assert(!llama_memory_status_is_fail(status)); |
232 | |
|
233 | 0 | bool res = true; |
234 | |
|
235 | 0 | res = res & ctx_mla->apply(); |
236 | 0 | res = res & ctx_lid->apply(); |
237 | |
|
238 | 0 | return res; |
239 | 0 | } |
240 | | |
241 | 0 | llama_memory_status llama_kv_cache_dsa_context::get_status() const { |
242 | 0 | return status; |
243 | 0 | } |
244 | | |
245 | 0 | const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { |
246 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
247 | |
|
248 | 0 | return ubatches[i_next]; |
249 | 0 | } |
250 | | |
251 | 0 | const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { |
252 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
253 | |
|
254 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_mla.get()); |
255 | 0 | } |
256 | | |
257 | 0 | const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { |
258 | 0 | assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |
259 | |
|
260 | 0 | return static_cast<const llama_kv_cache_context *>(ctx_lid.get()); |
261 | 0 | } |