Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.cpp
Line
Count
Source
1
#include "llama-graph.h"
2
3
#include "llama-impl.h"
4
#include "llama-model.h"
5
#include "llama-batch.h"
6
#include "llama-cparams.h"
7
8
#include "llama-kv-cache.h"
9
#include "llama-kv-cache-iswa.h"
10
#include "llama-kv-cache-dsa.h"
11
#include "llama-memory-hybrid.h"
12
#include "llama-memory-hybrid-iswa.h"
13
#include "llama-memory-recurrent.h"
14
15
#include <cassert>
16
#include <cmath>
17
#include <cstring>
18
#include <numeric>
19
#include <sstream>
20
#include <unordered_set>
21
22
// dedup helpers
23
24
static ggml_tensor * build_attn_inp_kq_mask(
25
        ggml_context * ctx,
26
        const llama_kv_cache_context * mctx,
27
        const llama_ubatch & ubatch,
28
0
        const llama_cparams & cparams) {
29
0
    const auto n_kv     = mctx->get_n_kv();
30
0
    const auto n_tokens = ubatch.n_tokens;
31
0
    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
32
33
    // flash attention requires an f16 mask
34
0
    const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
35
36
0
    ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
37
0
    ggml_set_input(res);
38
0
    ggml_set_name(res, "attn_inp_kq_mask");
39
40
0
    return res;
41
0
}
42
43
static bool can_reuse_kq_mask(
44
        ggml_tensor * kq_mask,
45
        const llama_kv_cache_context * mctx,
46
        const llama_ubatch & ubatch,
47
0
        const llama_cparams & cparams) {
48
0
    const auto n_kv     = mctx->get_n_kv();
49
0
    const auto n_tokens = ubatch.n_tokens;
50
0
    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
51
52
0
    bool res = true;
53
54
0
    res &= (kq_mask->ne[0] == n_kv);
55
0
    res &= (kq_mask->ne[1] == n_tokens/n_stream);
56
0
    res &= (kq_mask->ne[2] == 1);
57
0
    res &= (kq_mask->ne[3] == n_stream);
58
59
0
    return res;
60
0
}
61
62
// impl
63
64
static ggml_tensor * ggml_mul_mat_aux(
65
        ggml_context * ctx,
66
        ggml_tensor * cur,
67
0
        ggml_tensor * rot) {
68
0
    const auto n = rot->ne[0];
69
70
0
    ggml_tensor * res;
71
72
0
    if (!ggml_is_contiguous(cur)) {
73
0
        res = ggml_cont_2d   (ctx, cur, n, ggml_nelements(cur)/n);
74
0
    } else {
75
0
        res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
76
0
    }
77
0
    res = ggml_mul_mat   (ctx, rot, res);
78
0
    ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
79
0
    res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
80
81
0
    return res;
82
0
}
83
84
0
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
85
0
    if (ubatch->token) {
86
0
        const int64_t n_tokens = ubatch->n_tokens;
87
88
0
        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
89
0
    }
90
91
0
    if (ubatch->embd) {
92
0
        GGML_ASSERT(n_embd == embd->ne[0]);
93
94
0
        const int64_t n_tokens = ubatch->n_tokens;
95
96
0
        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
97
0
    }
98
0
}
99
100
0
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
101
0
    bool res = true;
102
103
0
    res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
104
0
    res &= (!params.ubatch.embd)  || (embd   &&   embd->ne[1] == params.ubatch.n_tokens);
105
106
0
    return res;
107
0
}
108
109
0
void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) {
110
0
    const int64_t n_tokens = ubatch->n_tokens;
111
112
0
    if (ubatch->token) {
113
0
        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
114
0
    } else {
115
        // note: mtmd embedding input goes through here
116
0
        GGML_ASSERT(ubatch->embd);
117
0
        GGML_ASSERT(n_embd == embd->ne[0]);
118
119
0
        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
120
0
    }
121
122
    // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states
123
    //       for now, we assume that the hidden state is always provided as an embedding
124
    //       ref: https://github.com/ggml-org/llama.cpp/pull/23643
125
0
    if (ubatch->embd) {
126
0
        GGML_ASSERT(n_embd == h->ne[0]);
127
128
0
        ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
129
0
    }
130
0
}
131
132
0
bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) {
133
0
    bool res = true;
134
135
0
    res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
136
0
    res &= (!params.ubatch.embd)  || (embd   && embd->ne[1]   == params.ubatch.n_tokens);
137
0
    res &= (!params.ubatch.embd)  || (h      && h->ne[1]      == params.ubatch.n_tokens);
138
139
0
    return res;
140
0
}
141
142
0
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
143
0
    if (ubatch->pos && pos) {
144
0
        const int64_t n_tokens = ubatch->n_tokens;
145
146
0
        if (ubatch->token && n_pos_per_embd == 4) {
147
            // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
148
            // the 3 first dims are the same, and 4th dim is all 0
149
0
            std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
150
            // copy the first dimension
151
0
            for (int i = 0; i < n_tokens; ++i) {
152
0
                pos_data[               i] = ubatch->pos[i];
153
0
                pos_data[    n_tokens + i] = ubatch->pos[i];
154
0
                pos_data[2 * n_tokens + i] = ubatch->pos[i];
155
0
                pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
156
0
            }
157
0
            ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
158
0
        } else {
159
0
            ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
160
0
        }
161
0
    }
162
0
}
163
164
0
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
165
0
    bool res = true;
166
167
0
    res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
168
169
0
    return res;
170
0
}
171
172
0
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
173
0
    if (ubatch->pos && attn_scale) {
174
0
        const int64_t n_tokens = ubatch->n_tokens;
175
176
0
        GGML_ASSERT(f_attn_temp_scale != 0.0f);
177
0
        GGML_ASSERT(n_attn_temp_floor_scale != 0);
178
179
0
        std::vector<float> attn_scale_data(n_tokens, 0.0f);
180
0
        for (int i = 0; i < n_tokens; ++i) {
181
0
            const float pos = ubatch->pos[i];
182
0
            attn_scale_data[i] = std::log(
183
0
                std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
184
0
            ) * f_attn_temp_scale + 1.0;
185
0
        }
186
187
0
        ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
188
0
    }
189
0
}
190
191
0
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
192
0
    if (pos_bucket) {
193
0
        const int64_t n_tokens = ubatch->n_tokens;
194
195
0
        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
196
0
        GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
197
198
0
        int32_t * data = (int32_t *) pos_bucket->data;
199
200
0
        for (int j = 0; j < n_tokens; ++j) {
201
0
            for (int i = 0; i < n_tokens; ++i) {
202
0
                data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
203
0
            }
204
0
        }
205
0
    }
206
0
}
207
208
0
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
209
0
    if (pos_bucket) {
210
0
        mctx->set_input_pos_bucket(pos_bucket, ubatch);
211
0
    }
212
0
}
213
214
0
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
215
0
    GGML_ASSERT(out_ids);
216
217
0
    const int64_t n_tokens = ubatch->n_tokens;
218
219
0
    GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
220
0
    int32_t * data = (int32_t *) out_ids->data;
221
222
0
    if (n_outputs == n_tokens) {
223
0
        for (int i = 0; i < n_tokens; ++i) {
224
0
            data[i] = i;
225
0
        }
226
227
0
        return;
228
0
    }
229
230
0
    GGML_ASSERT(ubatch->output);
231
232
0
    int n_outputs = 0;
233
234
0
    for (int i = 0; i < n_tokens; ++i) {
235
0
        if (ubatch->output[i]) {
236
0
            data[n_outputs++] = i;
237
0
        }
238
0
    }
239
0
}
240
241
0
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
242
0
    bool res = true;
243
244
0
    res &= n_outputs == params.n_outputs;
245
246
0
    return res;
247
0
}
248
249
0
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
250
0
    if (cparams.embeddings   &&
251
0
       (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
252
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
253
254
0
        const int64_t n_tokens     = ubatch->n_tokens;
255
0
        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
256
0
        const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
257
258
0
        GGML_ASSERT(mean);
259
0
        GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
260
261
0
        float * data = (float *) mean->data;
262
0
        memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
263
264
0
        std::vector<uint64_t> sums(n_seqs_unq, 0);
265
0
        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
266
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
267
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
268
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
269
270
0
                sums[seq_idx] += ubatch->n_seq_tokens;
271
0
            }
272
0
        }
273
274
0
        std::vector<float> div(n_seqs_unq, 0.0f);
275
0
        for (int s = 0; s < n_seqs_unq; ++s) {
276
0
            const uint64_t sum = sums[s];
277
0
            if (sum > 0) {
278
0
                div[s] = 1.0f/float(sum);
279
0
            }
280
0
        }
281
282
0
        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
283
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
284
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
285
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
286
287
0
                for (int j = 0; j < n_seq_tokens; ++j) {
288
0
                    data[seq_idx*n_tokens + i + j] = div[seq_idx];
289
0
                }
290
0
            }
291
0
        }
292
0
    }
293
0
}
294
295
0
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
296
0
    const int64_t n_tokens     = ubatch->n_tokens;
297
0
    const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
298
299
0
    if (cparams.embeddings && (
300
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  ||
301
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
302
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
303
0
    )) {
304
0
        GGML_ASSERT(cls);
305
0
        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
306
307
0
        uint32_t * data = (uint32_t *) cls->data;
308
0
        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
309
310
0
        std::vector<int> target_pos(n_seqs_unq, -1);
311
0
        std::vector<int> target_row(n_seqs_unq, -1);
312
313
0
        const bool last = (
314
0
             cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
315
0
            (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
316
0
        );
317
318
0
        for (int i = 0; i < n_tokens; ++i) {
319
0
            const llama_pos pos = ubatch->pos[i];
320
321
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
322
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
323
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
324
325
0
                if (
326
0
                    (target_pos[seq_idx] == -1) ||
327
0
                    ( last && pos >= target_pos[seq_idx]) ||
328
0
                    (!last && pos <  target_pos[seq_idx])
329
0
                ) {
330
0
                    target_pos[seq_idx] = pos;
331
0
                    target_row[seq_idx] = i;
332
0
                }
333
0
            }
334
0
        }
335
336
0
        for (int s = 0; s < n_seqs_unq; ++s) {
337
0
            if (target_row[s] >= 0) {
338
0
                data[s] = target_row[s];
339
0
            }
340
0
        }
341
0
    }
342
0
}
343
344
0
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
345
0
    GGML_UNUSED(ubatch);
346
347
0
    const int64_t n_rs = mctx->get_n_rs();
348
349
0
    if (s_copy) {
350
0
        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
351
0
        int32_t * data = (int32_t *) s_copy->data;
352
353
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
354
0
        for (uint32_t i = 0; i < n_rs; ++i) {
355
0
            data[i] = mctx->s_copy(i);
356
0
        }
357
0
    }
358
0
}
359
360
0
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
361
0
    const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
362
363
0
    this->mctx = mctx;
364
365
0
    bool res = true;
366
367
0
    res &= s_copy->ne[0] == mctx->get_n_rs();
368
369
0
    res &= s_copy_main->ne[0]  == params.ubatch.n_seqs;
370
0
    res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
371
372
0
    res &= head == mctx->get_head();
373
0
    res &= rs_z == mctx->get_rs_z();
374
375
0
    return res;
376
0
}
377
378
0
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
379
0
    GGML_UNUSED(ubatch);
380
381
0
    if (cross_embd && !cross->v_embd.empty()) {
382
0
        assert(cross_embd->type == GGML_TYPE_F32);
383
384
0
        ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
385
0
    }
386
0
}
387
388
template <typename T>
389
0
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
390
0
    LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
391
0
    const char * swa_type_str = "unknown";
392
393
0
    switch (swa_type) {
394
0
        case LLAMA_SWA_TYPE_NONE:      swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
395
0
        case LLAMA_SWA_TYPE_STANDARD:  swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
396
0
        case LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
397
0
        case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
398
0
    };
399
400
0
    LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
401
0
    LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
402
0
    LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
403
404
0
    LLAMA_LOG_DEBUG("    ");
405
0
    for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
406
0
        LLAMA_LOG_DEBUG("%2d", j);
407
0
    }
408
0
    LLAMA_LOG_DEBUG("\n");
409
410
0
    for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
411
0
        LLAMA_LOG_DEBUG(" %2d ", i);
412
0
        for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
413
0
            float val = llama_cast<float>(data[i * n_kv + j]);
414
0
            if (val == -INFINITY) {
415
0
                LLAMA_LOG_DEBUG(" ∞");
416
0
            } else {
417
0
                LLAMA_LOG_DEBUG(" 0");
418
0
            }
419
0
        }
420
0
        LLAMA_LOG_DEBUG("\n");
421
0
    }
