Coverage Report

Created: 2025-12-28 06:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-arch.h"
4
#include "llama-batch.h"
5
#include "llama-hparams.h"
6
#include "llama-adapter.h"
7
8
#include <cstdint>
9
#include <vector>
10
#include <memory>
11
#include <set>
12
#include <functional>
13
14
struct ggml_cgraph;
15
struct ggml_context;
16
struct ggml_tensor;
17
18
struct llama_cparams;
19
20
struct llama_memory_context_i;
21
22
class llama_kv_cache_context;
23
class llama_kv_cache_iswa_context;
24
class llama_memory_recurrent_context;
25
class llama_memory_hybrid_context;
26
27
// certain models (typically multi-modal) can produce different types of graphs
28
enum llm_graph_type {
29
    LLM_GRAPH_TYPE_DEFAULT,
30
    LLM_GRAPH_TYPE_ENCODER,
31
    LLM_GRAPH_TYPE_DECODER,
32
};
33
34
enum llm_ffn_op_type {
35
    LLM_FFN_SILU,
36
    LLM_FFN_GELU,
37
    LLM_FFN_RELU,
38
    LLM_FFN_RELU_SQR,
39
    LLM_FFN_SWIGLU,
40
    LLM_FFN_GEGLU,
41
    LLM_FFN_REGLU,
42
    LLM_FFN_SWIGLU_OAI_MOE,
43
};
44
45
enum llm_ffn_gate_type {
46
    LLM_FFN_SEQ,
47
    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
48
};
49
50
enum llm_norm_type {
51
    LLM_NORM,
52
    LLM_NORM_RMS,
53
    LLM_NORM_GROUP,
54
};
55
56
// TODO: tmp - need something better to pass the data from the encoder to the decoder
57
struct llama_cross {
58
    // the output embeddings from the encoder as a ggml tensor
59
    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
60
    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
61
    //ggml_tensor * t_embd = nullptr;
62
63
    int64_t n_embd = 0;
64
    int64_t n_enc  = 0;
65
66
    // embeddings data copied to host memory (tmp)
67
    std::vector<float> v_embd;
68
69
    // needed to construct the cross-attention mask in the decoder
70
    std::vector<std::set<llama_seq_id>> seq_ids_enc;
71
};
72
73
struct llm_graph_params;
74
75
//
76
// llm_graph_input
77
//
78
79
class llm_graph_input_i {
80
public:
81
0
    llm_graph_input_i() {
82
0
        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
83
0
        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
84
0
    }
85
86
0
    virtual ~llm_graph_input_i() = default;
87
88
    virtual void set_input(const llama_ubatch * ubatch) = 0;
89
90
    // return true if the resulting input tensors using the provided graph parameters would be
91
    //   the same as the previous input tensors that we have currently stored in the object
92
0
    virtual bool can_reuse(const llm_graph_params & params) {
93
        // returning false here by default will prevent from reusing the graph if the check
94
        //   for the input type has not been implemented yet
95
0
        GGML_UNUSED(params);
96
0
        return false;
97
0
    }
98
protected:
99
    // env: LLAMA_GRAPH_INPUT_DEBUG
100
    int debug = 0;
101
};
102
103
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
104
105
class llm_graph_input_embd : public llm_graph_input_i {
106
public:
107
0
    llm_graph_input_embd()          = default;
108
    virtual ~llm_graph_input_embd() = default;
109
110
    void set_input(const llama_ubatch * ubatch) override;
111
112
    bool can_reuse(const llm_graph_params & params) override;
113
114
    ggml_tensor * tokens = nullptr; // I32 [n_batch]
115
    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
116
};
117
118
class llm_graph_input_pos : public llm_graph_input_i {
119
public:
120
0
    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
121
    virtual ~llm_graph_input_pos() = default;
122
123
    void set_input(const llama_ubatch * ubatch) override;
124
125
    bool can_reuse(const llm_graph_params & params) override;
126
127
    ggml_tensor * pos = nullptr; // I32 [n_batch]
128
129
    const uint32_t n_pos_per_embd = 1;
130
};
131
132
// temperature tuning, used by llama4
133
class llm_graph_input_attn_temp : public llm_graph_input_i {
134
public:
135
    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
136
0
        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
137
    virtual ~llm_graph_input_attn_temp() = default;
138
139
    void set_input(const llama_ubatch * ubatch) override;
140
141
    ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
142
143
    const uint32_t n_attn_temp_floor_scale;
144
    const float    f_attn_temp_scale;
145
    const float    f_attn_temp_offset;
146
};
147
148
class llm_graph_input_pos_bucket : public llm_graph_input_i {
149
public:
150
0
    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
151
    virtual ~llm_graph_input_pos_bucket() = default;
152
153
    void set_input(const llama_ubatch * ubatch) override;
154
155
    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
156
157
    const llama_hparams hparams;
158
};
159
160
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
161
public:
162
    llm_graph_input_pos_bucket_kv(
163
            const llama_hparams & hparams,
164
0
            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
165
    virtual ~llm_graph_input_pos_bucket_kv() = default;
166
167
    void set_input(const llama_ubatch * ubatch) override;
168
169
    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
170
171
    const llama_hparams hparams;
172
173
    const llama_kv_cache_context * mctx;
174
};
175
176
class llm_graph_input_out_ids : public llm_graph_input_i {
177
public:
178
    llm_graph_input_out_ids(
179
            const llama_hparams & hparams,
180
            const llama_cparams & cparams,
181
0
            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
182
    virtual ~llm_graph_input_out_ids() = default;
183
184
    void set_input(const llama_ubatch * ubatch) override;
185
186
    bool can_reuse(const llm_graph_params & params) override;
187
188
    ggml_tensor * out_ids; // I32 [n_outputs]
189
190
    const llama_hparams hparams;
191
    const llama_cparams cparams;
192
193
    const uint32_t n_outputs;
194
};
195
196
class llm_graph_input_mean : public llm_graph_input_i {
197
public:
198
0
    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
199
    virtual ~llm_graph_input_mean() = default;
200
201
    void set_input(const llama_ubatch * ubatch) override;
202
203
    ggml_tensor * mean; // F32 [n_batch, n_batch]
204
205
    const llama_cparams cparams;
206
};
207
208
class llm_graph_input_cls : public llm_graph_input_i {
209
public:
210
0
    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
211
    virtual ~llm_graph_input_cls() = default;
212
213
    void set_input(const llama_ubatch * ubatch) override;
214
215
    ggml_tensor * cls; // I32 [n_batch]
216
217
    const llama_cparams cparams;
218
    const llm_arch arch;
219
};
220
221
class llm_graph_input_rs : public llm_graph_input_i {
222
public:
223
0
    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
224
    virtual ~llm_graph_input_rs() = default;
225
226
    void set_input(const llama_ubatch * ubatch) override;
227
228
    bool can_reuse(const llm_graph_params & params) override;
229
230
    ggml_tensor * s_copy;  // I32 [n_rs]
231
232
    // views of s_copy, computed once per graph
233
    // and shared across layers which use build_rs
234
    ggml_tensor * s_copy_main;   // I32 [n_seqs]
235
    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
236
237
    const llama_memory_recurrent_context * mctx;
238
239
    // used in view offsets, need to match for valid graph reuse
240
    uint32_t head;
241
    int32_t rs_z;
242
};
243
244
class llm_graph_input_cross_embd : public llm_graph_input_i {
245
public:
246
    llm_graph_input_cross_embd(
247
0
            const llama_cross * cross) : cross(cross) {}
248
    virtual ~llm_graph_input_cross_embd() = default;
249
250
    void set_input(const llama_ubatch * ubatch) override;
251
252
    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
253
254
    const llama_cross * cross;
255
};
256
257
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
258
public:
259
    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
260
0
        hparams(hparams),
261
0
        cparams(cparams) {
262
0
    }
263
    ~llm_graph_input_attn_no_cache() = default;
264
265
    void set_input(const llama_ubatch * ubatch) override;
266
267
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
268
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
269
270
    // n_tokens == n_batch
271
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
272
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
273
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
274
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
275
276
    const llama_hparams hparams;
277
    const llama_cparams cparams;
278
};
279
280
class llm_graph_input_attn_kv : public llm_graph_input_i {
281
public:
282
    llm_graph_input_attn_kv(
283
            const llama_hparams & hparams,
284
            const llama_cparams & cparams,
285
            const llama_kv_cache_context * mctx) :
286
0
        hparams(hparams),
287
0
        cparams(cparams),
288
0
        mctx(mctx) {
289
0
    }
290
    ~llm_graph_input_attn_kv() = default;
291
292
    void set_input(const llama_ubatch * ubatch) override;
293
294
    bool can_reuse(const llm_graph_params & params) override;
295
296
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
297
0
    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
298
299
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
300
301
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
302
    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
303
304
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
305
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
306
307
    // note: these have to be copies because in order to be able to reuse a graph, its inputs
308
    //       need to carry these parameters with them. otherwise, they can point to freed
309
    //       llm_graph_params from a previous batch, causing stack-use-after-return
310
    const llama_hparams hparams;
311
    const llama_cparams cparams;
312
313
    const llama_kv_cache_context * mctx;
314
};
315
316
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
317
public:
318
    llm_graph_input_attn_kv_iswa(
319
            const llama_hparams & hparams,
320
            const llama_cparams & cparams,
321
            const llama_kv_cache_iswa_context * mctx) :
