Coverage Report

Created: 2026-06-13 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.cpp
Line
Count
Source
1
#include "llama-graph.h"
2
3
#include "llama-impl.h"
4
#include "llama-model.h"
5
#include "llama-batch.h"
6
#include "llama-cparams.h"
7
8
#include "llama-kv-cache.h"
9
#include "llama-kv-cache-iswa.h"
10
#include "llama-kv-cache-dsa.h"
11
#include "llama-memory-hybrid.h"
12
#include "llama-memory-hybrid-iswa.h"
13
#include "llama-memory-recurrent.h"
14
15
#include <cassert>
16
#include <cmath>
17
#include <cstring>
18
#include <numeric>
19
#include <sstream>
20
#include <unordered_set>
21
22
// dedup helpers
23
24
static ggml_tensor * build_attn_inp_kq_mask(
25
        ggml_context * ctx,
26
        const llama_kv_cache_context * mctx,
27
        const llama_ubatch & ubatch,
28
0
        const llama_cparams & cparams) {
29
0
    const auto n_kv     = mctx->get_n_kv();
30
0
    const auto n_tokens = ubatch.n_tokens;
31
0
    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
32
33
    // flash attention requires an f16 mask
34
0
    const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
35
36
0
    ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
37
0
    ggml_set_input(res);
38
0
    ggml_set_name(res, "attn_inp_kq_mask");
39
40
0
    return res;
41
0
}
42
43
static bool can_reuse_kq_mask(
44
        ggml_tensor * kq_mask,
45
        const llama_kv_cache_context * mctx,
46
        const llama_ubatch & ubatch,
47
0
        const llama_cparams & cparams) {
48
0
    const auto n_kv     = mctx->get_n_kv();
49
0
    const auto n_tokens = ubatch.n_tokens;
50
0
    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
51
52
0
    bool res = true;
53
54
0
    res &= (kq_mask->ne[0] == n_kv);
55
0
    res &= (kq_mask->ne[1] == n_tokens/n_stream);
56
0
    res &= (kq_mask->ne[2] == 1);
57
0
    res &= (kq_mask->ne[3] == n_stream);
58
59
0
    return res;
60
0
}
61
62
// impl
63
64
static ggml_tensor * ggml_mul_mat_aux(
65
        ggml_context * ctx,
66
        ggml_tensor * cur,
67
0
        ggml_tensor * rot) {
68
0
    const auto n = rot->ne[0];
69
70
0
    ggml_tensor * res;
71
72
0
    if (!ggml_is_contiguous(cur)) {
73
0
        res = ggml_cont_2d   (ctx, cur, n, ggml_nelements(cur)/n);
74
0
    } else {
75
0
        res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
76
0
    }
77
0
    res = ggml_mul_mat   (ctx, rot, res);
78
0
    ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
79
0
    res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
80
81
0
    return res;
82
0
}
83
84
0
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
85
0
    if (ubatch->token) {
86
0
        const int64_t n_tokens = ubatch->n_tokens;
87
88
0
        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
89
0
    }
90
91
0
    if (ubatch->embd) {
92
0
        GGML_ASSERT(n_embd == embd->ne[0]);
93
94
0
        const int64_t n_tokens = ubatch->n_tokens;
95
96
0
        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
97
0
    }
98
0
}
99
100
0
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
101
0
    bool res = true;
102
103
0
    res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
104
0
    res &= (!params.ubatch.embd)  || (embd   &&   embd->ne[1] == params.ubatch.n_tokens);
105
106
0
    return res;
107
0
}
108
109
0
void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) {
110
0
    const int64_t n_tokens = ubatch->n_tokens;
111
112
0
    if (ubatch->token) {
113
0
        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
114
0
    } else {
115
        // note: mtmd embedding input goes through here
116
0
        GGML_ASSERT(ubatch->embd);
117
0
        GGML_ASSERT(n_embd == embd->ne[0]);
118
119
0
        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
120
0
    }
121
122
    // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states
123
    //       for now, we assume that the hidden state is always provided as an embedding
124
    //       ref: https://github.com/ggml-org/llama.cpp/pull/23643
125
0
    if (ubatch->embd) {
126
0
        GGML_ASSERT(n_embd == h->ne[0]);
127
128
0
        ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
129
0
    }
130
0
}
131
132
0
bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) {
133
0
    bool res = true;
134
135
0
    res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
136
0
    res &= (!params.ubatch.embd)  || (embd   && embd->ne[1]   == params.ubatch.n_tokens);
137
0
    res &= (!params.ubatch.embd)  || (h      && h->ne[1]      == params.ubatch.n_tokens);
138
139
0
    return res;
140
0
}
141
142
0
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
143
0
    if (ubatch->pos && pos) {
144
0
        const int64_t n_tokens = ubatch->n_tokens;
145
146
0
        if (ubatch->token && n_pos_per_embd == 4) {
147
            // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
148
            // the 3 first dims are the same, and 4th dim is all 0
149
0
            std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
150
            // copy the first dimension
151
0
            for (int i = 0; i < n_tokens; ++i) {
152
0
                pos_data[               i] = ubatch->pos[i];
153
0
                pos_data[    n_tokens + i] = ubatch->pos[i];
154
0
                pos_data[2 * n_tokens + i] = ubatch->pos[i];
155
0
                pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
156
0
            }
157
0
            ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
158
0
        } else {
159
0
            ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
160
0
        }
161
0
    }
162
0
}
163
164
0
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
165
0
    bool res = true;
166
167
0
    res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
168
169
0
    return res;
170
0
}
171
172
0
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
173
0
    if (ubatch->pos && attn_scale) {
174
0
        const int64_t n_tokens = ubatch->n_tokens;
175
176
0
        GGML_ASSERT(f_attn_temp_scale != 0.0f);
177
0
        GGML_ASSERT(n_attn_temp_floor_scale != 0);
178
179
0
        std::vector<float> attn_scale_data(n_tokens, 0.0f);
180
0
        for (int i = 0; i < n_tokens; ++i) {
181
0
            const float pos = ubatch->pos[i];
182
0
            attn_scale_data[i] = std::log(
183
0
                std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
184
0
            ) * f_attn_temp_scale + 1.0;
185
0
        }
186
187
0
        ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
188
0
    }
189
0
}
190
191
0
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
192
0
    if (pos_bucket) {
193
0
        const int64_t n_tokens = ubatch->n_tokens;
194
195
0
        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
196
0
        GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
197
198
0
        int32_t * data = (int32_t *) pos_bucket->data;
199
200
0
        for (int j = 0; j < n_tokens; ++j) {
201
0
            for (int i = 0; i < n_tokens; ++i) {
202
0
                data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
203
0
            }
204
0
        }
205
0
    }
206
0
}
207
208
0
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
209
0
    if (pos_bucket) {
210
0
        mctx->set_input_pos_bucket(pos_bucket, ubatch);
211
0
    }
212
0
}
213
214
0
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
215
0
    GGML_ASSERT(out_ids);
216
217
0
    const int64_t n_tokens = ubatch->n_tokens;
218
219
0
    GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
220
0
    int32_t * data = (int32_t *) out_ids->data;
221
222
0
    if (n_outputs == n_tokens) {
223
0
        for (int i = 0; i < n_tokens; ++i) {
224
0
            data[i] = i;
225
0
        }
226
227
0
        return;
228
0
    }
229
230
0
    GGML_ASSERT(ubatch->output);
231
232
0
    int n_outputs = 0;
233
234
0
    for (int i = 0; i < n_tokens; ++i) {
235
0
        if (ubatch->output[i]) {
236
0
            data[n_outputs++] = i;
237
0
        }
238
0
    }
239
0
}
240
241
0
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
242
0
    bool res = true;
243
244
0
    res &= n_outputs == params.n_outputs;
245
246
0
    return res;
247
0
}
248
249
0
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
250
0
    if (cparams.embeddings   &&
251
0
       (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
252
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
253
254
0
        const int64_t n_tokens     = ubatch->n_tokens;
255
0
        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
256
0
        const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
257
258
0
        GGML_ASSERT(mean);
259
0
        GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
260
261
0
        float * data = (float *) mean->data;
262
0
        memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
263
264
0
        std::vector<uint64_t> sums(n_seqs_unq, 0);
265
0
        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
266
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
267
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
268
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
269
270
0
                sums[seq_idx] += ubatch->n_seq_tokens;
271
0
            }
272
0
        }
273
274
0
        std::vector<float> div(n_seqs_unq, 0.0f);
275
0
        for (int s = 0; s < n_seqs_unq; ++s) {
276
0
            const uint64_t sum = sums[s];
277
0
            if (sum > 0) {
278
0
                div[s] = 1.0f/float(sum);
279
0
            }
280
0
        }
281
282
0
        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
283
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
284
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
285
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
286
287
0
                for (int j = 0; j < n_seq_tokens; ++j) {
288
0
                    data[seq_idx*n_tokens + i + j] = div[seq_idx];
289
0
                }
290
0
            }
291
0
        }
292
0
    }
293
0
}
294
295
0
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
296
0
    const int64_t n_tokens     = ubatch->n_tokens;
297
0
    const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
298
299
0
    if (cparams.embeddings && (
300
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  ||
301
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
302
0
        cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
303
0
    )) {
304
0
        GGML_ASSERT(cls);
305
0
        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
306
307
0
        uint32_t * data = (uint32_t *) cls->data;
308
0
        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
309
310
0
        std::vector<int> target_pos(n_seqs_unq, -1);
311
0
        std::vector<int> target_row(n_seqs_unq, -1);
312
313
0
        const bool last = (
314
0
             cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
315
0
            (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
316
0
        );
317
318
0
        for (int i = 0; i < n_tokens; ++i) {
319
0
            const llama_pos pos = ubatch->pos[i];
320
321
0
            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
322
0
                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
323
0
                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
324
325
0
                if (
326
0
                    (target_pos[seq_idx] == -1) ||
327
0
                    ( last && pos >= target_pos[seq_idx]) ||
328
0
                    (!last && pos <  target_pos[seq_idx])
329
0
                ) {
330
0
                    target_pos[seq_idx] = pos;
331
0
                    target_row[seq_idx] = i;
332
0
                }
333
0
            }
334
0
        }
335
336
0
        for (int s = 0; s < n_seqs_unq; ++s) {
337
0
            if (target_row[s] >= 0) {
338
0
                data[s] = target_row[s];
339
0
            }
340
0
        }
341
0
    }
342
0
}
343
344
0
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
345
0
    GGML_UNUSED(ubatch);
346
347
0
    const int64_t n_rs = mctx->get_n_rs();
348
349
0
    if (s_copy) {
350
0
        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
351
0
        int32_t * data = (int32_t *) s_copy->data;
352
353
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
354
0
        for (uint32_t i = 0; i < n_rs; ++i) {
355
0
            data[i] = mctx->s_copy(i);
356
0
        }
357
0
    }
358
0
}
359
360
0
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
361
0
    const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
362
363
0
    this->mctx = mctx;
364
365
0
    bool res = true;
366
367
0
    res &= s_copy->ne[0] == mctx->get_n_rs();
368
369
0
    res &= s_copy_main->ne[0]  == params.ubatch.n_seqs;
370
0
    res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
371
372
0
    res &= head == mctx->get_head();
373
0
    res &= rs_z == mctx->get_rs_z();
374
375
0
    return res;
376
0
}
377
378
0
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
379
0
    GGML_UNUSED(ubatch);
380
381
0
    if (cross_embd && !cross->v_embd.empty()) {
382
0
        assert(cross_embd->type == GGML_TYPE_F32);
383
384
0
        ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
385
0
    }
386
0
}
387
388
template <typename T>
389
0
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
390
0
    LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
391
0
    const char * swa_type_str = "unknown";
392
393
0
    switch (swa_type) {
394
0
        case LLAMA_SWA_TYPE_NONE:      swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
395
0
        case LLAMA_SWA_TYPE_STANDARD:  swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
396
0
        case LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
397
0
        case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
398
0
    };
399
400
0
    LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
401
0
    LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
402
0
    LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
403
404
0
    LLAMA_LOG_DEBUG("    ");
405
0
    for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
406
0
        LLAMA_LOG_DEBUG("%2d", j);
407
0
    }
408
0
    LLAMA_LOG_DEBUG("\n");
409
410
0
    for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
411
0
        LLAMA_LOG_DEBUG(" %2d ", i);
412
0
        for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
413
0
            float val = llama_cast<float>(data[i * n_kv + j]);
