Coverage Report

Created: 2025-11-28 06:56

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-kv-cells.h
Line
Count
Source
1
#pragma once
2
3
#include "llama.h"
4
#include "llama-cparams.h"
5
6
#include <bitset>
7
#include <cassert>
8
#include <cstring>
9
#include <map>
10
#include <set>
11
#include <vector>
12
13
struct llama_kv_cell_ext {
14
    // 2D spatial positions, typically used for M-RoPE
15
    llama_pos x = 0;
16
    llama_pos y = 0;
17
18
    // return true if the current 2D spatial position is greater than other
19
0
    bool is_2d_gt(llama_pos ox, llama_pos oy) const {
20
0
        return (y > oy) || (y == oy && x > ox);
21
0
    }
22
23
0
    void reset() {
24
0
        static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
25
26
0
        memset(this, 0, sizeof(*this));
27
0
    }
28
};
29
30
// meta information about KV cells that can be part of multiple sequences at the same time
31
// TODO: add unit tests
32
class llama_kv_cells {
33
public:
34
0
    void reset() {
35
0
        for (uint32_t i = 0; i < pos.size(); ++i) {
36
0
            pos[i]   = -1;
37
0
            ext[i].reset();
38
0
            shift[i] =  0;
39
0
            seq[i].reset();
40
0
        }
41
42
0
        has_shift = false;
43
44
0
        used.clear();
45
46
0
        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
47
0
            seq_pos[s].clear();
48
0
        }
49
0
    }
50
51
0
    void reset_shift() {
52
0
        has_shift = false;
53
54
0
        for (uint32_t i = 0; i < shift.size(); ++i) {
55
0
            shift[i] = 0;
56
0
        }
57
0
    }
58
59
0
    uint32_t size() const {
60
0
        return pos.size();
61
0
    }
62
63
0
    void resize(uint32_t n) {
64
0
        pos.resize(n);
65
0
        ext.resize(n);
66
0
        shift.resize(n);
67
0
        seq.resize(n);
68
69
0
        reset();
70
0
    }
71
72
0
    bool is_empty(uint32_t i) const {
73
0
        assert(i < pos.size());
74
0
        assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
75
76
0
        return pos[i] == -1;
77
0
    }
78
79
0
    uint32_t get_used() const {
80
0
        return used.size();
81
0
    }
82
83
    // the index of the first cell that is used
84
    // return 0 if no cells are used
85
0
    uint32_t used_min() const {
86
0
        return used.empty() ? 0 : *used.begin();
87
0
    }
88
89
    // the index of the last cell that is used + 1
90
    // return 0 if no cells are used
91
0
    uint32_t used_max_p1() const {
92
0
        return used.empty() ? 0 : *used.rbegin() + 1;
93
0
    }
94
95
0
    bool get_has_shift() const {
96
0
        return has_shift;
97
0
    }
98
99
    // move cell isrc to idst (used during defrag)
100
    //void mv(uint32_t isrc, uint32_t idst) {
101
    //    assert(isrc < pos.size());
102
    //    assert(idst < pos.size());
103
104
    //    assert(pos[idst] == -1);
105
    //    assert(pos[isrc] != -1);
106
107
    //    pos  [idst] = pos  [isrc];
108
    //    shift[idst] = shift[isrc];
109
    //    seq  [idst] = seq  [isrc];
110
111
    //    pos  [isrc] = -1;
112
    //    shift[isrc] =  0;
113
    //    seq  [isrc].reset();
114
115
    //    used.erase (isrc);
116
    //    used.insert(idst);
117
    //}
118
119
    // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
120
0
    llama_kv_cells cp(uint32_t i, uint32_t n) const {
121
0
        assert(i + n <= pos.size());
122
0
123
0
        llama_kv_cells res;
124
0
125
0
        res.resize(n);
126
0
127
0
        for (uint32_t j = 0; j < n; ++j) {
128
0
            const auto idx = i + j;
129
0
130
0
            res.pos[j] = pos[idx];
131
0
            res.ext[j] = ext[idx];
132
0
            res.seq[j] = seq[idx];
133
0
134
0
            assert(shift[idx] == 0);
135
0
        }
136
0
137
0
        return res;
138
0
    }
