Coverage Report

Created: 2026-03-07 06:35

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-graph.h
Line
Count
Source
1
#pragma once
2
3
#include "llama-arch.h"
4
#include "llama-batch.h"
5
#include "llama-hparams.h"
6
#include "llama-adapter.h"
7
8
#include <cstdint>
9
#include <vector>
10
#include <memory>
11
#include <set>
12
#include <functional>
13
#include <map>
14
15
struct ggml_cgraph;
16
struct ggml_context;
17
struct ggml_tensor;
18
19
struct llama_cparams;
20
21
struct llama_memory_context_i;
22
23
class llama_kv_cache_context;
24
class llama_kv_cache_iswa_context;
25
class llama_memory_recurrent_context;
26
class llama_memory_hybrid_context;
27
class llama_memory_hybrid_iswa_context;
28
29
// certain models (typically multi-modal) can produce different types of graphs
30
enum llm_graph_type {
31
    LLM_GRAPH_TYPE_DEFAULT,
32
    LLM_GRAPH_TYPE_ENCODER,
33
    LLM_GRAPH_TYPE_DECODER,
34
};
35
36
enum llm_ffn_op_type {
37
    LLM_FFN_SILU,
38
    LLM_FFN_GELU,
39
    LLM_FFN_RELU,
40
    LLM_FFN_RELU_SQR,
41
    LLM_FFN_SWIGLU,
42
    LLM_FFN_GEGLU,
43
    LLM_FFN_REGLU,
44
    LLM_FFN_SWIGLU_OAI_MOE,
45
};
46
47
enum llm_ffn_gate_type {
48
    LLM_FFN_SEQ,
49
    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
50
};
51
52
enum llm_norm_type {
53
    LLM_NORM,
54
    LLM_NORM_RMS,
55
    LLM_NORM_GROUP,
56
};
57
58
// TODO: tmp - need something better to pass the data from the encoder to the decoder
59
struct llama_cross {
60
    // the output embeddings from the encoder as a ggml tensor
61
    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
62
    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
63
    //ggml_tensor * t_embd = nullptr;
64
65
    int64_t n_embd = 0;
66
    int64_t n_enc  = 0;
67
68
    // embeddings data copied to host memory (tmp)
69
    std::vector<float> v_embd;
70
71
    // needed to construct the cross-attention mask in the decoder
72
    std::vector<std::set<llama_seq_id>> seq_ids_enc;
73
};
74
75
struct llm_graph_params;
76
77
//
78
// llm_graph_input
79
//
80
81
class llm_graph_input_i {
82
public:
83
0
    llm_graph_input_i() {
84
0
        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
85
0
        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
86
0
    }
87
88
0
    virtual ~llm_graph_input_i() = default;
89
90
    virtual void set_input(const llama_ubatch * ubatch) = 0;
91
92
    // return true if the resulting input tensors using the provided graph parameters would be
93
    //   the same as the previous input tensors that we have currently stored in the object
94
0
    virtual bool can_reuse(const llm_graph_params & params) {
95
        // returning false here by default will prevent from reusing the graph if the check
96
        //   for the input type has not been implemented yet
97
0
        GGML_UNUSED(params);
98
0
        return false;
99
0
    }
100
protected:
101
    // env: LLAMA_GRAPH_INPUT_DEBUG
102
    int debug = 0;
103
};
104
105
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
106
107
class llm_graph_input_embd : public llm_graph_input_i {
108
public:
109
0
    llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
110
    virtual ~llm_graph_input_embd() = default;
111
112
    void set_input(const llama_ubatch * ubatch) override;
113
114
    bool can_reuse(const llm_graph_params & params) override;
115
116
    ggml_tensor * tokens = nullptr; // I32 [n_batch]
117
    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
118
119
    const int64_t n_embd = 0;
120
};
121
122
class llm_graph_input_pos : public llm_graph_input_i {
123
public:
124
0
    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
125
    virtual ~llm_graph_input_pos() = default;
126
127
    void set_input(const llama_ubatch * ubatch) override;
128
129
    bool can_reuse(const llm_graph_params & params) override;
130
131
    ggml_tensor * pos = nullptr; // I32 [n_batch]
132
133
    const uint32_t n_pos_per_embd = 1;
134
};
135
136
// temperature tuning, used by llama4
137
class llm_graph_input_attn_temp : public llm_graph_input_i {
138
public:
139
    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
140
0
        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
141
    virtual ~llm_graph_input_attn_temp() = default;
142
143
    void set_input(const llama_ubatch * ubatch) override;
144
145
    ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
146
147
    const uint32_t n_attn_temp_floor_scale;
148
    const float    f_attn_temp_scale;
149
    const float    f_attn_temp_offset;
150
};
151
152
class llm_graph_input_pos_bucket : public llm_graph_input_i {
153
public:
154
0
    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
155
    virtual ~llm_graph_input_pos_bucket() = default;
156
157
    void set_input(const llama_ubatch * ubatch) override;
158
159
    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
160
161
    const llama_hparams hparams;
162
};
163
164
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
165
public:
166
    llm_graph_input_pos_bucket_kv(
167
            const llama_hparams & hparams,
168
0
            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
169
    virtual ~llm_graph_input_pos_bucket_kv() = default;
170
171
    void set_input(const llama_ubatch * ubatch) override;
172
173
    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
174
175
    const llama_hparams hparams;
176
177
    const llama_kv_cache_context * mctx;
178
};
179
180
class llm_graph_input_out_ids : public llm_graph_input_i {
181
public:
182
    llm_graph_input_out_ids(
183
            const llama_hparams & hparams,
184
            const llama_cparams & cparams,
185
0
            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
186
    virtual ~llm_graph_input_out_ids() = default;
187
188
    void set_input(const llama_ubatch * ubatch) override;
189
190
    bool can_reuse(const llm_graph_params & params) override;
191
192
    ggml_tensor * out_ids; // I32 [n_outputs]
193
194
    const llama_hparams hparams;
195
    const llama_cparams cparams;
196
197
    const uint32_t n_outputs;
198
};
199
200
class llm_graph_input_mean : public llm_graph_input_i {
201
public:
202
0
    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
203
    virtual ~llm_graph_input_mean() = default;
204
205
    void set_input(const llama_ubatch * ubatch) override;
206
207
    ggml_tensor * mean; // F32 [n_batch, n_batch]
208
209
    const llama_cparams cparams;
210
};
211
212
class llm_graph_input_cls : public llm_graph_input_i {
213
public:
214
0
    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
215
    virtual ~llm_graph_input_cls() = default;
216
217
    void set_input(const llama_ubatch * ubatch) override;
218
219
    ggml_tensor * cls; // I32 [n_batch]
220
221
    const llama_cparams cparams;
222
    const llm_arch arch;
223
};
224
225
class llm_graph_input_rs : public llm_graph_input_i {
226
public:
227
0
    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
228
    virtual ~llm_graph_input_rs() = default;
229
230
    void set_input(const llama_ubatch * ubatch) override;
231
232
    bool can_reuse(const llm_graph_params & params) override;
233
234
    ggml_tensor * s_copy;  // I32 [n_rs]
235
236
    // views of s_copy, computed once per graph
237
    // and shared across layers which use build_rs
238
    ggml_tensor * s_copy_main;   // I32 [n_seqs]
239
    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
240
241
    const llama_memory_recurrent_context * mctx;
242
243
    // used in view offsets, need to match for valid graph reuse
244
    uint32_t head;
245
    int32_t rs_z;
246
};
247
248
class llm_graph_input_cross_embd : public llm_graph_input_i {
249
public:
250
    llm_graph_input_cross_embd(
251
0
            const llama_cross * cross) : cross(cross) {}
252
    virtual ~llm_graph_input_cross_embd() = default;
253
254
    void set_input(const llama_ubatch * ubatch) override;
255
256
    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
257
258
    const llama_cross * cross;
259
};
260
261
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
262
public:
263
    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
264
0
        hparams(hparams),
265
0
        cparams(cparams) {
266
0
    }
267
    ~llm_graph_input_attn_no_cache() = default;
268
269
    void set_input(const llama_ubatch * ubatch) override;
270
271
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
272
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
273
274
    // n_tokens == n_batch
275
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
276
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
277
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
278
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
279
280
    const llama_hparams hparams;
281
    const llama_cparams cparams;
282
};
283
284
class llm_graph_input_attn_kv : public llm_graph_input_i {
285
public:
286
    llm_graph_input_attn_kv(
287
            const llama_hparams & hparams,
288
            const llama_cparams & cparams,
289
            const llama_kv_cache_context * mctx) :
290
0
        hparams(hparams),
291
0
        cparams(cparams),
292
0
        mctx(mctx) {
293
0
    }
294
    ~llm_graph_input_attn_kv() = default;
295
296
    void set_input(const llama_ubatch * ubatch) override;
297
298
    bool can_reuse(const llm_graph_params & params) override;
299
300
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
301
0
    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
302
303
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
304
305
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
306
    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
307
308
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
310
311
    // note: these have to be copies because in order to be able to reuse a graph, its inputs
312
    //       need to carry these parameters with them. otherwise, they can point to freed
313
    //       llm_graph_params from a previous batch, causing stack-use-after-return
314
    const llama_hparams hparams;
315
    const llama_cparams cparams;
316
317
    const llama_kv_cache_context * mctx;
318
};
319
320
// V-less input for the KV cache
321
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
322
class llm_graph_input_attn_k : public llm_graph_input_i {
323
public:
324
    llm_graph_input_attn_k(
325
            const llama_hparams & hparams,
326
            const llama_cparams & cparams,
327
            const llama_kv_cache_context * mctx) :
328
0
        hparams(hparams),
329
0
        cparams(cparams),
330
0
        mctx(mctx) {
331
0
    }
332
    ~llm_graph_input_attn_k() = default;
333
334
    void set_input(const llama_ubatch * ubatch) override;
335
336
    bool can_reuse(const llm_graph_params & params) override;
337
338
0
    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
339
340
0
    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
341
342
    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
343
344
    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
345
    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
346
347
    const llama_hparams hparams;
348
    const llama_cparams cparams;
349
350
    const llama_kv_cache_context * mctx;
351
};
352
353
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
354
public:
355
    llm_graph_input_attn_kv_iswa(
356
            const llama_hparams & hparams,
357
            const llama_cparams & cparams,
358
            const llama_kv_cache_iswa_context * mctx) :