422
0
}
Unexecuted instantiation: llama-graph.cpp:void print_mask<unsigned short>(unsigned short const*, long, long, long, llama_swa_type)
Unexecuted instantiation: llama-graph.cpp:void print_mask<float>(float const*, long, long, long, llama_swa_type)
423
424
0
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
425
0
    const int64_t n_kv     = ubatch->n_tokens;
426
0
    const int64_t n_tokens = ubatch->n_tokens;
427
428
0
    const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
429
0
        using T = std::remove_reference_t<decltype(*data)>;
430
0
        std::fill(data, data + ne, llama_cast<T>(-INFINITY));
431
432
0
        for (int i1 = 0; i1 < n_tokens; ++i1) {
433
0
            const llama_seq_id s1 = ubatch->seq_id[i1][0];
434
0
            const llama_pos    p1 = ubatch->pos[i1];
435
436
0
            const uint64_t idst = i1*n_kv;
437
438
0
            for (int i0 = 0; i0 < n_tokens; ++i0) {
439
0
                const llama_seq_id s0 = ubatch->seq_id[i0][0];
440
0
                const llama_pos p0    = ubatch->pos[i0];
441
442
                // mask different sequences
443
0
                if (s0 != s1) {
444
0
                    continue;
445
0
                }
446
447
                // mask future tokens
448
0
                if (cparams.causal_attn && p0 > p1) {
449
0
                    continue;
450
0
                }
451
452
                // apply SWA if any
453
0
                if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
454
0
                    continue;
455
0
                }
456
457
0
                data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
458
0
            }
459
0
        }
460
461
0
        if (debug) {
462
0
            print_mask(data, n_tokens, n_kv, n_swa, swa_type);
463
0
        }
464
0
    };
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_no_cache::set_input(llama_ubatch const*)::$_0::operator()<unsigned short>(unsigned short*, long, int, llama_swa_type) const
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_no_cache::set_input(llama_ubatch const*)::$_0::operator()<float>(float*, long, int, llama_swa_type) const
465
466
0
    GGML_ASSERT(self_kq_mask);
467
0
    GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
468
0
    if (self_kq_mask->type == GGML_TYPE_F16) {
469
0
        fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
470
0
    } else {
471
0
        fill_mask((float       *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
472
0
    }
473
474
0
    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
475
0
        GGML_ASSERT(self_kq_mask_swa);
476
0
        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
477
0
        if (self_kq_mask_swa->type == GGML_TYPE_F16) {
478
0
            fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
479
0
        } else {
480
0
            fill_mask((float       *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
481
0
        }
482
0
    }
483
0
}
484
485
0
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
486
0
    mctx->set_input_k_idxs(self_k_idxs, ubatch);
487
0
    mctx->set_input_v_idxs(self_v_idxs, ubatch);
488
489
0
    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
490
491
0
    if (self_k_rot) {
492
0
        mctx->set_input_k_rot(self_k_rot);
493
0
    }
494
495
0
    if (self_v_rot) {
496
0
        mctx->set_input_v_rot(self_v_rot);
497
0
    }
498
0
}
499
500
0
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
501
0
    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
502
503
0
    this->mctx = mctx;
504
505
0
    bool res = true;
506
507
0
    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
508
  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
509
510
0
    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
511
512
0
    return res;
513
0
}
514
515
0
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
516
0
    mctx->set_input_k_idxs(self_k_idxs, ubatch);
517
518
0
    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
519
0
}
520
521
0
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
522
0
    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
523
524
0
    this->mctx = mctx;
525
526
0
    bool res = true;
527
528
0
    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
529
530
0
    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
531
532
0
    return res;
533
0
}
534
535
0
void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) {
536
0
    mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch);
537
538
0
    mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn);
539
540
0
    mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch);
541
542
0
    mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn);
543
544
0
    mctx->get_lid()->set_input_k_rot(self_k_rot_lid);
545
0
}
546
547
0
bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) {
548
0
    const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx);
549
550
0
    this->mctx = mctx;
551
552
0
    bool res = true;
553
554
0
    res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens;
555
0
    res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens;
556
557
0
    res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams);
558
0
    res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams);
559
560
0
    return res;
561
0
}
562
563
0
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
564
    // base tensors may not be allocated if there are no non-SWA attention layers
565
0
    if (self_k_idxs && self_k_idxs->buffer) {
566
0
        mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
567
0
        mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
568
0
    }
569
570
    // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
571
0
    if (self_kq_mask && self_kq_mask->buffer) {
572
0
        mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
573
0
    }
574
575
    // swa tensors may not be allocated if there are no SWA attention layers
576
0
    if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
577
0
        mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
578
0
        mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
579
0
    }
580
581
0
    if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
582
0
        mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
583
0
    }
584
585
0
    if (self_k_rot) {
586
0
        mctx->get_base()->set_input_k_rot(self_k_rot);
587
0
    }
588
589
0
    if (self_v_rot) {
590
0
        mctx->get_base()->set_input_v_rot(self_v_rot);
591
0
    }
592
593
0
    if (self_k_rot_swa) {
594
0
        mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
595
0
    }
596
597
0
    if (self_v_rot_swa) {
598
0
        mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
599
0
    }
600
0
}
601
602
0
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
603
0
    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
604
605
0
    this->mctx = mctx;
606
607
0
    bool res = true;
608
609
    // base tensors may not be allocated if there are no non-SWA attention layers
610
0
    if (self_k_idxs && self_k_idxs->buffer) {
611
0
        res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
612
      //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
613
0
    }
614
615
0
    if (self_kq_mask && self_kq_mask->buffer) {
616
0
        res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
617
0
    }
618
619
    // swa tensors may not be allocated if there are no SWA attention layers
620
0
    if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
621
0
        res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
622
      //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
623
0
    }
624
625
0
    if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
626
0
        res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
627
0
    }
628
629
0
    return res;
630
0
}
631
632
0
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
633
0
    GGML_ASSERT(cross_kq_mask);
634
635
0
    const int64_t n_enc    = cross_kq_mask->ne[0];
636
0
    const int64_t n_tokens = ubatch->n_tokens;
637
638
0
    GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
639
0
    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
640
641
0
    const auto fill_mask = [&](auto * data) {
642
0
        using T = std::remove_reference_t<decltype(*data)>;
643
0
        for (int i = 0; i < n_tokens; ++i) {
644
0
            GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
645
0
            for (int j = 0; j < n_enc; ++j) {
646
0
                float f = -INFINITY;
647
648
0
                for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
649
0
                    const llama_seq_id seq_id = ubatch->seq_id[i][s];
650
651
0
                    if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
652
0
                        f = 0.0f;
653
0
                    }
654
0
                }
655
656
0
                data[i*n_enc + j] = llama_cast<T>(f);
657
0
            }
658
0
        }
659
0
    };
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_cross::set_input(llama_ubatch const*)::$_0::operator()<unsigned short>(unsigned short*) const
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_cross::set_input(llama_ubatch const*)::$_0::operator()<float>(float*) const
660
661
0
    if (cross_kq_mask->type == GGML_TYPE_F16) {
662
0
        fill_mask((ggml_fp16_t *) cross_kq_mask->data);
663
0
    } else {
664
0
        fill_mask((float *) cross_kq_mask->data);
665
0
    }
666
0
}
667
668
0
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
669
0
    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
670
0
    mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
671
672
0
    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
673
674
0
    if (inp_attn->self_k_rot) {
675
0
        mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
676
0
    }
677
678
0
    if (inp_attn->self_v_rot) {
679
0
        mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
680
0
    }
681
682
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
683
684
0
    if (inp_rs->s_copy) {
685
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
686
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
687
688
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
689
0
        for (uint32_t i = 0; i < n_rs; ++i) {
690
0
            data[i] = mctx->get_recr()->s_copy(i);
691
0
        }
692
0
    }
693
0
}
694
695
0
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
696
0
    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
697
698
0
    this->mctx = mctx;
699
700
0
    bool res = true;
701
702
0
    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
703
  //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
704
705
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
706
707
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
708
709
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
710
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
711
712
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
713
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
714
715
0
    return res;
716
0
}
717
718
// TODO: Hybrid input classes are a bit redundant.
719
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
720
// Refactoring is required in the future.
721
0
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
722
0
    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
723
724
0
    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
725
726
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
727
728
0
    if (inp_rs->s_copy) {
729
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
730
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
731
732
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
733
0
        for (uint32_t i = 0; i < n_rs; ++i) {
734
0
            data[i] = mctx->get_recr()->s_copy(i);
735
0
        }
736
0
    }
737
0
}
738
739
0
bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
740
0
    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
741
742
0
    this->mctx = mctx;
743
744
0
    bool res = true;
745
746
0
    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
747
748
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
749
750
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
751
752
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
753
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
754
755
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
756
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
757
758
0
    return res;
759
0
}
760
761
0
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
762
0
    const auto * attn_ctx = mctx->get_attn();
763
764
    // base tensors may not be allocated if there are no non-SWA attention layers
765
0
    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
766
0
        attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
767
0
        attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
768
0
    }
769
770
0
    if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
771
0
        attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
772
0
    }
773
774
    // swa tensors may not be allocated if there are no SWA attention layers
775
0
    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
776
0
        attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
777
0
        attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
778
0
    }
779
780
0
    if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
781
0
        attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
782
0
    }
783
784
0
    if (inp_attn->self_k_rot) {
785
0
        attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
786
0
    }
787
788
0
    if (inp_attn->self_v_rot) {
789
0
        attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
790
0
    }
791
792
0
    if (inp_attn->self_k_rot_swa) {
793
0
        attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
794
0
    }
795
796
0
    if (inp_attn->self_v_rot_swa) {
797
0
        attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
798
0
    }
799
800
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
801
802
0
    if (inp_rs->s_copy) {
803
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
804
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
805
806
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
807
0
        for (uint32_t i = 0; i < n_rs; ++i) {
808
0
            data[i] = mctx->get_recr()->s_copy(i);
809
0
        }
810
0
    }
811
0
}
812
813
0
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
814
0
    const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
815
816
0
    this->mctx = mctx;
817
818
0
    bool res = true;
819
820
0
    const auto * attn_ctx = mctx->get_attn();
821
822
    // base tensors may not be allocated if there are no non-SWA attention layers
823
0
    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
824
0
        res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
825
      //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
826
0
    }
827
828
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
829
830
    // swa tensors may not be allocated if there are no SWA attention layers
831
0
    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
832
0
        res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
833
      //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
834
0
    }
835
836
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
837
838
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
839
840
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
841
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
842
843
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
844
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
845
846
0
    return res;
847
0
}
848
849
0
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
850
    // set the inputs only for the active samplers in the current ubatch
851
0
    std::unordered_set<llama_seq_id> active_samplers;
852
0
    for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
853
0
        if (ubatch->output[i]) {
854
0
            llama_seq_id seq_id = ubatch->seq_id[i][0];
855
0
            active_samplers.insert(seq_id);
856
0
        }
857
0
    }
858
859
0
    for (auto seq_id : active_samplers) {
860
0
        if (samplers.find(seq_id) == samplers.end()) {
861
0
            continue;
862
0
        }
863
864
0
        auto & sampler = samplers[seq_id];
865
866
0
        if (sampler->iface->backend_set_input) {
867
0
            sampler->iface->backend_set_input(sampler);
868
0
        }
869
0
    }
870
0
}
871
872
0
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
873
0
    if (samplers.size() != params.samplers.size()) {
874
0
        return false;
875
0
    }
876
877
0
    for (const auto & [seq_id, sampler] : params.samplers) {
878
0
        if (samplers[seq_id] != sampler) {
879
0
            return false;
880
0
        }
881
0
    }
882
883
0
    return true;
884
0
}
885
886
//
887
// llm_graph_result
888
//
889
890
0
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
891
0
    reset();
892
893
0
    const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
894
0
    debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
895
0
}
896
897
0
int64_t llm_graph_result::get_max_nodes() const {
898
0
    return max_nodes;
899
0
}
900
901
0
void llm_graph_result::reset() {
902
0
    t_inp_tokens  = nullptr;
903
0
    t_inp_embd    = nullptr;
904
0
    t_logits      = nullptr;
905
0
    t_embd        = nullptr;
906
0
    t_embd_pooled = nullptr;
907
908
0
    t_layer_inp.resize(LLAMA_MAX_LAYERS);
909
0
    std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
910
911
0
    t_sampled.clear();
912
0
    t_sampled_probs.clear();
913
0
    t_sampled_logits.clear();
914
0
    t_candidates.clear();
915
916
0
    params = {};
917
918
0
    inputs.clear();
919
920
0
    buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
921
922
0
    ggml_init_params params = {
923
0
        /*.mem_size   =*/ buf_compute_meta.size(),
924
0
        /*.mem_buffer =*/ buf_compute_meta.data(),
925
0
        /*.no_alloc   =*/ true,
926
0
    };
927
928
0
    ctx_compute.reset(ggml_init(params));
929
930
0
    gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
931
0
}
932
933
0
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
934
0
    for (auto & input : inputs) {
935
0
        input->set_input(ubatch);
936
0
    }
