Coverage Report

Created: 2026-04-12 06:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-arch.h"
4
#include "llama-batch.h"
5
#include "llama-hparams.h"
6
#include "llama-adapter.h"
7
8
#include <cstdint>
9
#include <vector>
10
#include <memory>
11
#include <set>
12
#include <functional>
13
#include <map>
14
15
struct ggml_cgraph;
16
struct ggml_context;
17
struct ggml_tensor;
18
19
struct llama_cparams;
20
21
struct llama_memory_context_i;
22
23
class llama_kv_cache_context;
24
class llama_kv_cache_iswa_context;
25
class llama_memory_recurrent_context;
26
class llama_memory_hybrid_context;
27
class llama_memory_hybrid_iswa_context;
28
29
// certain models (typically multi-modal) can produce different types of graphs
30
enum llm_graph_type {
31
    LLM_GRAPH_TYPE_DEFAULT,
32
    LLM_GRAPH_TYPE_ENCODER,
33
    LLM_GRAPH_TYPE_DECODER,
34
};
35
36
enum llm_ffn_op_type {
37
    LLM_FFN_SILU,
38
    LLM_FFN_GELU,
39
    LLM_FFN_RELU,
40
    LLM_FFN_RELU_SQR,
41
    LLM_FFN_SWIGLU,
42
    LLM_FFN_GEGLU,
43
    LLM_FFN_REGLU,
44
    LLM_FFN_SWIGLU_OAI_MOE,
45
};
46
47
enum llm_ffn_gate_type {
48
    LLM_FFN_SEQ,
49
    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
50
};
51
52
enum llm_norm_type {
53
    LLM_NORM,
54
    LLM_NORM_RMS,
55
    LLM_NORM_GROUP,
56
};
57
58
// TODO: tmp - need something better to pass the data from the encoder to the decoder
59
struct llama_cross {
60
    // the output embeddings from the encoder as a ggml tensor
61
    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
62
    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
63
    //ggml_tensor * t_embd = nullptr;
64
65
    int64_t n_embd = 0;
66
    int64_t n_enc  = 0;
67
68
    // embeddings data copied to host memory (tmp)
69
    std::vector<float> v_embd;
70
71
    // needed to construct the cross-attention mask in the decoder
72
    std::vector<std::set<llama_seq_id>> seq_ids_enc;
73
};
74
75
struct llm_graph_params;
76
77
//
78
// llm_graph_input
79
//
80
81
class llm_graph_input_i {
82
public:
83
0
    llm_graph_input_i() {
84
0
        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
85
0
        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
86
0
    }
87
88
0
    virtual ~llm_graph_input_i() = default;
89
90
    virtual void set_input(const llama_ubatch * ubatch) = 0;
91
92
    // return true if the resulting input tensors using the provided graph parameters would be
93
    //   the same as the previous input tensors that we have currently stored in the object
94
0
    virtual bool can_reuse(const llm_graph_params & params) {
95
        // returning false here by default will prevent from reusing the graph if the check
96
        //   for the input type has not been implemented yet
97
0
        GGML_UNUSED(params);
98
0
        return false;
99
0
    }
100
protected:
101
    // env: LLAMA_GRAPH_INPUT_DEBUG
102
    int debug = 0;
103
};
104
105
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
106
107
class llm_graph_input_embd : public llm_graph_input_i {
108
public:
109
0
    llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
110
    virtual ~llm_graph_input_embd() = default;
111
112
    void set_input(const llama_ubatch * ubatch) override;
113
114
    bool can_reuse(const llm_graph_params & params) override;
115
116
    ggml_tensor * tokens = nullptr; // I32 [n_batch]
117
    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
118
119
    const int64_t n_embd = 0;
120
};
121
122
class llm_graph_input_pos : public llm_graph_input_i {
123
public:
124
0
    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
125
    virtual ~llm_graph_input_pos() = default;
126
127
    void set_input(const llama_ubatch * ubatch) override;
128
129
    bool can_reuse(const llm_graph_params & params) override;
130
131
    ggml_tensor * pos = nullptr; // I32 [n_batch]
132
133
    const uint32_t n_pos_per_embd = 1;
134
};
135
136
// temperature tuning, used by llama4
137
class llm_graph_input_attn_temp : public llm_graph_input_i {
138
public:
139
    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
140
0
        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
141
    virtual ~llm_graph_input_attn_temp() = default;
142
143
    void set_input(const llama_ubatch * ubatch) override;
144
145
    ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
146
147
    const uint32_t n_attn_temp_floor_scale;
148
    const float    f_attn_temp_scale;
149
    const float    f_attn_temp_offset;
150
};
151
152
class llm_graph_input_pos_bucket : public llm_graph_input_i {
153
public:
154
0
    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
155
    virtual ~llm_graph_input_pos_bucket() = default;
156
157
    void set_input(const llama_ubatch * ubatch) override;
158
159
    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
160
161
    const llama_hparams hparams;
162
};
163
164
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
165
public:
166
    llm_graph_input_pos_bucket_kv(
167
            const llama_hparams & hparams,
168
0
            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
169
    virtual ~llm_graph_input_pos_bucket_kv() = default;
170
171
    void set_input(const llama_ubatch * ubatch) override;
172
173
    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
174
175
    const llama_hparams hparams;
176
177
    const llama_kv_cache_context * mctx;
178
};
179
180
class llm_graph_input_out_ids : public llm_graph_input_i {
181
public:
182
    llm_graph_input_out_ids(
183
            const llama_hparams & hparams,
184
            const llama_cparams & cparams,
185
0
            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
186
    virtual ~llm_graph_input_out_ids() = default;
187
188
    void set_input(const llama_ubatch * ubatch) override;
189
190
    bool can_reuse(const llm_graph_params & params) override;
191
192
    ggml_tensor * out_ids; // I32 [n_outputs]
193
194
    const llama_hparams hparams;
195
    const llama_cparams cparams;
196
197
    const uint32_t n_outputs;
198
};
199
200
class llm_graph_input_mean : public llm_graph_input_i {
201
public:
202
0
    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
203
    virtual ~llm_graph_input_mean() = default;
204
205
    void set_input(const llama_ubatch * ubatch) override;
206
207
    ggml_tensor * mean; // F32 [n_batch, n_batch]
208
209
    const llama_cparams cparams;
210
};
211
212
class llm_graph_input_cls : public llm_graph_input_i {
213
public:
214
0
    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
215
    virtual ~llm_graph_input_cls() = default;
216
217
    void set_input(const llama_ubatch * ubatch) override;
218
219
    ggml_tensor * cls; // I32 [n_batch]
220
221
    const llama_cparams cparams;
222
    const llm_arch arch;
223
};
224
225
class llm_graph_input_rs : public llm_graph_input_i {
226
public:
227
0
    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
228
    virtual ~llm_graph_input_rs() = default;
229
230
    void set_input(const llama_ubatch * ubatch) override;
231
232
    bool can_reuse(const llm_graph_params & params) override;
233
234
    ggml_tensor * s_copy;  // I32 [n_rs]
235
236
    // views of s_copy, computed once per graph
237
    // and shared across layers which use build_rs
238
    ggml_tensor * s_copy_main;   // I32 [n_seqs]
239
    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
240
241
    const llama_memory_recurrent_context * mctx;
242
243
    // used in view offsets, need to match for valid graph reuse
244
    uint32_t head;
245
    int32_t rs_z;
246
};
247
248
class llm_graph_input_cross_embd : public llm_graph_input_i {
249
public:
250
    llm_graph_input_cross_embd(
251
0
            const llama_cross * cross) : cross(cross) {}
252
    virtual ~llm_graph_input_cross_embd() = default;
253
254
    void set_input(const llama_ubatch * ubatch) override;
255
256
    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
257
258
    const llama_cross * cross;
259
};
260
261
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
262
public:
263
    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
264
0
        hparams(hparams),
265
0
        cparams(cparams) {
266
0
    }
267
    ~llm_graph_input_attn_no_cache() = default;
268
269
    void set_input(const llama_ubatch * ubatch) override;
270
271
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
272
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
273
274
    // n_tokens == n_batch
275
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
276
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
277
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
278
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
279
280
    const llama_hparams hparams;
281
    const llama_cparams cparams;
282
};
283
284
class llm_graph_input_attn_kv : public llm_graph_input_i {
285
public:
286
    llm_graph_input_attn_kv(
287
            const llama_hparams & hparams,
288
            const llama_cparams & cparams,
289
            const llama_kv_cache_context * mctx) :
290
0
        hparams(hparams),
291
0
        cparams(cparams),
292
0
        mctx(mctx) {
293
0
    }
294
    ~llm_graph_input_attn_kv() = default;
295
296
    void set_input(const llama_ubatch * ubatch) override;
297
298
    bool can_reuse(const llm_graph_params & params) override;
299
300
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
301
0
    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
302
303
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
304
305
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
306
    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
307
308
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
310
311
    // note: assumes v_rot^2 == I
312
    ggml_tensor * self_k_rot = nullptr;
313
    ggml_tensor * self_v_rot = nullptr;
314
315
    // note: these have to be copies because in order to be able to reuse a graph, its inputs
316
    //       need to carry these parameters with them. otherwise, they can point to freed
317
    //       llm_graph_params from a previous batch, causing stack-use-after-return
318
    const llama_hparams hparams;
319
    const llama_cparams cparams;
320
321
    const llama_kv_cache_context * mctx;
322
};
323
324
// V-less input for the KV cache
325
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
326
class llm_graph_input_attn_k : public llm_graph_input_i {
327
public:
328
    llm_graph_input_attn_k(
329
            const llama_hparams & hparams,
330
            const llama_cparams & cparams,
331
            const llama_kv_cache_context * mctx) :
332
0
        hparams(hparams),
333
0
        cparams(cparams),
334
0
        mctx(mctx) {
335
0
    }
336
    ~llm_graph_input_attn_k() = default;
337
338
    void set_input(const llama_ubatch * ubatch) override;
339
340
    bool can_reuse(const llm_graph_params & params) override;
341
342
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
343
344
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
345
346
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
347
348
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
349
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
350
351
    const llama_hparams hparams;
352
    const llama_cparams cparams;
353
354
    const llama_kv_cache_context * mctx;
355
};
356
357
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
358
public:
359
    llm_graph_input_attn_kv_iswa(
360
            const llama_hparams & hparams,
361
            const llama_cparams & cparams,
362
            const llama_kv_cache_iswa_context * mctx) :