359
0
        hparams(hparams),
360
0
        cparams(cparams),
361
0
        mctx(mctx) {
362
0
    }
363
    ~llm_graph_input_attn_kv_iswa() = default;
364
365
    void set_input(const llama_ubatch * ubatch) override;
366
367
    bool can_reuse(const llm_graph_params & params) override;
368
369
0
    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
370
0
    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
371
0
    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
372
0
    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
373
374
0
    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
375
0
    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
376
377
    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
378
    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
379
    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
380
    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
381
382
    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
383
    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
384
    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
385
    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
386
387
    const llama_hparams hparams;
388
    const llama_cparams cparams;
389
390
    const llama_kv_cache_iswa_context * mctx;
391
};
392
393
class llm_graph_input_attn_cross : public llm_graph_input_i {
394
public:
395
0
    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
396
    ~llm_graph_input_attn_cross() = default;
397
398
    void set_input(const llama_ubatch * ubatch) override;
399
400
0
    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
401
402
    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
403
    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
404
405
    const llama_cross * cross = nullptr;
406
};
407
408
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
409
public:
410
    llm_graph_input_mem_hybrid(
411
            const llama_cparams & cparams,
412
            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
413
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
414
            const llama_memory_hybrid_context *      mctx) :
415
0
        inp_attn(std::move(inp_attn)),
