Coverage Report

Created: 2026-01-09 06:17

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-arch.h"
4
#include "llama-batch.h"
5
#include "llama-hparams.h"
6
#include "llama-adapter.h"
7
8
#include <cstdint>
9
#include <vector>
10
#include <memory>
11
#include <set>
12
#include <functional>
13
#include <map>
14
15
struct ggml_cgraph;
16
struct ggml_context;
17
struct ggml_tensor;
18
19
struct llama_cparams;
20
21
struct llama_memory_context_i;
22
23
class llama_kv_cache_context;
24
class llama_kv_cache_iswa_context;
25
class llama_memory_recurrent_context;
26
class llama_memory_hybrid_context;
27
28
// certain models (typically multi-modal) can produce different types of graphs
29
enum llm_graph_type {
30
    LLM_GRAPH_TYPE_DEFAULT,
31
    LLM_GRAPH_TYPE_ENCODER,
32
    LLM_GRAPH_TYPE_DECODER,
33
};
34
35
enum llm_ffn_op_type {
36
    LLM_FFN_SILU,
37
    LLM_FFN_GELU,
38
    LLM_FFN_RELU,
39
    LLM_FFN_RELU_SQR,
40
    LLM_FFN_SWIGLU,
41
    LLM_FFN_GEGLU,
42
    LLM_FFN_REGLU,
43
    LLM_FFN_SWIGLU_OAI_MOE,
44
};
45
46
enum llm_ffn_gate_type {
47
    LLM_FFN_SEQ,
48
    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
49
};
50
51
enum llm_norm_type {
52
    LLM_NORM,
53
    LLM_NORM_RMS,
54
    LLM_NORM_GROUP,
55
};
56
57
// TODO: tmp - need something better to pass the data from the encoder to the decoder
58
struct llama_cross {
59
    // the output embeddings from the encoder as a ggml tensor
60
    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
61
    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
62
    //ggml_tensor * t_embd = nullptr;
63
64
    int64_t n_embd = 0;
65
    int64_t n_enc  = 0;
66
67
    // embeddings data copied to host memory (tmp)
68
    std::vector<float> v_embd;
69
70
    // needed to construct the cross-attention mask in the decoder
71
    std::vector<std::set<llama_seq_id>> seq_ids_enc;
72
};
73
74
struct llm_graph_params;
75
76
//
77
// llm_graph_input
78
//
79
80
class llm_graph_input_i {
81
public:
82
0
    llm_graph_input_i() {
83
0
        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
84
0
        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
85
0
    }
86
87
0
    virtual ~llm_graph_input_i() = default;
88
89
    virtual void set_input(const llama_ubatch * ubatch) = 0;
90
91
    // return true if the resulting input tensors using the provided graph parameters would be
92
    //   the same as the previous input tensors that we have currently stored in the object
93
0
    virtual bool can_reuse(const llm_graph_params & params) {
94
        // returning false here by default will prevent from reusing the graph if the check
95
        //   for the input type has not been implemented yet
96
0
        GGML_UNUSED(params);
97
0
        return false;
98
0
    }
99
protected:
100
    // env: LLAMA_GRAPH_INPUT_DEBUG
101
    int debug = 0;
102
};
103
104
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
105
106
class llm_graph_input_embd : public llm_graph_input_i {
107
public:
108
0
    llm_graph_input_embd()          = default;
109
    virtual ~llm_graph_input_embd() = default;
110
111
    void set_input(const llama_ubatch * ubatch) override;
112
113
    bool can_reuse(const llm_graph_params & params) override;
114
115
    ggml_tensor * tokens = nullptr; // I32 [n_batch]
116
    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
117
};
118
119
class llm_graph_input_pos : public llm_graph_input_i {
120
public:
121
0
    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
122
    virtual ~llm_graph_input_pos() = default;
123
124
    void set_input(const llama_ubatch * ubatch) override;
125
126
    bool can_reuse(const llm_graph_params & params) override;
127
128
    ggml_tensor * pos = nullptr; // I32 [n_batch]
129
130
    const uint32_t n_pos_per_embd = 1;
131
};
132
133
// temperature tuning, used by llama4
134
class llm_graph_input_attn_temp : public llm_graph_input_i {
135
public:
136
    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
137
0
        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
138
    virtual ~llm_graph_input_attn_temp() = default;
139
140
    void set_input(const llama_ubatch * ubatch) override;
141
142
    ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
143
144
    const uint32_t n_attn_temp_floor_scale;
145
    const float    f_attn_temp_scale;
146
    const float    f_attn_temp_offset;
147
};
148
149
class llm_graph_input_pos_bucket : public llm_graph_input_i {
150
public:
151
0
    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
152
    virtual ~llm_graph_input_pos_bucket() = default;
153
154
    void set_input(const llama_ubatch * ubatch) override;
155
156
    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
157
158
    const llama_hparams hparams;
159
};
160
161
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
162
public:
163
    llm_graph_input_pos_bucket_kv(
164
            const llama_hparams & hparams,
165
0
            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
166
    virtual ~llm_graph_input_pos_bucket_kv() = default;
167
168
    void set_input(const llama_ubatch * ubatch) override;
169
170
    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
171
172
    const llama_hparams hparams;
173
174
    const llama_kv_cache_context * mctx;
175
};
176
177
class llm_graph_input_out_ids : public llm_graph_input_i {
178
public:
179
    llm_graph_input_out_ids(
180
            const llama_hparams & hparams,
181
            const llama_cparams & cparams,
182
0
            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
183
    virtual ~llm_graph_input_out_ids() = default;
184
185
    void set_input(const llama_ubatch * ubatch) override;
186
187
    bool can_reuse(const llm_graph_params & params) override;
188
189
    ggml_tensor * out_ids; // I32 [n_outputs]
190
191
    const llama_hparams hparams;
192
    const llama_cparams cparams;
193
194
    const uint32_t n_outputs;
195
};
196
197
class llm_graph_input_mean : public llm_graph_input_i {
198
public:
199
0
    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
200
    virtual ~llm_graph_input_mean() = default;
201
202
    void set_input(const llama_ubatch * ubatch) override;
203
204
    ggml_tensor * mean; // F32 [n_batch, n_batch]
205
206
    const llama_cparams cparams;
207
};
208
209
class llm_graph_input_cls : public llm_graph_input_i {
210
public:
211
0
    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
212
    virtual ~llm_graph_input_cls() = default;
213
214
    void set_input(const llama_ubatch * ubatch) override;
215
216
    ggml_tensor * cls; // I32 [n_batch]
217
218
    const llama_cparams cparams;
219
    const llm_arch arch;
220
};
221
222
class llm_graph_input_rs : public llm_graph_input_i {
223
public:
224
0
    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
225
    virtual ~llm_graph_input_rs() = default;
226
227
    void set_input(const llama_ubatch * ubatch) override;
228
229
    bool can_reuse(const llm_graph_params & params) override;
230
231
    ggml_tensor * s_copy;  // I32 [n_rs]
232
233
    // views of s_copy, computed once per graph
234
    // and shared across layers which use build_rs
235
    ggml_tensor * s_copy_main;   // I32 [n_seqs]
236
    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
237
238
    const llama_memory_recurrent_context * mctx;
239
240
    // used in view offsets, need to match for valid graph reuse
241
    uint32_t head;
242
    int32_t rs_z;
243
};
244
245
class llm_graph_input_cross_embd : public llm_graph_input_i {
246
public:
247
    llm_graph_input_cross_embd(
248
0
            const llama_cross * cross) : cross(cross) {}
249
    virtual ~llm_graph_input_cross_embd() = default;
250
251
    void set_input(const llama_ubatch * ubatch) override;
252
253
    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
254
255
    const llama_cross * cross;
256
};
257
258
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
259
public:
260
    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
261
0
        hparams(hparams),
262
0
        cparams(cparams) {
263
0
    }
264
    ~llm_graph_input_attn_no_cache() = default;
265
266
    void set_input(const llama_ubatch * ubatch) override;
267
268
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
269
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
270
271
    // n_tokens == n_batch
272
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
273
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
274
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
275
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
276
277
    const llama_hparams hparams;
278
    const llama_cparams cparams;
279
};
280
281
class llm_graph_input_attn_kv : public llm_graph_input_i {
282
public:
283
    llm_graph_input_attn_kv(
284
            const llama_hparams & hparams,
285
            const llama_cparams & cparams,
286
            const llama_kv_cache_context * mctx) :
287
0
        hparams(hparams),
288
0
        cparams(cparams),
289
0
        mctx(mctx) {
290
0
    }
291
    ~llm_graph_input_attn_kv() = default;
292
293
    void set_input(const llama_ubatch * ubatch) override;
294
295
    bool can_reuse(const llm_graph_params & params) override;
296
297
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
298
0
    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
299
300
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
301
302
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
303
    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
304
305
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
306
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
307
308
    // note: these have to be copies because in order to be able to reuse a graph, its inputs
309
    //       need to carry these parameters with them. otherwise, they can point to freed
310
    //       llm_graph_params from a previous batch, causing stack-use-after-return
311
    const llama_hparams hparams;
312
    const llama_cparams cparams;
313
314
    const llama_kv_cache_context * mctx;
315
};
316
317
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
318
public:
319
    llm_graph_input_attn_kv_iswa(
320
            const llama_hparams & hparams,
321
            const llama_cparams & cparams,
322
            const llama_kv_cache_iswa_context * mctx) :