363
0
        hparams(hparams),
364
0
        cparams(cparams),
365
0
        mctx(mctx) {
366
0
    }
367
    ~llm_graph_input_attn_kv_iswa() = default;
368
369
    void set_input(const llama_ubatch * ubatch) override;
370
371
    bool can_reuse(const llm_graph_params & params) override;
372
373
0
    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
374
0
    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
375
0
    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
376
0
    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
377
378
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
379
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
380
381
    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
382
    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
383
    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
384
    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
385
386
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
387
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
388
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
389
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
390
391
    ggml_tensor * self_k_rot = nullptr;
392
    ggml_tensor * self_v_rot = nullptr;
393
394
    ggml_tensor * self_k_rot_swa = nullptr;
395
    ggml_tensor * self_v_rot_swa = nullptr;
396
397
    const llama_hparams hparams;
398
    const llama_cparams cparams;
399
400
    const llama_kv_cache_iswa_context * mctx;
401
};
402
403
class llm_graph_input_attn_cross : public llm_graph_input_i {
404
public:
405
0
    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
406
    ~llm_graph_input_attn_cross() = default;
407
408
    void set_input(const llama_ubatch * ubatch) override;
409
410
0
    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
411
412
    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
413
    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
414
415
    const llama_cross * cross = nullptr;
416
};
417
418
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
419
public:
420
    llm_graph_input_mem_hybrid(
421
            const llama_cparams & cparams,
422
            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
423
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
424
            const llama_memory_hybrid_context *      mctx) :