414
0
            if (val == -INFINITY) {
415
0
                LLAMA_LOG_DEBUG(" ∞");
416
0
            } else {
417
0
                LLAMA_LOG_DEBUG(" 0");
418
0
            }
419
0
        }
420
0
        LLAMA_LOG_DEBUG("\n");
421
0
    }
422
0
}
Unexecuted instantiation: llama-graph.cpp:void print_mask<unsigned short>(unsigned short const*, long, long, long, llama_swa_type)
Unexecuted instantiation: llama-graph.cpp:void print_mask<float>(float const*, long, long, long, llama_swa_type)
423
424
0
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
425
0
    const int64_t n_kv     = ubatch->n_tokens;
426
0
    const int64_t n_tokens = ubatch->n_tokens;
427
428
0
    const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
429
0
        using T = std::remove_reference_t<decltype(*data)>;
430
0
        std::fill(data, data + ne, llama_cast<T>(-INFINITY));
431
432
0
        for (int i1 = 0; i1 < n_tokens; ++i1) {
433
0
            const llama_seq_id s1 = ubatch->seq_id[i1][0];
434
0
            const llama_pos    p1 = ubatch->pos[i1];
435
436
0
            const uint64_t idst = i1*n_kv;
437
438
0
            for (int i0 = 0; i0 < n_tokens; ++i0) {
439
0
                const llama_seq_id s0 = ubatch->seq_id[i0][0];
440
0
                const llama_pos p0    = ubatch->pos[i0];
441
442
                // mask different sequences
443
0
                if (s0 != s1) {
444
0
                    continue;
445
0
                }
446
447
                // mask future tokens
448
0
                if (cparams.causal_attn && p0 > p1) {
449
0
                    continue;
450
0
                }
451
452
                // apply SWA if any
453
0
                if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
454
0
                    continue;
455
0
                }
456
457
0
                data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
458
0
            }
459
0
        }
460
461
0
        if (debug) {
462
0
            print_mask(data, n_tokens, n_kv, n_swa, swa_type);
463
0
        }
464
0
    };
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_no_cache::set_input(llama_ubatch const*)::$_0::operator()<unsigned short>(unsigned short*, long, int, llama_swa_type) const
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_no_cache::set_input(llama_ubatch const*)::$_0::operator()<float>(float*, long, int, llama_swa_type) const
465
466
0
    GGML_ASSERT(self_kq_mask);
467
0
    GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
468
0
    if (self_kq_mask->type == GGML_TYPE_F16) {
469
0
        fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
470
0
    } else {
471
0
        fill_mask((float       *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
472
0
    }
473
474
0
    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
475
0
        GGML_ASSERT(self_kq_mask_swa);
476
0
        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
477
0
        if (self_kq_mask_swa->type == GGML_TYPE_F16) {
478
0
            fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
479
0
        } else {
480
0
            fill_mask((float       *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
481
0
        }
482
0
    }
483
0
}
484
485
0
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
486
0
    mctx->set_input_k_idxs(self_k_idxs, ubatch);
487
0
    mctx->set_input_v_idxs(self_v_idxs, ubatch);
488
489
0
    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
490
491
0
    if (self_k_rot) {
492
0
        mctx->set_input_k_rot(self_k_rot);
493
0
    }
494
495
0
    if (self_v_rot) {
496
0
        mctx->set_input_v_rot(self_v_rot);
497
0
    }
498
0
}
499
500
0
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
501
0
    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
502
503
0
    this->mctx = mctx;
504
505
0
    bool res = true;
506
507
0
    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
508
  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
509
510
0
    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
511
512
0
    return res;
513
0
}
514
515
0
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
516
0
    mctx->set_input_k_idxs(self_k_idxs, ubatch);
517
518
0
    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
519
0
}
520
521
0
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
522
0
    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
523
524
0
    this->mctx = mctx;
525
526
0
    bool res = true;
527
528
0
    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
529
530
0
    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
531
532
0
    return res;
533
0
}
534
535
0
void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) {
536
0
    mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch);
537
538
0
    mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn);
539
540
0
    mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch);
541
542
0
    mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn);
543
544
0
    mctx->get_lid()->set_input_k_rot(self_k_rot_lid);
545
0
}
546
547
0
bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) {
548
0
    const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx);
549
550
0
    this->mctx = mctx;
551
552
0
    bool res = true;
553
554
0
    res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens;
555
0
    res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens;
556
557
0
    res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams);
558
0
    res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams);
559
560
0
    return res;
561
0
}
562
563
0
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
564
    // base tensors may not be allocated if there are no non-SWA attention layers
565
0
    if (self_k_idxs && self_k_idxs->buffer) {
566
0
        mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
567
0
        mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
568
0
    }
569
570
    // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
571
0
    if (self_kq_mask && self_kq_mask->buffer) {
572
0
        mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
573
0
    }
574
575
    // swa tensors may not be allocated if there are no SWA attention layers
576
0
    if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
577
0
        mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
578
0
        mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
579
0
    }
580
581
0
    if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
582
0
        mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
583
0
    }
584
585
0
    if (self_k_rot) {
586
0
        mctx->get_base()->set_input_k_rot(self_k_rot);
587
0
    }
588
589
0
    if (self_v_rot) {
590
0
        mctx->get_base()->set_input_v_rot(self_v_rot);
591
0
    }
592
593
0
    if (self_k_rot_swa) {
594
0
        mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
595
0
    }
596
597
0
    if (self_v_rot_swa) {
598
0
        mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
599
0
    }
600
0
}
601
602
0
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
603
0
    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
604
605
0
    this->mctx = mctx;
606
607
0
    bool res = true;
608
609
    // base tensors may not be allocated if there are no non-SWA attention layers
610
0
    if (self_k_idxs && self_k_idxs->buffer) {
611
0
        res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
612
      //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
613
0
    }
614
615
0
    if (self_kq_mask && self_kq_mask->buffer) {
616
0
        res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
617
0
    }
618
619
    // swa tensors may not be allocated if there are no SWA attention layers
620
0
    if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
621
0
        res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
622
      //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
623
0
    }
624
625
0
    if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
626
0
        res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
627
0
    }
628
629
0
    return res;
630
0
}
631
632
0
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
633
0
    GGML_ASSERT(cross_kq_mask);
634
635
0
    const int64_t n_enc    = cross_kq_mask->ne[0];
636
0
    const int64_t n_tokens = ubatch->n_tokens;
637
638
0
    GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
639
0
    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
640
641
0
    const auto fill_mask = [&](auto * data) {
642
0
        using T = std::remove_reference_t<decltype(*data)>;
643
0
        for (int i = 0; i < n_tokens; ++i) {
644
0
            GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
645
0
            for (int j = 0; j < n_enc; ++j) {
646
0
                float f = -INFINITY;
647
648
0
                for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
649
0
                    const llama_seq_id seq_id = ubatch->seq_id[i][s];
650
651
0
                    if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
652
0
                        f = 0.0f;
653
0
                    }
654
0
                }
655
656
0
                data[i*n_enc + j] = llama_cast<T>(f);
657
0
            }
658
0
        }
659
0
    };
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_cross::set_input(llama_ubatch const*)::$_0::operator()<unsigned short>(unsigned short*) const
Unexecuted instantiation: llama-graph.cpp:auto llm_graph_input_attn_cross::set_input(llama_ubatch const*)::$_0::operator()<float>(float*) const
660
661
0
    if (cross_kq_mask->type == GGML_TYPE_F16) {
662
0
        fill_mask((ggml_fp16_t *) cross_kq_mask->data);
663
0
    } else {
664
0
        fill_mask((float *) cross_kq_mask->data);
665
0
    }
666
0
}
667
668
0
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
669
0
    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
670
0
    mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
671
672
0
    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
673
674
0
    if (inp_attn->self_k_rot) {
675
0
        mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
676
0
    }
677
678
0
    if (inp_attn->self_v_rot) {
679
0
        mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
680
0
    }
681
682
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
683
684
0
    if (inp_rs->s_copy) {
685
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
686
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
687
688
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
689
0
        for (uint32_t i = 0; i < n_rs; ++i) {
690
0
            data[i] = mctx->get_recr()->s_copy(i);
691
0
        }
692
0
    }
693
0
}
694
695
0
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
696
0
    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
697
698
0
    this->mctx = mctx;
699
700
0
    bool res = true;
701
702
0
    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
703
  //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
704
705
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
706
707
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
708
709
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
710
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
711
712
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
713
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
714
715
0
    return res;
716
0
}
717
718
// TODO: Hybrid input classes are a bit redundant.
719
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
720
// Refactoring is required in the future.
721
0
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
722
0
    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
723
724
0
    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
725
726
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
727
728
0
    if (inp_rs->s_copy) {
729
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
730
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
731
732
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
733
0
        for (uint32_t i = 0; i < n_rs; ++i) {
734
0
            data[i] = mctx->get_recr()->s_copy(i);
735
0
        }
736
0
    }
737
0
}
738
739
0
bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
740
0
    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
741
742
0
    this->mctx = mctx;
743
744
0
    bool res = true;
745
746
0
    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
747
748
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
749
750
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
751
752
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
753
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
754
755
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
756
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
757
758
0
    return res;
759
0
}
760
761
0
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
762
0
    const auto * attn_ctx = mctx->get_attn();
763
764
    // base tensors may not be allocated if there are no non-SWA attention layers
765
0
    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
766
0
        attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
767
0
        attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
768
0
    }
769
770
0
    if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
771
0
        attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
772
0
    }
773
774
    // swa tensors may not be allocated if there are no SWA attention layers
775
0
    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
776
0
        attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
777
0
        attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
778
0
    }
779
780
0
    if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
781
0
        attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
782
0
    }
783
784
0
    if (inp_attn->self_k_rot) {
785
0
        attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
786
0
    }
787
788
0
    if (inp_attn->self_v_rot) {
789
0
        attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
790
0
    }
791
792
0
    if (inp_attn->self_k_rot_swa) {
793
0
        attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
794
0
    }
795
796
0
    if (inp_attn->self_v_rot_swa) {
797
0
        attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
798
0
    }
799
800
0
    const int64_t n_rs = mctx->get_recr()->get_n_rs();
801
802
0
    if (inp_rs->s_copy) {
803
0
        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
804
0
        int32_t * data = (int32_t *) inp_rs->s_copy->data;
805
806
        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
807
0
        for (uint32_t i = 0; i < n_rs; ++i) {
808
0
            data[i] = mctx->get_recr()->s_copy(i);
809
0
        }
810
0
    }
811
0
}
812
813
0
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
814
0
    const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
815
816
0
    this->mctx = mctx;
817
818
0
    bool res = true;
819
820
0
    const auto * attn_ctx = mctx->get_attn();
821
822
    // base tensors may not be allocated if there are no non-SWA attention layers
823
0
    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
824
0
        res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
825
      //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
826
0
    }
827
828
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
829
830
    // swa tensors may not be allocated if there are no SWA attention layers
831
0
    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
832
0
        res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
833
      //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
834
0
    }
835
836
0
    res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
837
838
0
    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
839
840
0
    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
841
0
    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
842
843
0
    res &= inp_rs->head == mctx->get_recr()->get_head();
844
0
    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
845
846
0
    return res;
847
0
}
848
849
0
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
850
    // set the inputs only for the active samplers in the current ubatch
851
0
    std::unordered_set<llama_seq_id> active_samplers;
852
0
    for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
853
0
        if (ubatch->output[i]) {
854
0
            llama_seq_id seq_id = ubatch->seq_id[i][0];
855
0
            active_samplers.insert(seq_id);
856
0
        }
857
0
    }
858
859
0
    for (auto seq_id : active_samplers) {
860
0
        if (samplers.find(seq_id) == samplers.end()) {
861
0
            continue;
862
0
        }
863
864
0
        auto & sampler = samplers[seq_id];
865
866
0
        if (sampler->iface->backend_set_input) {
867
0
            sampler->iface->backend_set_input(sampler);
868
0
        }
869
0
    }
870
0
}
871
872
0
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
873
0
    if (samplers.size() != params.samplers.size()) {
874
0
        return false;
875
0
    }
876
877
0
    for (const auto & [seq_id, sampler] : params.samplers) {
878
0
        if (samplers[seq_id] != sampler) {
879
0
            return false;
880
0
        }
881
0
    }
882
883
0
    return true;
884
0
}
885
886
//
887
// llm_graph_result
888
//
889
890
0
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
891
0
    reset();
892
893
0
    const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
894
0
    debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