937
0
}
938
939
0
void llm_graph_result::set_outputs(const llm_graph_params & params) {
940
0
    if (t_logits != nullptr) {
941
0
        ggml_set_output(t_logits);
942
0
    }
943
0
    if (t_embd != nullptr) {
944
0
        ggml_set_output(t_embd);
945
0
    }
946
0
    if (t_embd_pooled != nullptr) {
947
0
        ggml_set_output(t_embd_pooled);
948
0
    }
949
0
    if (t_h_nextn != nullptr) {
950
0
        ggml_set_output(t_h_nextn);
951
0
    }
952
0
    {
953
0
        const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp;
954
0
        for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) {
955
0
            if (embeddings_layer_inp[il]) {
956
0
                GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null");
957
0
                ggml_set_output(t_layer_inp[il]);
958
0
            }
959
0
        }
960
0
    }
961
0
    for (auto & [seq_id, t] : t_sampled) {
962
0
        if (t != nullptr) {
963
0
            ggml_set_output(t);
964
0
        }
965
0
    }
966
0
    for (auto & [seq_id, t] : t_sampled_probs) {
967
0
        if (t != nullptr) {
968
0
            ggml_set_output(t);
969
0
        }
970
0
    }
971
0
    for (auto & [seq_id, t] : t_sampled_logits) {
972
0
        if (t != nullptr) {
973
0
            ggml_set_output(t);
974
0
        }
975
0
    }
976
0
    for (auto & [seq_id, t] : t_candidates) {
977
0
        if (t != nullptr) {
978
0
            ggml_set_output(t);
979
0
        }
980
0
    }
981
0
}
982
983
0
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
984
0
    if (!this->params.allow_reuse(params)) {
985
0
        if (debug > 1) {
986
0
            LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
987
0
        }
988
989
0
        return false;
990
0
    }
991
992
0
    if (debug > 1) {
993
0
        LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
994
0
    }
995
996
0
    bool res = true;
997
998
0
    for (auto & input : inputs) {
999
0
        const bool cur = input->can_reuse(params);
1000
1001
0
        if (debug > 1) {
1002
0
            LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
1003
0
        }
1004
1005
0
        res = res && cur;
1006
0
    }
1007
1008
0
    if (debug > 0) {
1009
0
        LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
1010
0
    }
1011
1012
0
    return res;
1013
0
}
1014
1015
0
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
1016
0
    inputs.emplace_back(std::move(input));
1017
0
    return inputs.back().get();
1018
0
}
1019
1020
0
void llm_graph_result::set_params(const llm_graph_params & params) {
1021
0
    this->params = params;
1022
0
}
1023
1024
//
1025
// llm_graph_context
1026
//
1027
1028
llm_graph_context::llm_graph_context(const llm_graph_params & params) :
1029
0
    arch             (params.arch),
1030
0
    hparams          (params.hparams),
1031
0
    cparams          (params.cparams),
1032
0
    ubatch           (params.ubatch),
1033
0
    n_embd           (hparams.n_embd),
1034
0
    n_layer          (hparams.n_layer()),
1035
0
    n_layer_nextn    (hparams.n_layer_nextn),
1036
0
    n_rot            (hparams.n_rot()),
1037
0
    n_ctx            (cparams.n_ctx),
1038
0
    n_head           (hparams.n_head()),
1039
0
    n_head_kv        (hparams.n_head_kv()),
1040
0
    n_embd_head_k    (hparams.n_embd_head_k()),
1041
0
    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
1042
0
    n_embd_head_v    (hparams.n_embd_head_v()),
1043
0
    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
1044
0
    n_expert         (hparams.n_expert),
1045
0
    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
1046
0
    freq_base        (cparams.rope_freq_base),
1047
0
    freq_scale       (cparams.rope_freq_scale),
1048
0
    ext_factor       (cparams.yarn_ext_factor),
1049
0
    attn_factor      (cparams.yarn_attn_factor),
1050
0
    beta_fast        (cparams.yarn_beta_fast),
1051
0
    beta_slow        (cparams.yarn_beta_slow),
1052
0
    norm_eps         (hparams.f_norm_eps),
1053
0
    norm_rms_eps     (hparams.f_norm_rms_eps),
1054
0
    n_tokens         (ubatch.n_tokens),
1055
0
    n_outputs        (params.n_outputs),
1056
0
    n_ctx_orig       (cparams.n_ctx_orig_yarn),
1057
0
    pooling_type     (cparams.pooling_type),
1058
0
    rope_type        (hparams.rope_type),
1059
0
    sched            (params.sched),
1060
0
    backend_cpu      (params.backend_cpu),
1061
0
    cvec             (params.cvec),
1062
0
    loras            (params.loras),
1063
0
    mctx             (params.mctx),
1064
0
    cross            (params.cross),
1065
0
    samplers         (params.samplers),
1066
0
    cb_func          (params.cb),
1067
0
    res              (params.res),
1068
0
    ctx0             (res->get_ctx()),
1069
0
    gf               (res->get_gf()) {
1070
0
        res->set_params(params);
1071
0
    }
1072
1073
0
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
1074
0
    if (cb_func) {
1075
0
        cb_func(ubatch, cur, name, il);
1076
0
    }
1077
0
}
1078
1079
ggml_tensor * llm_graph_context::build_cvec(
1080
         ggml_tensor * cur,
1081
0
                 int   il) const {
1082
0
    return cvec->apply_to(ctx0, cur, il);
1083
0
}
1084
1085
ggml_tensor * llm_graph_context::build_lora_mm(
1086
          ggml_tensor * w,
1087
          ggml_tensor * cur,
1088
0
          ggml_tensor * w_s) const {
1089
0
    ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
1090
1091
0
    if (w_s) {
1092
0
        res = ggml_mul(ctx0, res, w_s);
1093
0
    }
1094
1095
0
    for (const auto & lora : *loras) {
1096
0
        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1097
0
        if (lw == nullptr) {
1098
0
            continue;
1099
0
        }
1100
1101
0
        const float adapter_scale = lora.second;
1102
0
        const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1103
1104
0
        ggml_tensor * ab_cur = ggml_mul_mat(
1105
0
                ctx0, lw->b,
1106
0
                ggml_mul_mat(ctx0, lw->a, cur)
1107
0
                );
1108
1109
0
        ab_cur = ggml_scale(ctx0, ab_cur, scale);
1110
0
        res = ggml_add(ctx0, res, ab_cur);
1111
0
    }
1112
1113
0
    return res;
1114
0
}
1115
1116
ggml_tensor * llm_graph_context::build_lora_mm_id(
1117
          ggml_tensor * w,   // ggml_tensor * as
1118
          ggml_tensor * cur, // ggml_tensor * b
1119
          ggml_tensor * ids,
1120
0
          ggml_tensor * w_s) const {
1121
0
    ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
1122
1123
0
    if (w_s) {
1124
0
        const int64_t n_expert = w_s->ne[0];
1125
0
        const int64_t n_tokens = cur->ne[2];
1126
0
        ggml_tensor * s = ggml_reshape_3d(ctx0, w_s, 1, n_expert, 1);
1127
0
        s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1128
0
        s = ggml_get_rows(ctx0, s, ids);
1129
0
        res = ggml_mul(ctx0, res, s);
1130
0
    }
1131
0
    for (const auto & lora : *loras) {
1132
0
        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1133
0
        if (lw == nullptr) {
1134
0
            continue;
1135
0
        }
1136
1137
0
        const float alpha = lora.first->alpha;
1138
0
        const float rank  = (float) lw->b->ne[0];
1139
0
        const float scale = alpha ? lora.second * alpha / rank : lora.second;
1140
1141
0
        ggml_tensor * ab_cur = ggml_mul_mat_id(
1142
0
                ctx0, lw->b,
1143
0
                ggml_mul_mat_id(ctx0, lw->a, cur, ids),
1144
0
                ids
1145
0
                );
1146
1147
0
        ab_cur = ggml_scale(ctx0, ab_cur, scale);
1148
0
        res = ggml_add(ctx0, res, ab_cur);
1149
0
    }
1150
1151
0
    return res;
1152
0
}
1153
1154
ggml_tensor * llm_graph_context::build_norm(
1155
         ggml_tensor * cur,
1156
         ggml_tensor * mw,
1157
         ggml_tensor * mb,
1158
       llm_norm_type   type,
1159
0
                 int   il) const {
1160
0
    switch (type) {
1161
0
        case LLM_NORM:       cur = ggml_norm    (ctx0, cur, hparams.f_norm_eps);     break;
1162
0
        case LLM_NORM_RMS:   cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
1163
0
        case LLM_NORM_GROUP:
1164
0
            {
1165
0
                cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
1166
0
                cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
1167
0
                cur = ggml_reshape_2d(ctx0, cur, cur->ne[0],    cur->ne[2]);
1168
0
            } break;
1169
0
    }
1170
1171
0
    if (mw || mb) {
1172
0
        cb(cur, "norm", il);
1173
0
    }
1174
1175
0
    if (mw) {
1176
0
        cur = ggml_mul(ctx0, cur, mw);
1177
0
        if (mb) {
1178
0
            cb(cur, "norm_w", il);
1179
0
        }
1180
0
    }
1181
1182
0
    if (mb) {
1183
0
        cur = ggml_add(ctx0, cur, mb);
1184
0
    }
1185
1186
0
    return cur;
1187
0
}
1188
1189
1190
llm_graph_qkv llm_graph_context::build_qkv(
1191
        const llama_layer & layer,
1192
              ggml_tensor * cur,
1193
                  int64_t   n_embd_head,
1194
                  int64_t   n_head,
1195
                  int64_t   n_head_kv,
1196
0
                      int   il) const {
1197
0
    const int64_t n_embd_q  = n_embd_head * n_head;
1198
0
    const int64_t n_embd_kv = n_embd_head * n_head_kv;
1199
1200
0
    ggml_tensor * Qcur, * Kcur, * Vcur;
1201
1202
0
    if (layer.wqkv) {
1203
        // fused QKV path
1204
0
        ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s);
1205
0
        cb(qkv, "wqkv", il);
1206
0
        if (layer.wqkv_b) {
1207
0
            qkv = ggml_add(ctx0, qkv, layer.wqkv_b);
1208
0
            cb(qkv, "wqkv_b", il);
1209
0
        }
1210
0
        if (hparams.f_clamp_kqv > 0.0f) {
1211
0
            qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1212
0
            cb(qkv, "wqkv_clamped", il);
1213
0
        }
1214
0
        Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head,    n_tokens,
1215
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0);
1216
0
        Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
1217
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
1218
0
            ggml_row_size(qkv->type, n_embd_q));
1219
0
        Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
1220
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
1221
0
            ggml_row_size(qkv->type, n_embd_q + n_embd_kv));
1222
0
    } else {
1223
        // separate Q/K/V path
1224
0
        Qcur = build_lora_mm(layer.wq, cur, layer.wq_s);
1225
0
        cb(Qcur, "Qcur", il);
1226
0
        if (layer.wq_b) {
1227
0
            Qcur = ggml_add(ctx0, Qcur, layer.wq_b);
1228
0
            cb(Qcur, "Qcur", il);
1229
0
        }
1230
0
        if (hparams.f_clamp_kqv > 0.0f) {
1231
0
            Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1232
0
            cb(Qcur, "Qcur_clamped", il);
1233
0
        }
1234
0
        Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
1235
0
        cb(Kcur, "Kcur", il);
1236
0
        if (layer.wk_b) {
1237
0
            Kcur = ggml_add(ctx0, Kcur, layer.wk_b);
1238
0
            cb(Kcur, "Kcur", il);
1239
0
        }
1240
0
        if (hparams.f_clamp_kqv > 0.0f) {
1241
0
            Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1242
0
            cb(Kcur, "Kcur_clamped", il);
1243
0
        }
1244
0
        Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
1245
0
        cb(Vcur, "Vcur", il);
1246
0
        if (layer.wv_b) {
1247
0
            Vcur = ggml_add(ctx0, Vcur, layer.wv_b);
1248
0
            cb(Vcur, "Vcur", il);
1249
0
        }
1250
0
        if (hparams.f_clamp_kqv > 0.0f) {
1251
0
            Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1252
0
            cb(Vcur, "Vcur_clamped", il);
1253
0
        }
1254
0
        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
1255
0
        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
1256
0
        Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1257
0
    }