139
140
    // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
141
0
    llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
142
0
        llama_kv_cells res;
143
144
0
        res.resize(idxs.size());
145
146
0
        for (uint32_t j = 0; j < idxs.size(); ++j) {
147
0
            const auto idx = idxs[j];
148
149
0
            res.pos[j] = pos[idx];
150
0
            res.ext[j] = ext[idx];
151
0
            res.seq[j] = seq[idx];
152
153
0
            assert(shift[idx] == 0);
154
0
        }
155
156
0
        return res;
157
0
    }
158
159
    // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
160
0
    void set(uint32_t i, const llama_kv_cells & other) {
161
0
        assert(i + other.pos.size() <= pos.size());
162
0
163
0
        for (uint32_t j = 0; j < other.pos.size(); ++j) {
164
0
            const auto idx = i + j;
165
0
166
0
            if (pos[idx] == -1 && other.pos[j] != -1) {
167
0
                used.insert(i + j);
168
0
            }
169
0
170
0
            if (pos[idx] != -1 && other.pos[j] == -1) {
171
0
                used.erase(i + j);
172
0
            }
173
0
174
0
            if (pos[idx] != -1) {
175
0
                seq_pos_rm(i + j);
176
0
            }
177
0
178
0
            pos[idx] = other.pos[j];
179
0
            ext[idx] = other.ext[j];
180
0
            seq[idx] = other.seq[j];
181
0
182
0
            if (pos[idx] != -1) {
183
0
                seq_pos_add(i + j);
184
0
            }
185
0
186
0
            assert(shift[idx] == 0);
187
0
        }
188
0
    }
189
190
    // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
191
0
    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
192
0
        assert(idxs.size() == other.pos.size());
193
194
0
        for (uint32_t j = 0; j < other.pos.size(); ++j) {
195
0
            const auto idx = idxs[j];
196
197
0
            if (pos[idx] == -1 && other.pos[j] != -1) {
198
0
                used.insert(idx);
199
0
            }
200
201
0
            if (pos[idx] != -1 && other.pos[j] == -1) {
202
0
                used.erase(idx);
203
0
            }
204
205
0
            if (pos[idx] != -1) {
206
0
                seq_pos_rm(idx);
207
0
            }
208
209
0
            pos[idx] = other.pos[j];
210
0
            ext[idx] = other.ext[j];
211
0
            seq[idx] = other.seq[j];
212
213
0
            if (pos[idx] != -1) {
214
0
                seq_pos_add(idx);
215
0
            }
216
217
0
            assert(shift[idx] == 0);
218
0
        }
219
0
    }
220
221
    // clear a non-empty cell
222
0
    void rm(uint32_t i) {
223
0
        assert(i < pos.size());
224
0
        assert(pos[i] != -1);
225
226
0
        seq_pos_rm(i);
227
0
        seq[i].reset();
228
229
0
        pos[i] = -1;
230
0
        ext[i].reset();
231
0
        shift[i] = 0;
232
233
0
        used.erase(i);
234
0
    }
235
236
    // note: call only if the cell has seq_id
237
    // return true if the cell becomes empty
238
0
    bool seq_rm(uint32_t i, llama_seq_id seq_id) {
239
0
        assert(i < pos.size());
240
0
        assert(seq[i].test(seq_id));
241
0
        assert(pos[i] != -1);
242
0
        assert(seq_id >= 0);
243
244
0
        seq[i].reset(seq_id);
245
0
        seq_pos_dec(seq_id, pos[i]);
246
247
0
        if (seq[i].none()) {
248
0
            pos[i] = -1;
249
0
            ext[i].reset();
250
0
            shift[i] = 0;
251
252
0
            used.erase(i);
253
254
0
            return true;
255
0
        }
256
257
0
        return false;
258
0
    }
259
260
    // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