895
0
}
896
897
0
int64_t llm_graph_result::get_max_nodes() const {
898
0
    return max_nodes;
899
0
}
900
901
0
void llm_graph_result::reset() {
902
0
    t_inp_tokens  = nullptr;
903
0
    t_inp_embd    = nullptr;
904
0
    t_logits      = nullptr;
905
0
    t_embd        = nullptr;
906
0
    t_embd_pooled = nullptr;
907
908
0
    t_layer_inp.resize(LLAMA_MAX_LAYERS);
909
0
    std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
910
911
0
    t_sampled.clear();
912
0
    t_sampled_probs.clear();
913
0
    t_sampled_logits.clear();
914
0
    t_candidates.clear();
915
916
0
    params = {};
917
918
0
    inputs.clear();
919
920
0
    buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
921
922
0
    ggml_init_params params = {
923
0
        /*.mem_size   =*/ buf_compute_meta.size(),
924
0
        /*.mem_buffer =*/ buf_compute_meta.data(),
925
0
        /*.no_alloc   =*/ true,
926
0
    };
927
928
0
    ctx_compute.reset(ggml_init(params));
929
930
0
    gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
931
0
}
932
933
0
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
934
0
    for (auto & input : inputs) {
935
0
        input->set_input(ubatch);
936
0
    }
937
0
}
938
939
0
void llm_graph_result::set_outputs(const llm_graph_params & params) {
940
0
    if (t_logits != nullptr) {
941
0
        ggml_set_output(t_logits);
942
0
    }
943
0
    if (t_embd != nullptr) {
944
0
        ggml_set_output(t_embd);
945
0
    }
946
0
    if (t_embd_pooled != nullptr) {
947
0
        ggml_set_output(t_embd_pooled);
948
0
    }
949
0
    if (t_h_nextn != nullptr) {
950
0
        ggml_set_output(t_h_nextn);
951
0
    }
952
0
    {
953
0
        const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp;
954
0
        for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) {
955
0
            if (embeddings_layer_inp[il]) {
956
0
                GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null");
957
0
                ggml_set_output(t_layer_inp[il]);
958
0
            }
959
0
        }
960
0
    }
961
0
    for (auto & [seq_id, t] : t_sampled) {
962
0
        if (t != nullptr) {
963
0
            ggml_set_output(t);
964
0
        }
965
0
    }
966
0
    for (auto & [seq_id, t] : t_sampled_probs) {
967
0
        if (t != nullptr) {
968
0
            ggml_set_output(t);
969
0
        }
970
0
    }
971
0
    for (auto & [seq_id, t] : t_sampled_logits) {
972
0
        if (t != nullptr) {
973
0
            ggml_set_output(t);
974
0
        }
975
0
    }
976
0
    for (auto & [seq_id, t] : t_candidates) {
977
0
        if (t != nullptr) {
978
0
            ggml_set_output(t);
979
0
        }
980
0
    }
981
0
}
982
983
0
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
984
0
    if (!this->params.allow_reuse(params)) {
985
0
        if (debug > 1) {
986
0
            LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
987
0
        }
988
989
0
        return false;
990
0
    }
991
992
0
    if (debug > 1) {
993
0
        LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
994
0
    }
995
996
0
    bool res = true;
997
998
0
    for (auto & input : inputs) {
999
0
        const bool cur = input->can_reuse(params);
1000
1001
0
        if (debug > 1) {
1002
0
            LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
1003
0
        }
1004
1005
0
        res = res && cur;
1006
0
    }
1007
1008
0
    if (debug > 0) {
1009
0
        LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
1010
0
    }
1011
1012
0
    return res;
1013
0
}
1014
1015
0
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
1016
0
    inputs.emplace_back(std::move(input));
1017
0
    return inputs.back().get();
1018
0
}
1019
1020
0
void llm_graph_result::set_params(const llm_graph_params & params) {
1021
0
    this->params = params;
1022
0
}
1023
1024
//
1025
// llm_graph_context
1026
//
1027
1028
llm_graph_context::llm_graph_context(const llm_graph_params & params) :
1029
0
    arch             (params.arch),
1030
0
    hparams          (params.hparams),
1031
0
    cparams          (params.cparams),
1032
0
    ubatch           (params.ubatch),
1033
0
    n_embd           (hparams.n_embd),
1034
0
    n_layer          (hparams.n_layer()),
1035
0
    n_layer_nextn    (hparams.n_layer_nextn),
1036
0
    n_rot            (hparams.n_rot()),
1037
0
    n_ctx            (cparams.n_ctx),
1038
0
    n_head           (hparams.n_head()),
1039
0
    n_head_kv        (hparams.n_head_kv()),
1040
0
    n_embd_head_k    (hparams.n_embd_head_k()),
1041
0
    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
1042
0
    n_embd_head_v    (hparams.n_embd_head_v()),
1043
0
    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
1044
0
    n_expert         (hparams.n_expert),
1045
0
    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
1046
0
    freq_base        (cparams.rope_freq_base),
1047
0
    freq_scale       (cparams.rope_freq_scale),
1048
0
    ext_factor       (cparams.yarn_ext_factor),
1049
0
    attn_factor      (cparams.yarn_attn_factor),
1050
0
    beta_fast        (cparams.yarn_beta_fast),
1051
0
    beta_slow        (cparams.yarn_beta_slow),
1052
0
    norm_eps         (hparams.f_norm_eps),
1053
0
    norm_rms_eps     (hparams.f_norm_rms_eps),
1054
0
    n_tokens         (ubatch.n_tokens),
1055
0
    n_outputs        (params.n_outputs),
1056
0
    n_ctx_orig       (cparams.n_ctx_orig_yarn),
1057
0
    pooling_type     (cparams.pooling_type),
1058
0
    rope_type        (hparams.rope_type),
1059
0
    sched            (params.sched),
1060
0
    backend_cpu      (params.backend_cpu),
1061
0
    cvec             (params.cvec),
1062
0
    loras            (params.loras),
1063
0
    mctx             (params.mctx),
1064
0
    cross            (params.cross),
1065
0
    samplers         (params.samplers),
1066
0
    cb_func          (params.cb),
1067
0
    res              (params.res),
1068
0
    ctx0             (res->get_ctx()),
1069
0
    gf               (res->get_gf()) {
1070
0
        res->set_params(params);
1071
0
    }
1072
1073
0
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
1074
0
    if (cb_func) {
1075
0
        cb_func(ubatch, cur, name, il);
1076
0
    }
1077
0
}
1078
1079
ggml_tensor * llm_graph_context::build_cvec(
1080
         ggml_tensor * cur,
1081
0
                 int   il) const {
1082
0
    return cvec->apply_to(ctx0, cur, il);
1083
0
}
1084
1085
ggml_tensor * llm_graph_context::build_lora_mm(
1086
          ggml_tensor * w,
1087
          ggml_tensor * cur,
1088
0
          ggml_tensor * w_s) const {
1089
0
    ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
1090
1091
0
    for (const auto & lora : *loras) {
1092
0
        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1093
0
        if (lw == nullptr) {
1094
0
            continue;
1095
0
        }
1096
1097
0
        const float adapter_scale = lora.second;
1098
0
        const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1099
1100
0
        ggml_tensor * ab_cur = ggml_mul_mat(
1101
0
                ctx0, lw->b,
1102
0
                ggml_mul_mat(ctx0, lw->a, cur)
1103
0
                );
1104
1105
0
        ab_cur = ggml_scale(ctx0, ab_cur, scale);
1106
0
        res = ggml_add(ctx0, res, ab_cur);
1107
0
    }
1108
1109
0
    if (w_s) {
1110
0
        res = ggml_mul(ctx0, res, w_s);
1111
0
    }
1112
1113
0
    return res;
1114
0
}
1115
1116
ggml_tensor * llm_graph_context::build_lora_mm_id(
1117
          ggml_tensor * w,   // ggml_tensor * as
1118
          ggml_tensor * cur, // ggml_tensor * b
1119
0
          ggml_tensor * ids) const {
1120
0
    ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
1121
0
    for (const auto & lora : *loras) {
1122
0
        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1123
0
        if (lw == nullptr) {
1124
0
            continue;
1125
0
        }
1126
1127
0
        const float alpha = lora.first->alpha;
1128
0
        const float rank  = (float) lw->b->ne[0];
1129
0
        const float scale = alpha ? lora.second * alpha / rank : lora.second;
1130
1131
0
        ggml_tensor * ab_cur = ggml_mul_mat_id(
1132
0
                ctx0, lw->b,
1133
0
                ggml_mul_mat_id(ctx0, lw->a, cur, ids),
1134
0
                ids
1135
0
                );
1136
1137
0
        ab_cur = ggml_scale(ctx0, ab_cur, scale);
1138
0
        res = ggml_add(ctx0, res, ab_cur);
1139
0
    }
1140
1141
0
    return res;
1142
0
}
1143
1144
ggml_tensor * llm_graph_context::build_norm(
1145
         ggml_tensor * cur,
1146
         ggml_tensor * mw,
1147
         ggml_tensor * mb,
1148
       llm_norm_type   type,
1149
0
                 int   il) const {
1150
0
    switch (type) {
1151
0
        case LLM_NORM:       cur = ggml_norm    (ctx0, cur, hparams.f_norm_eps);     break;
1152
0
        case LLM_NORM_RMS:   cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
1153
0
        case LLM_NORM_GROUP:
1154
0
            {
1155
0
                cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
1156
0
                cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
1157
0
                cur = ggml_reshape_2d(ctx0, cur, cur->ne[0],    cur->ne[2]);
1158
0
            } break;
1159
0
    }
1160
1161
0
    if (mw || mb) {
1162
0
        cb(cur, "norm", il);
1163
0
    }
1164
1165
0
    if (mw) {
1166
0
        cur = ggml_mul(ctx0, cur, mw);
1167
0
        if (mb) {
1168
0
            cb(cur, "norm_w", il);
1169
0
        }
1170
0
    }
1171
1172
0
    if (mb) {
1173
0
        cur = ggml_add(ctx0, cur, mb);
1174
0
    }
1175
1176
0
    return cur;
1177
0
}
1178
1179
1180
llm_graph_qkv llm_graph_context::build_qkv(
1181
        const llama_layer & layer,
1182
              ggml_tensor * cur,
1183
                  int64_t   n_embd_head,
1184
                  int64_t   n_head,
1185
                  int64_t   n_head_kv,
1186
0
                      int   il) const {
1187
0
    const int64_t n_embd_q  = n_embd_head * n_head;
1188
0
    const int64_t n_embd_kv = n_embd_head * n_head_kv;
1189
1190
0
    ggml_tensor * Qcur, * Kcur, * Vcur;
1191
1192
0
    if (layer.wqkv) {
1193
        // fused QKV path
1194
0
        ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s);
1195
0
        cb(qkv, "wqkv", il);
1196
0
        if (layer.wqkv_b) {
1197
0
            qkv = ggml_add(ctx0, qkv, layer.wqkv_b);
1198
0
            cb(qkv, "wqkv_b", il);
1199
0
        }
1200
0
        if (hparams.f_clamp_kqv > 0.0f) {
1201
0
            qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1202
0
            cb(qkv, "wqkv_clamped", il);
1203
0
        }
1204
0
        Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head,    n_tokens,
1205
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0);
1206
0
        Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
1207
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
1208
0
            ggml_row_size(qkv->type, n_embd_q));
1209
0
        Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
1210
0
            ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
1211
0
            ggml_row_size(qkv->type, n_embd_q + n_embd_kv));
1212
0
    } else {
1213
        // separate Q/K/V path
1214
0
        Qcur = build_lora_mm(layer.wq, cur, layer.wq_s);
1215
0
        cb(Qcur, "Qcur", il);
1216
0
        if (layer.wq_b) {
1217
0
            Qcur = ggml_add(ctx0, Qcur, layer.wq_b);
1218
0
            cb(Qcur, "Qcur", il);
1219
0
        }
1220
0
        if (hparams.f_clamp_kqv > 0.0f) {
1221
0
            Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1222
0
            cb(Qcur, "Qcur_clamped", il);
1223
0
        }
1224
0
        Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
1225
0
        cb(Kcur, "Kcur", il);
1226
0
        if (layer.wk_b) {
1227
0
            Kcur = ggml_add(ctx0, Kcur, layer.wk_b);
1228
0
            cb(Kcur, "Kcur", il);
1229
0
        }
1230
0
        if (hparams.f_clamp_kqv > 0.0f) {
1231
0
            Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1232
0
            cb(Kcur, "Kcur_clamped", il);
1233
0
        }
1234
0
        Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
1235
0
        cb(Vcur, "Vcur", il);
1236
0
        if (layer.wv_b) {
1237
0
            Vcur = ggml_add(ctx0, Vcur, layer.wv_b);
1238
0
            cb(Vcur, "Vcur", il);
1239
0
        }
1240
0
        if (hparams.f_clamp_kqv > 0.0f) {
1241
0
            Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
1242
0
            cb(Vcur, "Vcur_clamped", il);
1243
0
        }
1244
0
        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
1245
0
        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
1246
0
        Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1247
0
    }
