Coverage Report

Created: 2025-12-28 06:26

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-kv-cache-iswa.cpp
Line
Count
Source
1
#include "llama-kv-cache-iswa.h"
2
3
#include "llama-impl.h"
4
#include "llama-batch.h"
5
#include "llama-model.h"
6
7
#include <algorithm>
8
#include <cassert>
9
10
//
11
// llama_kv_cache_iswa
12
//
13
14
llama_kv_cache_iswa::llama_kv_cache_iswa(
15
        const llama_model & model,
16
                ggml_type   type_k,
17
                ggml_type   type_v,
18
                     bool   v_trans,
19
                     bool   offload,
20
                     bool   swa_full,
21
                     bool   unified,
22
                 uint32_t   kv_size,
23
                 uint32_t   n_seq_max,
24
                 uint32_t   n_ubatch,
25
                 uint32_t   n_pad,
26
    const layer_filter_cb & filter,
27
0
    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
29
    // chain filters
30
0
    const layer_filter_cb filter_base = [&](int32_t il) {
31
0
        if (filter && !filter(il)) {
32
0
            return false;
33
0
        }
34
35
0
        return !model.hparams.is_swa(il);
36
0
    };
37
38
0
    const layer_filter_cb filter_swa  = [&](int32_t il) {
39
0
        if (filter && !filter(il)) {
40
0
            return false;
41
0
        }
42
43
0
        return  model.hparams.is_swa(il);
44
0
    };
45
46
0
    const uint32_t size_base = kv_size;
47
48
    // note: the SWA cache is always padded to 256 for performance
49
    //       https://github.com/ggml-org/llama.cpp/issues/17037
50
0
    uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
51
52
    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
53
0
    if (swa_full) {
54
0
        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
55
0
                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
56
57
0
        size_swa = size_base;
58
0
    }
59
60
0
    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
61
62
0
    kv_base = std::make_unique<llama_kv_cache>(
63
0
            model, type_k, type_v,
64
0
            v_trans, offload, unified, size_base, n_seq_max, n_pad,
65
0
            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
66
67
0
    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
68
69
0
    kv_swa = std::make_unique<llama_kv_cache>(
70
0
            model, type_k, type_v,
71
0
            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
72
0
            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
73
0
}
74
75
0
void llama_kv_cache_iswa::clear(bool data) {
76
0
    kv_base->clear(data);
77
0
    kv_swa ->clear(data);
78
0
}
79
80
0
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
81
0
    bool res = true;
82
83
0
    res = res & kv_base->seq_rm(seq_id, p0, p1);
84
0
    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
85
86
0
    return res;
87
0
}
88
89
0
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
90
0
    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
91
0
    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
92
0
}
93
94
0
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
95
0
    kv_base->seq_keep(seq_id);
96
0
    kv_swa ->seq_keep(seq_id);
97
0
}
98
99
0
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
100
0
    kv_base->seq_add(seq_id, p0, p1, shift);
101
0
    kv_swa ->seq_add(seq_id, p0, p1, shift);
102
0
}
103
104
0
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
105
0
    kv_base->seq_div(seq_id, p0, p1, d);
106
0
    kv_swa ->seq_div(seq_id, p0, p1, d);
107
0
}
108
109
0
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
110
    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
111
0
    return kv_swa->seq_pos_min(seq_id);
112
0
}
113
114
0
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
115
0
    return kv_swa->seq_pos_max(seq_id);
116
0
}
117
118
0
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
119
0
    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
120
0
    for (const auto & buft_size : kv_swa->memory_breakdown()) {
121
0
        mb[buft_size.first] += buft_size.second;
122
0
    }
123
0
    return mb;
124
0
}
125
126
0
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
127
0
    GGML_UNUSED(embd_all);
128
129
    // first try simple split