322
0
        hparams(hparams),
323
0
        cparams(cparams),
324
0
        mctx(mctx) {
325
0
    }
326
    ~llm_graph_input_attn_kv_iswa() = default;
327
328
    void set_input(const llama_ubatch * ubatch) override;
329
330
    bool can_reuse(const llm_graph_params & params) override;
331
332
0
    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
333
0
    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
334
0
    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
335
0
    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
336
337
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
338
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
339
340
    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
341
    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
342
    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
343
    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
344
345
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
346
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
347
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
348
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
349
350
    const llama_hparams hparams;
351
    const llama_cparams cparams;
352
353
    const llama_kv_cache_iswa_context * mctx;
354
};
355
356
class llm_graph_input_attn_cross : public llm_graph_input_i {
357
public:
358
0
    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
359
    ~llm_graph_input_attn_cross() = default;
360
361
    void set_input(const llama_ubatch * ubatch) override;
362
363
0
    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
364
365
    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
366
    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
367
368
    const llama_cross * cross = nullptr;
369
};
370
371
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
372
public:
373
    llm_graph_input_mem_hybrid(
374
            const llama_cparams & cparams,
375
            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
376
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
377
            const llama_memory_hybrid_context *      mctx) :
378
0
        inp_attn(std::move(inp_attn)),