1258
1259
0
    cb(Qcur, "Qcur", il);
1260
0
    cb(Kcur, "Kcur", il);
1261
0
    cb(Vcur, "Vcur", il);
1262
1263
0
    return { Qcur, Kcur, Vcur };
1264
0
}
1265
1266
1267
ggml_tensor * llm_graph_context::build_ffn(
1268
         ggml_tensor * cur,
1269
         ggml_tensor * up,
1270
         ggml_tensor * up_b,
1271
         ggml_tensor * up_s,
1272
         ggml_tensor * gate,
1273
         ggml_tensor * gate_b,
1274
         ggml_tensor * gate_s,
1275
         ggml_tensor * down,
1276
         ggml_tensor * down_b,
1277
         ggml_tensor * down_s,
1278
         ggml_tensor * act_scales,
1279
     llm_ffn_op_type   type_op,
1280
   llm_ffn_gate_type   type_gate,
1281
0
                 int   il) const {
1282
    // NVFP4 support is currently restricted to
1283
    // 1) LORA absence (*_s would be applied after LORA residual, which is incorrect)
1284
    // 2) bias absense (*_s would be applied after bias addition, which is incorrect)
1285
    // TODO: disambiguate LLM-architectural scales (which use *_s) from NVFP4 scale_2 (which also uses *_s currently)
1286
0
    auto has_lora = [this](ggml_tensor * w) {
1287
0
        if (!w) {
1288
0
            return false;
1289
0
        }
1290
0
        for (const auto & lora : *loras) {
1291
0
            if (lora.first->get_weight(w) != nullptr) {
1292
0
                return true;
1293
0
            }
1294
0
        }
1295
0
        return false;
1296
0
    };
1297
1298
0
    GGML_ASSERT(!up_s   || !up_b   || !up   || up->type   != GGML_TYPE_NVFP4);
1299
0
    GGML_ASSERT(!gate_s || !gate_b || !gate || gate->type != GGML_TYPE_NVFP4);
1300
0
    GGML_ASSERT(!down_s || !down_b || !down || down->type != GGML_TYPE_NVFP4);
1301
0
    GGML_ASSERT(!up_s   || !up   || up->type   != GGML_TYPE_NVFP4 || !has_lora(up));
1302
0
    GGML_ASSERT(!gate_s || !gate || gate->type != GGML_TYPE_NVFP4 || !has_lora(gate));
1303
0
    GGML_ASSERT(!down_s || !down || down->type != GGML_TYPE_NVFP4 || !has_lora(down));
1304
1305
0
    ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
1306
0
    cb(tmp, "ffn_up", il);
1307
1308
0
    if (up_b) {
1309
0
        tmp = ggml_add(ctx0, tmp, up_b);
1310
0
        cb(tmp, "ffn_up_b", il);
1311
0
    }
1312
1313
0
    if (up_s) {
1314
0
        tmp = ggml_mul(ctx0, tmp, up_s);
1315
0
        cb(tmp, "ffn_up_s", il);
1316
0
    }
1317
1318
0
    if (gate) {
1319
0
        switch (type_gate) {
1320
0
            case LLM_FFN_SEQ:
1321
0
                {
1322
0
                    cur = build_lora_mm(gate, tmp);
1323
0
                    cb(cur, "ffn_gate", il);
1324
0
                } break;
1325
0
            case LLM_FFN_PAR:
1326
0
                {
1327
0
                    cur = build_lora_mm(gate, cur);
1328
0
                    cb(cur, "ffn_gate", il);
1329
0
                } break;
1330
0
        }
1331
1332
0
        if (gate_b) {
1333
0
            cur = ggml_add(ctx0, cur, gate_b);
1334
0
            cb(cur, "ffn_gate_b", il);
1335
0
        }
1336
1337
0
        if (gate_s) {
1338
0
            cur = ggml_mul(ctx0, cur, gate_s);
1339
0
            cb(cur, "ffn_gate_s", il);
1340
0
        }
1341
1342
0
    } else {
1343
0
        cur = tmp;
1344
0
    }
1345
1346
0
    switch (type_op) {
1347
0
        case LLM_FFN_SILU:
1348
0
            if (gate && type_gate == LLM_FFN_PAR) {
1349
                // Step35: HF clamps gate (after SiLU) and up before multiplication
1350
0
                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1351
0
                    const float limit = hparams.swiglu_clamp_shexp[il];
1352
0
                    constexpr float eps = 1e-6f;
1353
0
                    if (limit > eps) {
1354
0
                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1355
0
                        cb(gate_act, "ffn_silu", il);
1356
0
                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1357
0
                        cb(gate_act, "ffn_silu_clamped", il);
1358
1359
0
                        tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1360
0
                        cb(tmp, "ffn_up_clamped", il);
1361
1362
0
                        cur = ggml_mul(ctx0, gate_act, tmp);
1363
0
                        cb(cur, "ffn_swiglu_limited", il);
1364
0
                        type_gate = LLM_FFN_SEQ;
1365
0
                        break;
1366
0
                    }
1367
0
                }
1368
1369
0
                cur = ggml_swiglu_split(ctx0, cur, tmp);
1370
0
                cb(cur, "ffn_swiglu", il);
1371
0
                type_gate = LLM_FFN_SEQ;
1372
0
            } else {
1373
0
                cur = ggml_silu(ctx0, cur);
1374
0
                cb(cur, "ffn_silu", il);
1375
0
            } break;
1376
0
        case LLM_FFN_GELU:
1377
0
            if (gate && type_gate == LLM_FFN_PAR) {
1378
0
                cur = ggml_geglu_split(ctx0, cur, tmp);
1379
0
                cb(cur, "ffn_geglu", il);
1380
0
                type_gate = LLM_FFN_SEQ;
1381
0
            } else {
1382
0
                cur = ggml_gelu(ctx0, cur);
1383
0
                cb(cur, "ffn_gelu", il);
1384
0
                if (act_scales != NULL) {
1385
0
                    cur = ggml_div(ctx0, cur, act_scales);
1386
0
                    cb(cur, "ffn_act", il);
1387
0
                }
1388
0
            } break;
1389
0
        case LLM_FFN_RELU:
1390
0
            if (gate && type_gate == LLM_FFN_PAR) {
1391
0
                cur = ggml_reglu_split(ctx0, cur, tmp);
1392
0
                cb(cur, "ffn_reglu", il);
1393
0
                type_gate = LLM_FFN_SEQ;
1394
0
            } else {
1395
0
                cur = ggml_relu(ctx0, cur);
1396
0
                cb(cur, "ffn_relu", il);
1397
0
            } break;
1398
0
        case LLM_FFN_RELU_SQR:
1399
0
            {
1400
0
                cur = ggml_relu(ctx0, cur);
1401
0
                cb(cur, "ffn_relu", il);
1402
1403
0
                cur = ggml_sqr(ctx0, cur);
1404
0
                cb(cur, "ffn_sqr(relu)", il);
1405
0
            } break;
1406
0
        case LLM_FFN_SWIGLU:
1407
0
            {
1408
0
                cur = ggml_swiglu(ctx0, cur);
1409
0
                cb(cur, "ffn_swiglu", il);
1410
0
            } break;
1411
0
        case LLM_FFN_GEGLU:
1412
0
            {
1413
0
                cur = ggml_geglu(ctx0, cur);
1414
0
                cb(cur, "ffn_geglu", il);
1415
0
            } break;
1416
0
        case LLM_FFN_REGLU:
1417
0
            {
1418
0
                cur = ggml_reglu(ctx0, cur);
1419
0
                cb(cur, "ffn_reglu", il);
1420
0
            } break;
1421
0
        default:
1422
0
            GGML_ABORT("fatal error");
1423
0
    }
1424
1425
0
    if (gate && type_gate == LLM_FFN_PAR) {
1426
0
        cur = ggml_mul(ctx0, cur, tmp);
1427
0
        cb(cur, "ffn_gate_par", il);
1428
0
    }
1429
1430
0
    if (down) {
1431
0
        cur = build_lora_mm(down, cur);
1432
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1433
            // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
1434
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1435
0
        }
1436
0
    }
1437
1438
0
    if (down_b) {
1439
0
        cb(cur, "ffn_down", il);
1440
0
    }
1441
1442
0
    if (down_b) {
1443
0
        cur = ggml_add(ctx0, cur, down_b);
1444
0
    }
1445
1446
0
    if (down_s) {
1447
0
        cur = ggml_mul(ctx0, cur, down_s);
1448
0
        cb(cur, "ffn_down_s", il);
1449
0
    }
1450
1451
0
    return cur;
1452
0
}
1453
1454
ggml_tensor * llm_graph_context::build_moe_ffn(
1455
         ggml_tensor * cur,
1456
         ggml_tensor * gate_inp,
1457
         ggml_tensor * up_exps,
1458
         ggml_tensor * gate_exps,
1459
         ggml_tensor * down_exps,
1460
         ggml_tensor * exp_probs_b,
1461
             int64_t   n_expert,
1462
             int64_t   n_expert_used,
1463
     llm_ffn_op_type   type_op,
1464
                bool   norm_w,
1465
               float   w_scale,
1466
         llama_expert_gating_func_type gating_op,
1467
                 int   il,
1468
         ggml_tensor * probs_in,
1469
         ggml_tensor * gate_up_exps,
1470
         ggml_tensor * up_exps_s,
1471
         ggml_tensor * gate_exps_s,
1472
0
         ggml_tensor * down_exps_s) const {
1473
0
    return build_moe_ffn(
1474
0
        cur,
1475
0
        gate_inp,  /* gate_inp_b  */ nullptr,
1476
0
        up_exps,   /* up_exps_b   */ nullptr,
1477
0
        gate_exps, /* gate_exps_b */ nullptr,
1478
0
        down_exps, /* down_exps_b */ nullptr,
1479
0
        exp_probs_b,
1480
0
        n_expert,
1481
0
        n_expert_used,
1482
0
        type_op,
1483
0
        norm_w,
1484
0
        w_scale,
1485
0
        gating_op,
1486
0
        il,
1487
0
        probs_in,
1488
0
        gate_up_exps,
1489
0
        /* gate_up_exps_b */ nullptr,
1490
0
        up_exps_s,
1491
0
        gate_exps_s,
1492
0
        down_exps_s
1493
0
    );
1494
0
}
1495
1496
ggml_tensor * llm_graph_context::build_moe_ffn(
1497
         ggml_tensor * cur,
1498
         ggml_tensor * gate_inp,
1499
         ggml_tensor * gate_inp_b,
1500
         ggml_tensor * up_exps,
1501
         ggml_tensor * up_exps_b,
1502
         ggml_tensor * gate_exps,
1503
         ggml_tensor * gate_exps_b,
1504
         ggml_tensor * down_exps,
1505
         ggml_tensor * down_exps_b,
1506
         ggml_tensor * exp_probs_b,
1507
             int64_t   n_expert,
1508
             int64_t   n_expert_used,
1509
     llm_ffn_op_type   type_op,
1510
                bool   norm_w,
1511
               float   w_scale,
1512
        llama_expert_gating_func_type gating_op,
1513
                 int   il,
1514
         ggml_tensor * probs_in,
1515
         ggml_tensor * gate_up_exps,
1516
         ggml_tensor * gate_up_exps_b,
1517
         ggml_tensor * up_exps_s,
1518
         ggml_tensor * gate_exps_s,
1519
0
         ggml_tensor * down_exps_s) const {
1520
0
    const int64_t n_embd   = cur->ne[0];
1521
0
    const int64_t n_tokens = cur->ne[1];
1522
0
    const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1523
1524
0
    ggml_tensor * logits = nullptr;
1525
1526
0
    if (probs_in == nullptr) {
1527
0
        logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1528
0
        cb(logits, "ffn_moe_logits", il);
1529
0
    } else {
1530
0
        logits = probs_in;
1531
0
    }
1532
1533
0
    if (gate_inp_b) {
1534
0
        logits = ggml_add(ctx0, logits, gate_inp_b);
1535
0
        cb(logits, "ffn_moe_logits_biased", il);
1536
0
    }
1537
1538
0
    ggml_tensor * probs = nullptr;
1539
0
    switch (gating_op) {
1540
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1541
0
            {
1542
0
                probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1543
0
            } break;
1544
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1545
0
            {
1546
0
                probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1547
0
            } break;
1548
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1549
0
            {
1550
0
                probs = logits; // [n_expert, n_tokens]
1551
0
            } break;
1552
0
        default:
1553
0
            GGML_ABORT("fatal error");
1554
0
    }
1555
0
    cb(probs, "ffn_moe_probs", il);
1556
1557
    // add experts selection bias - introduced in DeepSeek V3
1558
    // leave probs unbiased as it's later used to get expert weights
1559
0
    ggml_tensor * selection_probs = probs;
1560
0
    if (exp_probs_b != nullptr) {
1561
0
        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1562
0
        cb(selection_probs, "ffn_moe_probs_biased", il);
1563
0
    }
1564
1565
    // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1566
    // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1567
0
    if (arch == LLM_ARCH_LLAMA4) {
1568
0
        selection_probs = logits;
1569
0
    }
1570
1571
0
    if (arch == LLM_ARCH_GROVEMOE) {
1572
0
        selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1573
0
        cb(selection_probs, "ffn_moe_probs_biased", il);
1574
0
    }
1575
1576
    // select top n_group_used expert groups
1577
    // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1578
0
    if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1579
0
        const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1580
1581
        // organize experts into n_expert_groups
1582
0
        ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1583
1584
0
        ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1585
0
        group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1586
1587
        // get top n_group_used expert groups
1588
0
        group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1589
0
        group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1590
1591
0
        ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1592
0
        cb(expert_groups, "ffn_moe_group_topk", il);
1593
1594
        // mask out the other groups
1595
0
        selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1596
0
        selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1597
0
        selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1598
0
        cb(selection_probs, "ffn_moe_probs_masked", il);
1599
0
    }