323
0
        hparams(hparams),
324
0
        cparams(cparams),
325
0
        mctx(mctx) {
326
0
    }
327
    ~llm_graph_input_attn_kv_iswa() = default;
328
329
    void set_input(const llama_ubatch * ubatch) override;
330
331
    bool can_reuse(const llm_graph_params & params) override;
332
333
0
    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
334
0
    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
335
0
    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
336
0
    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
337
338
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
339
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
340
341
    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
342
    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
343
    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
344
    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
345
346
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
347
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
348
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
349
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
350
351
    const llama_hparams hparams;
352
    const llama_cparams cparams;
353
354
    const llama_kv_cache_iswa_context * mctx;
355
};
356
357
class llm_graph_input_attn_cross : public llm_graph_input_i {
358
public:
359
0
    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
360
    ~llm_graph_input_attn_cross() = default;
361
362
    void set_input(const llama_ubatch * ubatch) override;
363
364
0
    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
365
366
    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
367
    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
368
369
    const llama_cross * cross = nullptr;
370
};
371
372
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
373
public:
374
    llm_graph_input_mem_hybrid(
375
            const llama_cparams & cparams,
376
            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
377
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
378
            const llama_memory_hybrid_context *      mctx) :
379
0
        inp_attn(std::move(inp_attn)),