416
0
        inp_rs(std::move(inp_rs)),
417
0
        cparams(cparams),
418
0
        mctx(mctx) { }
419
0
    virtual ~llm_graph_input_mem_hybrid() = default;
420
421
    void set_input(const llama_ubatch * ubatch) override;
422
423
    bool can_reuse(const llm_graph_params & params) override;
424
425
    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
426
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
427
428
0
    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
429
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
430
431
    const llama_cparams cparams;
432
433
    const llama_memory_hybrid_context * mctx;
434
};
435
436
class llm_graph_input_mem_hybrid_k : public llm_graph_input_i {
437
public:
438
    llm_graph_input_mem_hybrid_k(
439
            const llama_cparams & cparams,
440
            std::unique_ptr<llm_graph_input_attn_k> inp_attn,
441
            std::unique_ptr<llm_graph_input_rs>      inp_rs,
442
            const llama_memory_hybrid_context *      mctx) :
443
0
        inp_attn(std::move(inp_attn)),
444
0
        inp_rs(std::move(inp_rs)),
445
0
        cparams(cparams),
446
0
        mctx(mctx) { }
447
0
    virtual ~llm_graph_input_mem_hybrid_k() = default;
448
449
    void set_input(const llama_ubatch * ubatch) override;
450
451
    bool can_reuse(const llm_graph_params & params) override;
452
453
    std::unique_ptr<llm_graph_input_attn_k> inp_attn;
454
    std::unique_ptr<llm_graph_input_rs>      inp_rs;
455
456
0
    llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); }