425
0
        inp_attn(std::move(inp_attn)),
426
0
        inp_rs(std::move(inp_rs)),
427
0
        cparams(cparams),
428
0
        mctx(mctx) { }
429
0
    virtual ~llm_graph_input_mem_hybrid() = default;
430
431
    void set_input(const llama_ubatch * ubatch) override;
432
433
    bool can_reuse(const llm_graph_params & params) override;
434
435
    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
436
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
437
438
0
    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
439
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
440
441
    const llama_cparams cparams;
442
443
    const llama_memory_hybrid_context * mctx;
444
};
445
446
class llm_graph_input_mem_hybrid_k : public llm_graph_input_i {
447
public:
448
    llm_graph_input_mem_hybrid_k(
449
            const llama_cparams & cparams,
450
            std::unique_ptr<llm_graph_input_attn_k> inp_attn,
451
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
452
            const llama_memory_hybrid_context *      mctx) :
453
0
        inp_attn(std::move(inp_attn)),
454
0
        inp_rs(std::move(inp_rs)),
455
0
        cparams(cparams),
456
0
        mctx(mctx) { }
457
0
    virtual ~llm_graph_input_mem_hybrid_k() = default;
458
459
    void set_input(const llama_ubatch * ubatch) override;
460
461
    bool can_reuse(const llm_graph_params & params) override;
462
463
    std::unique_ptr<llm_graph_input_attn_k> inp_attn;
464
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
465
466
0
    llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); }