380
0
        inp_rs(std::move(inp_rs)),
381
0
        cparams(cparams),
382
0
        mctx(mctx) { }
383
0
    virtual ~llm_graph_input_mem_hybrid() = default;
384
385
    void set_input(const llama_ubatch * ubatch) override;
386
387
    bool can_reuse(const llm_graph_params & params) override;
388
389
    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
390
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
391
392
0
    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
393
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
394
395
    const llama_cparams cparams;
396
397
    const llama_memory_hybrid_context * mctx;
398
};
399
400
class llm_graph_input_sampling : public llm_graph_input_i {
401
public:
402
    llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
403
0
        samplers(std::move(samplers)) { }
404
0
    virtual ~llm_graph_input_sampling() = default;
405
406
    void set_input(const llama_ubatch * ubatch) override;
407
    bool can_reuse(const llm_graph_params & params) override;
408
409
    std::map<llama_seq_id, llama_sampler *> samplers;
410
};
411
412
//
413
// llm_graph_result
414
//
415
416
// these objects deliver the result from the graph build process back to the llama_context
417
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
418
//   specific data, by calling the set_inputs() method
419
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
420
//   these are used by the llama_context to extact the relevant data, based on the compute parameters
421
422
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
423
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
424
425
class llm_graph_result;
426
427
struct llm_graph_params {
428
    llm_arch arch = LLM_ARCH_UNKNOWN;
429
430
    llama_hparams hparams;
431
    llama_cparams cparams;
432
433
    llama_ubatch ubatch; // note: intentionally make a copy
434
435
    llm_graph_type gtype;
436
437
    ggml_backend_sched_t sched;
438
    ggml_backend_t backend_cpu;
439
440
    const llama_adapter_cvec     * cvec;
441
    const llama_adapter_loras    * loras;
442
    const llama_memory_context_i * mctx;
443
    const llama_cross            * cross;
444
445
    std::map<llama_seq_id, llama_sampler *> samplers;
446
447
    static bool samplers_equal(
448
          const std::map<llama_seq_id, llama_sampler *> & lhs,
449
0
          const std::map<llama_seq_id, llama_sampler *> & rhs) {
450
0
        if (lhs.size() != rhs.size()) {
451
0
            return false;
452
0
        }
453
0
        for (const auto & [seq_id, sampler] : lhs) {
454
0
            auto it = rhs.find(seq_id);
455
0
            if (it == rhs.end() || it->second != sampler) {
456
0
                return false;
457
0
            }
458
0
        }
459
0
        return true;
460
0
    }
461
462
    uint32_t n_outputs;
463
464
    llm_graph_cb cb;
465
466
    llm_graph_result * res;
467
468
    // return true if the "other" params would result in a graph with the same topology as with the current params
469
    //   having the same topology allows us to reuse the graph in some cases
470
0
    bool allow_reuse(const llm_graph_params & other) const {
471
        // first check the ubatch
472
0
        bool can_reuse_ubatch =
473
0
            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
474
0
            ubatch.n_tokens     == other.ubatch.n_tokens &&
475
0
            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
476
0
            ubatch.n_seqs       == other.ubatch.n_seqs &&
477
0
            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
478
0
            (
479
0
                (!ubatch.token && !other.ubatch.token) ||
480
0
                (!ubatch.embd  && !other.ubatch.embd)
481
0
            );
482
483
        // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
484
        //   the reason is because the set of attention streams would be different for different sequences
485
0
        if (can_reuse_ubatch && ubatch.equal_seqs()) {
486
0
            if (!ubatch.data) {
487
                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
488
                //   therefore we cannot perform the sequence id check. normally should never happen
489
0
                can_reuse_ubatch = false;
490
0
            } else {
491
0
                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
492
0
                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
493
0
                }
494
0
            }
495
0
        }
496
497
0
        if (!can_reuse_ubatch) {
498
0
            return false;
499
0
        }
500
501
0
        if (n_outputs != other.n_outputs) {
502
0
            return false;
503
0
        }
504
505
0
        if (!samplers_equal(samplers, other.samplers)) {
506
0
            return false;
507
0
        }
508
509
0
        if (samplers.size() > 0) {
510
0
            if (!ubatch.data || !other.ubatch.data) {
511
0
                return false;
512
0
            }
513
514
            // check that the outputs are the same for all samplers
515
0
            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
516
0
                if (ubatch.output[i]    != other.ubatch.output[i] ||
517
0
                    ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
518
0
                    return false;
519
0
                }
520
0
            }
521
0
        }