1248
1249
0
    cb(Qcur, "Qcur", il);
1250
0
    cb(Kcur, "Kcur", il);
1251
0
    cb(Vcur, "Vcur", il);
1252
1253
0
    return { Qcur, Kcur, Vcur };
1254
0
}
1255
1256
1257
ggml_tensor * llm_graph_context::build_ffn(
1258
         ggml_tensor * cur,
1259
         ggml_tensor * up,
1260
         ggml_tensor * up_b,
1261
         ggml_tensor * up_s,
1262
         ggml_tensor * gate,
1263
         ggml_tensor * gate_b,
1264
         ggml_tensor * gate_s,
1265
         ggml_tensor * down,
1266
         ggml_tensor * down_b,
1267
         ggml_tensor * down_s,
1268
         ggml_tensor * act_scales,
1269
     llm_ffn_op_type   type_op,
1270
   llm_ffn_gate_type   type_gate,
1271
0
                 int   il) const {
1272
0
    ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
1273
0
    cb(tmp, "ffn_up", il);
1274
1275
0
    if (up_b) {
1276
0
        tmp = ggml_add(ctx0, tmp, up_b);
1277
0
        cb(tmp, "ffn_up_b", il);
1278
0
    }
1279
1280
0
    if (up_s) {
1281
0
        tmp = ggml_mul(ctx0, tmp, up_s);
1282
0
        cb(tmp, "ffn_up_s", il);
1283
0
    }
1284
1285
0
    if (gate) {
1286
0
        switch (type_gate) {
1287
0
            case LLM_FFN_SEQ:
1288
0
                {
1289
0
                    cur = build_lora_mm(gate, tmp);
1290
0
                    cb(cur, "ffn_gate", il);
1291
0
                } break;
1292
0
            case LLM_FFN_PAR:
1293
0
                {
1294
0
                    cur = build_lora_mm(gate, cur);
1295
0
                    cb(cur, "ffn_gate", il);
1296
0
                } break;
1297
0
        }
1298
1299
0
        if (gate_b) {
1300
0
            cur = ggml_add(ctx0, cur, gate_b);
1301
0
            cb(cur, "ffn_gate_b", il);
1302
0
        }
1303
1304
0
        if (gate_s) {
1305
0
            cur = ggml_mul(ctx0, cur, gate_s);
1306
0
            cb(cur, "ffn_gate_s", il);
1307
0
        }
1308
1309
0
    } else {
1310
0
        cur = tmp;
1311
0
    }
1312
1313
0
    switch (type_op) {
1314
0
        case LLM_FFN_SILU:
1315
0
            if (gate && type_gate == LLM_FFN_PAR) {
1316
                // Step35: HF clamps gate (after SiLU) and up before multiplication
1317
0
                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1318
0
                    const float limit = hparams.swiglu_clamp_shexp[il];
1319
0
                    constexpr float eps = 1e-6f;
1320
0
                    if (limit > eps) {
1321
0
                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1322
0
                        cb(gate_act, "ffn_silu", il);
1323
0
                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1324
0
                        cb(gate_act, "ffn_silu_clamped", il);
1325
1326
0
                        tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1327
0
                        cb(tmp, "ffn_up_clamped", il);
1328
1329
0
                        cur = ggml_mul(ctx0, gate_act, tmp);
1330
0
                        cb(cur, "ffn_swiglu_limited", il);
1331
0
                        type_gate = LLM_FFN_SEQ;
1332
0
                        break;
1333
0
                    }
1334
0
                }
1335
1336
0
                cur = ggml_swiglu_split(ctx0, cur, tmp);
1337
0
                cb(cur, "ffn_swiglu", il);
1338
0
                type_gate = LLM_FFN_SEQ;
1339
0
            } else {
1340
0
                cur = ggml_silu(ctx0, cur);
1341
0
                cb(cur, "ffn_silu", il);
1342
0
            } break;
1343
0
        case LLM_FFN_GELU:
1344
0
            if (gate && type_gate == LLM_FFN_PAR) {
1345
0
                cur = ggml_geglu_split(ctx0, cur, tmp);
1346
0
                cb(cur, "ffn_geglu", il);
1347
0
                type_gate = LLM_FFN_SEQ;
1348
0
            } else {
1349
0
                cur = ggml_gelu(ctx0, cur);
1350
0
                cb(cur, "ffn_gelu", il);
1351
0
                if (act_scales != NULL) {
1352
0
                    cur = ggml_div(ctx0, cur, act_scales);
1353
0
                    cb(cur, "ffn_act", il);
1354
0
                }
1355
0
            } break;
1356
0
        case LLM_FFN_RELU:
1357
0
            if (gate && type_gate == LLM_FFN_PAR) {
1358
0
                cur = ggml_reglu_split(ctx0, cur, tmp);
1359
0
                cb(cur, "ffn_reglu", il);
1360
0
                type_gate = LLM_FFN_SEQ;
1361
0
            } else {
1362
0
                cur = ggml_relu(ctx0, cur);
1363
0
                cb(cur, "ffn_relu", il);
1364
0
            } break;
1365
0
        case LLM_FFN_RELU_SQR:
1366
0
            {
1367
0
                cur = ggml_relu(ctx0, cur);
1368
0
                cb(cur, "ffn_relu", il);
1369
1370
0
                cur = ggml_sqr(ctx0, cur);
1371
0
                cb(cur, "ffn_sqr(relu)", il);
1372
0
            } break;
1373
0
        case LLM_FFN_SWIGLU:
1374
0
            {
1375
0
                cur = ggml_swiglu(ctx0, cur);
1376
0
                cb(cur, "ffn_swiglu", il);
1377
0
            } break;
1378
0
        case LLM_FFN_GEGLU:
1379
0
            {
1380
0
                cur = ggml_geglu(ctx0, cur);
1381
0
                cb(cur, "ffn_geglu", il);
1382
0
            } break;
1383
0
        case LLM_FFN_REGLU:
1384
0
            {
1385
0
                cur = ggml_reglu(ctx0, cur);
1386
0
                cb(cur, "ffn_reglu", il);
1387
0
            } break;
1388
0
        default:
1389
0
            GGML_ABORT("fatal error");
1390
0
    }
1391
1392
0
    if (gate && type_gate == LLM_FFN_PAR) {
1393
0
        cur = ggml_mul(ctx0, cur, tmp);
1394
0
        cb(cur, "ffn_gate_par", il);
1395
0
    }
1396
1397
0
    if (down) {
1398
0
        cur = build_lora_mm(down, cur);
1399
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1400
            // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
1401
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1402
0
        }
1403
0
    }
1404
1405
0
    if (down_b) {
1406
0
        cb(cur, "ffn_down", il);
1407
0
    }
1408
1409
0
    if (down_b) {
1410
0
        cur = ggml_add(ctx0, cur, down_b);
1411
0
    }
1412
1413
0
    if (down_s) {
1414
0
        cur = ggml_mul(ctx0, cur, down_s);
1415
0
        cb(cur, "ffn_down_s", il);
1416
0
    }
1417
1418
0
    return cur;
1419
0
}
1420
1421
ggml_tensor * llm_graph_context::build_moe_ffn(
1422
         ggml_tensor * cur,
1423
         ggml_tensor * gate_inp,
1424
         ggml_tensor * up_exps,
1425
         ggml_tensor * gate_exps,
1426
         ggml_tensor * down_exps,
1427
         ggml_tensor * exp_probs_b,
1428
             int64_t   n_expert,
1429
             int64_t   n_expert_used,
1430
     llm_ffn_op_type   type_op,
1431
                bool   norm_w,
1432
               float   w_scale,
1433
         llama_expert_gating_func_type gating_op,
1434
                 int   il,
1435
         ggml_tensor * probs_in,
1436
         ggml_tensor * gate_up_exps,
1437
         ggml_tensor * up_exps_s,
1438
         ggml_tensor * gate_exps_s,
1439
0
         ggml_tensor * down_exps_s) const {
1440
0
    return build_moe_ffn(
1441
0
        cur,
1442
0
        gate_inp,  /* gate_inp_b  */ nullptr,
1443
0
        up_exps,   /* up_exps_b   */ nullptr,
1444
0
        gate_exps, /* gate_exps_b */ nullptr,
1445
0
        down_exps, /* down_exps_b */ nullptr,
1446
0
        exp_probs_b,
1447
0
        n_expert,
1448
0
        n_expert_used,
1449
0
        type_op,
1450
0
        norm_w,
1451
0
        w_scale,
1452
0
        gating_op,
1453
0
        il,
1454
0
        probs_in,
1455
0
        gate_up_exps,
1456
0
        /* gate_up_exps_b */ nullptr,
1457
0
        up_exps_s,
1458
0
        gate_exps_s,
1459
0
        down_exps_s
1460
0
    );
1461
0
}
1462
1463
ggml_tensor * llm_graph_context::build_moe_ffn(
1464
         ggml_tensor * cur,
1465
         ggml_tensor * gate_inp,
1466
         ggml_tensor * gate_inp_b,
1467
         ggml_tensor * up_exps,
1468
         ggml_tensor * up_exps_b,
1469
         ggml_tensor * gate_exps,
1470
         ggml_tensor * gate_exps_b,
1471
         ggml_tensor * down_exps,
1472
         ggml_tensor * down_exps_b,
1473
         ggml_tensor * exp_probs_b,
1474
             int64_t   n_expert,
1475
             int64_t   n_expert_used,
1476
     llm_ffn_op_type   type_op,
1477
                bool   norm_w,
1478
               float   w_scale,
1479
        llama_expert_gating_func_type gating_op,
1480
                 int   il,
1481
         ggml_tensor * probs_in,
1482
         ggml_tensor * gate_up_exps,
1483
         ggml_tensor * gate_up_exps_b,
1484
         ggml_tensor * up_exps_s,
1485
         ggml_tensor * gate_exps_s,
1486
0
         ggml_tensor * down_exps_s) const {
1487
0
    const int64_t n_embd   = cur->ne[0];
1488
0
    const int64_t n_tokens = cur->ne[1];
1489
0
    const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1490
1491
0
    ggml_tensor * logits = nullptr;
1492
1493
0
    if (probs_in == nullptr) {
1494
0
        logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1495
0
        cb(logits, "ffn_moe_logits", il);
1496
0
    } else {
1497
0
        logits = probs_in;
1498
0
    }
1499
1500
0
    if (gate_inp_b) {
1501
0
        logits = ggml_add(ctx0, logits, gate_inp_b);
1502
0
        cb(logits, "ffn_moe_logits_biased", il);
1503
0
    }
1504
1505
0
    ggml_tensor * probs = nullptr;
1506
0
    switch (gating_op) {
1507
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1508
0
            {
1509
0
                probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1510
0
            } break;
1511
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1512
0
            {
1513
0
                probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1514
0
            } break;
1515
0
        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1516
0
            {
1517
0
                probs = logits; // [n_expert, n_tokens]
1518
0
            } break;
1519
0
        default:
1520
0
            GGML_ABORT("fatal error");
1521
0
    }
1522
0
    cb(probs, "ffn_moe_probs", il);
1523
1524
    // add experts selection bias - introduced in DeepSeek V3
1525
    // leave probs unbiased as it's later used to get expert weights
1526
0
    ggml_tensor * selection_probs = probs;
1527
0
    if (exp_probs_b != nullptr) {
1528
0
        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1529
0
        cb(selection_probs, "ffn_moe_probs_biased", il);
1530
0
    }
1531
1532
    // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1533
    // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1534
0
    if (arch == LLM_ARCH_LLAMA4) {
1535
0
        selection_probs = logits;
1536
0
    }
1537
1538
0
    if (arch == LLM_ARCH_GROVEMOE) {
1539
0
        selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1540
0
        cb(selection_probs, "ffn_moe_probs_biased", il);
1541
0
    }
1542
1543
    // select top n_group_used expert groups
1544
    // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1545
0
    if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1546
0
        const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1547
1548
        // organize experts into n_expert_groups
1549
0
        ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1550
1551
0
        ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1552
0
        group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1553
1554
        // get top n_group_used expert groups
1555
0
        group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1556
0
        group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1557
1558
0
        ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1559
0
        cb(expert_groups, "ffn_moe_group_topk", il);
1560
1561
        // mask out the other groups
1562
0
        selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1563
0
        selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1564
0
        selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1565
0
        cb(selection_probs, "ffn_moe_probs_masked", il);
1566
0
    }