457
0
    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
458
459
    const llama_cparams cparams;
460
461
    const llama_memory_hybrid_context * mctx;
462
};
463
464
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
465
public:
466
    llm_graph_input_mem_hybrid_iswa(
467
            const llama_cparams & cparams,
468
            std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
469
            std::unique_ptr<llm_graph_input_rs>          inp_rs,
470
            const llama_memory_hybrid_iswa_context *     mctx) :
471
0
        inp_attn(std::move(inp_attn)),
472
0
        inp_rs(std::move(inp_rs)),
473
0
        cparams(cparams),
474
0
        mctx(mctx) { }
475
0
    virtual ~llm_graph_input_mem_hybrid_iswa() = default;
476
477
    void set_input(const llama_ubatch * ubatch) override;
478
479
    bool can_reuse(const llm_graph_params & params) override;
480
481
    std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
482
    std::unique_ptr<llm_graph_input_rs>          inp_rs;
483
484
0
    llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
485
0
    llm_graph_input_rs           * get_recr() const { return inp_rs.get(); }
486
487
    const llama_cparams cparams;
488
489
    const llama_memory_hybrid_iswa_context * mctx;
490
};
491
492
class llm_graph_input_sampling : public llm_graph_input_i {
493
public:
494
    llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
495
0
        samplers(std::move(samplers)) { }
496
0
    virtual ~llm_graph_input_sampling() = default;
497
498
    void set_input(const llama_ubatch * ubatch) override;
499
    bool can_reuse(const llm_graph_params & params) override;
500
501
    std::map<llama_seq_id, llama_sampler *> samplers;
502
};
503
504
//
505
// llm_graph_result
506
//
507
508
// these objects deliver the result from the graph build process back to the llama_context
509
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
510
//   specific data, by calling the set_inputs() method
511
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
512
//   these are used by the llama_context to extact the relevant data, based on the compute parameters
513
514
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
515
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
516
517
class llm_graph_result;
518
519
struct llm_graph_params {
520
    llm_arch arch = LLM_ARCH_UNKNOWN;
521
522
    llama_hparams hparams;
523
    llama_cparams cparams;
524
525
    llama_ubatch ubatch; // note: intentionally make a copy
526
527
    llm_graph_type gtype;
528
529
    ggml_backend_sched_t sched;
530
    ggml_backend_t backend_cpu;
531
532
    const llama_adapter_cvec     * cvec;
533
    const llama_adapter_loras    * loras;
534
    const llama_memory_context_i * mctx;
535
    const llama_cross            * cross;
536
537
    std::map<llama_seq_id, llama_sampler *> samplers;
538
539
    static bool samplers_equal(
540
          const std::map<llama_seq_id, llama_sampler *> & lhs,
541
0
          const std::map<llama_seq_id, llama_sampler *> & rhs) {
542
0
        if (lhs.size() != rhs.size()) {
543
0
            return false;
544
0
        }
545
0
        for (const auto & [seq_id, sampler] : lhs) {
546
0
            auto it = rhs.find(seq_id);
547
0
            if (it == rhs.end() || it->second != sampler) {
548
0
                return false;
549
0
            }
550
0
        }
551
0
        return true;
552
0
    }
553
554
    uint32_t n_outputs;
555
556
    llm_graph_cb cb;
557
558
    llm_graph_result * res;
559
560
    // return true if the "other" params would result in a graph with the same topology as with the current params
561
    //   having the same topology allows us to reuse the graph in some cases
562
0
    bool allow_reuse(const llm_graph_params & other) const {
563
        // first check the ubatch
564
0
        bool can_reuse_ubatch =
565
0
            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
566
0
            ubatch.n_tokens     == other.ubatch.n_tokens &&
567
0
            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
568
0
            ubatch.n_seqs       == other.ubatch.n_seqs &&
569
0
            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
570
0
            (
571
0
                (!ubatch.token && !other.ubatch.token) ||
572
0
                (!ubatch.embd  && !other.ubatch.embd)
573
0
            );
574
575
        // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
576
        //   the reason is because the set of attention streams would be different for different sequences
577
0
        if (can_reuse_ubatch && ubatch.equal_seqs()) {
578
0
            if (!ubatch.data) {
579
                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
580
                //   therefore we cannot perform the sequence id check. normally should never happen
581
0
                can_reuse_ubatch = false;
582
0
            } else {
583
0
                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
584
0
                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
585
0
                }
586
0
            }
587
0
        }
588
589
0
        if (!can_reuse_ubatch) {
590
0
            return false;
591
0
        }
592
593
0
        if (n_outputs != other.n_outputs) {
594
0
            return false;
595
0
        }
596
597
0
        if (!samplers_equal(samplers, other.samplers)) {
598
0
            return false;
599
0
        }
600
601
0
        if (samplers.size() > 0) {
602
0
            if (!ubatch.data || !other.ubatch.data) {
603
0
                return false;
604
0
            }
605
606
            // check that the outputs are the same for all samplers
607
0
            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
608
0
                if (ubatch.output[i]    != other.ubatch.output[i] ||
609
0
                    ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
610
0
                    return false;
611
0
                }
612
0
            }
613
0
        }