522
523
0
        return
524
0
            cparams.embeddings  == other.cparams.embeddings  &&
525
0
            cparams.causal_attn == other.cparams.causal_attn &&
526
0
            arch  == other.arch  &&
527
0
            gtype == other.gtype &&
528
0
            cvec  == other.cvec  &&
529
0
            loras == other.loras &&
530
0
            cross == other.cross;
531
0
    }
532
};
533
534
class llm_graph_result {
535
public:
536
    llm_graph_result(int64_t max_nodes);
537
538
0
    virtual ~llm_graph_result() = default;
539
540
0
    ggml_tensor * get_tokens()      const { return t_tokens; }
541
0
    ggml_tensor * get_logits()      const { return t_logits; }
542
0
    ggml_tensor * get_embd()        const { return t_embd; }
543
0
    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
544
545
0
    ggml_cgraph  * get_gf()  const { return gf; }
546
0
    ggml_context * get_ctx() const { return ctx_compute.get(); }
547
548
    int64_t get_max_nodes() const;
549
550
    void reset();
551
552
    void set_inputs(const llama_ubatch * ubatch);
553
    void set_outputs();
554
555
    // try to update the existing graph result using the new graph parameters in order to reuse it
556
    // this can only be done if we determine that the resulting graph using the new graph parameters
557
    //   would be identical to the existing graph. in that case, we simply have to update the memory
558
    //   contexts of the input tensors of the graph and we can reuse it for another computation
559
    // return true if the graph was updated and can be reused
560
    bool can_reuse(const llm_graph_params & params);
561
562
    llm_graph_input_i * add_input(llm_graph_input_ptr input);
563
564
    void set_params(const llm_graph_params & params);
565
566
    // important graph nodes
567
    ggml_tensor * t_tokens      = nullptr;
568
    ggml_tensor * t_logits      = nullptr;
569
    ggml_tensor * t_embd        = nullptr;
570
    ggml_tensor * t_embd_pooled = nullptr;
571
572
    std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
573
    std::map<llama_seq_id, ggml_tensor*> t_candidates;
574
    std::map<llama_seq_id, ggml_tensor*> t_sampled;
575
    std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
576
577
    std::vector<llm_graph_input_ptr> inputs;
578
579
    ggml_context_ptr ctx_compute;
580
581
    // memory buffers used to evaluate the model
582
    std::vector<uint8_t> buf_compute_meta;
583
584
    ggml_cgraph * gf;
585
586
    int64_t max_nodes;
587
588
private:
589
    // keep a copy of the previous graph parameters
590
    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
591
    // note: these are updated after constructing the new graph
592
    llm_graph_params params;
593
594
    // env: LLAMA_GRAPH_RESULT_DEBUG
595
    int debug = 0;
596
};
597
598
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
599
600
//
601
// llm_graph_context
602
//
603
604
// used in build_rs to properly order writes and avoid unnecessary copies
605
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
606
607
struct llm_graph_context {
608
    const llm_arch arch;
609
610
    const llama_hparams & hparams;
611
    const llama_cparams & cparams;
612
    const llama_ubatch  & ubatch;
613
614
    const int64_t n_embd;
615
    const int64_t n_layer;
616
    const int64_t n_rot;
617
    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
618
    const int64_t n_head;
619
    const int64_t n_head_kv;
620
    const int64_t n_embd_head_k;
621
    const int64_t n_embd_k_gqa;
622
    const int64_t n_embd_head_v;
623
    const int64_t n_embd_v_gqa;
624
    const int64_t n_expert;
625
    const int64_t n_expert_used;
626
627
    const float freq_base;
628
    const float freq_scale;
629
    const float ext_factor;
630
    const float attn_factor;
631
    const float beta_fast;
632
    const float beta_slow;
633
    const float norm_eps;
634
    const float norm_rms_eps;
635
636
    const int64_t n_tokens;
637
    const int64_t n_outputs;
638
    const int32_t n_ctx_orig; // yarn
639
640
    const enum llama_pooling_type pooling_type;
641
    const enum llama_rope_type    rope_type;
642
643
    ggml_backend_sched_t sched;
644
645
    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
646
647
    const llama_adapter_cvec     * cvec;
648
    const llama_adapter_loras    * loras;
649
    const llama_memory_context_i * mctx;
650
    const llama_cross            * cross;
651
652
    std::map<llama_seq_id, llama_sampler *> samplers;
653
654
    const llm_graph_cb & cb_func;
655
656
    llm_graph_result * res;
657
658
    ggml_context * ctx0 = nullptr;
659
    ggml_cgraph  * gf   = nullptr;
660
661
    llm_graph_context(const llm_graph_params & params);
662
0
    virtual ~llm_graph_context() = default;
663
664
    void cb(ggml_tensor * cur, const char * name, int il) const;
665
666
    //
667
    // common
668
    //
669
670
    ggml_tensor * build_cvec(
671
             ggml_tensor * cur,
672
                     int   il) const;
673
674
    // do mat_mul, while optionally apply lora
675
    ggml_tensor * build_lora_mm(
676
              ggml_tensor * w,
677
              ggml_tensor * cur) const;
678
679
    // do mat_mul_id, while optionally apply lora
680
    ggml_tensor * build_lora_mm_id(
681
              ggml_tensor * w,   // ggml_tensor * as
682
              ggml_tensor * cur, // ggml_tensor * b
683
              ggml_tensor * ids) const;
684
685
    ggml_tensor * build_norm(
686
             ggml_tensor * cur,
687
             ggml_tensor * mw,
688
             ggml_tensor * mb,
689
           llm_norm_type   type,
690
                     int   il) const;
691
692
    ggml_tensor * build_ffn(
693
             ggml_tensor * cur,
694
             ggml_tensor * up,
695
             ggml_tensor * up_b,
696
             ggml_tensor * up_s,
697
             ggml_tensor * gate,
698
             ggml_tensor * gate_b,
699
             ggml_tensor * gate_s,
700
             ggml_tensor * down,
701
             ggml_tensor * down_b,
702
             ggml_tensor * down_s,
703
             ggml_tensor * act_scales,
704
         llm_ffn_op_type   type_op,
705
       llm_ffn_gate_type   type_gate,
706
                     int   il) const;
707
708
    // build MoE FFN without bias tensors
709
    ggml_tensor * build_moe_ffn(
710
             ggml_tensor * cur,
711
             ggml_tensor * gate_inp,
712
             ggml_tensor * up_exps,
713
             ggml_tensor * gate_exps,
714
             ggml_tensor * down_exps,
715
             ggml_tensor * exp_probs_b,
716
                 int64_t   n_expert,
717
                 int64_t   n_expert_used,
718
         llm_ffn_op_type   type_op,
719
                    bool   norm_w,
720
                    bool   scale_w,
721
                   float   w_scale,
722
            llama_expert_gating_func_type gating_op,
723
                     int   il,
724
             ggml_tensor * probs_in = nullptr) const;
725
726
    ggml_tensor * build_moe_ffn(
727
             ggml_tensor * cur,
728
             ggml_tensor * gate_inp,
729
             ggml_tensor * gate_inp_b,
730
             ggml_tensor * up_exps,
731
             ggml_tensor * up_exps_b,
732
             ggml_tensor * gate_exps,
733
             ggml_tensor * gate_exps_b,
734
             ggml_tensor * down_exps,
735
             ggml_tensor * down_exps_b,
736
             ggml_tensor * exp_probs_b,
737
                 int64_t   n_expert,
738
                 int64_t   n_expert_used,
739
         llm_ffn_op_type   type_op,
740
                    bool   norm_w,
741
                    bool   scale_w,
742
                   float   w_scale,
743
            llama_expert_gating_func_type gating_op,
744
                     int   il,
745
             ggml_tensor * probs_in = nullptr) const;
746
747
    //
748
    // inputs
749
    //
750
751
    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
752
    ggml_tensor * build_inp_pos() const;
753
    ggml_tensor * build_inp_attn_scale() const;
754
    ggml_tensor * build_inp_out_ids() const;
755
    ggml_tensor * build_inp_mean() const;
756
    ggml_tensor * build_inp_cls() const;
757
758
    ggml_tensor * build_inp_cross_embd() const;
759
    ggml_tensor * build_inp_pos_bucket_enc() const;
760
    ggml_tensor * build_inp_pos_bucket_dec() const;
761
    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
762
763
    //
764
    // attention
765
    //
766
767
    ggml_tensor * build_attn_mha(
768
            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
769
            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
770
            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
771
            ggml_tensor * kq_b,
772
            ggml_tensor * kq_mask,
773
            ggml_tensor * sinks,   // [n_head_q]
774
            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
775
                  float   kq_scale,
776
                    int   il) const;
777
778
    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
779
780
    ggml_tensor * build_attn(
781
            llm_graph_input_attn_no_cache * inp,
782
            ggml_tensor * wo,
783
            ggml_tensor * wo_b,
784
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
785
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
786
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
787
            ggml_tensor * kq_b,
788
            ggml_tensor * sinks, // [n_head_q]
789
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
790
                  float   kq_scale,
791
                    int   il) const;
792
793
    llm_graph_input_attn_kv * build_attn_inp_kv() const;
794
795
    ggml_tensor * build_attn(
796
            llm_graph_input_attn_kv * inp,
797
            ggml_tensor * wo,
798
            ggml_tensor * wo_b,
799
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
800
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
801
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
802
            ggml_tensor * kq_b,
803
            ggml_tensor * sinks, // [n_head_q]
804
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
805
                  float   kq_scale,
806
                    int   il) const;
807
808
    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
809
810
    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
811
    ggml_tensor * build_attn(
812
            llm_graph_input_attn_kv_iswa * inp,
813
            ggml_tensor * wo,
814
            ggml_tensor * wo_b,
815
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
816
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
817
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
818
            ggml_tensor * kq_b,
819
            ggml_tensor * sinks, // [n_head_q]
820
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
821
                  float   kq_scale,
822
                    int   il) const;
823
824
    llm_graph_input_attn_cross * build_attn_inp_cross() const;
825
826
    ggml_tensor * build_attn(
827
            llm_graph_input_attn_cross * inp,
828
            ggml_tensor * wo,
829
            ggml_tensor * wo_b,
830
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
831
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
832
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
833
            ggml_tensor * kq_b,
834
            ggml_tensor * sinks, // [n_head_q]
835
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
836
                  float   kq_scale,
837
                    int   il) const;
838
839
    //
840
    // recurrent
841
    //
842
843
    // TODO: move this implementation to llama_memory_recurrent.
844
    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
845
    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
846
    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
847
    //         `llama_memory_recurrent`
848
    ggml_tensor * build_rs(
849
            ggml_tensor * s,
850
            ggml_tensor * state_copy_main,
851
            ggml_tensor * state_copy_extra,
852
                int32_t   state_size,
853
                int32_t   n_seqs,
854
               uint32_t   n_rs,
855
               uint32_t   rs_head,
856
               uint32_t   rs_size,
857
                int32_t   rs_zero,
858
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
859
860
    llm_graph_input_rs * build_rs_inp() const;
861
862
    ggml_tensor * build_rs(
863
            llm_graph_input_rs * inp,
864
            ggml_tensor * s,
865
                int32_t   state_size,
866
                int32_t   n_seqs,
867
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
868
869
    ggml_tensor * build_rwkv_token_shift_load(
870
        llm_graph_input_rs * inp,
871
        const llama_ubatch & ubatch,
872
                       int   il) const;
873
874
    ggml_tensor * build_rwkv_token_shift_store(
875
             ggml_tensor * token_shift,
876
      const llama_ubatch & ubatch,
877
                     int   il) const;
878
    //
879
    // hybrid
880
    //
881
882
    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
883
884
    //
885
    // pooling
886
    //
887
888
    void build_pooling(
889
            ggml_tensor * cls,
890
            ggml_tensor * cls_b,
891
            ggml_tensor * cls_out,
892
            ggml_tensor * cls_out_b) const;
893
894
    //
895
    // sampling (backend sampling)
896
    //
897
898
    void build_sampling() const;
899
900
    //
901
    // dense (out)
902
    //
903
904
    void build_dense_out(
905
            ggml_tensor * dense_2,
906
            ggml_tensor * dense_3) const;
907
};
908
909
// TODO: better name
910
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);