261
0
    bool seq_keep(uint32_t i, llama_seq_id seq_id) {
262
0
        assert(i < pos.size());
263
264
0
        if (seq[i].test(seq_id)) {
265
0
            seq_pos_rm(i);
266
0
            seq[i].reset();
267
268
0
            seq[i].set(seq_id);
269
0
            seq_pos_inc(seq_id, pos[i]);
270
271
0
            return false;
272
0
        }
273
274
0
        if (seq[i].any()) {
275
0
            seq_pos_rm(i);
276
0
            seq[i].reset();
277
278
0
            pos[i] = -1;
279
0
            ext[i].reset();
280
0
            shift[i] = 0;
281
282
0
            used.erase(i);
283
284
0
            return true;
285
0
        }
286
287
0
        assert(pos[i] == -1);
288
289
0
        return false;
290
0
    }
291
292
    // number of different sequences in the cell
293
0
    int seq_count(uint32_t i) const {
294
0
        assert(i < pos.size());
295
0
        assert(pos[i] != -1);
296
297
0
        return seq[i].count();
298
0
    }
299
300
    // check if the cell contains seq_id
301
0
    bool seq_has(uint32_t i, llama_seq_id seq_id) const {
302
0
        assert(i < pos.size());
303
0
        assert(seq_id >= 0);
304
305
0
        return seq[i].test(seq_id);
306
0
    }
307
308
    // note: call only if the cell is not empty and the seq_id is not in the cell
309
0
    void seq_add(uint32_t i, llama_seq_id seq_id) {
310
0
        assert(i < pos.size());
311
0
        assert(pos[i] != -1);
312
0
        assert(!seq[i].test(seq_id));
313
314
0
        seq[i].set(seq_id);
315
0
        seq_pos_inc(seq_id, pos[i]);
316
0
    }
317
318
    // return the sequence id of this cell
319
    // note: call only for cells with exactly one sequence
320
0
    llama_seq_id seq_get(uint32_t i) const {
321
0
        assert(seq[i].count() == 1);
322
323
0
        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
324
0
            if (seq[i].test(s)) {
325
0
                return s;
326
0
            }
327
0
        }
328
329
0
        return -1;
330
0
    }
331
332
    // the minimum position of sequence seq_id currently present in any of the cells
333
    // return -1 if the sequence is not present
334
0
    llama_pos seq_pos_min(llama_seq_id seq_id) const {
335
0
        assert(seq_id >= 0);
336
0
        assert(seq_id < LLAMA_MAX_SEQ);
337
338
0
        if (seq_pos[seq_id].empty()) {
339
0
            return -1;
340
0
        }
341
342
0
        assert(seq_pos[seq_id].begin()->second > 0);
343
344
0
        return seq_pos[seq_id].begin()->first;
345
0
    }
346
347
    // the maximum position of sequence seq_id currently present in any of the cells
348
    // return -1 if the sequence is not present
349
0
    llama_pos seq_pos_max(llama_seq_id seq_id) const {
350
0
        assert(seq_id >= 0);
351
0
        assert(seq_id < LLAMA_MAX_SEQ);
352
353
0
        if (seq_pos[seq_id].empty()) {
354
0
            return -1;
355
0
        }
356
357
0
        assert(seq_pos[seq_id].rbegin()->second > 0);
358
359
0
        return seq_pos[seq_id].rbegin()->first;
360
0
    }
361
362
    // note: call only if the cell is not empty
363
0
    llama_pos pos_get(uint32_t i) const {
364
0
        assert(i < pos.size());
365
0
        assert(pos[i] != -1);
366
367
0
        return pos[i];
368
0
    }
369
370
0
    const llama_kv_cell_ext & ext_get(uint32_t i) const {
371
0
        assert(i < pos.size());
372
0
        assert(pos[i] != -1);
373
374
0
        return ext[i];
375
0
    }
376
377
    // note: call only if the cell is not empty
378
0
    llama_pos get_shift(uint32_t i) const {
379
0
        assert(i < pos.size());
380
0
        assert(pos[i] != -1);
381
382
0
        return shift[i];
383
0
    }
384
385
    // check if a cell is not empty and its position is within [p0, p1)
386
0
    bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
387
0
        assert(i < pos.size());
388
389
0
        return pos[i] >= p0 && pos[i] < p1;
390
0
    }
391
392
    // set the position of an empty cell
393
    // does not modify "has_shift"