379
0
        inp_rs(std::move(inp_rs)),
380
0
        cparams(cparams),
381
0
        mctx(mctx) { }
382
0
    virtual ~llm_graph_input_mem_hybrid() = default;
383
384
    void set_input(const llama_ubatch * ubatch) override;
385
386
    bool can_reuse(const llm_graph_params & params) override;
387
388
    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
389
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
390
391
0
    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
392
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
393
394
    const llama_cparams cparams;
395
396
    const llama_memory_hybrid_context * mctx;
397
};
398
399
//
400
// llm_graph_result
401
//
402
403
// these objects deliver the result from the graph build process back to the llama_context
404
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
405
//   specific data, by calling the set_inputs() method
406
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
407
//   these are used by the llama_context to extact the relevant data, based on the compute parameters
408
409
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
410
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
411
412
class llm_graph_result;
413
414
struct llm_graph_params {
415
    llm_arch arch = LLM_ARCH_UNKNOWN;
416
417
    llama_hparams hparams;
418
    llama_cparams cparams;
419
420
    llama_ubatch ubatch; // note: intentionally make a copy
421
422
    llm_graph_type gtype;
423
424
    ggml_backend_sched_t sched;
425
    ggml_backend_t backend_cpu;
426
427
    const llama_adapter_cvec     * cvec;
428
    const llama_adapter_loras    * loras;
429
    const llama_memory_context_i * mctx;
430
    const llama_cross            * cross;
431
432
    uint32_t n_outputs;
433
434
    llm_graph_cb cb;
435
436
    llm_graph_result * res;
437
438
    // return true if the "other" params would result in a graph with the same topology as with the current params
439
    //   having the same topology allows us to reuse the graph in some cases
440
0
    bool allow_reuse(const llm_graph_params & other) const {
441
        // first check the ubatch
442
0
        bool can_reuse_ubatch =
443
0
            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
444
0
            ubatch.n_tokens     == other.ubatch.n_tokens &&
445
0
            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
446
0
            ubatch.n_seqs       == other.ubatch.n_seqs &&
447
0
            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
448
0
            (
449
0
                (!ubatch.token && !other.ubatch.token) ||
450
0
                (!ubatch.embd  && !other.ubatch.embd)
451
0
            );
452
453
        // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
454
        //   the reason is because the set of attention streams would be different for different sequences
455
0
        if (can_reuse_ubatch && ubatch.equal_seqs()) {
456
0
            if (!ubatch.data) {
457
                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
458
                //   therefore we cannot perform the sequence id check. normally should never happen
459
0
                can_reuse_ubatch = false;
460
0
            } else {
461
0
                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
462
0
                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
463
0
                }
464
0
            }
465
0
        }
466
467
0
        if (!can_reuse_ubatch) {
468
0
            return false;
469
0
        }
470
471
0
        return
472
0
            cparams.embeddings  == other.cparams.embeddings  &&
473
0
            cparams.causal_attn == other.cparams.causal_attn &&
474
0
            arch      == other.arch  &&
475
0
            gtype     == other.gtype &&
476
0
            cvec      == other.cvec  &&
477
0
            loras     == other.loras &&
478
0
            cross     == other.cross &&
479
0
            n_outputs == other.n_outputs;
480
0
    }