1567
1568
    // select experts
1569
0
    ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1570
0
    cb(selected_experts->src[0], "ffn_moe_argsort", il);
1571
0
    cb(selected_experts, "ffn_moe_topk", il);
1572
1573
0
    if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1574
        // TODO: Use scalar div instead when/if implemented
1575
0
        ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1576
0
        selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1577
0
        probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1578
0
    } else {
1579
0
        probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1580
0
    }
1581
1582
0
    ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1583
0
    cb(weights, "ffn_moe_weights", il);
1584
1585
1586
0
    if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1587
0
        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1588
0
        weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1589
0
        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1590
0
        cb(weights, "ffn_moe_weights_softmax", il);
1591
0
    }
1592
1593
0
    if (norm_w) {
1594
0
        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1595
1596
0
        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1597
0
        cb(weights_sum, "ffn_moe_weights_sum", il);
1598
1599
        // Avoid division by zero, clamp to smallest number representable by F16
1600
0
        weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1601
0
        cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1602
1603
0
        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1604
0
        cb(weights, "ffn_moe_weights_norm", il);
1605
1606
0
        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1607
0
    }
1608
0
    if (w_scale != 0.0f && w_scale != 1.0f) {
1609
0
        weights = ggml_scale(ctx0, weights, w_scale);
1610
0
        cb(weights, "ffn_moe_weights_scaled", il);
1611
0
    }
1612
1613
    //call early so that topk-moe can be used
1614
0
    ggml_build_forward_expand(gf, weights);
1615
1616
0
    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1617
1618
0
    if (weight_before_ffn) {
1619
        // repeat cur to [n_embd, n_expert_used, n_tokens]
1620
0
        ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1621
0
        cur = ggml_mul(ctx0, repeated, weights);
1622
0
        cb(cur, "ffn_moe_weighted", il);
1623
0
    }
1624
1625
0
    ggml_tensor * up = nullptr;
1626
0
    ggml_tensor * experts = nullptr;
1627
1628
0
    if (gate_up_exps) {
1629
        // merged gate_up path: one mul_mat_id, then split into gate and up views
1630
0
        ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
1631
0
        cb(gate_up, "ffn_moe_gate_up", il);
1632
1633
0
        if (gate_up_exps_b) {
1634
0
            gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1635
0
            cb(gate_up, "ffn_moe_gate_up_biased", il);
1636
0
        }
1637
1638
        // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
1639
0
        if (up_exps_s) {
1640
0
            ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1641
0
            s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1642
0
            s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1643
0
            gate_up = ggml_mul(ctx0, gate_up, s);
1644
0
            cb(gate_up, "ffn_moe_gate_up_scaled", il);
1645
0
        }
1646
1647
0
        const int64_t n_ff = gate_up->ne[0] / 2;
1648
0
        cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
1649
0
        cb(cur, "ffn_moe_gate", il);
1650
0
        up  = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1651
0
        cb(up, "ffn_moe_up", il);
1652
0
    } else {
1653
        // separate gate and up path
1654
0
        up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1655
0
        cb(up, "ffn_moe_up", il);
1656
1657
0
        if (up_exps_b) {
1658
0
            up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1659
0
            cb(up, "ffn_moe_up_biased", il);
1660
0
        }
1661
1662
        // apply per-expert scale2 to up
1663
0
        if (up_exps_s) {
1664
0
            ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1665
0
            s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1666
0
            s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1667
0
            up = ggml_mul(ctx0, up, s);
1668
0
            cb(up, "ffn_moe_up_scaled", il);
1669
0
        }
1670
1671
0
        if (gate_exps) {
1672
0
            cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1673
0
            cb(cur, "ffn_moe_gate", il);
1674
0
        } else {
1675
0
            cur = up;
1676
0
        }
1677
1678
0
        if (gate_exps_b) {
1679
0
            cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1680
0
            cb(cur, "ffn_moe_gate_biased", il);
1681
0
        }
1682
1683
        // apply per-expert scale2 to gate
1684
0
        if (gate_exps_s) {
1685
0
            ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
1686
0
            s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1687
0
            s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1688
0
            cur = ggml_mul(ctx0, cur, s);
1689
0
            cb(cur, "ffn_moe_gate_scaled", il);
1690
0
        }
1691
0
    }
1692
1693
0
    const bool has_gate = gate_exps || gate_up_exps;
1694
1695
0
    switch (type_op) {
1696
0
        case LLM_FFN_SILU:
1697
0
            if (gate_exps) {
1698
                // Step35: per-layer clamp for routed experts
1699
0
                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1700
0
                    const float limit = hparams.swiglu_clamp_exp[il];
1701
0
                    constexpr float eps = 1e-6f;
1702
0
                    if (limit > eps) {
1703
0
                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1704
0
                        cb(gate_act, "ffn_moe_silu", il);
1705
0
                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1706
0
                        cb(gate_act, "ffn_moe_silu_clamped", il);
1707
1708
0
                        up = ggml_clamp(ctx0, up, -limit, limit);
1709
0
                        cb(up, "ffn_moe_up_clamped", il);
1710
1711
0
                        cur = ggml_mul(ctx0, gate_act, up);
1712
0
                        cb(cur, "ffn_moe_swiglu_limited", il);
1713
0
                        break;
1714
0
                    }
1715
0
                }
1716
0
            }
1717
1718
0
            if (has_gate) {
1719
0
                cur = ggml_swiglu_split(ctx0, cur, up);
1720
0
                cb(cur, "ffn_moe_swiglu", il);
1721
0
            } else {
1722
0
                cur = ggml_silu(ctx0, cur);
1723
0
                cb(cur, "ffn_moe_silu", il);
1724
0
            } break;
1725
0
        case LLM_FFN_GELU:
1726
0
            if (has_gate) {
1727
0
                cur = ggml_geglu_split(ctx0, cur, up);
1728
0
                cb(cur, "ffn_moe_geglu", il);
1729
0
            } else {
1730
0
                cur = ggml_gelu(ctx0, cur);
1731
0
                cb(cur, "ffn_moe_gelu", il);
1732
0
            } break;
1733
0
        case LLM_FFN_SWIGLU_OAI_MOE:
1734
0
            {
1735
                // TODO: move to hparams?
1736
0
                constexpr float alpha = 1.702f;
1737
0
                constexpr float limit = 7.0f;
1738
0
                cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1739
0
                cb(cur, "ffn_moe_swiglu_oai", il);
1740
0
            } break;
1741
0
        case LLM_FFN_RELU:
1742
0
            if (has_gate) {
1743
0
                cur = ggml_reglu_split(ctx0, cur, up);
1744
0
                cb(cur, "ffn_moe_reglu", il);
1745
0
            } else {
1746
0
                cur = ggml_relu(ctx0, cur);
1747
0
                cb(cur, "ffn_moe_relu", il);
1748
0
            } break;
1749
0
        case LLM_FFN_RELU_SQR:
1750
0
            if (has_gate) {
1751
                // TODO: add support for gated squared relu
1752
0
                GGML_ABORT("fatal error: gated squared relu not implemented");
1753
0
            } else {
1754
0
                cur = ggml_relu(ctx0, cur);
1755
0
                cur = ggml_sqr(ctx0, cur);
1756
0
                cb(cur, "ffn_moe_relu_sqr", il);
1757
0
            } break;
1758
0
        default:
1759
0
            GGML_ABORT("fatal error");
1760
0
    }
1761
1762
0
    experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1763
0
    cb(experts, "ffn_moe_down", il);
1764
1765
0
    if (down_exps_b) {
1766
0
        experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1767
0
        cb(experts, "ffn_moe_down_biased", il);
1768
0
    }
1769
1770
    // apply per-expert scale2 to down
1771
0
    if (down_exps_s) {
1772
0
        ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
1773
0
        s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1774
0
        s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1775
0
        experts = ggml_mul(ctx0, experts, s);
1776
0
        cb(experts, "ffn_moe_down_scaled", il);
1777
0
    }
1778
1779
0
    if (!weight_before_ffn) {
1780
0
        experts = ggml_mul(ctx0, experts, weights);
1781
0
        cb(experts, "ffn_moe_weighted", il);
1782
0
    }
1783
1784
0
    ggml_build_forward_expand(gf, experts);
1785
1786
0
    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1787
1788
0
    assert(n_expert_used > 0);
1789
1790
    // order the views before the adds
1791
0
    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1792
0
        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1793
1794
0
        ggml_build_forward_expand(gf, cur_experts[i]);
1795
0
    }
1796
1797
    // aggregate experts
1798
    // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1799
    //       to avoid potentially a large number of add nodes during warmup
1800
    //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
1801
0
    ggml_tensor * moe_out = cur_experts[0];
1802
1803
0
    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1804
0
        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1805
1806
0
        ggml_build_forward_expand(gf, moe_out);
1807
0
    }
1808
1809
0
    if (hparams.n_expert_used == 1) {
1810
        // avoid returning a non-contiguous tensor
1811
0
        moe_out = ggml_cont(ctx0, moe_out);
1812
0
    }
1813
1814
0
    cb(moe_out, "ffn_moe_out", il);
1815
1816
0
    return moe_out;
1817
0
}
1818
1819
// input embeddings with optional lora
1820
0
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1821
0
    const int64_t n_embd_inp = hparams.n_embd_inp();
1822
0
    const int64_t n_embd     = hparams.n_embd;
1823
1824
0
    assert(n_embd_inp >= n_embd);
1825
1826
0
    auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1827
1828
0
    inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1829
0
    cb(inp->tokens, "inp_tokens", -1);
1830
0
    ggml_set_input(inp->tokens);
1831
0
    res->t_inp_tokens = inp->tokens;
1832
1833
0
    inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1834
0
    cb(inp->embd, "inp_embd", -1);
1835
0
    ggml_set_input(inp->embd);
1836
1837
    // select one of the 2 inputs, based on the batch contents
1838
    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1839
0
    std::array<ggml_tensor *, 2> inps;
1840
1841
    // token embeddings path (ubatch.token != nullptr)
1842
0
    {
1843
0
        auto & cur = inps[0];
1844
1845
0
        cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1846
1847
        // apply lora for embedding tokens if needed
1848
0
        for (const auto & lora : *loras) {
1849
0
            llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1850
0
            if (lw == nullptr) {
1851
0
                continue;
1852
0
            }
1853
1854
0
            const float adapter_scale = lora.second;
1855
0
            const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1856
1857
0
            ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1858
0
                        ctx0, lw->b, // non-transposed lora_b
1859
0
                        ggml_get_rows(ctx0, lw->a, inp->tokens)
1860
0
                        ), scale);
1861
1862
0
            cur = ggml_add(ctx0, cur, inpL_delta);
1863
0
        }
1864
1865
0
        if (n_embd_inp != n_embd) {
1866
0
            cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1867
0
        }
1868
0
    }
1869
1870
    // vector embeddings path (ubatch.embd != nullptr)
1871
0
    {
1872
0
        auto & cur = inps[1];
1873
1874
0
        cur = inp->embd;
1875
0
    }
1876
1877
0
    assert(ggml_are_same_shape (inps[0], inps[1]));
1878
0
    assert(ggml_are_same_stride(inps[0], inps[1]));
1879
1880
0
    ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1881
1882
0
    if (n_embd_inp != n_embd) {
1883
0
        cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1884
0
    }
1885
1886
0
    res->t_inp_embd = cur;
1887
1888
    // For Granite architecture
1889
    // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input).
1890
    //  Raw embeddings are assumed to be multimodal inputs that should not be scaled.
1891
0
    if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) {
1892
0
        if (!ggml_is_contiguous(cur)) {
1893
0
            cur = ggml_cont(ctx0, cur);
1894
0
        }
1895
0
        cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1896
0
    }
1897
1898
0
    cb(cur, "embd", -1);
1899
1900
0
    res->add_input(std::move(inp));
1901
1902
    // make sure the produced embeddings are immediately materialized in the ggml graph
1903
    // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1904
0
    ggml_build_forward_expand(gf, cur);
1905
1906
0
    return cur;
1907
0
}
1908
1909
0
ggml_tensor * llm_graph_context::build_inp_pos() const {
1910
0
    auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1911
1912
0
    auto & cur = inp->pos;
1913
1914
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1915
0
    ggml_set_input(cur);
1916
1917
0
    res->add_input(std::move(inp));
1918
1919
0
    return cur;
1920
0
}
1921
1922
0
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1923
0
    auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1924