467
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
468
469
    const llama_cparams cparams;
470
471
    const llama_memory_hybrid_context * mctx;
472
};
473
474
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
475
public:
476
    llm_graph_input_mem_hybrid_iswa(
477
            const llama_cparams & cparams,
478
            std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
479
            std::unique_ptr<llm_graph_input_rs>          inp_rs,
480
            const llama_memory_hybrid_iswa_context *     mctx) :
481
0
        inp_attn(std::move(inp_attn)),
482
0
        inp_rs(std::move(inp_rs)),
483
0
        cparams(cparams),
484
0
        mctx(mctx) { }
485
0
    virtual ~llm_graph_input_mem_hybrid_iswa() = default;
486
487
    void set_input(const llama_ubatch * ubatch) override;
488
489
    bool can_reuse(const llm_graph_params & params) override;
490
491
    std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
492
    std::unique_ptr<llm_graph_input_rs>          inp_rs;
493
494
0
    llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
495
0
    llm_graph_input_rs           * get_recr() const { return inp_rs.get(); }
496
497
    const llama_cparams cparams;
498
499
    const llama_memory_hybrid_iswa_context * mctx;
500
};
501
502
class llm_graph_input_sampling : public llm_graph_input_i {
503
public:
504
    llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
505
0
        samplers(std::move(samplers)) { }
506
0
    virtual ~llm_graph_input_sampling() = default;
507
508
    void set_input(const llama_ubatch * ubatch) override;
509
    bool can_reuse(const llm_graph_params & params) override;
510
511
    std::map<llama_seq_id, llama_sampler *> samplers;
512
};
513
514
//
515
// llm_graph_result
516
//
517
518
// these objects deliver the result from the graph build process back to the llama_context
519
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
520
//   specific data, by calling the set_inputs() method
521
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
522
//   these are used by the llama_context to extact the relevant data, based on the compute parameters
523
524
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
525
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
526
527
class llm_graph_result;
528
529
struct llm_graph_params {
530
    llm_arch arch = LLM_ARCH_UNKNOWN;
531
532
    llama_hparams hparams;
533
    llama_cparams cparams;
534
535
    llama_ubatch ubatch; // note: intentionally make a copy
536
537
    llm_graph_type gtype;
538
539
    ggml_backend_sched_t sched;
540
    ggml_backend_t backend_cpu;
541
542
    const llama_adapter_cvec     * cvec;
543
    const llama_adapter_loras    * loras;
544
    const llama_memory_context_i * mctx;
545
    const llama_cross            * cross;
546
547
    std::map<llama_seq_id, llama_sampler *> samplers;
548
549
    static bool samplers_equal(
550
          const std::map<llama_seq_id, llama_sampler *> & lhs,
551
0
          const std::map<llama_seq_id, llama_sampler *> & rhs) {
552
0
        if (lhs.size() != rhs.size()) {
553
0
            return false;
554
0
        }
555
0
        for (const auto & [seq_id, sampler] : lhs) {
556
0
            auto it = rhs.find(seq_id);
557
0
            if (it == rhs.end() || it->second != sampler) {
558
0
                return false;
559
0
            }
560
0
        }
561
0
        return true;
562
0
    }
563
564
    uint32_t n_outputs;
565
566
    llm_graph_cb cb;
567
568
    llm_graph_result * res;
569
570
    // return true if the "other" params would result in a graph with the same topology as with the current params
571
    //   having the same topology allows us to reuse the graph in some cases
572
0
    bool allow_reuse(const llm_graph_params & other) const {
573
        // first check the ubatch
574
0
        bool can_reuse_ubatch =
575
0
            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
576
0
            ubatch.n_tokens     == other.ubatch.n_tokens &&
577
0
            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
578
0
            ubatch.n_seqs       == other.ubatch.n_seqs &&
579
0
            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
580
0
            (
581
0
                (!ubatch.token && !other.ubatch.token) ||
582
0
                (!ubatch.embd  && !other.ubatch.embd)
583
0
            );
584
585
        // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
586
        //   the reason is because the set of attention streams would be different for different sequences
587
0
        if (can_reuse_ubatch && ubatch.equal_seqs()) {
588
0
            if (!ubatch.data) {
589
                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
590
                //   therefore we cannot perform the sequence id check. normally should never happen
591
0
                can_reuse_ubatch = false;
592
0
            } else {
593
0
                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
594
0
                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
595
0
                }
596
0
            }
597
0
        }
598
599
0
        if (!can_reuse_ubatch) {
600
0
            return false;
601
0
        }
602
603
0
        if (n_outputs != other.n_outputs) {
604
0
            return false;
605
0
        }
606
607
0
        if (!samplers_equal(samplers, other.samplers)) {
608
0
            return false;
609
0
        }
610
611
0
        if (samplers.size() > 0) {
612
0
            if (!ubatch.data || !other.ubatch.data) {
613
0
                return false;
614
0
            }
615
616
            // check that the outputs are the same for all samplers
617
0
            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
618
0
                if (ubatch.output[i]    != other.ubatch.output[i] ||
619
0
                    ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
620
0
                    return false;
621
0
                }
622
0
            }
623
0
        }
