/src/llama.cpp/src/llama-kv-cells.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "llama-cparams.h" |
5 | | |
6 | | #include <bitset> |
7 | | #include <cassert> |
8 | | #include <cstring> |
9 | | #include <map> |
10 | | #include <set> |
11 | | #include <vector> |
12 | | |
13 | | struct llama_kv_cell_ext { |
14 | | // 2D spatial positions, typically used for M-RoPE |
15 | | llama_pos x = 0; |
16 | | llama_pos y = 0; |
17 | | |
18 | | // return true if the current 2D spatial position is greater than other |
19 | 0 | bool is_2d_gt(llama_pos ox, llama_pos oy) const { |
20 | 0 | return (y > oy) || (y == oy && x > ox); |
21 | 0 | } |
22 | | |
23 | 0 | void reset() { |
24 | 0 | static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>); |
25 | |
|
26 | 0 | memset(this, 0, sizeof(*this)); |
27 | 0 | } |
28 | | }; |
29 | | |
30 | | // meta information about KV cells that can be part of multiple sequences at the same time |
31 | | // TODO: add unit tests |
32 | | class llama_kv_cells { |
33 | | public: |
34 | 0 | void reset() { |
35 | 0 | for (uint32_t i = 0; i < pos.size(); ++i) { |
36 | 0 | pos[i] = -1; |
37 | 0 | ext[i].reset(); |
38 | 0 | shift[i] = 0; |
39 | 0 | seq[i].reset(); |
40 | 0 | } |
41 | |
|
42 | 0 | has_shift = false; |
43 | |
|
44 | 0 | used.clear(); |
45 | |
|
46 | 0 | for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { |
47 | 0 | seq_pos[s].clear(); |
48 | 0 | } |
49 | 0 | } |
50 | | |
51 | 0 | void reset_shift() { |
52 | 0 | has_shift = false; |
53 | |
|
54 | 0 | for (uint32_t i = 0; i < shift.size(); ++i) { |
55 | 0 | shift[i] = 0; |
56 | 0 | } |
57 | 0 | } |
58 | | |
59 | 0 | uint32_t size() const { |
60 | 0 | return pos.size(); |
61 | 0 | } |
62 | | |
63 | 0 | void resize(uint32_t n) { |
64 | 0 | pos.resize(n); |
65 | 0 | ext.resize(n); |
66 | 0 | shift.resize(n); |
67 | 0 | seq.resize(n); |
68 | |
|
69 | 0 | reset(); |
70 | 0 | } |
71 | | |
72 | 0 | bool is_empty(uint32_t i) const { |
73 | 0 | assert(i < pos.size()); |
74 | 0 | assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); |
75 | |
|
76 | 0 | return pos[i] == -1; |
77 | 0 | } |
78 | | |
79 | 0 | uint32_t get_used() const { |
80 | 0 | return used.size(); |
81 | 0 | } |
82 | | |
83 | | // the index of the first cell that is used |
84 | | // return 0 if no cells are used |
85 | 0 | uint32_t used_min() const { |
86 | 0 | return used.empty() ? 0 : *used.begin(); |
87 | 0 | } |
88 | | |
89 | | // the index of the last cell that is used + 1 |
90 | | // return 0 if no cells are used |
91 | 0 | uint32_t used_max_p1() const { |
92 | 0 | return used.empty() ? 0 : *used.rbegin() + 1; |
93 | 0 | } |
94 | | |
95 | 0 | bool get_has_shift() const { |
96 | 0 | return has_shift; |
97 | 0 | } |
98 | | |
99 | | // move cell isrc to idst (used during defrag) |
100 | | //void mv(uint32_t isrc, uint32_t idst) { |
101 | | // assert(isrc < pos.size()); |
102 | | // assert(idst < pos.size()); |
103 | | |
104 | | // assert(pos[idst] == -1); |
105 | | // assert(pos[isrc] != -1); |
106 | | |
107 | | // pos [idst] = pos [isrc]; |
108 | | // shift[idst] = shift[isrc]; |
109 | | // seq [idst] = seq [isrc]; |
110 | | |
111 | | // pos [isrc] = -1; |
112 | | // shift[isrc] = 0; |
113 | | // seq [isrc].reset(); |
114 | | |
115 | | // used.erase (isrc); |
116 | | // used.insert(idst); |
117 | | //} |
118 | | |
119 | | // copy the state of cells [i, i + n) (used for save/restore the state of the cells) |
120 | 0 | llama_kv_cells cp(uint32_t i, uint32_t n) const { |
121 | 0 | assert(i + n <= pos.size()); |
122 | 0 |
|
123 | 0 | llama_kv_cells res; |
124 | 0 |
|
125 | 0 | res.resize(n); |
126 | 0 |
|
127 | 0 | for (uint32_t j = 0; j < n; ++j) { |
128 | 0 | const auto idx = i + j; |
129 | 0 |
|
130 | 0 | res.pos[j] = pos[idx]; |
131 | 0 | res.ext[j] = ext[idx]; |
132 | 0 | res.seq[j] = seq[idx]; |
133 | 0 |
|
134 | 0 | assert(shift[idx] == 0); |
135 | 0 | } |
136 | 0 |
|
137 | 0 | return res; |
138 | 0 | } |
139 | | |
140 | | // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) |
141 | 0 | llama_kv_cells cp(const std::vector<uint32_t> & idxs) const { |
142 | 0 | llama_kv_cells res; |
143 | |
|
144 | 0 | res.resize(idxs.size()); |
145 | |
|
146 | 0 | for (uint32_t j = 0; j < idxs.size(); ++j) { |
147 | 0 | const auto idx = idxs[j]; |
148 | |
|
149 | 0 | res.pos[j] = pos[idx]; |
150 | 0 | res.ext[j] = ext[idx]; |
151 | 0 | res.seq[j] = seq[idx]; |
152 | |
|
153 | 0 | assert(shift[idx] == 0); |
154 | 0 | } |
155 | |
|
156 | 0 | return res; |
157 | 0 | } |
158 | | |
159 | | // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) |
160 | 0 | void set(uint32_t i, const llama_kv_cells & other) { |
161 | 0 | assert(i + other.pos.size() <= pos.size()); |
162 | 0 |
|
163 | 0 | for (uint32_t j = 0; j < other.pos.size(); ++j) { |
164 | 0 | const auto idx = i + j; |
165 | 0 |
|
166 | 0 | if (pos[idx] == -1 && other.pos[j] != -1) { |
167 | 0 | used.insert(i + j); |
168 | 0 | } |
169 | 0 |
|
170 | 0 | if (pos[idx] != -1 && other.pos[j] == -1) { |
171 | 0 | used.erase(i + j); |
172 | 0 | } |
173 | 0 |
|
174 | 0 | if (pos[idx] != -1) { |
175 | 0 | seq_pos_rm(i + j); |
176 | 0 | } |
177 | 0 |
|
178 | 0 | pos[idx] = other.pos[j]; |
179 | 0 | ext[idx] = other.ext[j]; |
180 | 0 | seq[idx] = other.seq[j]; |
181 | 0 |
|
182 | 0 | if (pos[idx] != -1) { |
183 | 0 | seq_pos_add(i + j); |
184 | 0 | } |
185 | 0 |
|
186 | 0 | assert(shift[idx] == 0); |
187 | 0 | } |
188 | 0 | } |
189 | | |
190 | | // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) |
191 | 0 | void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) { |
192 | 0 | assert(idxs.size() == other.pos.size()); |
193 | |
|
194 | 0 | for (uint32_t j = 0; j < other.pos.size(); ++j) { |
195 | 0 | const auto idx = idxs[j]; |
196 | |
|
197 | 0 | if (pos[idx] == -1 && other.pos[j] != -1) { |
198 | 0 | used.insert(idx); |
199 | 0 | } |
200 | |
|
201 | 0 | if (pos[idx] != -1 && other.pos[j] == -1) { |
202 | 0 | used.erase(idx); |
203 | 0 | } |
204 | |
|
205 | 0 | if (pos[idx] != -1) { |
206 | 0 | seq_pos_rm(idx); |
207 | 0 | } |
208 | |
|
209 | 0 | pos[idx] = other.pos[j]; |
210 | 0 | ext[idx] = other.ext[j]; |
211 | 0 | seq[idx] = other.seq[j]; |
212 | |
|
213 | 0 | if (pos[idx] != -1) { |
214 | 0 | seq_pos_add(idx); |
215 | 0 | } |
216 | |
|
217 | 0 | assert(shift[idx] == 0); |
218 | 0 | } |
219 | 0 | } |
220 | | |
221 | | // clear a non-empty cell |
222 | 0 | void rm(uint32_t i) { |
223 | 0 | assert(i < pos.size()); |
224 | 0 | assert(pos[i] != -1); |
225 | |
|
226 | 0 | seq_pos_rm(i); |
227 | 0 | seq[i].reset(); |
228 | |
|
229 | 0 | pos[i] = -1; |
230 | 0 | ext[i].reset(); |
231 | 0 | shift[i] = 0; |
232 | |
|
233 | 0 | used.erase(i); |
234 | 0 | } |
235 | | |
236 | | // note: call only if the cell has seq_id |
237 | | // return true if the cell becomes empty |
238 | 0 | bool seq_rm(uint32_t i, llama_seq_id seq_id) { |
239 | 0 | assert(i < pos.size()); |
240 | 0 | assert(seq[i].test(seq_id)); |
241 | 0 | assert(pos[i] != -1); |
242 | 0 | assert(seq_id >= 0); |
243 | |
|
244 | 0 | seq[i].reset(seq_id); |
245 | 0 | seq_pos_dec(seq_id, pos[i]); |
246 | |
|
247 | 0 | if (seq[i].none()) { |
248 | 0 | pos[i] = -1; |
249 | 0 | ext[i].reset(); |
250 | 0 | shift[i] = 0; |
251 | |
|
252 | 0 | used.erase(i); |
253 | |
|
254 | 0 | return true; |
255 | 0 | } |
256 | | |
257 | 0 | return false; |
258 | 0 | } |
259 | | |
260 | | // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) |
261 | 0 | bool seq_keep(uint32_t i, llama_seq_id seq_id) { |
262 | 0 | assert(i < pos.size()); |
263 | |
|
264 | 0 | if (seq[i].test(seq_id)) { |
265 | 0 | seq_pos_rm(i); |
266 | 0 | seq[i].reset(); |
267 | |
|
268 | 0 | seq[i].set(seq_id); |
269 | 0 | seq_pos_inc(seq_id, pos[i]); |
270 | |
|
271 | 0 | return false; |
272 | 0 | } |
273 | | |
274 | 0 | if (seq[i].any()) { |
275 | 0 | seq_pos_rm(i); |
276 | 0 | seq[i].reset(); |
277 | |
|
278 | 0 | pos[i] = -1; |
279 | 0 | ext[i].reset(); |
280 | 0 | shift[i] = 0; |
281 | |
|
282 | 0 | used.erase(i); |
283 | |
|
284 | 0 | return true; |
285 | 0 | } |
286 | | |
287 | 0 | assert(pos[i] == -1); |
288 | |
|
289 | 0 | return false; |
290 | 0 | } |
291 | | |
292 | | // number of different sequences in the cell |
293 | 0 | int seq_count(uint32_t i) const { |
294 | 0 | assert(i < pos.size()); |
295 | 0 | assert(pos[i] != -1); |
296 | |
|
297 | 0 | return seq[i].count(); |
298 | 0 | } |
299 | | |
300 | | // check if the cell contains seq_id |
301 | 0 | bool seq_has(uint32_t i, llama_seq_id seq_id) const { |
302 | 0 | assert(i < pos.size()); |
303 | 0 | assert(seq_id >= 0); |
304 | |
|
305 | 0 | return seq[i].test(seq_id); |
306 | 0 | } |
307 | | |
308 | | // note: call only if the cell is not empty and the seq_id is not in the cell |
309 | 0 | void seq_add(uint32_t i, llama_seq_id seq_id) { |
310 | 0 | assert(i < pos.size()); |
311 | 0 | assert(pos[i] != -1); |
312 | 0 | assert(!seq[i].test(seq_id)); |
313 | |
|
314 | 0 | seq[i].set(seq_id); |
315 | 0 | seq_pos_inc(seq_id, pos[i]); |
316 | 0 | } |
317 | | |
318 | | // return the sequence id of this cell |
319 | | // note: call only for cells with exactly one sequence |
320 | 0 | llama_seq_id seq_get(uint32_t i) const { |
321 | 0 | assert(seq[i].count() == 1); |
322 | |
|
323 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
324 | 0 | if (seq[i].test(s)) { |
325 | 0 | return s; |
326 | 0 | } |
327 | 0 | } |
328 | | |
329 | 0 | return -1; |
330 | 0 | } |
331 | | |
332 | | // the minimum position of sequence seq_id currently present in any of the cells |
333 | | // return -1 if the sequence is not present |
334 | 0 | llama_pos seq_pos_min(llama_seq_id seq_id) const { |
335 | 0 | assert(seq_id >= 0); |
336 | 0 | assert(seq_id < LLAMA_MAX_SEQ); |
337 | |
|
338 | 0 | if (seq_pos[seq_id].empty()) { |
339 | 0 | return -1; |
340 | 0 | } |
341 | | |
342 | 0 | assert(seq_pos[seq_id].begin()->second > 0); |
343 | |
|
344 | 0 | return seq_pos[seq_id].begin()->first; |
345 | 0 | } |
346 | | |
347 | | // the maximum position of sequence seq_id currently present in any of the cells |
348 | | // return -1 if the sequence is not present |
349 | 0 | llama_pos seq_pos_max(llama_seq_id seq_id) const { |
350 | 0 | assert(seq_id >= 0); |
351 | 0 | assert(seq_id < LLAMA_MAX_SEQ); |
352 | |
|
353 | 0 | if (seq_pos[seq_id].empty()) { |
354 | 0 | return -1; |
355 | 0 | } |
356 | | |
357 | 0 | assert(seq_pos[seq_id].rbegin()->second > 0); |
358 | |
|
359 | 0 | return seq_pos[seq_id].rbegin()->first; |
360 | 0 | } |
361 | | |
362 | | // note: call only if the cell is not empty |
363 | 0 | llama_pos pos_get(uint32_t i) const { |
364 | 0 | assert(i < pos.size()); |
365 | 0 | assert(pos[i] != -1); |
366 | |
|
367 | 0 | return pos[i]; |
368 | 0 | } |
369 | | |
370 | 0 | const llama_kv_cell_ext & ext_get(uint32_t i) const { |
371 | 0 | assert(i < pos.size()); |
372 | 0 | assert(pos[i] != -1); |
373 | |
|
374 | 0 | return ext[i]; |
375 | 0 | } |
376 | | |
377 | | // note: call only if the cell is not empty |
378 | 0 | llama_pos get_shift(uint32_t i) const { |
379 | 0 | assert(i < pos.size()); |
380 | 0 | assert(pos[i] != -1); |
381 | |
|
382 | 0 | return shift[i]; |
383 | 0 | } |
384 | | |
385 | | // check if a cell is not empty and its position is within [p0, p1) |
386 | 0 | bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { |
387 | 0 | assert(i < pos.size()); |
388 | |
|
389 | 0 | return pos[i] >= p0 && pos[i] < p1; |
390 | 0 | } |
391 | | |
392 | | // set the position of an empty cell |
393 | | // does not modify "has_shift" |
394 | | // note: call only if the cell is empty |
395 | 0 | void pos_set(uint32_t i, llama_pos p) { |
396 | 0 | assert(i < pos.size()); |
397 | 0 | assert(pos[i] == -1); |
398 | 0 | assert(seq[i].none()); |
399 | |
|
400 | 0 | pos[i] = p; |
401 | |
|
402 | 0 | used.insert(i); |
403 | 0 | } |
404 | | |
405 | 0 | void ext_set(uint32_t i, llama_kv_cell_ext p) { |
406 | 0 | assert(i < ext.size()); |
407 | 0 | ext[i] = p; |
408 | 0 | } |
409 | | |
410 | | // pos[i] = pos[i] + d |
411 | | // sets "has_shift" to true |
412 | | // note: call only if the cell is not empty |
413 | 0 | bool pos_add(uint32_t i, llama_pos d) { |
414 | 0 | assert(i < pos.size()); |
415 | 0 | assert(pos[i] != -1); |
416 | |
|
417 | 0 | seq_pos_rm(i); |
418 | |
|
419 | 0 | pos[i] += d; |
420 | 0 | shift[i] += d; |
421 | |
|
422 | 0 | has_shift = true; |
423 | |
|
424 | 0 | if (pos[i] < 0) { |
425 | 0 | seq[i].reset(); |
426 | 0 | pos[i] = -1; |
427 | 0 | shift[i] = 0; |
428 | |
|
429 | 0 | used.erase(i); |
430 | |
|
431 | 0 | return true; |
432 | 0 | } |
433 | | |
434 | 0 | seq_pos_add(i); |
435 | |
|
436 | 0 | return false; |
437 | 0 | } |
438 | | |
439 | | // pos[i] = pos[i] / d |
440 | | // sets "has_shift" to true |
441 | | // note: call only if the cell is not empty |
442 | 0 | void pos_div(uint32_t i, int d) { |
443 | 0 | assert(i < pos.size()); |
444 | 0 | assert(pos[i] != -1); |
445 | |
|
446 | 0 | const llama_pos p_old = pos[i]; |
447 | |
|
448 | 0 | seq_pos_rm(i); |
449 | |
|
450 | 0 | pos[i] /= d; |
451 | 0 | shift[i] += p_old - pos[i]; |
452 | |
|
453 | 0 | seq_pos_add(i); |
454 | |
|
455 | 0 | has_shift = true; |
456 | 0 | } |
457 | | |
458 | | private: |
459 | | bool has_shift = false; |
460 | | |
461 | | // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) |
462 | | std::set<uint32_t> used; |
463 | | |
464 | | std::vector<llama_pos> pos; |
465 | | |
466 | | // stores extra info per cell |
467 | | std::vector<llama_kv_cell_ext> ext; |
468 | | |
469 | | // this array accumulates any applied shifts to the pos array since the last reset_shift() call |
470 | | // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: |
471 | | // |
472 | | // cells.pos_add(x, shift_x); |
473 | | // cells.pos_div(y, shift_y); |
474 | | // ... |
475 | | // |
476 | | // if (cells.has_shift()) { |
477 | | // for (int i = 0; i < n; ++i) { |
478 | | // auto shift_i = cells.get_shift(i); |
479 | | // ... |
480 | | // } |
481 | | // cells.reset_shift(); |
482 | | // } |
483 | | // |
484 | | std::vector<llama_pos> shift; |
485 | | |
486 | | using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
487 | | |
488 | | // the bitset seq[i] tells us which sequences are currently occupying the i-th cell |
489 | | std::vector<seq_set_t> seq; |
490 | | |
491 | | // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s |
492 | | // if the position p is not present, seq_pos[s][p] is not set |
493 | | // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache |
494 | | // |
495 | | // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq: |
496 | | // - during performing a cache reuse via (rm + add) |
497 | | // - some vision models have input embeddings with repeating positions |
498 | | // |
499 | | std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ]; |
500 | | |
501 | | // helper functions for updating `seq_pos`, once cell at a time: |
502 | | |
503 | 0 | void seq_pos_dec(llama_seq_id s, llama_pos p) { |
504 | 0 | auto it = seq_pos[s].find(p); |
505 | 0 | assert(it != seq_pos[s].end()); |
506 | |
|
507 | 0 | if (--it->second == 0) { |
508 | 0 | seq_pos[s].erase(it); |
509 | 0 | } |
510 | 0 | } |
511 | | |
512 | 0 | void seq_pos_inc(llama_seq_id s, llama_pos p) { |
513 | 0 | seq_pos[s][p]++; |
514 | 0 | } |
515 | | |
516 | | // remove cell i |
517 | 0 | void seq_pos_rm(uint32_t i) { |
518 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
519 | 0 | if (seq[i].test(s)) { |
520 | 0 | seq_pos_dec(s, pos[i]); |
521 | 0 | } |
522 | 0 | } |
523 | 0 | } |
524 | | |
525 | | // add cell i |
526 | 0 | void seq_pos_add(uint32_t i) { |
527 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
528 | 0 | if (seq[i].test(s)) { |
529 | 0 | seq_pos_inc(s, pos[i]); |
530 | 0 | } |
531 | 0 | } |
532 | 0 | } |
533 | | }; |