1925
0
    auto & cur = inp->attn_scale;
1926
1927
    // this need to be 1x1xN for broadcasting
1928
0
    cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1929
0
    ggml_set_input(cur);
1930
0
    ggml_set_name(cur, "attn_scale");
1931
1932
0
    res->add_input(std::move(inp));
1933
1934
0
    return cur;
1935
0
}
1936
1937
0
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1938
    // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1939
    //       but this would make the graph topology depend on the number of output tokens, which can interfere with
1940
    //       features that require constant topology such as pipeline parallelism
1941
    //       ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1942
    //if (n_outputs < n_tokens) {
1943
    //    return nullptr;
1944
    //}
1945
1946
0
    auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1947
1948
0
    auto & cur = inp->out_ids;
1949
1950
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1951
0
    ggml_set_input(cur);
1952
1953
0
    res->add_input(std::move(inp));
1954
1955
0
    return cur;
1956
0
}
1957
1958
0
ggml_tensor * llm_graph_context::build_inp_mean() const {
1959
0
    auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1960
1961
0
    auto & cur = inp->mean;
1962
1963
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1964
0
    ggml_set_input(cur);
1965
1966
0
    res->add_input(std::move(inp));
1967
1968
0
    return cur;
1969
0
}
1970
1971
0
ggml_tensor * llm_graph_context::build_inp_cls() const {
1972
0
    auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1973
1974
0
    auto & cur = inp->cls;
1975
1976
0
    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1977
0
    ggml_set_input(cur);
1978
1979
0
    res->add_input(std::move(inp));
1980
1981
0
    return cur;
1982
0
}
1983
1984
0
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1985
0
    auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1986
1987
0
    auto & cur = inp->cross_embd;
1988
1989
    // if we have the output embeddings from the encoder, use them directly
1990
    // TODO: needs more work to be correct, for now just use the tensor shape
1991
    //if (cross->t_embd) {
1992
    //    cur = ggml_view_tensor(ctx0, cross->t_embd);
1993
1994
    //    return cur;
1995
    //}
1996
1997
0
    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1998
0
    const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc  : hparams.n_ctx_train;
1999
2000
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
2001
0
    ggml_set_input(cur);
2002
2003
0
    res->add_input(std::move(inp));
2004
2005
0
    return cur;
2006
0
}
2007
2008
0
ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
2009
0
    auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
2010
2011
0
    auto & cur = inp->pos_bucket;
2012
2013
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
2014
0
    ggml_set_input(cur);
2015
2016
0
    res->add_input(std::move(inp));
2017
2018
0
    return cur;
2019
0
}
2020
2021
0
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
2022
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2023
2024
0
    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
2025
2026
0
    const auto n_kv = mctx_cur->get_n_kv();
2027
2028
0
    auto & cur = inp->pos_bucket;
2029
2030
0
    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
2031
0
    ggml_set_input(cur);
2032
2033
0
    res->add_input(std::move(inp));
2034
2035
0
    return cur;
2036
0
}
2037
2038
0
ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
2039
0
    ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
2040
0
    cb(pos_bucket_1d, "pos_bucket_1d", -1);
2041
2042
0
    ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
2043
2044
0
    pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
2045
0
    pos_bias = ggml_permute   (ctx0, pos_bias, 2, 0, 1, 3);
2046
0
    pos_bias = ggml_cont      (ctx0, pos_bias);
2047
2048
0
    cb(pos_bias, "pos_bias", -1);
2049
2050
0
    return pos_bias;
2051
0
}
2052
2053
ggml_tensor * llm_graph_context::build_attn_mha(
2054
         ggml_tensor * q,
2055
         ggml_tensor * k,
2056
         ggml_tensor * v,
2057
         ggml_tensor * kq_b,
2058
         ggml_tensor * kq_mask,
2059
         ggml_tensor * sinks,
2060
         ggml_tensor * v_mla,
2061
               float   kq_scale,
2062
0
                 int   il) const {
2063
0
    const bool v_trans = v->nb[1] > v->nb[2];
2064
2065
    // split the batch into streams if needed
2066
0
    const auto n_stream = k->ne[3];
2067
2068
0
    q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
2069
2070
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3);
2071
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3);
2072
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3);
2073
2074
0
    ggml_tensor * cur;
2075
2076
0
    const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
2077
0
    if (use_flash_attn) {
2078
0
        GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
2079
2080
0
        if (v_trans) {
2081
0
            v = ggml_transpose(ctx0, v);
2082
0
        }
2083
2084
        // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
2085
0
        if (k->type == GGML_TYPE_F32) {
2086
0
            k = ggml_cast(ctx0, k, GGML_TYPE_F16);
2087
0
        }
2088
2089
0
        if (v->type == GGML_TYPE_F32) {
2090
0
            v = ggml_cast(ctx0, v, GGML_TYPE_F16);
2091
0
        }
2092
2093
0
        cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
2094
0
                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
2095
0
        cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
2096
2097
0
        ggml_flash_attn_ext_add_sinks(cur, sinks);
2098
0
        ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
2099
2100
0
        if (v_mla) {
2101
#if 0
2102
            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
2103
            // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
2104
            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
2105
            cur = ggml_mul_mat(ctx0, v_mla, cur);
2106
#else
2107
            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
2108
            // The permutations are noops and only change how the tensor data is interpreted.
2109
0
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
2110
0
            cur = ggml_mul_mat(ctx0, v_mla, cur);
2111
0
            cb(cur, "fattn_mla", il);
2112
0
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
2113
0
            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
2114
0
#endif
2115
0
        }
2116
2117
0
        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2118
0
    } else {
2119
0
        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
2120
0
        cb(kq, "kq", il);
2121
2122
        // note: this op tends to require high floating point range
2123
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
2124
0
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
2125
2126
0
        if (arch == LLM_ARCH_GROK) {
2127
            // need to do the following:
2128
            // multiply by attn_output_multiplier
2129
            // and then :
2130
            // kq = 30 * tanh(kq / 30)
2131
            // before the softmax below
2132
2133
0
            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
2134
0
            cb(kq, "kq_tanh", il);
2135
0
            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
2136
0
            cb(kq, "kq_scaled", il);
2137
0
        }
2138
2139
0
        if (hparams.attn_soft_cap) {
2140
0
            kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
2141
0
            cb(kq, "kq_scaled_1", il);
2142
0
            kq = ggml_tanh (ctx0, kq);
2143
0
            cb(kq, "kq_tanh", il);
2144
0
            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
2145
0
            cb(kq, "kq_scaled_2", il);
2146
0
        }
2147
2148
0
        if (kq_b) {
2149
0
            kq = ggml_add(ctx0, kq, kq_b);
2150
0
            cb(kq, "kq_plus_kq_b", il);
2151
0
        }
2152
2153
0
        kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
2154
0
        ggml_soft_max_add_sinks(kq, sinks);
2155
0
        cb(kq, "kq_soft_max", il);
2156
2157
0
        if (!v_trans) {
2158
            // note: avoid this branch
2159
0
            v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
2160
0
            cb(v, "v_cont", il);
2161
0
        }
2162
2163
0
        ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
2164
0
        cb(kqv, "kqv", il);
2165
2166
        // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
2167
0
        if (v_mla) {
2168
0
            kqv = ggml_mul_mat(ctx0, v_mla, kqv);
2169
0
            cb(kqv, "kqv_mla", il);
2170
0
        }
2171
2172
0
        cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
2173
2174
        // recombine streams
2175
0
        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
2176
2177
0
        if (!cparams.offload_kqv) {
2178
            // all nodes between the KV store and the attention output are run on the CPU
2179
0
            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
2180
0
        }
2181
0
    }
2182
2183
0
    ggml_build_forward_expand(gf, cur);
2184
2185
0
    return cur;
2186
0
}
2187
2188
0
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
2189
0
    auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
2190
2191
    // flash attention requires an f16 mask
2192
0
    const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
2193
2194
    // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
2195
0
    inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
2196
0
    ggml_set_input(inp->self_kq_mask);
2197
2198
0
    inp->self_kq_mask_cnv = inp->self_kq_mask;
2199
2200
0
    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
2201
0
        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
2202
0
        ggml_set_input(inp->self_kq_mask_swa);
2203
2204
0
        inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
2205
0
    } else {
2206
0
        inp->self_kq_mask_swa     = nullptr;
2207
0
        inp->self_kq_mask_swa_cnv = nullptr;
2208
0
    }
2209
2210
0
    return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
2211
0
}
2212
2213
ggml_tensor * llm_graph_context::build_attn(
2214
        llm_graph_input_attn_no_cache * inp,
2215
        ggml_tensor * wo,
2216
        ggml_tensor * wo_b,
2217
        ggml_tensor * wo_s,
2218
        ggml_tensor * q_cur,
2219
        ggml_tensor * k_cur,
2220
        ggml_tensor * v_cur,
2221
        ggml_tensor * kq_b,
2222
        ggml_tensor * sinks,
2223
        ggml_tensor * v_mla,
2224
            float     kq_scale,
2225
0
            int       il) const {
2226
0
    GGML_UNUSED(n_tokens);
2227
2228
    // these nodes are added to the graph together so that they are not reordered
2229
    // by doing so, the number of splits in the graph is reduced
2230
0
    ggml_build_forward_expand(gf, q_cur);
2231
0
    ggml_build_forward_expand(gf, k_cur);
2232
0
    ggml_build_forward_expand(gf, v_cur);
2233
2234
0
    const bool is_swa = hparams.is_swa(il);
2235
2236
0
    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2237
2238
    // [TAG_NO_CACHE_PAD]
2239
    // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
2240
    //       but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
2241
    //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
2242
2243
0
    ggml_tensor * q = q_cur;
2244
0
    ggml_tensor * k = k_cur;
2245
0
    ggml_tensor * v = v_cur;
2246
2247
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2248
0
    cb(cur, "kqv_out", il);
2249
2250
0
    if (wo) {
2251
0
        cur = build_lora_mm(wo, cur, wo_s);
2252
0
    }
2253
2254
0
    if (wo_b) {
2255
        //cb(cur, "kqv_wo", il);
2256
0
    }
2257
2258
0
    if (wo_b) {
2259
0
        cur = ggml_add(ctx0, cur, wo_b);
2260
0
    }
2261
2262
0
    return cur;
2263
0
}
2264
2265
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
2266
           ggml_context * ctx0,
2267
     const llama_ubatch & ubatch,
2268
    const llama_hparams & hparams,
2269
    const llama_cparams & cparams,
2270
0
    const llama_kv_cache_context * mctx_cur) {
2271
2272
0
    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
2273
2274
0
    {
2275
0
        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2276
2277
0
        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2278
0
        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
2279
2280
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2281
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2282
0
    }
2283
2284
0
    inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
2285
0
    inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
2286
2287
0
    return inp;
2288
0
}
2289
2290
0
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
2291
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2292
2293
0
    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2294
2295
0
    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
2296
0
}
2297
2298
ggml_tensor * llm_graph_context::build_attn(
2299
        llm_graph_input_attn_kv * inp,
2300
        ggml_tensor * wo,
2301
        ggml_tensor * wo_b,
2302
        ggml_tensor * wo_s,
2303
        ggml_tensor * q_cur,
2304
        ggml_tensor * k_cur,
2305
        ggml_tensor * v_cur,
2306
        ggml_tensor * kq_b,
2307
        ggml_tensor * sinks,
2308
        ggml_tensor * v_mla, // TODO: remove
2309
            float     kq_scale,
2310
0
            int       il) const {
2311
0
    GGML_ASSERT(v_mla == nullptr);
2312
2313
0
    if (inp->self_k_rot) {
2314
0
        q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2315
0
        k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2316
0
    }
2317
2318
0
    if (inp->self_v_rot) {
2319
0
        v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2320
0
    }
2321
2322
    // these nodes are added to the graph together so that they are not reordered
2323
    // by doing so, the number of splits in the graph is reduced
2324
    // expand k later to enable rope fusion which directly writes into k-v cache
2325
0
    ggml_build_forward_expand(gf, q_cur);
2326
0
    ggml_build_forward_expand(gf, v_cur);
2327
0
    ggml_build_forward_expand(gf, k_cur);
2328
2329
0
    const auto * mctx_cur = inp->mctx;
2330
2331
    // store to KV cache
2332
0
    {
2333
0
        const auto & k_idxs = inp->get_k_idxs();
2334
0
        const auto & v_idxs = inp->get_v_idxs();
2335
2336
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2337
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2338
0
    }
2339
2340
0
    const auto & kq_mask = inp->get_kq_mask();
2341
2342
0
    ggml_tensor * q = q_cur;
2343
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2344
0
    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2345
2346
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2347
0
    cb(cur, "kqv_out", il);
2348
2349
0
    if (inp->self_v_rot) {
2350
0
        cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2351
0
    }
2352
2353
0
    if (wo) {
2354
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2355
            // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2356
0
            cur = build_lora_mm(wo, cur);
2357
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2358
0
            if (wo_s) {
2359
0
                cur = ggml_mul(ctx0, cur, wo_s);
2360
0
            }
2361
0
        } else {
2362
0
            cur = build_lora_mm(wo, cur, wo_s);
2363
0
        }
2364
0
    }