130
0
    do {
131
0
        if (!unified) {
132
            // requires equal splits, so we skip the simple split
133
0
            break;
134
0
        }
135
136
0
        balloc.split_reset();
137
138
0
        std::vector<llama_ubatch> ubatches;
139
0
        while (true) {
140
0
            auto ubatch = balloc.split_simple(n_ubatch);
141
142
0
            if (ubatch.n_tokens == 0) {
143
0
                break;
144
0
            }
145
146
0
            ubatches.push_back(std::move(ubatch)); // NOLINT
147
0
        }
148
149
0
        if (balloc.get_n_used() < balloc.get_n_tokens()) {
150
            // failed to find a suitable split
151
0
            break;
152
0
        }
153
154
0
        auto sinfos_base = kv_base->prepare(ubatches);
155
0
        if (sinfos_base.empty()) {
156
0
            break;
157
0
        }
158
159
0
        auto sinfos_swa = kv_swa->prepare(ubatches);
160
0
        if (sinfos_swa.empty()) {
161
0
            break;
162
0
        }
163
164
0
        assert(sinfos_base.size() == sinfos_swa.size());
165
166
0
        return std::make_unique<llama_kv_cache_iswa_context>(
167
0
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
168
0
    } while (false);
169
170
    // if it fails, try equal split
171
0
    do {
172
0
        balloc.split_reset();
173
174
0
        std::vector<llama_ubatch> ubatches;
175
0
        while (true) {
176
0
            auto ubatch = balloc.split_equal(n_ubatch, !unified);
177
178
0
            if (ubatch.n_tokens == 0) {
179
0
                break;
180
0
            }
181
182
0
            ubatches.push_back(std::move(ubatch)); // NOLINT
183
0
        }
184
185
0
        if (balloc.get_n_used() < balloc.get_n_tokens()) {
186
            // failed to find a suitable split
187
0
            break;
188
0
        }
189
190
0
        auto sinfos_base = kv_base->prepare(ubatches);
191
0
        if (sinfos_base.empty()) {
192
0
            break;
193
0
        }
194
195
0
        auto sinfos_swa = kv_swa->prepare(ubatches);
196
0
        if (sinfos_swa.empty()) {
197
0
            break;
198
0
        }
199
200
0
        assert(sinfos_base.size() == sinfos_swa.size());
201
202
0
        return std::make_unique<llama_kv_cache_iswa_context>(
203
0
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
204
0
    } while (false);
205
206
    // TODO: if we fail again, we should attempt different splitting strategies
207
    //       but to do that properly, we first have to refactor the batches to be more flexible
208
209
0
    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210
0
}
211
212
0
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
213
0
    return std::make_unique<llama_kv_cache_iswa_context>(this);
214
0
}
215
216
0
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
217
0
    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
218
0
}
219
220
0
bool llama_kv_cache_iswa::get_can_shift() const {
221
0
    return kv_base->get_size() == kv_swa->get_size();
222
0
}
223
224
0
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
225
0
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
226
0
        kv_base->state_write(io, seq_id, flags);
227
0
    }
228
229
0
    kv_swa->state_write(io, seq_id, flags);
230
0
}
231
232
0
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
233
0
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
234
0
        kv_base->state_read(io, seq_id, flags);
235
0
    }
236
237
0
    kv_swa->state_read(io, seq_id, flags);
238
0
}
239
240
0
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
241
0
    return kv_base.get();
242
0
}
243
244
0
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
245
0
    return kv_swa.get();
246
0
}
247
248
//
249
// llama_kv_cache_iswa_context
250
//
251
252
0
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
253
254
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
255
        llama_kv_cache_iswa * kv) :
256
0
    ctx_base(kv->get_base()->init_full()),
257
0
    ctx_swa (kv->get_swa ()->init_full()),
258
0
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
259
0
}
260
261
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
262
        llama_kv_cache_iswa * kv,
263
        llama_context * lctx,
264
        bool optimize) :
265
0
    ctx_base(kv->get_base()->init_update(lctx, optimize)),
266
0
    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
267
0
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
268
0
}
269
270
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
271
        llama_kv_cache_iswa * kv,
272
        slot_info_vec_t sinfos_base,
273
        slot_info_vec_t sinfos_swa,
274
        std::vector<llama_ubatch> ubatches) :
275
0
    ubatches(std::move(ubatches)),
276
    // note: here we copy the ubatches. not sure if this is ideal
277
0
    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
278
0
    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
279
0
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
280
0
}
281
282
0
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
283
284
0
bool llama_kv_cache_iswa_context::next() {
285
0
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
286
287
0
    ctx_base->next();
288
0
    ctx_swa ->next();
289
290
0
    if (++i_next >= ubatches.size()) {
291
0
        return false;
292
0
    }
293
294
0
    return true;
295
0
}
296
297
0
bool llama_kv_cache_iswa_context::apply() {
298
0
    assert(!llama_memory_status_is_fail(status));
299
300
0
    bool res = true;
301
302
0
    res = res & ctx_base->apply();
303
0
    res = res & ctx_swa ->apply();
304
305
0
    return res;
306
0
}
307
308
0
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
309
0
    return status;
310
0
}
311
312
0
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
313
0
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
314
315
0
    return ubatches[i_next];
316
0
}
317
318
0
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
319
0
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
320
321
0
    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
322
0
}
323
324
0
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
325
0
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
326
327
0
    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
328
0
}