624
625
0
        return
626
0
            cparams.embeddings  == other.cparams.embeddings  &&
627
0
            cparams.causal_attn == other.cparams.causal_attn &&
628
0
            arch  == other.arch  &&
629
0
            gtype == other.gtype &&
630
0
            cvec  == other.cvec  &&
631
0
            loras == other.loras &&
632
0
            cross == other.cross;
633
0
    }
634
};
635
636
class llm_graph_result {
637
public:
638
    llm_graph_result(int64_t max_nodes);
639
640
0
    virtual ~llm_graph_result() = default;
641
642
0
    ggml_tensor * get_inp_tokens()  const { return t_inp_tokens; }
643
0
    ggml_tensor * get_logits()      const { return t_logits; }
644
0
    ggml_tensor * get_embd()        const { return t_embd; }
645
0
    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
646
647
0
    ggml_cgraph  * get_gf()  const { return gf; }
648
0
    ggml_context * get_ctx() const { return ctx_compute.get(); }
649
650
    int64_t get_max_nodes() const;
651
652
    void reset();
653
654
    void set_inputs(const llama_ubatch * ubatch);
655
    void set_outputs();
656
657
    // try to update the existing graph result using the new graph parameters in order to reuse it
658
    // this can only be done if we determine that the resulting graph using the new graph parameters
659
    //   would be identical to the existing graph. in that case, we simply have to update the memory
660
    //   contexts of the input tensors of the graph and we can reuse it for another computation
661
    // return true if the graph was updated and can be reused
662
    bool can_reuse(const llm_graph_params & params);
663
664
    llm_graph_input_i * add_input(llm_graph_input_ptr input);
665
666
    void set_params(const llm_graph_params & params);
667
668
    // important graph nodes
669
    ggml_tensor * t_inp_tokens  = nullptr;
670
    ggml_tensor * t_inp_embd    = nullptr; // [n_embd_inp, n_tokens]
671
    ggml_tensor * t_logits      = nullptr;
672
    ggml_tensor * t_embd        = nullptr;
673
    ggml_tensor * t_embd_pooled = nullptr;
674
675
    std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
676
    std::map<llama_seq_id, ggml_tensor*> t_candidates;
677
    std::map<llama_seq_id, ggml_tensor*> t_sampled;
678
    std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
679
680
    std::vector<llm_graph_input_ptr> inputs;
681
682
    ggml_context_ptr ctx_compute;
683
684
    // memory buffers used to evaluate the model
685
    std::vector<uint8_t> buf_compute_meta;
686
687
    ggml_cgraph * gf;
688
689
    int64_t max_nodes;
690
691
private:
692
    // keep a copy of the previous graph parameters
693
    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
694
    // note: these are updated after constructing the new graph
695
    llm_graph_params params;
696
697
    // env: LLAMA_GRAPH_RESULT_DEBUG
698
    int debug = 0;
699
};
700
701
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
702
703
//
704
// llm_graph_context
705
//
706
707
// used in build_rs to properly order writes and avoid unnecessary copies
708
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
709
710
struct llm_graph_context {
711
    const llm_arch arch;
712
713
    const llama_hparams & hparams;
714
    const llama_cparams & cparams;
715
    const llama_ubatch  & ubatch;
716
717
    const int64_t n_embd;
718
    const int64_t n_layer;
719
    const int64_t n_rot;
720
    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
721
    const int64_t n_head;
722
    const int64_t n_head_kv;
723
    const int64_t n_embd_head_k;
724
    const int64_t n_embd_k_gqa;
725
    const int64_t n_embd_head_v;
726
    const int64_t n_embd_v_gqa;
727
    const int64_t n_expert;
728
    const int64_t n_expert_used;
729
730
    const float freq_base;
731
    const float freq_scale;
732
    const float ext_factor;
733
    const float attn_factor;
734
    const float beta_fast;
735
    const float beta_slow;
736
    const float norm_eps;
737
    const float norm_rms_eps;
738
739
    const int64_t n_tokens;
740
    const int64_t n_outputs;
741
    const int32_t n_ctx_orig; // yarn
742
743
    const enum llama_pooling_type pooling_type;
744
    const enum llama_rope_type    rope_type;
745
746
    ggml_backend_sched_t sched;
747
748
    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
749
750
    const llama_adapter_cvec     * cvec;
751
    const llama_adapter_loras    * loras;
752
    const llama_memory_context_i * mctx;
753
    const llama_cross            * cross;
754
755
    std::map<llama_seq_id, llama_sampler *> samplers;
756
757
    const llm_graph_cb & cb_func;
758
759
    llm_graph_result * res;
760
761
    ggml_context * ctx0 = nullptr;
762
    ggml_cgraph  * gf   = nullptr;
763
764
    llm_graph_context(const llm_graph_params & params);
765
0
    virtual ~llm_graph_context() = default;
766
767
    void cb(ggml_tensor * cur, const char * name, int il) const;
768
769
    //
770
    // common
771
    //
772
773
    ggml_tensor * build_cvec(
774
             ggml_tensor * cur,
775
                     int   il) const;
776
777
    // do mat_mul, while optionally apply lora and per-tensor scale
778
    ggml_tensor * build_lora_mm(
779
              ggml_tensor * w,
780
              ggml_tensor * cur,
781
              ggml_tensor * w_s = nullptr) const;
782
783
    // do mat_mul_id, while optionally apply lora
784
    ggml_tensor * build_lora_mm_id(
785
              ggml_tensor * w,   // ggml_tensor * as
786
              ggml_tensor * cur, // ggml_tensor * b
787
              ggml_tensor * ids) const;
788
789
    ggml_tensor * build_norm(
790
             ggml_tensor * cur,
791
             ggml_tensor * mw,
792
             ggml_tensor * mb,
793
           llm_norm_type   type,
794
                     int   il) const;
795
796
    ggml_tensor * build_ffn(
797
             ggml_tensor * cur,
798
             ggml_tensor * up,
799
             ggml_tensor * up_b,
800
             ggml_tensor * up_s,
801
             ggml_tensor * gate,
802
             ggml_tensor * gate_b,
803
             ggml_tensor * gate_s,
804
             ggml_tensor * down,
805
             ggml_tensor * down_b,
806
             ggml_tensor * down_s,
807
             ggml_tensor * act_scales,
808
         llm_ffn_op_type   type_op,
809
       llm_ffn_gate_type   type_gate,
810
                     int   il) const;
811
812
    // build MoE FFN without bias tensors
813
    ggml_tensor * build_moe_ffn(
814
             ggml_tensor * cur,
815
             ggml_tensor * gate_inp,
816
             ggml_tensor * up_exps,
817
             ggml_tensor * gate_exps,
818
             ggml_tensor * down_exps,
819
             ggml_tensor * exp_probs_b,
820
                 int64_t   n_expert,
821
                 int64_t   n_expert_used,
822
         llm_ffn_op_type   type_op,
823
                    bool   norm_w,
824
                   float   w_scale,
825
            llama_expert_gating_func_type gating_op,
826
                     int   il,
827
             ggml_tensor * probs_in = nullptr,
828
             ggml_tensor * gate_up_exps = nullptr,
829
             ggml_tensor * up_exps_s = nullptr,
830
             ggml_tensor * gate_exps_s = nullptr,
831
             ggml_tensor * down_exps_s = nullptr) const;
832
833
    ggml_tensor * build_moe_ffn(
834
             ggml_tensor * cur,
835
             ggml_tensor * gate_inp,
836
             ggml_tensor * gate_inp_b,
837
             ggml_tensor * up_exps,
838
             ggml_tensor * up_exps_b,
839
             ggml_tensor * gate_exps,
840
             ggml_tensor * gate_exps_b,
841
             ggml_tensor * down_exps,
842
             ggml_tensor * down_exps_b,
843
             ggml_tensor * exp_probs_b,
844
                 int64_t   n_expert,
845
                 int64_t   n_expert_used,
846
         llm_ffn_op_type   type_op,
847
                    bool   norm_w,
848
                   float   w_scale,
849
            llama_expert_gating_func_type gating_op,
850
                     int   il,
851
             ggml_tensor * probs_in = nullptr,
852
             ggml_tensor * gate_up_exps = nullptr,
853
             ggml_tensor * gate_up_exps_b = nullptr,
854
             ggml_tensor * up_exps_s = nullptr,
855
             ggml_tensor * gate_exps_s = nullptr,
856
             ggml_tensor * down_exps_s = nullptr) const;
857
858
    //
859
    // inputs
860
    //
861
862
    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
863
    ggml_tensor * build_inp_pos() const;
864
    ggml_tensor * build_inp_attn_scale() const;
865
    ggml_tensor * build_inp_out_ids() const;
866
    ggml_tensor * build_inp_mean() const;
867
    ggml_tensor * build_inp_cls() const;
868
869
    ggml_tensor * build_inp_cross_embd() const;
870
    ggml_tensor * build_inp_pos_bucket_enc() const;
871
    ggml_tensor * build_inp_pos_bucket_dec() const;
872
    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
873
874
    //
875
    // attention
876
    //
877
878
    ggml_tensor * build_attn_mha(
879
            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
880
            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
881
            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
882
            ggml_tensor * kq_b,
883
            ggml_tensor * kq_mask,
884
            ggml_tensor * sinks,   // [n_head_q]
885
            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
886
                  float   kq_scale,
887
                    int   il) const;
888
889
    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
890
891
    ggml_tensor * build_attn(
892
            llm_graph_input_attn_no_cache * inp,
893
            ggml_tensor * wo,
894
            ggml_tensor * wo_b,
895
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
896
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
897
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
898
            ggml_tensor * kq_b,
899
            ggml_tensor * sinks, // [n_head_q]
900
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
901
                  float   kq_scale,
902
                    int   il) const;
903
904
    llm_graph_input_attn_kv * build_attn_inp_kv() const;
905
906
    ggml_tensor * build_attn(
907
            llm_graph_input_attn_kv * inp,
908
            ggml_tensor * wo,
909
            ggml_tensor * wo_b,
910
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
911
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
912
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
913
            ggml_tensor * kq_b,
914
            ggml_tensor * sinks, // [n_head_q]
915
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
916
                  float   kq_scale,
917
                    int   il) const;
918
919
    llm_graph_input_attn_k  * build_attn_inp_k() const;
920
921
    ggml_tensor * build_attn(
922
            llm_graph_input_attn_k * inp,
923
            ggml_tensor * wo,
924
            ggml_tensor * wo_b,
925
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
926
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
927
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
928
            ggml_tensor * kq_b,
929
            ggml_tensor * sinks, // [n_head_q]
930
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
931
                  float   kq_scale,
932
                    int   il) const;
933
934
    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
935
936
    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
937
    ggml_tensor * build_attn(
938
            llm_graph_input_attn_kv_iswa * inp,
939
            ggml_tensor * wo,
940
            ggml_tensor * wo_b,
941
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
942
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
943
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
944
            ggml_tensor * kq_b,
945
            ggml_tensor * sinks, // [n_head_q]
946
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
947
                  float   kq_scale,
948
                    int   il) const;
949
950
    llm_graph_input_attn_cross * build_attn_inp_cross() const;
951
952
    ggml_tensor * build_attn(
953
            llm_graph_input_attn_cross * inp,
954
            ggml_tensor * wo,
955
            ggml_tensor * wo_b,
956
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
957
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
958
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
959
            ggml_tensor * kq_b,
960
            ggml_tensor * sinks, // [n_head_q]
961
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
962
                  float   kq_scale,
963
                    int   il) const;
964
965
    //
966
    // recurrent
967
    //
968
969
    // TODO: move this implementation to llama_memory_recurrent.
970
    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
971
    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
972
    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
973
    //         `llama_memory_recurrent`
974
    ggml_tensor * build_rs(
975
            ggml_tensor * s,
976
            ggml_tensor * state_copy_main,
977
            ggml_tensor * state_copy_extra,
978
                int32_t   state_size,
979
                int32_t   n_seqs,
980
               uint32_t   n_rs,
981
               uint32_t   rs_head,
982
               uint32_t   rs_size,
983
                int32_t   rs_zero,
984
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
985
986
    llm_graph_input_rs * build_rs_inp() const;
987
988
    ggml_tensor * build_rs(
989
            llm_graph_input_rs * inp,
990
            ggml_tensor * s,
991
                int32_t   state_size,
992
                int32_t   n_seqs,
993
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
994
995
    ggml_tensor * build_rwkv_token_shift_load(
996
        llm_graph_input_rs * inp,
997
        const llama_ubatch & ubatch,
998
                       int   il) const;
999
1000
    ggml_tensor * build_rwkv_token_shift_store(
1001
             ggml_tensor * token_shift,
1002
      const llama_ubatch & ubatch,
1003
                     int   il) const;
1004
    //
1005
    // hybrid
1006
    //
1007
1008
    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
1009
    llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const;
1010
1011
    llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
1012
1013
    //
1014
    // pooling
1015
    //
1016
1017
    void build_pooling(
1018
            ggml_tensor * cls,
1019
            ggml_tensor * cls_b,
1020
            ggml_tensor * cls_out,
1021
            ggml_tensor * cls_out_b,
1022
            ggml_tensor * cls_norm) const;
1023
1024
    //
1025
    // sampling (backend sampling)
1026
    //
1027
1028
    void build_sampling() const;
1029
1030
    //
1031
    // dense (out)
1032
    //
1033
1034
    void build_dense_out(
1035
            ggml_tensor * dense_2,
1036
            ggml_tensor * dense_2_b,
1037
            ggml_tensor * dense_3) const;
1038
};
1039
1040
// TODO: better name
1041
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);