2365
2366
0
    if (wo_b) {
2367
0
        cur = ggml_add(ctx0, cur, wo_b);
2368
0
    }
2369
2370
0
    return cur;
2371
0
}
2372
2373
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2374
           ggml_context * ctx0,
2375
     const llama_ubatch & ubatch,
2376
    const llama_hparams & hparams,
2377
    const llama_cparams & cparams,
2378
0
    const llama_kv_cache_context * mctx_cur) {
2379
2380
0
    auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2381
2382
0
    {
2383
0
        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2384
2385
0
        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2386
2387
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2388
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2389
0
    }
2390
2391
0
    return inp;
2392
0
}
2393
2394
0
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2395
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2396
2397
0
    auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2398
2399
0
    return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2400
0
}
2401
2402
ggml_tensor * llm_graph_context::build_attn(
2403
        llm_graph_input_attn_k * inp,
2404
        ggml_tensor * wo,
2405
        ggml_tensor * wo_b,
2406
        ggml_tensor * wo_s,
2407
        ggml_tensor * q_cur,
2408
        ggml_tensor * k_cur,
2409
        ggml_tensor * v_cur,
2410
        ggml_tensor * kq_b,
2411
        ggml_tensor * sinks,
2412
        ggml_tensor * v_mla,
2413
            float     kq_scale,
2414
0
            int       il) const {
2415
    // these nodes are added to the graph together so that they are not reordered
2416
    // by doing so, the number of splits in the graph is reduced
2417
    // expand k later to enable rope fusion which directly writes into k-v cache
2418
0
    ggml_build_forward_expand(gf, q_cur);
2419
0
    ggml_build_forward_expand(gf, v_cur);
2420
0
    ggml_build_forward_expand(gf, k_cur);
2421
2422
0
    const auto * mctx_cur = inp->mctx;
2423
2424
    // store to KV cache
2425
0
    {
2426
0
        const auto & k_idxs = inp->get_k_idxs();
2427
2428
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2429
0
    }
2430
2431
0
    const auto & kq_mask = inp->get_kq_mask();
2432
2433
0
    ggml_tensor * q = q_cur;
2434
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2435
0
    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2436
2437
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2438
0
    cb(cur, "kqv_out", il);
2439
2440
0
    if (wo) {
2441
0
        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2442
            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2443
0
            cur = build_lora_mm(wo, cur);
2444
0
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2445
0
            if (wo_s) {
2446
0
                cur = ggml_mul(ctx0, cur, wo_s);
2447
0
            }
2448
0
        } else {
2449
0
            cur = build_lora_mm(wo, cur, wo_s);
2450
0
        }
2451
0
    }
2452
2453
0
    if (wo_b) {
2454
0
        cur = ggml_add(ctx0, cur, wo_b);
2455
0
    }
2456
2457
0
    return cur;
2458
0
}
2459
2460
ggml_tensor * llm_graph_context::build_attn(
2461
        llm_graph_input_attn_k_dsa * inp,
2462
        ggml_tensor * wo,
2463
        ggml_tensor * wo_b,
2464
        ggml_tensor * wo_s,
2465
        ggml_tensor * q_cur,
2466
        ggml_tensor * k_cur,
2467
        ggml_tensor * v_cur,
2468
        ggml_tensor * kq_b,
2469
        ggml_tensor * sinks,
2470
        ggml_tensor * v_mla,
2471
        ggml_tensor * top_k,
2472
            float     kq_scale,
2473
0
            int       il) const {
2474
    // these nodes are added to the graph together so that they are not reordered
2475
    // by doing so, the number of splits in the graph is reduced
2476
    // expand k later to enable rope fusion which directly writes into k-v cache
2477
0
    ggml_build_forward_expand(gf, q_cur);
2478
0
    ggml_build_forward_expand(gf, v_cur);
2479
0
    ggml_build_forward_expand(gf, k_cur);
2480
2481
0
    const auto * mctx_cur = inp->mctx->get_mla();
2482
2483
    // store to KV cache
2484
0
    {
2485
0
        const auto & k_idxs = inp->get_k_idxs_mla();
2486
2487
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2488
0
    }
2489
2490
0
    const auto & kq_mask = inp->get_kq_mask_mla();
2491
2492
    // prepare new kq mask - starts filled with -INFINITY
2493
0
    ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);
2494
2495
    // reshape KQ mask into tensor with rows of size 1:
2496
    // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream]
2497
0
    kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0);
2498
2499
    // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1]
2500
0
    ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0);
2501
2502
    // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream]
2503
    // this will be our source of zero values for unmasking top k mask elements
2504
0
    ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]);
2505
0
    zeros = ggml_fill(ctx0, zeros, 0.0f);
2506
2507
    // modify KQ mask by unmasking elements that are in top_k indices
2508
    // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1])
2509
0
    ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d);
2510
2511
    // reshape to restore the original shape of KQ mask:
2512
    // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream]
2513
0
    kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0);
2514
2515
    // combine with the original kq mask
2516
0
    kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);
2517
2518
0
    ggml_tensor * q = q_cur;
2519
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2520
0
    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2521
2522
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il);
2523
0
    cb(cur, "kqv_out", il);
2524
2525
0
    if (wo) {
2526
0
        cur = build_lora_mm(wo, cur, wo_s);
2527
0
    }
2528
2529
0
    if (wo_b) {
2530
0
        cur = ggml_add(ctx0, cur, wo_b);
2531
0
    }
2532
2533
0
    return cur;
2534
0
}
2535
2536
ggml_tensor * llm_graph_context::build_attn(
2537
        llm_graph_input_attn_kv_iswa * inp,
2538
        ggml_tensor * wo,
2539
        ggml_tensor * wo_b,
2540
        ggml_tensor * wo_s,
2541
        ggml_tensor * q_cur,
2542
        ggml_tensor * k_cur,
2543
        ggml_tensor * v_cur,
2544
        ggml_tensor * kq_b,
2545
        ggml_tensor * sinks,
2546
        ggml_tensor * v_mla,
2547
            float     kq_scale,
2548
0
            int       il) const {
2549
0
    const bool is_swa = hparams.is_swa(il);
2550
2551
0
    auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2552
0
    auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2553
2554
0
    if (k_rot) {
2555
0
        q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
2556
0
        if (k_cur) {
2557
0
            k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
2558
0
        }
2559
0
    }
2560
0
    if (v_rot) {
2561
0
        if (v_cur) {
2562
0
            v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
2563
0
        }
2564
0
    }
2565
2566
    // these nodes are added to the graph together so that they are not reordered
2567
    // by doing so, the number of splits in the graph is reduced
2568
0
    ggml_build_forward_expand(gf, q_cur);
2569
2570
0
    if (k_cur) {
2571
0
        ggml_build_forward_expand(gf, k_cur);
2572
0
    }
2573
2574
0
    if (v_cur) {
2575
0
        ggml_build_forward_expand(gf, v_cur);
2576
0
    }
2577
2578
0
    const auto * mctx_iswa = inp->mctx;
2579
2580
0
    const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2581
2582
    // optionally store to KV cache
2583
0
    if (k_cur) {
2584
0
        const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2585
2586
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2587
0
    }
2588
2589
0
    if (v_cur) {
2590
0
        const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2591
2592
0
        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2593
0
    }
2594
2595
0
    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2596
2597
0
    ggml_tensor * q = q_cur;
2598
0
    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2599
0
    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2600
2601
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2602
0
    cb(cur, "kqv_out", il);
2603
2604
0
    if (v_rot) {
2605
0
        cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
2606
0
    }
2607
2608
0
    if (wo) {
2609
0
        cur = build_lora_mm(wo, cur, wo_s);
2610
0
    }
2611
2612
0
    if (wo_b) {
2613
        //cb(cur, "kqv_wo", il);
2614
0
    }
2615
2616
0
    if (wo_b) {
2617
0
        cur = ggml_add(ctx0, cur, wo_b);
2618
0
    }
2619
2620
0
    return cur;
2621
0
}
2622
2623
0
llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2624
0
    auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2625
2626
0
    const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2627
2628
    // flash attention requires an f16 mask
2629
0
    const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
2630
2631
0
    inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
2632
0
    ggml_set_input(inp->cross_kq_mask);
2633
2634
0
    inp->cross_kq_mask_cnv = inp->cross_kq_mask;
2635
2636
0
    return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2637
0
}
2638
2639
ggml_tensor * llm_graph_context::build_attn(
2640
        llm_graph_input_attn_cross * inp,
2641
        ggml_tensor * wo,
2642
        ggml_tensor * wo_b,
2643
        ggml_tensor * wo_s,
2644
        ggml_tensor * q_cur,
2645
        ggml_tensor * k_cur,
2646
        ggml_tensor * v_cur,
2647
        ggml_tensor * kq_b,
2648
        ggml_tensor * sinks,
2649
        ggml_tensor * v_mla,
2650
            float     kq_scale,
2651
0
            int       il) const {
2652
    // these nodes are added to the graph together so that they are not reordered
2653
    // by doing so, the number of splits in the graph is reduced
2654
0
    ggml_build_forward_expand(gf, q_cur);
2655
0
    ggml_build_forward_expand(gf, k_cur);
2656
0
    ggml_build_forward_expand(gf, v_cur);
2657
2658
0
    const auto & kq_mask = inp->get_kq_mask_cross();
2659
2660
0
    ggml_tensor * q = q_cur;
2661
0
    ggml_tensor * k = k_cur;
2662
0
    ggml_tensor * v = v_cur;
2663
2664
0
    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2665
0
    cb(cur, "kqv_out", il);
2666
2667
0
    if (wo) {
2668
0
        cur = build_lora_mm(wo, cur, wo_s);
2669
0
    }
2670
2671
0
    if (wo_b) {
2672
        //cb(cur, "kqv_wo", il);
2673
0
    }
2674
2675
0
    if (wo_b) {
2676
0
        cur = ggml_add(ctx0, cur, wo_b);
2677
0
    }
2678
2679
0
    return cur;
2680
0
}
2681
2682
0
llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const {
2683
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx);
2684
2685
0
    auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur);
2686
2687
0
    {
2688
0
        inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch);
2689
2690
0
        inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams);
2691
0
        inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla;
2692
0
    }
2693
2694
0
    {
2695
0
        inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch);
2696
2697
        // ensure F32 mask
2698
0
        auto cparams_copy = cparams;
2699
0
        cparams_copy.flash_attn = false;
2700
2701
0
        inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy);
2702
0
        inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid;
2703
2704
0
        inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0);
2705
0
    }
2706
2707
0
    return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp));
2708
0
}
2709
2710
// TODO: maybe separate the inner implementation into a separate function
2711
//       like with the non-sliding window equivalent
2712
//       once sliding-window hybrid caches are a thing.
2713
0
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2714
0
    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2715
2716
0
    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2717
2718
0
    {
2719
0
        inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2720
0
        inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2721
2722
0
        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
2723
0
        inp->self_kq_mask_cnv = inp->self_kq_mask;
2724
0
    }
2725
2726
0
    {
2727
0
        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2728
2729
0
        inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2730
0
        inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2731
2732
0
        inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
2733
0
        inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
2734
0
    }
2735
2736
0
    inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
2737
0
    inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
2738
2739
0
    inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
2740
0
    inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2741
2742
0
    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2743
0
}
2744
2745
ggml_tensor * llm_graph_context::build_rs(
2746
        ggml_tensor * s,
2747
        ggml_tensor * state_copy_main,
2748
        ggml_tensor * state_copy_extra,
2749
            int32_t   state_size,
2750
            int32_t   n_seqs,
2751
           uint32_t   n_rs,
2752
           uint32_t   rs_head,
2753
           uint32_t   rs_size,
2754
            int32_t   rs_zero,
2755
0
        const llm_graph_get_rows_fn & get_state_rows) const {
2756
2757
0
    GGML_UNUSED(rs_size);
2758
0
    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]);