1600
1601
    // select experts
1602
0
    ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1603
0
    cb(selected_experts->src[0], "ffn_moe_argsort", il);
1604
0
    cb(selected_experts, "ffn_moe_topk", il);
1605
1606
0
    if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1607
        // TODO: Use scalar div instead when/if implemented
1608
0
        ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1609
0
        selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1610
0
        probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1611
0
    } else {
1612
0
        probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1613
0
    }
1614
1615
0
    ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1616
0
    cb(weights, "ffn_moe_weights", il);
1617
1618
1619
0
    if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1620
0
        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1621
0
        weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1622
0
        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1623
0
        cb(weights, "ffn_moe_weights_softmax", il);
1624
0
    }
1625
1626
0
    if (norm_w) {
1627
0
        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1628
1629
0
        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1630
0
        cb(weights_sum, "ffn_moe_weights_sum", il);
1631
1632
        // Avoid division by zero, clamp to smallest number representable by F16
1633
0
        weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1634
0
        cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1635
1636
0
        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1637
0
        cb(weights, "ffn_moe_weights_norm", il);
1638
1639
0
        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1640
0
    }
1641
0
    if (w_scale != 0.0f && w_scale != 1.0f) {
1642
0
        weights = ggml_scale(ctx0, weights, w_scale);
1643
0
        cb(weights, "ffn_moe_weights_scaled", il);
1644
0
    }
1645
1646
    //call early so that topk-moe can be used
1647
0
    ggml_build_forward_expand(gf, weights);
1648
1649
0
    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1650
1651
0
    if (weight_before_ffn) {
1652
        // repeat cur to [n_embd, n_expert_used, n_tokens]
1653
0
        ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1654
0
        cur = ggml_mul(ctx0, repeated, weights);
1655
0
        cb(cur, "ffn_moe_weighted", il);
1656
0
    }
1657
1658
0
    ggml_tensor * up = nullptr;
1659
0
    ggml_tensor * experts = nullptr;
1660
1661
0
    if (gate_up_exps) {
1662
        // merged gate_up path: one mul_mat_id, then split into gate and up views
1663
0
        ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts, up_exps_s); // [n_ff*2, n_expert_used, n_tokens]
1664
0
        cb(gate_up, "ffn_moe_gate_up", il);
1665
1666
0
        if (up_exps_s) {
1667
0
            cb(gate_up, "ffn_moe_gate_up_scaled", il);
1668
0
        }
1669
1670
0
        if (gate_up_exps_b) {
1671
0
            gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1672
0
            cb(gate_up, "ffn_moe_gate_up_biased", il);
1673
0
        }
1674
1675
0
        const int64_t n_ff = gate_up->ne[0] / 2;
1676
0
        cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
1677
0
        cb(cur, "ffn_moe_gate", il);
1678
0
        up  = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1679
0
        cb(up, "ffn_moe_up", il);
1680
0
    } else {
1681
        // separate gate and up path
1682
0
        up = build_lora_mm_id(up_exps, cur, selected_experts, up_exps_s); // [n_ff, n_expert_used, n_tokens]
1683
0
        cb(up, "ffn_moe_up", il);
1684
1685
0
        if (up_exps_s) {
1686
0
            cb(up, "ffn_moe_up_scaled", il);
1687
0
        }
1688
1689
0
        if (up_exps_b) {
1690
0
            up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1691
0
            cb(up, "ffn_moe_up_biased", il);
1692
0
        }
1693
1694
0
        if (gate_exps) {
1695
0
            cur = build_lora_mm_id(gate_exps, cur, selected_experts, gate_exps_s); // [n_ff, n_expert_used, n_tokens]
1696
0
            cb(cur, "ffn_moe_gate", il);
1697
0
        } else {
1698
0
            cur = up;
1699
0
        }
1700
1701
0
        if (gate_exps_s) {
1702
0
            cb(cur, "ffn_moe_gate_scaled", il);
1703
0
        }
1704
1705
0
        if (gate_exps_b) {
1706
0
            cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1707
0
            cb(cur, "ffn_moe_gate_biased", il);
1708
0
        }
1709
0
    }
1710
1711
0
    const bool has_gate = gate_exps || gate_up_exps;
1712
1713
0
    switch (type_op) {
1714
0
        case LLM_FFN_SILU:
1715
0
            if (gate_exps) {
1716
                // Step35: per-layer clamp for routed experts
1717
0
                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1718
0
                    const float limit = hparams.swiglu_clamp_exp[il];
1719
0
                    constexpr float eps = 1e-6f;
1720
0
                    if (limit > eps) {
1721
0
                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1722
0
                        cb(gate_act, "ffn_moe_silu", il);
1723
0
                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1724
0
                        cb(gate_act, "ffn_moe_silu_clamped", il);
1725
1726
0
                        up = ggml_clamp(ctx0, up, -limit, limit);
1727
0
                        cb(up, "ffn_moe_up_clamped", il);
1728
1729
0
                        cur = ggml_mul(ctx0, gate_act, up);
1730
0
                        cb(cur, "ffn_moe_swiglu_limited", il);
1731
0
                        break;
1732
0
                    }
1733
0
                }
1734
0
            }
1735
1736
0
            if (has_gate) {
1737
0
                cur = ggml_swiglu_split(ctx0, cur, up);
1738
0
                cb(cur, "ffn_moe_swiglu", il);
1739
0
            } else {
1740
0
                cur = ggml_silu(ctx0, cur);
1741
0
                cb(cur, "ffn_moe_silu", il);
1742
0
            } break;
1743
0
        case LLM_FFN_GELU:
1744
0
            if (has_gate) {
1745
0
                cur = ggml_geglu_split(ctx0, cur, up);
1746
0
                cb(cur, "ffn_moe_geglu", il);
1747
0
            } else {
1748
0
                cur = ggml_gelu(ctx0, cur);
1749
0
                cb(cur, "ffn_moe_gelu", il);
1750
0
            } break;
1751
0
        case LLM_FFN_SWIGLU_OAI_MOE:
1752
0
            {
1753
                // TODO: move to hparams?
1754
0
                constexpr float alpha = 1.702f;
1755
0
                constexpr float limit = 7.0f;
1756
0
                cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1757
0
                cb(cur, "ffn_moe_swiglu_oai", il);
1758
0
            } break;
1759
0
        case LLM_FFN_RELU:
1760
0
            if (has_gate) {
1761
0
                cur = ggml_reglu_split(ctx0, cur, up);
1762
0
                cb(cur, "ffn_moe_reglu", il);
1763
0
            } else {
1764
0
                cur = ggml_relu(ctx0, cur);
1765
0
                cb(cur, "ffn_moe_relu", il);
1766
0
            } break;
1767
0
        case LLM_FFN_RELU_SQR:
1768
0
            if (has_gate) {
1769
                // TODO: add support for gated squared relu
1770
0
                GGML_ABORT("fatal error: gated squared relu not implemented");
1771
0
            } else {
1772
0
                cur = ggml_relu(ctx0, cur);
1773
0
                cur = ggml_sqr(ctx0, cur);
1774
0
                cb(cur, "ffn_moe_relu_sqr", il);
1775
0
            } break;
1776
0
        default:
1777
0
            GGML_ABORT("fatal error");
1778
0
    }
1779
1780
0
    experts = build_lora_mm_id(down_exps, cur, selected_experts, down_exps_s); // [n_embd, n_expert_used, n_tokens]
1781
0
    cb(experts, "ffn_moe_down", il);
1782
1783
0
    if (down_exps_s) {
1784
0
        cb(experts, "ffn_moe_down_scaled", il);
1785
0
    }
1786
1787
0
    if (down_exps_b) {
1788
0
        experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1789
0
        cb(experts, "ffn_moe_down_biased", il);
1790
0
    }
1791
1792
0
    if (!weight_before_ffn) {
1793
0
        experts = ggml_mul(ctx0, experts, weights);
1794
0
        cb(experts, "ffn_moe_weighted", il);
1795
0
    }
1796
1797
0
    ggml_build_forward_expand(gf, experts);
1798
1799
0
    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1800
1801
0
    assert(n_expert_used > 0);
1802
1803
    // order the views before the adds
1804
0
    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1805
0
        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1806
1807
0
        ggml_build_forward_expand(gf, cur_experts[i]);
1808
0
    }
1809
1810
    // aggregate experts
1811
    // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1812
    //       to avoid potentially a large number of add nodes during warmup
1813
    //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
1814
0
    ggml_tensor * moe_out = cur_experts[0];
1815
1816
0
    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1817
0
        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1818
1819
0
        ggml_build_forward_expand(gf, moe_out);
1820
0
    }
1821
1822
0
    if (hparams.n_expert_used == 1) {
1823
        // avoid returning a non-contiguous tensor
1824
0
        moe_out = ggml_cont(ctx0, moe_out);
1825
0
    }
1826
1827
0
    cb(moe_out, "ffn_moe_out", il);
1828
1829
0
    return moe_out;
1830
0
}
1831
1832
// input embeddings with optional lora
1833
0
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1834
0
    const int64_t n_embd_inp = hparams.n_embd_inp();
1835
0
    const int64_t n_embd     = hparams.n_embd;
1836
1837
0
    assert(n_embd_inp >= n_embd);
1838
1839
0
    auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1840
1841
0
    inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1842
0
    cb(inp->tokens, "inp_tokens", -1);
1843
0
    ggml_set_input(inp->tokens);
1844
0
    res->t_inp_tokens = inp->tokens;
1845
1846
0
    inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1847
0
    cb(inp->embd, "inp_embd", -1);
1848
0
    ggml_set_input(inp->embd);
1849
1850
    // select one of the 2 inputs, based on the batch contents
1851
    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1852
0
    std::array<ggml_tensor *, 2> inps;
1853
1854
    // token embeddings path (ubatch.token != nullptr)
1855
0
    {
1856
0
        auto & cur = inps[0];
1857
1858
0
        cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1859
1860
        // apply lora for embedding tokens if needed
1861
0
        for (const auto & lora : *loras) {
1862
0
            llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1863
0
            if (lw == nullptr) {
1864
0
                continue;
1865
0
            }
1866
1867
0
            const float adapter_scale = lora.second;
1868
0
            const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1869
1870
0
            ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1871
0
                        ctx0, lw->b, // non-transposed lora_b
1872
0
                        ggml_get_rows(ctx0, lw->a, inp->tokens)
1873
0
                        ), scale);
1874
1875
0
            cur = ggml_add(ctx0, cur, inpL_delta);
1876
0
        }
1877
1878
0
        if (n_embd_inp != n_embd) {
1879
0
            cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1880
0
        }
1881
0
    }
1882
1883
    // vector embeddings path (ubatch.embd != nullptr)
1884
0
    {
1885
0
        auto & cur = inps[1];
1886
1887
0
        cur = inp->embd;
1888
0
    }
1889
1890
0
    assert(ggml_are_same_shape (inps[0], inps[1]));
1891
0
    assert(ggml_are_same_stride(inps[0], inps[1]));
1892
1893
0
    ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1894
1895
0
    if (n_embd_inp != n_embd) {
1896
0
        cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1897
0
    }
1898
1899
0
    res->t_inp_embd = cur;
1900
1901
    // For Granite architecture
1902
    // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input).
1903
    //  Raw embeddings are assumed to be multimodal inputs that should not be scaled.
1904
0
    if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) {
1905
0
        if (!ggml_is_contiguous(cur)) {
1906
0
            cur = ggml_cont(ctx0, cur);
1907
0
        }
1908
0
        cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1909
0
    }