614
615
0
        return
616
0
            cparams.embeddings  == other.cparams.embeddings  &&
617
0
            cparams.causal_attn == other.cparams.causal_attn &&
618
0
            arch  == other.arch  &&
619
0
            gtype == other.gtype &&
620
0
            cvec  == other.cvec  &&
621
0
            loras == other.loras &&
622
0
            cross == other.cross;
623
0
    }
624
};
625
626
class llm_graph_result {
627
public:
628
    llm_graph_result(int64_t max_nodes);
629
630
0
    virtual ~llm_graph_result() = default;
631
632
0
    ggml_tensor * get_inp_tokens()  const { return t_inp_tokens; }
633
0
    ggml_tensor * get_logits()      const { return t_logits; }
634
0
    ggml_tensor * get_embd()        const { return t_embd; }
635
0
    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
636
637
0
    ggml_cgraph  * get_gf()  const { return gf; }
638
0
    ggml_context * get_ctx() const { return ctx_compute.get(); }
639
640
    int64_t get_max_nodes() const;
641
642
    void reset();
643
644
    void set_inputs(const llama_ubatch * ubatch);
645
    void set_outputs();
646
647
    // try to update the existing graph result using the new graph parameters in order to reuse it
648
    // this can only be done if we determine that the resulting graph using the new graph parameters
649
    //   would be identical to the existing graph. in that case, we simply have to update the memory
650
    //   contexts of the input tensors of the graph and we can reuse it for another computation
651
    // return true if the graph was updated and can be reused
652
    bool can_reuse(const llm_graph_params & params);
653
654
    llm_graph_input_i * add_input(llm_graph_input_ptr input);
655
656
    void set_params(const llm_graph_params & params);
657
658
    // important graph nodes
659
    ggml_tensor * t_inp_tokens  = nullptr;
660
    ggml_tensor * t_inp_embd    = nullptr; // [n_embd_inp, n_tokens]
661
    ggml_tensor * t_logits      = nullptr;
662
    ggml_tensor * t_embd        = nullptr;
663
    ggml_tensor * t_embd_pooled = nullptr;
664
665
    std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
666
    std::map<llama_seq_id, ggml_tensor*> t_candidates;
667
    std::map<llama_seq_id, ggml_tensor*> t_sampled;
668
    std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
669
670
    std::vector<llm_graph_input_ptr> inputs;
671
672
    ggml_context_ptr ctx_compute;
673
674
    // memory buffers used to evaluate the model
675
    std::vector<uint8_t> buf_compute_meta;
676
677
    ggml_cgraph * gf;
678
679
    int64_t max_nodes;
680
681
private:
682
    // keep a copy of the previous graph parameters
683
    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
684
    // note: these are updated after constructing the new graph
685
    llm_graph_params params;
686
687
    // env: LLAMA_GRAPH_RESULT_DEBUG
688
    int debug = 0;
689
};
690
691
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
692
693
//
694
// llm_graph_context
695
//
696
697
// used in build_rs to properly order writes and avoid unnecessary copies
698
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
699
700
struct llm_graph_context {
701
    const llm_arch arch;
702
703
    const llama_hparams & hparams;
704
    const llama_cparams & cparams;
705
    const llama_ubatch  & ubatch;
706
707
    const int64_t n_embd;
708
    const int64_t n_layer;
709
    const int64_t n_rot;
710
    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
711
    const int64_t n_head;
712
    const int64_t n_head_kv;
713
    const int64_t n_embd_head_k;
714
    const int64_t n_embd_k_gqa;
715
    const int64_t n_embd_head_v;
716
    const int64_t n_embd_v_gqa;
717
    const int64_t n_expert;
718
    const int64_t n_expert_used;
719
720
    const float freq_base;
721
    const float freq_scale;
722
    const float ext_factor;
723
    const float attn_factor;
724
    const float beta_fast;
725
    const float beta_slow;
726
    const float norm_eps;
727
    const float norm_rms_eps;
728
729
    const int64_t n_tokens;
730
    const int64_t n_outputs;
731
    const int32_t n_ctx_orig; // yarn
732
733
    const enum llama_pooling_type pooling_type;
734
    const enum llama_rope_type    rope_type;
735
736
    ggml_backend_sched_t sched;
737
738
    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
739
740
    const llama_adapter_cvec     * cvec;
741
    const llama_adapter_loras    * loras;
742
    const llama_memory_context_i * mctx;
743
    const llama_cross            * cross;
744
745
    std::map<llama_seq_id, llama_sampler *> samplers;
746
747
    const llm_graph_cb & cb_func;
748
749
    llm_graph_result * res;
750
751
    ggml_context * ctx0 = nullptr;
752
    ggml_cgraph  * gf   = nullptr;
753
754
    llm_graph_context(const llm_graph_params & params);
755
0
    virtual ~llm_graph_context() = default;
756
757
    void cb(ggml_tensor * cur, const char * name, int il) const;
758
759
    //
760
    // common
761
    //
762
763
    ggml_tensor * build_cvec(
764
             ggml_tensor * cur,
765
                     int   il) const;
766
767
    // do mat_mul, while optionally apply lora
768
    ggml_tensor * build_lora_mm(
769
              ggml_tensor * w,
770
              ggml_tensor * cur) const;
771
772
    // do mat_mul_id, while optionally apply lora
773
    ggml_tensor * build_lora_mm_id(
774
              ggml_tensor * w,   // ggml_tensor * as
775
              ggml_tensor * cur, // ggml_tensor * b
776
              ggml_tensor * ids) const;
777
778
    ggml_tensor * build_norm(
779
             ggml_tensor * cur,
780
             ggml_tensor * mw,
781
             ggml_tensor * mb,
782
           llm_norm_type   type,
783
                     int   il) const;
784
785
    ggml_tensor * build_ffn(
786
             ggml_tensor * cur,
787
             ggml_tensor * up,
788
             ggml_tensor * up_b,
789
             ggml_tensor * up_s,
790
             ggml_tensor * gate,
791
             ggml_tensor * gate_b,
792
             ggml_tensor * gate_s,
793
             ggml_tensor * down,
794
             ggml_tensor * down_b,
795
             ggml_tensor * down_s,
796
             ggml_tensor * act_scales,
797
         llm_ffn_op_type   type_op,
798
       llm_ffn_gate_type   type_gate,
799
                     int   il) const;
800
801
    // build MoE FFN without bias tensors
802
    ggml_tensor * build_moe_ffn(
803
             ggml_tensor * cur,
804
             ggml_tensor * gate_inp,
805
             ggml_tensor * up_exps,
806
             ggml_tensor * gate_exps,
807
             ggml_tensor * down_exps,
808
             ggml_tensor * exp_probs_b,
809
                 int64_t   n_expert,
810
                 int64_t   n_expert_used,
811
         llm_ffn_op_type   type_op,
812
                    bool   norm_w,
813
                    bool   scale_w,
814
                   float   w_scale,
815
            llama_expert_gating_func_type gating_op,
816
                     int   il,
817
             ggml_tensor * probs_in = nullptr,
818
             ggml_tensor * gate_up_exps = nullptr) const;
819
820
    ggml_tensor * build_moe_ffn(
821
             ggml_tensor * cur,
822
             ggml_tensor * gate_inp,
823
             ggml_tensor * gate_inp_b,
824
             ggml_tensor * up_exps,
825
             ggml_tensor * up_exps_b,
826
             ggml_tensor * gate_exps,
827
             ggml_tensor * gate_exps_b,
828
             ggml_tensor * down_exps,
829
             ggml_tensor * down_exps_b,
830
             ggml_tensor * exp_probs_b,
831
                 int64_t   n_expert,
832
                 int64_t   n_expert_used,
833
         llm_ffn_op_type   type_op,
834
                    bool   norm_w,
835
                    bool   scale_w,
836
                   float   w_scale,
837
            llama_expert_gating_func_type gating_op,
838
                     int   il,
839
             ggml_tensor * probs_in = nullptr,
840
             ggml_tensor * gate_up_exps = nullptr,
841
             ggml_tensor * gate_up_exps_b = nullptr) const;
842
843
    //
844
    // inputs
845
    //
846
847
    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
848
    ggml_tensor * build_inp_pos() const;
849
    ggml_tensor * build_inp_attn_scale() const;
850
    ggml_tensor * build_inp_out_ids() const;
851
    ggml_tensor * build_inp_mean() const;
852
    ggml_tensor * build_inp_cls() const;
853
854
    ggml_tensor * build_inp_cross_embd() const;
855
    ggml_tensor * build_inp_pos_bucket_enc() const;
856
    ggml_tensor * build_inp_pos_bucket_dec() const;
857
    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
858
859
    //
860
    // attention
861
    //
862
863
    ggml_tensor * build_attn_mha(
864
            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
865
            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
866
            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
867
            ggml_tensor * kq_b,
868
            ggml_tensor * kq_mask,
869
            ggml_tensor * sinks,   // [n_head_q]
870
            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
871
                  float   kq_scale,
872
                    int   il) const;
873
874
    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
875
876
    ggml_tensor * build_attn(
877
            llm_graph_input_attn_no_cache * inp,
878
            ggml_tensor * wo,
879
            ggml_tensor * wo_b,
880
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
881
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
882
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
883
            ggml_tensor * kq_b,
884
            ggml_tensor * sinks, // [n_head_q]
885
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
886
                  float   kq_scale,
887
                    int   il) const;
888
889
    llm_graph_input_attn_kv * build_attn_inp_kv() const;
890
891
    ggml_tensor * build_attn(
892
            llm_graph_input_attn_kv * inp,
893
            ggml_tensor * wo,
894
            ggml_tensor * wo_b,
895
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
896
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
897
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
898
            ggml_tensor * kq_b,
899
            ggml_tensor * sinks, // [n_head_q]
900
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
901
                  float   kq_scale,
902
                    int   il) const;
903
904
    llm_graph_input_attn_k  * build_attn_inp_k() const;
905
906
    ggml_tensor * build_attn(
907
            llm_graph_input_attn_k * inp,
908
            ggml_tensor * wo,
909
            ggml_tensor * wo_b,
910
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
911
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
912
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
913
            ggml_tensor * kq_b,
914
            ggml_tensor * sinks, // [n_head_q]
915
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
916
                  float   kq_scale,
917
                    int   il) const;
918
919
    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
920
921
    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
922
    ggml_tensor * build_attn(
923
            llm_graph_input_attn_kv_iswa * inp,
924
            ggml_tensor * wo,
925
            ggml_tensor * wo_b,
926
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
927
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
928
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
929
            ggml_tensor * kq_b,
930
            ggml_tensor * sinks, // [n_head_q]
931
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
932
                  float   kq_scale,
933
                    int   il) const;
934
935
    llm_graph_input_attn_cross * build_attn_inp_cross() const;
936
937
    ggml_tensor * build_attn(
938
            llm_graph_input_attn_cross * inp,
939
            ggml_tensor * wo,
940
            ggml_tensor * wo_b,
941
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
942
            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
943
            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
944
            ggml_tensor * kq_b,
945
            ggml_tensor * sinks, // [n_head_q]
946
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
947
                  float   kq_scale,
948
                    int   il) const;
949
950
    //
951
    // recurrent
952
    //
953
954
    // TODO: move this implementation to llama_memory_recurrent.
955
    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
956
    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
957
    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
958
    //         `llama_memory_recurrent`
959
    ggml_tensor * build_rs(
960
            ggml_tensor * s,
961
            ggml_tensor * state_copy_main,
962
            ggml_tensor * state_copy_extra,
963
                int32_t   state_size,
964
                int32_t   n_seqs,
965
               uint32_t   n_rs,
966
               uint32_t   rs_head,
967
               uint32_t   rs_size,
968
                int32_t   rs_zero,
969
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
970
971
    llm_graph_input_rs * build_rs_inp() const;
972
973
    ggml_tensor * build_rs(
974
            llm_graph_input_rs * inp,
975
            ggml_tensor * s,
976
                int32_t   state_size,
977
                int32_t   n_seqs,
978
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
979
980
    ggml_tensor * build_rwkv_token_shift_load(
981
        llm_graph_input_rs * inp,
982
        const llama_ubatch & ubatch,
983
                       int   il) const;
984
985
    ggml_tensor * build_rwkv_token_shift_store(
986
             ggml_tensor * token_shift,
987
      const llama_ubatch & ubatch,
988
                     int   il) const;
989
    //
990
    // hybrid
991
    //
992
993
    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
994
    llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const;
995
996
    llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
997
998
    //
999
    // pooling
1000
    //
1001
1002
    void build_pooling(
1003
            ggml_tensor * cls,
1004
            ggml_tensor * cls_b,
1005
            ggml_tensor * cls_out,
1006
            ggml_tensor * cls_out_b,
1007
            ggml_tensor * cls_norm) const;
1008
1009
    //
1010
    // sampling (backend sampling)
1011
    //
1012
1013
    void build_sampling() const;
1014
1015
    //
1016
    // dense (out)
1017
    //
1018
1019
    void build_dense_out(
1020
            ggml_tensor * dense_2,
1021
            ggml_tensor * dense_2_b,
1022
            ggml_tensor * dense_3) const;
1023
};
1024
1025
// TODO: better name
1026
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);