2759
2760
    // Clear a single state which will then be copied to the other cleared states.
2761
    // Note that this is a no-op when the view is zero-sized.
2762
0
    ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2763
0
    ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2764
2765
    // copy states
2766
    // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2767
    // {state_size, rs_size} -> {state_size, n_seqs}
2768
0
    ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2769
0
    ggml_build_forward_expand(gf, output_states);
2770
2771
    // copy extra states which won't be changed further (between n_seqs and n_rs)
2772
0
    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2773
0
    ggml_build_forward_expand(gf,
2774
0
        ggml_cpy(ctx0,
2775
0
            states_extra,
2776
0
            ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1])));
2777
2778
0
    return output_states;
2779
0
}
2780
2781
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2782
           ggml_context * ctx0,
2783
     const llama_ubatch & ubatch,
2784
0
    const llama_memory_recurrent_context * mctx_cur) {
2785
2786
0
    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2787
2788
0
    const int64_t n_rs   = mctx_cur->get_n_rs();
2789
0
    const int64_t n_seqs = ubatch.n_seqs;
2790
2791
0
    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2792
0
    ggml_set_input(inp->s_copy);
2793
2794
0
    inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2795
0
    inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2796
2797
0
    inp->head = mctx_cur->get_head();
2798
0
    inp->rs_z = mctx_cur->get_rs_z();
2799
2800
0
    return inp;
2801
0
}
2802
2803
0
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2804
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2805
2806
0
    auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2807
2808
0
    return (llm_graph_input_rs *) res->add_input(std::move(inp));
2809
0
}
2810
2811
ggml_tensor * llm_graph_context::build_rs(
2812
        llm_graph_input_rs * inp,
2813
        ggml_tensor * s,
2814
            int32_t   state_size,
2815
            int32_t   n_seqs,
2816
0
        const llm_graph_get_rows_fn & get_state_rows) const {
2817
0
    const auto * kv_state = inp->mctx;
2818
2819
0
    return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2820
0
                    kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2821
0
                    get_state_rows);
2822
0
}
2823
2824
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2825
    llm_graph_input_rs * inp,
2826
    const llama_ubatch & ubatch,
2827
0
                   int   il) const {
2828
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2829
2830
0
    const auto token_shift_count = hparams.token_shift_count;
2831
2832
0
    const int64_t n_seqs  = ubatch.n_seqs;
2833
2834
0
    ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2835
2836
0
    ggml_tensor * token_shift = build_rs(
2837
0
            inp, token_shift_all,
2838
0
            hparams.n_embd_r(), n_seqs);
2839
2840
0
    token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2841
2842
0
    return token_shift;
2843
0
}
2844
2845
ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2846
         ggml_tensor * token_shift,
2847
  const llama_ubatch & ubatch,
2848
0
                 int   il) const {
2849
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2850
2851
0
    const auto token_shift_count = hparams.token_shift_count;
2852
0
    const auto n_embd = hparams.n_embd;
2853
2854
0
    const int64_t n_seqs = ubatch.n_seqs;
2855
2856
0
    const auto kv_head = mctx_cur->get_head();
2857
2858
0
    return ggml_cpy(
2859
0
        ctx0,
2860
0
        ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2861
0
        ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2862
0
    );
2863
0
}
2864
2865
0
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2866
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2867
2868
0
    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2869
0
    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2870
2871
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2872
2873
0
    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2874
0
}
2875
2876
0
llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2877
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2878
2879
0
    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2880
0
    auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2881
2882
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2883
2884
0
    return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2885
0
}
2886
2887
0
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2888
0
    const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2889
2890
0
    auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2891
2892
    // build iswa attention input
2893
0
    const auto * attn_ctx = mctx_cur->get_attn();
2894
2895
0
    auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2896
2897
0
    {
2898
0
        inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2899
0
        inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2900
2901
0
        inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2902
0
        inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
2903
0
    }
2904
2905
0
    {
2906
0
        inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2907
0
        inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2908
2909
0
        inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2910
0
        inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
2911
0
    }
2912
2913
0
    auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2914
2915
0
    return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2916
0
}
2917
2918
void llm_graph_context::build_dense_out(
2919
    ggml_tensor * dense_2,
2920
    ggml_tensor * dense_2_b,
2921
0
    ggml_tensor * dense_3) const {
2922
0
    if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2923
0
        return;
2924
0
    }
2925
0
    ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2926
0
    GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2927
2928
0
    if (dense_2) {
2929
0
        cur = ggml_mul_mat(ctx0, dense_2, cur);
2930
0
    }
2931
0
    if (dense_2_b) {
2932
0
        cur = ggml_add(ctx0, cur, dense_2_b);
2933
0
    }
2934
0
    if (dense_3) {
2935
0
        cur = ggml_mul_mat(ctx0, dense_3, cur);
2936
0
    }
2937
0
    cb(cur, "result_embd_pooled", -1);
2938
0
    res->t_embd_pooled = cur;
2939
0
    ggml_build_forward_expand(gf, cur);
2940
0
}
2941
2942
2943
void llm_graph_context::build_pooling(
2944
        ggml_tensor * cls,
2945
        ggml_tensor * cls_b,
2946
        ggml_tensor * cls_out,
2947
        ggml_tensor * cls_out_b,
2948
0
        ggml_tensor * cls_norm) const {
2949
0
    if (!cparams.embeddings) {
2950
0
        return;
2951
0
    }
2952
2953
0
    ggml_tensor * inp = res->t_embd;
2954
2955
    //// find result_norm tensor for input
2956
    //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2957
    //    inp = ggml_graph_node(gf, i);
2958
    //    if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2959
    //        break;
2960
    //    }
2961
2962
    //    inp = nullptr;
2963
    //}
2964
2965
0
    GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2966
2967
0
    ggml_tensor * cur;
2968
2969
0
    switch (pooling_type) {
2970
0
        case LLAMA_POOLING_TYPE_NONE:
2971
0
            {
2972
0
                cur = inp;
2973
0
            } break;
2974
0
        case LLAMA_POOLING_TYPE_MEAN:
2975
0
            {
2976
0
                ggml_tensor * inp_mean = build_inp_mean();
2977
0
                cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2978
0
            } break;
2979
0
        case LLAMA_POOLING_TYPE_CLS:
2980
0
        case LLAMA_POOLING_TYPE_LAST:
2981
0
            {
2982
0
                ggml_tensor * inp_cls = build_inp_cls();
2983
0
                cur = ggml_get_rows(ctx0, inp, inp_cls);
2984
0
            } break;
2985
0
        case LLAMA_POOLING_TYPE_RANK:
2986
0
            {
2987
0
                if (arch == LLM_ARCH_MODERN_BERT) {
2988
                    // modern bert gte reranker builds mean first then applies prediction head and classifier
2989
                    // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
2990
0
                    ggml_tensor * inp_mean = build_inp_mean();
2991
0
                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2992
0
                } else {
2993
0
                    ggml_tensor * inp_cls = build_inp_cls();
2994
0
                    cur = ggml_get_rows(ctx0, inp, inp_cls);
2995
0
                }
2996
2997
                // classification head
2998
                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2999
0
                if (cls) {
3000
0
                    cur = ggml_mul_mat(ctx0, cls, cur);
3001
0
                    if (cls_b) {
3002
0
                        cur = ggml_add(ctx0, cur, cls_b);
3003
0
                    }
3004
0
                    if (arch == LLM_ARCH_MODERN_BERT) {
3005
0
                        cur = ggml_gelu(ctx0, cur);
3006
0
                    } else {
3007
0
                        cur = ggml_tanh(ctx0, cur);
3008
0
                    }
3009
0
                    if (cls_norm) {
3010
                        // head norm
3011
0
                        cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
3012
0
                    }
3013
0
                }
3014
3015
                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
3016
                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
3017
                // Single layer classification head (direct projection)
3018
                // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
3019
0
                if (cls_out) {
3020
0
                    cur = ggml_mul_mat(ctx0, cls_out, cur);
3021
0
                    if (cls_out_b) {
3022
0
                        cur = ggml_add(ctx0, cur, cls_out_b);
3023
0
                    }
3024
0
                }
3025
3026
                // softmax for qwen3 reranker
3027
0
                if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
3028
0
                    cur = ggml_soft_max(ctx0, cur);
3029
0
                }
3030
0
            } break;
3031
0
        default:
3032
0
            {
3033
0
                GGML_ABORT("unknown pooling type");
3034
0
            }
3035
0
    }
3036
3037
0
    cb(cur, "result_embd_pooled", -1);
3038
0
    res->t_embd_pooled = cur;
3039
3040
0
    ggml_build_forward_expand(gf, cur);
3041
0
}
3042
3043
0
void llm_graph_context::build_sampling() const {
3044
0
    if (samplers.empty() || !res->t_logits) {
3045
0
        return;
3046
0
    }
3047
3048
0
    std::array<ggml_tensor *, 2> outs;
3049
0
    outs[0] = res->t_logits;
3050
3051
0
    auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
3052
0
    res->add_input(std::move(inp_sampling));
3053
3054
0
    std::map<llama_seq_id, int32_t> seq_to_logit_row;
3055
0
    int32_t logit_row_idx = 0;
3056
3057
0
    for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
3058
0
        if (ubatch.output[i]) {
3059
0
            llama_seq_id seq_id = ubatch.seq_id[i][0];
3060
0
            seq_to_logit_row[seq_id] = logit_row_idx;
3061
0
            logit_row_idx++;
3062
0
        }
3063
0
    }
3064
3065
    // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
3066
0
    GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
3067
3068
    // add a dummy row of logits
3069
    // this trick makes the graph static, regardless of which samplers are activated
3070
    // this is important in order to minimize graph reallocations
3071
0
    ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
3072
3073
0
    for (const auto & [seq_id, sampler] : samplers) {
3074
0
        const auto it = seq_to_logit_row.find(seq_id);
3075
3076
        // inactive samplers always work on the first row
3077
0
        const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
3078
0
        const int i_out    = it != seq_to_logit_row.end() ? 1          : 0;
3079
3080
0
        ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
3081
0
        ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
3082
3083
0
        struct llama_sampler_data data = {
3084
0
            /*.logits      =*/ logits_seq,
3085
0
            /*.probs       =*/ nullptr,
3086
0
            /*.sampled     =*/ nullptr,
3087
0
            /*.candidates  =*/ nullptr,
3088
0
        };
3089
3090
0
        assert(sampler->iface->backend_apply);
3091
0
        sampler->iface->backend_apply(sampler, ctx0, gf, &data);
3092
3093
0
        if (data.sampled != nullptr) {
3094
0
            res->t_sampled[seq_id] = data.sampled;
3095
0
            outs[1] = data.sampled;
3096
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3097
0
        }
3098
3099
0
        if (data.probs != nullptr) {
3100
0
            res->t_sampled_probs[seq_id] = data.probs;
3101
0
            outs[1] = data.probs;
3102
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3103
0
        }
3104
3105
0
        if (data.logits != nullptr) {
3106
0
            res->t_sampled_logits[seq_id] = data.logits;
3107
0
            outs[1] = data.logits;
3108
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3109
0
        }
3110
3111
0
        if (data.candidates != nullptr) {
3112
0
            res->t_candidates[seq_id] = data.candidates;
3113
0
            outs[1] = data.candidates;
3114
0
            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
3115
0
        }
3116
0
    }
3117
3118
    // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
3119
    /*
3120
    for (const auto & [seq_id, sampler] : samplers) {
3121
        if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
3122
            ggml_tensor * selected_token = it->second;
3123
            if (selected_token != nullptr) {
3124
                llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
3125
            }
3126
        }
3127
    }
3128
    */
3129
0
}
3130
3131
0
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
3132
    // TODO move to hparams if a T5 variant appears that uses a different value
3133
0
    const int64_t max_distance = 128;
3134
3135
0
    if (bidirectional) {
3136
0
        n_buckets >>= 1;
3137
0
    }
3138
3139
0
    const int64_t max_exact = n_buckets >> 1;
3140
3141
0
    int32_t relative_position = x - y;
3142
0
    int32_t relative_bucket = 0;
3143
3144
0
    if (bidirectional) {
3145
0
        relative_bucket += (relative_position > 0) * n_buckets;
3146
0
        relative_position = std::abs(relative_position);
3147
0
    } else {
3148
0
        relative_position = -std::min<int32_t>(relative_position, 0);
3149
0
    }
3150
3151
0
    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
3152
0
    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
3153
0
    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
3154
3155
0
    return relative_bucket;
3156
0
}