394
    // note: call only if the cell is empty
395
0
    void pos_set(uint32_t i, llama_pos p) {
396
0
        assert(i < pos.size());
397
0
        assert(pos[i] == -1);
398
0
        assert(seq[i].none());
399
400
0
        pos[i] = p;
401
402
0
        used.insert(i);
403
0
    }
404
405
0
    void ext_set(uint32_t i, llama_kv_cell_ext p) {
406
0
        assert(i < ext.size());
407
0
        ext[i] = p;
408
0
    }
409
410
    // pos[i] = pos[i] + d
411
    // sets "has_shift" to true
412
    // note: call only if the cell is not empty
413
0
    bool pos_add(uint32_t i, llama_pos d) {
414
0
        assert(i < pos.size());
415
0
        assert(pos[i] != -1);
416
417
0
        seq_pos_rm(i);
418
419
0
        pos[i]   += d;
420
0
        shift[i] += d;
421
422
0
        has_shift = true;
423
424
0
        if (pos[i] < 0) {
425
0
            seq[i].reset();
426
0
            pos[i] = -1;
427
0
            shift[i] = 0;
428
429
0
            used.erase(i);
430
431
0
            return true;
432
0
        }
433
434
0
        seq_pos_add(i);
435
436
0
        return false;
437
0
    }
438
439
    // pos[i] = pos[i] / d
440
    // sets "has_shift" to true
441
    // note: call only if the cell is not empty
442
0
    void pos_div(uint32_t i, int d) {
443
0
        assert(i < pos.size());
444
0
        assert(pos[i] != -1);
445
446
0
        const llama_pos p_old = pos[i];
447
448
0
        seq_pos_rm(i);
449
450
0
        pos[i]   /= d;
451
0
        shift[i] += p_old - pos[i];
452
453
0
        seq_pos_add(i);
454
455
0
        has_shift = true;
456
0
    }
457
458
private:
459
    bool has_shift = false;
460
461
    // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
462
    std::set<uint32_t> used;
463
464
    std::vector<llama_pos> pos;
465
466
    // stores extra info per cell
467
    std::vector<llama_kv_cell_ext> ext;
468
469
    // this array accumulates any applied shifts to the pos array since the last reset_shift() call
470
    // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
471
    //
472
    //   cells.pos_add(x, shift_x);
473
    //   cells.pos_div(y, shift_y);
474
    //   ...
475
    //
476
    //   if (cells.has_shift()) {
477
    //      for (int i = 0; i < n; ++i) {
478
    //          auto shift_i = cells.get_shift(i);
479
    //          ...
480
    //      }
481
    //      cells.reset_shift();
482
    //   }
483
    //
484
    std::vector<llama_pos> shift;
485
486
    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
487
488
    // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
489
    std::vector<seq_set_t> seq;
490
491
    // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
492
    // if the position p is not present, seq_pos[s][p] is not set
493
    // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
494
    //
495
    // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
496
    //  - during performing a cache reuse via (rm + add)
497
    //  - some vision models have input embeddings with repeating positions
498
    //
499
    std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
500
501
    // helper functions for updating `seq_pos`, once cell at a time:
502
503
0
    void seq_pos_dec(llama_seq_id s, llama_pos p) {
504
0
        auto it = seq_pos[s].find(p);
505
0
        assert(it != seq_pos[s].end());
506
507
0
        if (--it->second == 0) {
508
0
            seq_pos[s].erase(it);
509
0
        }
510
0
    }
511
512
0
    void seq_pos_inc(llama_seq_id s, llama_pos p) {
513
0
        seq_pos[s][p]++;
514
0
    }
515
516
    // remove cell i
517
0
    void seq_pos_rm(uint32_t i) {
518
0
        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
519
0
            if (seq[i].test(s)) {
520
0
                seq_pos_dec(s, pos[i]);
521
0
            }
522
0
        }
523
0
    }
524
525
    // add cell i
526
0
    void seq_pos_add(uint32_t i) {
527
0
        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
528
0
            if (seq[i].test(s)) {
529
0
                seq_pos_inc(s, pos[i]);
530
0
            }
531
0
        }
532
0
    }
533
};