Coverage Report

Created: 2025-12-28 06:26

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-batch.h
Line
Count
Source
1
#pragma once
2
3
#include "llama.h"
4
5
#include "llama-cparams.h"
6
7
#include <array>
8
#include <vector>
9
#include <set>
10
#include <bitset>
11
#include <memory>
12
#include <unordered_map>
13
14
// keep this struct lightweight
15
struct llama_ubatch {
16
0
    bool equal_seqs() const {
17
0
        return b_equal_seqs != 0;
18
0
    }
19
20
    // typical for M-RoPE cases:
21
    //   0 - sequantial position of the tokens/embeddings in the sequence
22
    //   1 - y position in the image
23
    //   2 - x position in the image
24
    //   3 - other
25
0
    bool is_pos_2d() const {
26
        // TODO @ngxson : we may need to check for model arch when more models use >1 positions
27
0
        return n_pos >= 3;
28
0
    }
29
30
    uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
31
                           //       otherwise address sanitizer complains
32
    // TODO: whole_seqs for embeddings?
33
34
    uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
35
    uint32_t n_seq_tokens; // tokens per sequence set
36
    uint32_t n_seqs;       // sequence sets in the ubatch
37
    uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
38
    uint32_t n_pos;        // number of position inputs for each token/embedding
39
40
    // seq_id_unq: unique sequence ids in the ubatch
41
    // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
42
    //             used for extracting sequence pooled embeddings
43
44
    //                          // size               | idx | val
45
    llama_token  *  token;      // [n_tokens]         | i   | id, token
46
    float        *  embd;       // [n_embd, n_tokens] | i   | embd
47
    llama_pos    *  pos;        // [n_tokens*n_pos]   | i   | pos
48
    int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
49
    llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
50
    llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
51
    int32_t      *  seq_idx;    // [LLAMA_MAX_SEQ]    | -   | seq_idx
52
    int8_t       *  output;     // [n_tokens]         | i   | -
53
54
    struct data_t {
55
        std::vector<llama_token>    token;
56
        std::vector<float>          embd;
57
        std::vector<llama_pos>      pos;
58
        std::vector<int32_t>        n_seq_id;
59
        std::vector<llama_seq_id *> seq_id;      // these point into the seq_id_data below
60
        std::vector<llama_seq_id>   seq_id_unq;
61
        std::vector<int32_t>        seq_idx;
62
        std::vector<int8_t>         output;
63
64
        std::vector<llama_seq_id> seq_id_data;
65
    };
66
67
    // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
68
    std::shared_ptr<data_t> data;
69
};
70
71
// a helper for sanitizing, fulfilling and splitting a batch
72
class llama_batch_allocr {
73
public:
74
    llama_batch_allocr(uint32_t n_pos_per_embd);
75
76
    // sanitize and auto-gen missing data in the input batch
77
    // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
78
    bool init(
79
            const llama_batch & batch_inp,
80
            const llama_vocab & vocab,
81
            const llama_memory_i * memory,
82
            uint32_t n_embd,
83
            uint32_t n_seq_max,
84
            bool output_all);
85
86
    const llama_batch & get_batch() const;
87
88
    uint32_t get_n_tokens()  const;
89
    uint32_t get_n_outputs() const;
90
    uint32_t get_n_used()    const;
91
92
    // the array of output indices in the order they were encountered during the ubatch splitting
93
    std::vector<int32_t> & get_out_ids();
94
95
    // min/max positions of each sequence in the current ubatch
96
    llama_pos seq_pos_min(llama_seq_id seq_id) const;
97
    llama_pos seq_pos_max(llama_seq_id seq_id) const;
98
99
    // call once before splitting the batch to reset the internal state
100
    void split_reset();
101
102
    // simple split, unknown number of sequence sets of unequal lengths
103
    llama_ubatch split_simple(uint32_t n_ubatch);
104
105
    // make ubatches of equal-length sequences sets
106
    // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
107
    llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
108
109
    // sequence-set-wise split - each ubatch contains a single sequence-set
110
    llama_ubatch split_seq(uint32_t n_ubatch);
111
112
    // a helper method for creating a well-defined ubatch of tokens
113
    // TODO: support embeddings if needed in the future
114
    llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
115
116
private:
117
    void clear();
118
119
    // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
120
    // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
121
    llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
122
123
    // for debugging, start with LLAMA_BATCH_DEBUG=2
124
    void ubatch_print(const llama_ubatch & ubatch, int debug);
125
126
    llama_batch batch;
127
128
    // only for debugging purposes
129
    const llama_vocab * vocab;
130
131
    // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
132
    //       ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
133
    const uint32_t n_pos_per_embd;
134
135
    uint32_t n_embd;
136
    uint32_t n_seq_max;
137
    uint32_t n_outputs;
138
139
    std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
140
141
    std::vector<llama_pos>      pos;
142
    std::vector<int32_t>        n_seq_id;
143
    std::vector<llama_seq_id *> seq_id;
144
    std::vector<llama_seq_id>   seq_id_unq;
145
    std::vector<int32_t>        seq_idx;
146
    std::vector<int8_t>         output;
147
148
    using pos_set_t = std::set<llama_pos>;
149
    using seq_cpl_t = std::vector<bool>;
150
151
    // helper flag to quickly determine if there are any coupled sequences in the batch
152
    bool has_cpl = false;
153
154
    std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
155
    std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
156
157
    using idx_vec_t = std::vector<int32_t>;
158
    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
159
160
    std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
161
162
    std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
163
164
    // batch indices of the output
165
    std::vector<int32_t> out_ids;
166
167
    uint32_t n_used;
168
169
    // used[i] indicates if token i has already been used in a previous ubatch
170
    std::vector<bool> used;
171
172
    int debug;
173
};