1910
1911
0
    cb(cur, "embd", -1);
1912
1913
0
    res->add_input(std::move(inp));
1914
1915
    // make sure the produced embeddings are immediately materialized in the ggml graph
1916
    // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1917
0
    ggml_build_forward_expand(gf, cur);
1918
1919
0
    return cur;
1920
0
}
1921
1922
0
ggml_tensor * llm_graph_context::build_inp_pos() const {
1923
0
    auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1924
1925
0
    auto & cur = inp->pos;
1926
1927
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1928
0
    ggml_set_input(cur);
1929
1930
0
    res->add_input(std::move(inp));
1931
1932
0
    return cur;
1933
0
}
1934
1935
0
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1936
0
    auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1937
1938
0
    auto & cur = inp->attn_scale;
1939
1940
    // this need to be 1x1xN for broadcasting
1941
0
    cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1942
0
    ggml_set_input(cur);
1943
0
    ggml_set_name(cur, "attn_scale");
1944
1945
0
    res->add_input(std::move(inp));
1946
1947
0
    return cur;
1948
0
}
1949
1950
0
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1951
    // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1952
    //       but this would make the graph topology depend on the number of output tokens, which can interfere with
1953
    //       features that require constant topology such as pipeline parallelism
1954
    //       ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1955
    //if (n_outputs < n_tokens) {
1956
    //    return nullptr;
1957
    //}
1958
1959
0
    auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1960
1961
0
    auto & cur = inp->out_ids;
1962
1963
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1964
0
    ggml_set_input(cur);
1965
1966
0
    res->add_input(std::move(inp));
1967
1968
0
    return cur;
1969
0
}
1970
1971
0
ggml_tensor * llm_graph_context::build_inp_mean() const {
1972
0
    auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1973
1974
0
    auto & cur = inp->mean;
1975
1976
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1977
0
    ggml_set_input(cur);
1978
1979
0
    res->add_input(std::move(inp));
1980
1981
0
    return cur;
1982
0
}
1983
1984
0
ggml_tensor * llm_graph_context::build_inp_cls() const {
1985
0
    auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1986
1987
0
    auto & cur = inp->cls;
1988
1989
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1990
0
    ggml_set_input(cur);
1991
1992
0
    res->add_input(std::move(inp));
1993
1994
0
    return cur;
1995
0
}
1996
1997
0
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1998
0
    auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1999
2000
0
    auto & cur = inp->cross_embd;
2001
2002
    // if we have the output embeddings from the encoder, use them directly
2003
    // TODO: needs more work to be correct, for now just use the tensor shape
2004
    //if (cross->t_embd) {
2005
    //    cur = ggml_view_tensor(ctx0, cross->t_embd);
2006
2007
    //    return cur;
2008
    //}
2009
2010
0
    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
2011
0
    const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc  : hparams.n_ctx_train;
2012
2013
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
2014
0
    ggml_set_input(cur);
2015
2016
0
    res->add_input(std::move(inp));
2017
2018
0
    return cur;
2019
0
}
2020
2021
0
ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
2022
0
    auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
2023
2024
0
    auto & cur = inp->pos_bucket;
2025
2026
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
2027
0
    ggml_set_input(cur);
2028
2029
0
    res->add_input(std::move(inp));
2030
2031
0
    return cur;
2032
0
}
2033
2034
0
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
2035
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2036
2037
0
    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
2038
2039
0
    const auto n_kv = mctx_cur->get_n_kv();
2040
2041
0
    auto & cur = inp->pos_bucket;
2042
2043
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
2044
0
    ggml_set_input(cur);
2045
2046
0
    res->add_input(std::move(inp));
2047
2048
0
    return cur;
2049
0
}
2050
2051
0
ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
2052
0
    ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
2053
0
    cb(pos_bucket_1d, "pos_bucket_1d", -1);
2054
2055
0
    ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
2056
2057
0
    pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
2058
0
    pos_bias = ggml_permute   (ctx0, pos_bias, 2, 0, 1, 3);
2059
0
    pos_bias = ggml_cont      (ctx0, pos_bias);
2060
2061
0
    cb(pos_bias, "pos_bias", -1);
2062
2063
0
    return pos_bias;
2064
0
}
2065
2066
ggml_tensor * llm_graph_context::build_attn_mha(
2067
         ggml_tensor * q,
2068
         ggml_tensor * k,
2069
         ggml_tensor * v,
2070
         ggml_tensor * kq_b,
2071
         ggml_tensor * kq_mask,
2072
         ggml_tensor * sinks,
2073
         ggml_tensor * v_mla,
2074
               float   kq_scale,
2075
0
                 int   il) const {
2076
0
    const bool v_trans = v->nb[1] > v->nb[2];
2077
2078
    // split the batch into streams if needed
2079
0
    const auto n_stream = k->ne[3];
2080
2081
0
    q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
2082
2083
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3);
2084
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3);
2085
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3);
2086
2087
0
    ggml_tensor * cur;
2088
2089
0
    const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
2090
0
    if (use_flash_attn) {
2091
0
        GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
2092
2093
0
        if (v_trans) {
2094
0
            v = ggml_transpose(ctx0, v);
2095
0
        }
2096
2097
        // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
2098
0
        if (k->type == GGML_TYPE_F32) {
2099
0
            k = ggml_cast(ctx0, k, GGML_TYPE_F16);
2100
0
        }
2101
2102
0
        if (v->type == GGML_TYPE_F32) {
2103
0
            v = ggml_cast(ctx0, v, GGML_TYPE_F16);
2104
0
        }
2105
2106
0
        cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
2107
0
                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
2108
0
        cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
2109
2110
0
        ggml_flash_attn_ext_add_sinks(cur, sinks);
2111
0
        ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
2112
2113
0
        if (v_mla) {
2114
#if 0
2115
            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
2116
            // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
2117
            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
2118
            cur = ggml_mul_mat(ctx0, v_mla, cur);
2119
#else
2120
            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
2121
            // The permutations are noops and only change how the tensor data is interpreted.
2122
0
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
2123
0
            cur = ggml_mul_mat(ctx0, v_mla, cur);
2124
0
            cb(cur, "fattn_mla", il);
2125
0
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
2126
0
            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
2127
0
#endif
2128
0
        }
2129
2130
0
        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2131
0
    } else {
2132
0
        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
2133
0
        cb(kq, "kq", il);
2134
2135
        // note: this op tends to require high floating point range
2136
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
2137
0
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
2138
2139
0
        if (arch == LLM_ARCH_GROK) {
2140
            // need to do the following:
2141
            // multiply by attn_output_multiplier
2142
            // and then :
2143
            // kq = 30 * tanh(kq / 30)
2144
            // before the softmax below
2145
2146
0
            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
2147
0
            cb(kq, "kq_tanh", il);
2148
0
            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
2149
0
            cb(kq, "kq_scaled", il);
2150
0
        }
2151
2152
0
        if (hparams.attn_soft_cap) {
2153
0
            kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
2154
0
            cb(kq, "kq_scaled_1", il);
2155
0
            kq = ggml_tanh (ctx0, kq);
2156
0
            cb(kq, "kq_tanh", il);
2157
0
            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
2158
0
            cb(kq, "kq_scaled_2", il);
2159
0
        }
2160
2161
0
        if (kq_b) {
2162
0
            kq = ggml_add(ctx0, kq, kq_b);
2163
0
            cb(kq, "kq_plus_kq_b", il);
2164
0
        }
2165
2166
0
        kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
2167
0
        ggml_soft_max_add_sinks(kq, sinks);
2168
0
        cb(kq, "kq_soft_max", il);
2169
2170
0
        if (!v_trans) {
2171
            // note: avoid this branch
2172
0
            v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
2173
0
            cb(v, "v_cont", il);
2174
0
        }
2175
2176
0
        ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
2177
0
        cb(kqv, "kqv", il);
2178
2179
        // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
2180
0
        if (v_mla) {
2181
0
            kqv = ggml_mul_mat(ctx0, v_mla, kqv);
2182
0
            cb(kqv, "kqv_mla", il);
2183
0
        }
2184
2185
0
        cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
2186
2187
        // recombine streams
2188
0
        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2189
2190
0
        if (!cparams.offload_kqv) {
2191
            // all nodes between the KV store and the attention output are run on the CPU
2192
0
            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
2193
0
        }
2194
0
    }
2195
2196
0
    ggml_build_forward_expand(gf, cur);
2197
2198
0
    return cur;
2199
0
}
2200
2201
0
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
2202
0
    auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
2203
2204
    // flash attention requires an f16 mask
2205
0
    const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
2206
2207
    // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
2208
0
    inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
2209
0
    ggml_set_input(inp->self_kq_mask);
2210
2211
0
    inp->self_kq_mask_cnv = inp->self_kq_mask;
2212
2213
0
    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
2214
0
        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
2215
0
        ggml_set_input(inp->self_kq_mask_swa);
2216
2217
0
        inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
2218
0
    } else {
2219
0
        inp->self_kq_mask_swa     = nullptr;
2220
0
        inp->self_kq_mask_swa_cnv = nullptr;
2221
0
    }
2222
2223
0
    return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
2224
0
}
2225
2226
ggml_tensor * llm_graph_context::build_attn(
2227
        llm_graph_input_attn_no_cache * inp,
2228
        ggml_tensor * wo,
2229
        ggml_tensor * wo_b,
2230
        ggml_tensor * wo_s,
2231
        ggml_tensor * q_cur,
2232
        ggml_tensor * k_cur,
2233
        ggml_tensor * v_cur,
2234
        ggml_tensor * kq_b,
2235
        ggml_tensor * sinks,
2236
        ggml_tensor * v_mla,
2237
            float     kq_scale,
2238
0
            int       il) const {
2239
0
    GGML_UNUSED(n_tokens);
2240
2241
    // these nodes are added to the graph together so that they are not reordered
2242
    // by doing so, the number of splits in the graph is reduced
2243
0
    ggml_build_forward_expand(gf, q_cur);
2244
0
    ggml_build_forward_expand(gf, k_cur);
2245
0
    ggml_build_forward_expand(gf, v_cur);
2246
2247
0
    const bool is_swa = hparams.is_swa(il);
2248
2249
0
    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2250
2251
    // [TAG_NO_CACHE_PAD]
2252
    // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
2253
    //       but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
2254
    //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
2255
2256
0
    ggml_tensor * q = q_cur;
2257
0
    ggml_tensor * k = k_cur;
2258
0
    ggml_tensor * v = v_cur;
2259
2260
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2261
0
    cb(cur, "kqv_out", il);
2262
2263
0
    if (wo) {
2264
0
        cur = build_lora_mm(wo, cur, wo_s);
2265
0
    }
2266
2267
0
    if (wo_b) {
2268
        //cb(cur, "kqv_wo", il);
2269
0
    }
2270
2271
0
    if (wo_b) {
2272
0
        cur = ggml_add(ctx0, cur, wo_b);
2273
0
    }
2274
2275
0
    return cur;
2276
0
}
2277
2278
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
2279
           ggml_context * ctx0,
2280
     const llama_ubatch & ubatch,
2281
    const llama_hparams & hparams,
2282
    const llama_cparams & cparams,
2283
0
    const llama_kv_cache_context * mctx_cur) {
2284
2285
0
    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
2286
2287
0
    {
2288
0
        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2289
2290
0
        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2291
0
        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
2292
2293
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2294
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2295
0
    }
2296
2297
0
    inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
2298
0
    inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
2299
2300
0
    return inp;
2301
0
}
2302
2303
0
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
2304
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2305
2306
0
    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2307
2308
0
    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