481
};
482
483
class llm_graph_result {
484
public:
485
    llm_graph_result(int64_t max_nodes);
486
487
0
    virtual ~llm_graph_result() = default;
488
489
0
    ggml_tensor * get_tokens()      const { return t_tokens; }
490
0
    ggml_tensor * get_logits()      const { return t_logits; }
491
0
    ggml_tensor * get_embd()        const { return t_embd; }
492
0
    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
493
494
0
    ggml_cgraph  * get_gf()  const { return gf; }
495
0
    ggml_context * get_ctx() const { return ctx_compute.get(); }
496
497
    int64_t get_max_nodes() const;
498
499
    void reset();
500
501
    void set_inputs(const llama_ubatch * ubatch);
502
503
    // try to update the existing graph result using the new graph parameters in order to reuse it
504
    // this can only be done if we determine that the resulting graph using the new graph parameters
505
    //   would be identical to the existing graph. in that case, we simply have to update the memory
506
    //   contexts of the input tensors of the graph and we can reuse it for another computation
507
    // return true if the graph was updated and can be reused
508
    bool can_reuse(const llm_graph_params & params);
509
510
    llm_graph_input_i * add_input(llm_graph_input_ptr input);
511
512
    void set_params(const llm_graph_params & params);
513
514
    // important graph nodes
515
    ggml_tensor * t_tokens      = nullptr;
516
    ggml_tensor * t_logits      = nullptr;
517
    ggml_tensor * t_embd        = nullptr;
518
    ggml_tensor * t_embd_pooled = nullptr;
519
520
    std::vector<llm_graph_input_ptr> inputs;
521
522
    ggml_context_ptr ctx_compute;
523
524
    // memory buffers used to evaluate the model
525
    std::vector<uint8_t> buf_compute_meta;
526
527
    ggml_cgraph * gf;
528
529
    int64_t max_nodes;
530
531
private:
532
    // keep a copy of the previous graph parameters
533
    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
534
    // note: these are updated after constructing the new graph
535
    llm_graph_params params;
536
537
    // env: LLAMA_GRAPH_RESULT_DEBUG
538
    int debug = 0;
539
};
540
541
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
542
543
//
544
// llm_graph_context
545
//
546
547
// used in build_rs to properly order writes and avoid unnecessary copies
548
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
549
550
struct llm_graph_context {
551
    const llm_arch arch;
552
553
    const llama_hparams & hparams;
554
    const llama_cparams & cparams;
555
    const llama_ubatch  & ubatch;
556
557
    const int64_t n_embd;
558
    const int64_t n_layer;
559
    const int64_t n_rot;
560
    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
561
    const int64_t n_head;
562
    const int64_t n_head_kv;
563
    const int64_t n_embd_head_k;
564
    const int64_t n_embd_k_gqa;
565
    const int64_t n_embd_head_v;
566
    const int64_t n_embd_v_gqa;
567
    const int64_t n_expert;
568
    const int64_t n_expert_used;
569
570
    const float freq_base;
571
    const float freq_scale;
572
    const float ext_factor;
573
    const float attn_factor;
574
    const float beta_fast;
575
    const float beta_slow;
576
    const float norm_eps;
577
    const float norm_rms_eps;
578
579
    const int64_t n_tokens;
580
    const int64_t n_outputs;
581
    const int32_t n_ctx_orig; // yarn
582
583
    const enum llama_pooling_type pooling_type;
584
    const enum llama_rope_type    rope_type;
585
586
    ggml_backend_sched_t sched;
587
588
    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
589
590
    const llama_adapter_cvec     * cvec;
591
    const llama_adapter_loras    * loras;
592
    const llama_memory_context_i * mctx;
593
    const llama_cross            * cross;
594
595
    const llm_graph_cb & cb_func;
596
597
    llm_graph_result * res;
598
599
    ggml_context * ctx0 = nullptr;
600
    ggml_cgraph  * gf   = nullptr;
601
602
    llm_graph_context(const llm_graph_params & params);
603
0
    virtual ~llm_graph_context() = default;
604
605
    void cb(ggml_tensor * cur, const char * name, int il) const;
606
607
    //
608
    // common
609
    //
610
611
    ggml_tensor * build_cvec(
612
             ggml_tensor * cur,
613
                     int   il) const;
614
615
    // do mat_mul, while optionally apply lora
616
    ggml_tensor * build_lora_mm(
617
              ggml_tensor * w,
618
              ggml_tensor * cur) const;
619
620
    // do mat_mul_id, while optionally apply lora
621
    ggml_tensor * build_lora_mm_id(
622
              ggml_tensor * w,   // ggml_tensor * as
623
              ggml_tensor * cur, // ggml_tensor * b
624
              ggml_tensor * ids) const;
625
626
    ggml_tensor * build_norm(
627
             ggml_tensor * cur,
628
             ggml_tensor * mw,
629
             ggml_tensor * mb,
630
           llm_norm_type   type,
631
                     int   il) const;
632
633
    ggml_tensor * build_ffn(
634
             ggml_tensor * cur,
635
             ggml_tensor * up,
636
             ggml_tensor * up_b,
637
             ggml_tensor * up_s,
638
             ggml_tensor * gate,
639
             ggml_tensor * gate_b,
640
             ggml_tensor * gate_s,
641
             ggml_tensor * down,
642
             ggml_tensor * down_b,
643
             ggml_tensor * down_s,
644
             ggml_tensor * act_scales,
645
         llm_ffn_op_type   type_op,
646
       llm_ffn_gate_type   type_gate,
647
                     int   il) const;
648
649
    // build MoE FFN without bias tensors
650
    ggml_tensor * build_moe_ffn(
651
             ggml_tensor * cur,
652
             ggml_tensor * gate_inp,
653
             ggml_tensor * up_exps,
654
             ggml_tensor * gate_exps,
655
             ggml_tensor * down_exps,
656
             ggml_tensor * exp_probs_b,
657
                 int64_t   n_expert,
658
                 int64_t   n_expert_used,
659
         llm_ffn_op_type   type_op,
660
                    bool   norm_w,
661
                    bool   scale_w,
662
                   float   w_scale,
663
            llama_expert_gating_func_type gating_op,
664
                     int   il,
665
             ggml_tensor * probs_in = nullptr) const;
666
667
    ggml_tensor * build_moe_ffn(
668
             ggml_tensor * cur,
669
             ggml_tensor * gate_inp,
670
             ggml_tensor * gate_inp_b,
671
             ggml_tensor * up_exps,
672
             ggml_tensor * up_exps_b,
673
             ggml_tensor * gate_exps,
674
             ggml_tensor * gate_exps_b,
675
             ggml_tensor * down_exps,
676
             ggml_tensor * down_exps_b,
677
             ggml_tensor * exp_probs_b,
678
                 int64_t   n_expert,
679
                 int64_t   n_expert_used,
680
         llm_ffn_op_type   type_op,
681
                    bool   norm_w,
682
                    bool   scale_w,
683
                   float   w_scale,
684
            llama_expert_gating_func_type gating_op,
685
                     int   il,
686
             ggml_tensor * probs_in = nullptr) const;
687
688
    //
689
    // inputs
690
    //
691
692
    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
693
    ggml_tensor * build_inp_pos() const;
694
    ggml_tensor * build_inp_attn_scale() const;
695
    ggml_tensor * build_inp_out_ids() const;
696
    ggml_tensor * build_inp_mean() const;
697
    ggml_tensor * build_inp_cls() const;
698
699
    ggml_tensor * build_inp_cross_embd() const;
700
    ggml_tensor * build_inp_pos_bucket_enc() const;
701
    ggml_tensor * build_inp_pos_bucket_dec() const;
702
    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
703
704
    //
705
    // attention
706
    //
707
708
    ggml_tensor * build_attn_mha(
709
            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
710
            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
711
            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
712
            ggml_tensor * kq_b,
713
            ggml_tensor * kq_mask,
714
            ggml_tensor * sinks,   // [n_head_q]
715
            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
716
                  float   kq_scale,
717
                    int   il) const;
718
719
    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
720
721
    ggml_tensor * build_attn(
722
            llm_graph_input_attn_no_cache * inp,
723
            ggml_tensor * wo,
724
            ggml_tensor * wo_b,
725
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
726
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
727
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
728
            ggml_tensor * kq_b,
729
            ggml_tensor * sinks, // [n_head_q]
730
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
731
                  float   kq_scale,
732
                    int   il) const;
733
734
    llm_graph_input_attn_kv * build_attn_inp_kv() const;
735
736
    ggml_tensor * build_attn(
737
            llm_graph_input_attn_kv * inp,
738
            ggml_tensor * wo,
739
            ggml_tensor * wo_b,
740
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
742
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
743
            ggml_tensor * kq_b,
744
            ggml_tensor * sinks, // [n_head_q]
745
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
746
                  float   kq_scale,
747
                    int   il) const;
748
749
    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
750
751
    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
752
    ggml_tensor * build_attn(
753
            llm_graph_input_attn_kv_iswa * inp,
754
            ggml_tensor * wo,
755
            ggml_tensor * wo_b,
756
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
757
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
758
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
759
            ggml_tensor * kq_b,
760
            ggml_tensor * sinks, // [n_head_q]
761
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
762
                  float   kq_scale,
763
                    int   il) const;
764
765
    llm_graph_input_attn_cross * build_attn_inp_cross() const;
766
767
    ggml_tensor * build_attn(
768
            llm_graph_input_attn_cross * inp,
769
            ggml_tensor * wo,
770
            ggml_tensor * wo_b,
771
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
772
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
773
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
774
            ggml_tensor * kq_b,
775
            ggml_tensor * sinks, // [n_head_q]
776
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
777
                  float   kq_scale,
778
                    int   il) const;
779
780
    //
781
    // recurrent
782
    //
783
784
    // TODO: move this implementation to llama_memory_recurrent.
785
    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
786
    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
787
    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
788
    //         `llama_memory_recurrent`
789
    ggml_tensor * build_rs(
790
            ggml_tensor * s,
791
            ggml_tensor * state_copy_main,
792
            ggml_tensor * state_copy_extra,
793
                int32_t   state_size,
794
                int32_t   n_seqs,
795
               uint32_t   n_rs,
796
               uint32_t   rs_head,
797
               uint32_t   rs_size,
798
                int32_t   rs_zero,
799
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
800
801
    llm_graph_input_rs * build_rs_inp() const;
802
803
    ggml_tensor * build_rs(
804
            llm_graph_input_rs * inp,
805
            ggml_tensor * s,
806
                int32_t   state_size,
807
                int32_t   n_seqs,
808
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
809
810
    ggml_tensor * build_rwkv_token_shift_load(
811
        llm_graph_input_rs * inp,
812
        const llama_ubatch & ubatch,
813
                       int   il) const;
814
815
    ggml_tensor * build_rwkv_token_shift_store(
816
             ggml_tensor * token_shift,
817
      const llama_ubatch & ubatch,
818
                     int   il) const;
819
    //
820
    // hybrid
821
    //
822
823
    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
824
825
    //
826
    // pooling
827
    //
828
829
    void build_pooling(
830
            ggml_tensor * cls,
831
            ggml_tensor * cls_b,
832
            ggml_tensor * cls_out,
833
            ggml_tensor * cls_out_b) const;
834
835
    //
836
    // dense (out)
837
    //
838
839
    void build_dense_out(
840
            ggml_tensor * dense_2,
841
            ggml_tensor * dense_3) const;
842
};
843
844
// TODO: better name
845
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);