2309
0
}
2310
2311
ggml_tensor * llm_graph_context::build_attn(
2312
        llm_graph_input_attn_kv * inp,
2313
        ggml_tensor * wo,
2314
        ggml_tensor * wo_b,
2315
        ggml_tensor * wo_s,
2316
        ggml_tensor * q_cur,
2317
        ggml_tensor * k_cur,
2318
        ggml_tensor * v_cur,
2319
        ggml_tensor * kq_b,
2320
        ggml_tensor * sinks,
2321
        ggml_tensor * v_mla, // TODO: remove
2322
            float     kq_scale,
2323
0
            int       il) const {
2324
0
    GGML_ASSERT(v_mla == nullptr);
2325
2326
0
    if (inp->self_k_rot) {
2327
0
        q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2328
0
        k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2329
0
    }
2330
2331
0
    if (inp->self_v_rot) {
2332
0
        v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2333
0
    }
2334
2335
    // these nodes are added to the graph together so that they are not reordered
2336
    // by doing so, the number of splits in the graph is reduced
2337
    // expand k later to enable rope fusion which directly writes into k-v cache
2338
0
    ggml_build_forward_expand(gf, q_cur);
2339
0
    ggml_build_forward_expand(gf, v_cur);
2340
0
    ggml_build_forward_expand(gf, k_cur);
2341
2342
0
    const auto * mctx_cur = inp->mctx;
2343
2344
    // store to KV cache
2345
0
    {
2346
0
        const auto & k_idxs = inp->get_k_idxs();
2347
0
        const auto & v_idxs = inp->get_v_idxs();
2348
2349
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2350
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2351
0
    }
2352
2353
0
    const auto & kq_mask = inp->get_kq_mask();
2354
2355
0
    ggml_tensor * q = q_cur;
2356
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2357
0
    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2358
2359
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2360
0
    cb(cur, "kqv_out", il);
2361
2362
0
    if (inp->self_v_rot) {
2363
0
        cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2364
0
    }
2365
2366
0
    if (wo) {
2367
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2368
            // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2369
0
            cur = build_lora_mm(wo, cur);
2370
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2371
0
            if (wo_s) {
2372
0
                cur = ggml_mul(ctx0, cur, wo_s);
2373
0
            }
2374
0
        } else {
2375
0
            cur = build_lora_mm(wo, cur, wo_s);
2376
0
        }
2377
0
    }
2378
2379
0
    if (wo_b) {
2380
0
        cur = ggml_add(ctx0, cur, wo_b);
2381
0
    }
2382
2383
0
    return cur;
2384
0
}
2385
2386
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2387
           ggml_context * ctx0,
2388
     const llama_ubatch & ubatch,
2389
    const llama_hparams & hparams,
2390
    const llama_cparams & cparams,
2391
0
    const llama_kv_cache_context * mctx_cur) {
2392
2393
0
    auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2394
2395
0
    {
2396
0
        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2397
2398
0
        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2399
2400
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2401
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2402
0
    }
2403
2404
0
    return inp;
2405
0
}
2406
2407
0
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2408
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2409
2410
0
    auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2411
2412
0
    return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2413
0
}
2414
2415
ggml_tensor * llm_graph_context::build_attn(
2416
        llm_graph_input_attn_k * inp,
2417
        ggml_tensor * wo,
2418
        ggml_tensor * wo_b,
2419
        ggml_tensor * wo_s,
2420
        ggml_tensor * q_cur,
2421
        ggml_tensor * k_cur,
2422
        ggml_tensor * v_cur,
2423
        ggml_tensor * kq_b,
2424
        ggml_tensor * sinks,
2425
        ggml_tensor * v_mla,
2426
            float     kq_scale,
2427
0
            int       il) const {
2428
    // these nodes are added to the graph together so that they are not reordered
2429
    // by doing so, the number of splits in the graph is reduced
2430
    // expand k later to enable rope fusion which directly writes into k-v cache
2431
0
    ggml_build_forward_expand(gf, q_cur);
2432
0
    ggml_build_forward_expand(gf, v_cur);
2433
0
    ggml_build_forward_expand(gf, k_cur);
2434
2435
0
    const auto * mctx_cur = inp->mctx;
2436
2437
    // store to KV cache
2438
0
    {
2439
0
        const auto & k_idxs = inp->get_k_idxs();
2440
2441
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2442
0
    }
2443
2444
0
    const auto & kq_mask = inp->get_kq_mask();
2445
2446
0
    ggml_tensor * q = q_cur;
2447
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2448
0
    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2449
2450
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2451
0
    cb(cur, "kqv_out", il);
2452
2453
0
    if (wo) {
2454
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2455
            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2456
0
            cur = build_lora_mm(wo, cur);
2457
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2458
0
            if (wo_s) {
2459
0
                cur = ggml_mul(ctx0, cur, wo_s);
2460
0
            }
2461
0
        } else {
2462
0
            cur = build_lora_mm(wo, cur, wo_s);
2463
0
        }
2464
0
    }
2465
2466
0
    if (wo_b) {
2467
0
        cur = ggml_add(ctx0, cur, wo_b);
2468
0
    }
2469
2470
0
    return cur;
2471
0
}
2472
2473
ggml_tensor * llm_graph_context::build_attn(
2474
        llm_graph_input_attn_k_dsa * inp,
2475
        ggml_tensor * wo,
2476
        ggml_tensor * wo_b,
2477
        ggml_tensor * wo_s,
2478
        ggml_tensor * q_cur,
2479
        ggml_tensor * k_cur,
2480
        ggml_tensor * v_cur,
2481
        ggml_tensor * kq_b,
2482
        ggml_tensor * sinks,
2483
        ggml_tensor * v_mla,
2484
        ggml_tensor * top_k,
2485
            float     kq_scale,
2486
0
            int       il) const {
2487
    // these nodes are added to the graph together so that they are not reordered
2488
    // by doing so, the number of splits in the graph is reduced
2489
    // expand k later to enable rope fusion which directly writes into k-v cache
2490
0
    ggml_build_forward_expand(gf, q_cur);
2491
0
    ggml_build_forward_expand(gf, v_cur);
2492
0
    ggml_build_forward_expand(gf, k_cur);
2493
2494
0
    const auto * mctx_cur = inp->mctx->get_mla();
2495
2496
    // store to KV cache
2497
0
    {
2498
0
        const auto & k_idxs = inp->get_k_idxs_mla();
2499
2500
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2501
0
    }
2502
2503
0
    const auto & kq_mask = inp->get_kq_mask_mla();
2504
2505
    // prepare new kq mask - starts filled with -INFINITY
2506
0
    ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);
2507
2508
    // reshape KQ mask into tensor with rows of size 1:
2509
    // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream]
2510
0
    kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0);
2511
2512
    // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1]
2513
0
    ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0);
2514
2515
    // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream]
2516
    // this will be our source of zero values for unmasking top k mask elements
2517
0
    ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]);
2518
0
    zeros = ggml_fill(ctx0, zeros, 0.0f);
2519
2520
    // modify KQ mask by unmasking elements that are in top_k indices
2521
    // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1])
2522
0
    ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d);
2523
2524
    // reshape to restore the original shape of KQ mask:
2525
    // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream]
2526
0
    kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0);
2527
2528
    // combine with the original kq mask
2529
0
    kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);
2530
2531
0
    ggml_tensor * q = q_cur;
2532
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2533
0
    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2534
2535
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il);
2536
0
    cb(cur, "kqv_out", il);
2537
2538
0
    if (wo) {
2539
0
        cur = build_lora_mm(wo, cur, wo_s);
2540
0
    }
2541
2542
0
    if (wo_b) {
2543
0
        cur = ggml_add(ctx0, cur, wo_b);
2544
0
    }
2545
2546
0
    return cur;
2547
0
}
2548
2549
ggml_tensor * llm_graph_context::build_attn(
2550
        llm_graph_input_attn_kv_iswa * inp,
2551
        ggml_tensor * wo,
2552
        ggml_tensor * wo_b,
2553
        ggml_tensor * wo_s,
2554
        ggml_tensor * q_cur,
2555
        ggml_tensor * k_cur,
2556
        ggml_tensor * v_cur,
2557
        ggml_tensor * kq_b,
2558
        ggml_tensor * sinks,
2559
        ggml_tensor * v_mla,
2560
            float     kq_scale,
2561
0
            int       il) const {
2562
0
    const bool is_swa = hparams.is_swa(il);
2563
2564
0
    auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2565
0
    auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2566
2567
0
    if (k_rot) {
2568
0
        q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
2569
0
        if (k_cur) {
2570
0
            k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
2571
0
        }
2572
0
    }
2573
0
    if (v_rot) {
2574
0
        if (v_cur) {
2575
0
            v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
2576
0
        }
2577
0
    }
2578
2579
    // these nodes are added to the graph together so that they are not reordered
2580
    // by doing so, the number of splits in the graph is reduced
2581
0
    ggml_build_forward_expand(gf, q_cur);
2582
2583
0
    if (k_cur) {
2584
0
        ggml_build_forward_expand(gf, k_cur);
2585
0
    }
2586
2587
0
    if (v_cur) {
2588
0
        ggml_build_forward_expand(gf, v_cur);
2589
0
    }
2590
2591
0
    const auto * mctx_iswa = inp->mctx;
2592
2593
0
    const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2594
2595
    // optionally store to KV cache
2596
0
    if (k_cur) {
2597
0
        const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2598
2599
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2600
0
    }
2601
2602
0
    if (v_cur) {
2603
0
        const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2604
2605
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2606
0
    }
2607
2608
0
    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2609
2610
0
    ggml_tensor * q = q_cur;
2611
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2612
0
    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2613
2614
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2615
0
    cb(cur, "kqv_out", il);
2616
2617
0
    if (v_rot) {
2618
0
        cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
2619
0
    }
2620
2621
0
    if (wo) {
2622
0
        cur = build_lora_mm(wo, cur, wo_s);
2623
0
    }
2624
2625
0
    if (wo_b) {
2626
        //cb(cur, "kqv_wo", il);
2627
0
    }
2628
2629
0
    if (wo_b) {
2630
0
        cur = ggml_add(ctx0, cur, wo_b);
2631
0
    }
2632
2633
0
    return cur;
2634
0
}
2635
2636
0
llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2637
0
    auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2638
2639
0
    const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2640
2641
    // flash attention requires an f16 mask
2642
0
    const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
2643
2644
0
    inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
2645
0
    ggml_set_input(inp->cross_kq_mask);
2646
2647
0
    inp->cross_kq_mask_cnv = inp->cross_kq_mask;
2648
2649
0
    return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2650
0
}
2651
2652
ggml_tensor * llm_graph_context::build_attn(
2653
        llm_graph_input_attn_cross * inp,
2654
        ggml_tensor * wo,
2655
        ggml_tensor * wo_b,
2656
        ggml_tensor * wo_s,
2657
        ggml_tensor * q_cur,
2658
        ggml_tensor * k_cur,
2659
        ggml_tensor * v_cur,
2660
        ggml_tensor * kq_b,
2661
        ggml_tensor * sinks,
2662
        ggml_tensor * v_mla,
2663
            float     kq_scale,
2664
0
            int       il) const {
2665
    // these nodes are added to the graph together so that they are not reordered
2666
    // by doing so, the number of splits in the graph is reduced
2667
0
    ggml_build_forward_expand(gf, q_cur);
2668
0
    ggml_build_forward_expand(gf, k_cur);
2669
0
    ggml_build_forward_expand(gf, v_cur);
2670
2671
0
    const auto & kq_mask = inp->get_kq_mask_cross();
2672
2673
0
    ggml_tensor * q = q_cur;
2674
0
    ggml_tensor * k = k_cur;
2675
0
    ggml_tensor * v = v_cur;
2676
2677
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2678
0
    cb(cur, "kqv_out", il);
2679
2680
0
    if (wo) {
2681
0
        cur = build_lora_mm(wo, cur, wo_s);
2682
0
    }
2683
2684
0
    if (wo_b) {
2685
        //cb(cur, "kqv_wo", il);
2686
0
    }
2687
2688
0
    if (wo_b) {
2689
0
        cur = ggml_add(ctx0, cur, wo_b);
2690
0
    }
2691
2692
0
    return cur;
2693
0
}
2694
2695
0
llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const {
2696
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx);
2697
2698
0
    auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur);
2699
2700
0
    {
2701
0
        inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch);
2702
2703
0
        inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams);
2704
0
        inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla;
2705
0
    }
2706
2707
0
    {
2708
0
        inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch);
2709
2710
        // ensure F32 mask
2711
0
        auto cparams_copy = cparams;
2712
0
        cparams_copy.flash_attn = false;
2713
2714
0
        inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy);
2715
0
        inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid;
2716
2717
0
        inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0);
2718
0
    }
2719
2720
0
    return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp));
2721
0
}
2722
2723
// TODO: maybe separate the inner implementation into a separate function
2724
//       like with the non-sliding window equivalent
2725
//       once sliding-window hybrid caches are a thing.
2726
0
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2727
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2728
2729
0
    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2730
2731
0
    {
2732
0
        inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2733
0
        inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2734
2735
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
2736
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2737
0
    }
2738
2739
0
    {
2740
0
        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2741
2742
0
        inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2743
0
        inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2744
2745
0
        inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
2746
0
        inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
2747
0
    }
2748
2749
0
    inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
2750
0
    inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
2751
2752
0
    inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
2753
0
    inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2754
2755
0
    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2756
0
}
2757
2758
ggml_tensor * llm_graph_context::build_rs(
2759
        ggml_tensor * s,
2760
        ggml_tensor * state_copy_main,
2761
        ggml_tensor * state_copy_extra,
2762
            int32_t   state_size,
2763
            int32_t   n_seqs,
2764
           uint32_t   n_rs,
2765
           uint32_t   rs_head,
2766
           uint32_t   rs_size,
2767
            int32_t   rs_zero,
2768
0
        const llm_graph_get_rows_fn & get_state_rows) const {
2769
2770
0
    GGML_UNUSED(rs_size);
2771
0
    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]);
2772
2773
    // Clear a single state which will then be copied to the other cleared states.
2774
    // Note that this is a no-op when the view is zero-sized.
2775
0
    ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2776
0
    ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2777
2778
    // copy states
2779
    // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2780
    // {state_size, rs_size} -> {state_size, n_seqs}
2781
0
    ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2782
0
    ggml_build_forward_expand(gf, output_states);
2783
2784
    // copy extra states which won't be changed further (between n_seqs and n_rs)
2785
0
    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2786
0
    ggml_build_forward_expand(gf,
2787
0
        ggml_cpy(ctx0,
2788
0
            states_extra,
2789
0
            ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1])));
2790
2791
0
    return output_states;
2792
0
}
2793
2794
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2795
           ggml_context * ctx0,
2796
     const llama_ubatch & ubatch,
2797
0
    const llama_memory_recurrent_context * mctx_cur) {
2798
2799
0
    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2800
2801
0
    const int64_t n_rs   = mctx_cur->get_n_rs();
2802
0
    const int64_t n_seqs = ubatch.n_seqs;
2803
2804
0
    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2805
0
    ggml_set_input(inp->s_copy);
2806
2807
0
    inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2808
0
    inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2809
2810
0
    inp->head = mctx_cur->get_head();
2811
0
    inp->rs_z = mctx_cur->get_rs_z();
2812
2813
0
    return inp;
2814
0
}
2815
2816
0
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2817
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2818
2819
0
    auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2820
2821
0
    return (llm_graph_input_rs *) res->add_input(std::move(inp));
2822
0
}
2823
2824
ggml_tensor * llm_graph_context::build_rs(
2825
        llm_graph_input_rs * inp,
2826
        ggml_tensor * s,
2827
            int32_t   state_size,
2828
            int32_t   n_seqs,
2829
0
        const llm_graph_get_rows_fn & get_state_rows) const {
2830
0
    const auto * kv_state = inp->mctx;
2831
2832
0
    return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2833
0
                    kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2834
0
                    get_state_rows);
2835
0
}
2836
2837
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2838
    llm_graph_input_rs * inp,
2839
    const llama_ubatch & ubatch,
2840
0
                   int   il) const {
2841
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2842
2843
0
    const auto token_shift_count = hparams.token_shift_count;
2844
2845
0
    const int64_t n_seqs  = ubatch.n_seqs;
2846
2847
0
    ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2848
2849
0
    ggml_tensor * token_shift = build_rs(
2850
0
            inp, token_shift_all,
2851
0
            hparams.n_embd_r(), n_seqs);
2852
2853
0
    token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2854
2855
0
    return token_shift;
2856
0
}
2857
2858
ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2859
         ggml_tensor * token_shift,
2860
  const llama_ubatch & ubatch,
2861
0
                 int   il) const {
2862
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2863
2864
0
    const auto token_shift_count = hparams.token_shift_count;
2865
0
    const auto n_embd = hparams.n_embd;
2866
2867
0
    const int64_t n_seqs = ubatch.n_seqs;
2868
2869
0
    const auto kv_head = mctx_cur->get_head();
2870
2871
0
    return ggml_cpy(
2872
0
        ctx0,
2873
0
        ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2874
0
        ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2875
0
    );
2876
0
}
2877
2878
0
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2879
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2880
2881
0
    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2882
0
    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2883
2884
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2885
2886
0
    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2887
0
}
2888
2889
0
llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2890
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2891
2892
0
    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2893
0
    auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2894
2895
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2896
2897
0
    return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2898
0
}
2899
2900
0
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2901
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2902
2903
0
    auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2904
2905
    // build iswa attention input
2906
0
    const auto * attn_ctx = mctx_cur->get_attn();
2907
2908
0
    auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2909
2910
0
    {
2911
0
        inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2912
0
        inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2913
2914
0
        inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2915
0
        inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
2916
0
    }
2917
2918
0
    {
2919
0
        inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2920
0
        inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2921
2922
0
        inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2923
0
        inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
2924
0
    }
2925
2926
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2927
2928
0
    return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2929
0
}
2930
2931
void llm_graph_context::build_dense_out(
2932
    ggml_tensor * dense_2,
2933
    ggml_tensor * dense_2_b,
2934
0
    ggml_tensor * dense_3) const {
2935
0
    if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2936
0
        return;
2937
0
    }
2938
0
    ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2939
0
    GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2940
2941
0
    if (dense_2) {
2942
0
        cur = ggml_mul_mat(ctx0, dense_2, cur);
2943
0
    }
2944
0
    if (dense_2_b) {
2945
0
        cur = ggml_add(ctx0, cur, dense_2_b);
2946
0
    }
2947
0
    if (dense_3) {
2948
0
        cur = ggml_mul_mat(ctx0, dense_3, cur);
2949
0
    }
2950
0
    cb(cur, "result_embd_pooled", -1);
2951
0
    res->t_embd_pooled = cur;
2952
0
    ggml_build_forward_expand(gf, cur);
2953
0
}
2954
2955
2956
void llm_graph_context::build_pooling(
2957
        ggml_tensor * cls,
2958
        ggml_tensor * cls_b,
2959
        ggml_tensor * cls_out,
2960
        ggml_tensor * cls_out_b,
2961
0
        ggml_tensor * cls_norm) const {
2962
0
    if (!cparams.embeddings) {
2963
0
        return;
2964
0
    }
2965
2966
0
    ggml_tensor * inp = res->t_embd;
2967
2968
    //// find result_norm tensor for input
2969
    //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2970
    //    inp = ggml_graph_node(gf, i);
2971
    //    if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2972
    //        break;
2973
    //    }
2974
2975
    //    inp = nullptr;
2976
    //}
2977
2978
0
    GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2979
2980
0
    ggml_tensor * cur;
2981
2982
0
    switch (pooling_type) {
2983
0
        case LLAMA_POOLING_TYPE_NONE:
2984
0
            {
2985
0
                cur = inp;
2986
0
            } break;
2987
0
        case LLAMA_POOLING_TYPE_MEAN:
2988
0
            {
2989
0
                ggml_tensor * inp_mean = build_inp_mean();
2990
0
                cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2991
0
            } break;
2992
0
        case LLAMA_POOLING_TYPE_CLS:
2993
0
        case LLAMA_POOLING_TYPE_LAST:
2994
0
            {
2995
0
                ggml_tensor * inp_cls = build_inp_cls();
2996
0
                cur = ggml_get_rows(ctx0, inp, inp_cls);
2997
0
            } break;
2998
0
        case LLAMA_POOLING_TYPE_RANK:
2999
0
            {
3000
0
                if (arch == LLM_ARCH_MODERN_BERT) {
3001
                    // modern bert gte reranker builds mean first then applies prediction head and classifier
3002
                    // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
3003
0
                    ggml_tensor * inp_mean = build_inp_mean();
3004
0
                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
3005
0
                } else {
3006
0
                    ggml_tensor * inp_cls = build_inp_cls();
3007
0
                    cur = ggml_get_rows(ctx0, inp, inp_cls);
3008
0
                }
3009
3010
                // classification head
3011
                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
3012
0
                if (cls) {
3013
0
                    cur = ggml_mul_mat(ctx0, cls, cur);
3014
0
                    if (cls_b) {
3015
0
                        cur = ggml_add(ctx0, cur, cls_b);
3016
0
                    }
3017
0
                    if (arch == LLM_ARCH_MODERN_BERT) {
3018
0
                        cur = ggml_gelu(ctx0, cur);
3019
0
                    } else {
3020
0
                        cur = ggml_tanh(ctx0, cur);
3021
0
                    }
3022
0
                    if (cls_norm) {
3023
                        // head norm
3024
0
                        cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
3025
0
                    }
3026
0
                }
3027
3028
                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
3029
                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
3030
                // Single layer classification head (direct projection)
3031
                // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
3032
0
                if (cls_out) {
3033
0
                    cur = ggml_mul_mat(ctx0, cls_out, cur);
3034
0
                    if (cls_out_b) {
3035
0
                        cur = ggml_add(ctx0, cur, cls_out_b);
3036
0
                    }
3037
0
                }
3038
3039
                // softmax for qwen3 reranker
3040
0
                if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
3041
0
                    cur = ggml_soft_max(ctx0, cur);
3042
0
                }
3043
0
            } break;
3044
0
        default:
3045
0
            {
3046
0
                GGML_ABORT("unknown pooling type");
3047
0
            }
3048
0
    }
3049
3050
0
    cb(cur, "result_embd_pooled", -1);
3051
0
    res->t_embd_pooled = cur;
3052
3053
0
    ggml_build_forward_expand(gf, cur);
3054
0
}
3055
3056
0
void llm_graph_context::build_sampling() const {
3057
0
    if (samplers.empty() || !res->t_logits) {
3058
0
        return;
3059
0
    }
3060
3061
0
    std::array<ggml_tensor *, 2> outs;
3062
0
    outs[0] = res->t_logits;
3063
3064
0
    auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
3065
0
    res->add_input(std::move(inp_sampling));
3066
3067
0
    std::map<llama_seq_id, int32_t> seq_to_logit_row;
3068
0
    int32_t logit_row_idx = 0;
3069
3070
0
    for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
3071
0
        if (ubatch.output[i]) {
3072
0
            llama_seq_id seq_id = ubatch.seq_id[i][0];
3073
0
            seq_to_logit_row[seq_id] = logit_row_idx;
3074
0
            logit_row_idx++;
3075
0
        }
3076
0
    }
3077
3078
    // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
3079
0
    GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
3080
3081
    // add a dummy row of logits
3082
    // this trick makes the graph static, regardless of which samplers are activated
3083
    // this is important in order to minimize graph reallocations
3084
0
    ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
3085
3086
0
    for (const auto & [seq_id, sampler] : samplers) {
3087
0
        const auto it = seq_to_logit_row.find(seq_id);
3088
3089
        // inactive samplers always work on the first row
3090
0
        const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
3091
0
        const int i_out    = it != seq_to_logit_row.end() ? 1          : 0;
3092
3093
0
        ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
3094
0
        ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
3095
3096
0
        struct llama_sampler_data data = {
3097
0
            /*.logits      =*/ logits_seq,
3098
0
            /*.probs       =*/ nullptr,
3099
0
            /*.sampled     =*/ nullptr,
3100
0
            /*.candidates  =*/ nullptr,
3101
0
        };
3102
3103
0
        assert(sampler->iface->backend_apply);
3104
0
        sampler->iface->backend_apply(sampler, ctx0, gf, &data);
3105
3106
0
        if (data.sampled != nullptr) {
3107
0
            res->t_sampled[seq_id] = data.sampled;
3108
0
            outs[1] = data.sampled;
3109
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3110
0
        }
3111
3112
0
        if (data.probs != nullptr) {
3113
0
            res->t_sampled_probs[seq_id] = data.probs;
3114
0
            outs[1] = data.probs;
3115
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3116
0
        }
3117
3118
0
        if (data.logits != nullptr) {
3119
0
            res->t_sampled_logits[seq_id] = data.logits;
3120
0
            outs[1] = data.logits;
3121
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3122
0
        }
3123
3124
0
        if (data.candidates != nullptr) {
3125
0
            res->t_candidates[seq_id] = data.candidates;
3126
0
            outs[1] = data.candidates;
3127
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3128
0
        }
3129
0
    }
3130
3131
    // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
3132
    /*
3133
    for (const auto & [seq_id, sampler] : samplers) {
3134
        if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
3135
            ggml_tensor * selected_token = it->second;
3136
            if (selected_token != nullptr) {
3137
                llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
3138
            }
3139
        }
3140
    }
3141
    */
3142
0
}
3143
3144
0
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
3145
    // TODO move to hparams if a T5 variant appears that uses a different value
3146
0
    const int64_t max_distance = 128;
3147
3148
0
    if (bidirectional) {
3149
0
        n_buckets >>= 1;
3150
0
    }
3151
3152
0
    const int64_t max_exact = n_buckets >> 1;
3153
3154
0
    int32_t relative_position = x - y;
3155
0
    int32_t relative_bucket = 0;
3156
3157
0
    if (bidirectional) {
3158
0
        relative_bucket += (relative_position > 0) * n_buckets;
3159
0
        relative_position = std::abs(relative_position);
3160
0
    } else {
3161
0
        relative_position = -std::min<int32_t>(relative_position, 0);
3162
0
    }
3163
3164
0
    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
3165
0
    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
3166
0
    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
3167
3168
0
    return relative_bucket;
3169
0
}