Coverage Report

Created: 2026-01-09 06:17

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-batch.cpp
Line
Count
Source
1
#include "llama-batch.h"
2
3
#include "llama-impl.h"
4
#include "llama-vocab.h"
5
#include "llama-memory.h"
6
7
#include <cassert>
8
#include <cstring>
9
#include <algorithm>
10
#include <sstream>
11
12
0
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
13
0
    const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
14
0
    debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
15
16
0
    seq_pos.resize(LLAMA_MAX_SEQ);
17
0
    seq_cpl.resize(LLAMA_MAX_SEQ);
18
0
    for (auto & cur : seq_cpl) {
19
0
        cur.resize(LLAMA_MAX_SEQ);
20
0
    }
21
22
0
    seq_idx.resize(LLAMA_MAX_SEQ, -1);
23
0
}
24
25
bool llama_batch_allocr::init(
26
        const llama_batch & batch_inp,
27
        const llama_vocab & vocab,
28
        const llama_memory_i * memory,
29
        uint32_t n_embd,
30
        uint32_t n_seq_max,
31
0
        bool output_all) {
32
0
    clear();
33
34
0
    batch = batch_inp;
35
36
0
    this->vocab = &vocab;
37
38
0
    GGML_ASSERT(batch.n_tokens > 0);
39
40
    //
41
    // validate input batch
42
    //
43
44
0
    if (n_seq_max > LLAMA_MAX_SEQ) {
45
0
        LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
0
        return false;
47
0
    }
48
49
0
    if (batch.token) {
50
0
        for (int32_t i = 0; i < batch.n_tokens; ++i) {
51
0
            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
52
0
                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
53
0
                return false;
54
0
            }
55
0
        }
56
0
    }
57
58
0
    if (batch.seq_id) {
59
0
        for (int32_t i = 0; i < batch.n_tokens; ++i) {
60
0
            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
61
0
                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
0
                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
63
0
                    return false;
64
0
                }
65
0
            }
66
0
        }
67
0
    }
68
69
    //
70
    // auto-generate missing fields
71
    //
72
73
0
    if (!batch.n_seq_id) {
74
0
        n_seq_id.resize(batch.n_tokens);
75
0
        for (int32_t i = 0; i < batch.n_tokens; i++) {
76
0
            n_seq_id[i] = seq_id_0.size();
77
0
        }
78
0
        batch.n_seq_id = n_seq_id.data();
79
0
    }
80
81
0
    if (!batch.seq_id) {
82
0
        seq_id.resize(batch.n_tokens + 1);
83
0
        seq_id[batch.n_tokens] = NULL;
84
0
        for (int32_t i = 0; i < batch.n_tokens; i++) {
85
0
            seq_id[i] = seq_id_0.data();
86
0
        }
87
0
        batch.seq_id = seq_id.data();
88
0
    }
89
90
0
    if (!batch.pos) {
91
0
        pos.resize(batch.n_tokens);
92
93
        // initialize the starting position for each sequence based on the positions in the memory
94
0
        llama_pos p0[LLAMA_MAX_SEQ];
95
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
96
0
            if (!memory) {
97
                // if no memory -> start from 0
98
0
                p0[s] = 0;
99
0
            } else {
100
0
                p0[s] = memory->seq_pos_max(s) + 1;
101
0
            }
102
0
        }
103
104
0
        for (int32_t i = 0; i < batch.n_tokens; i++) {
105
0
            const llama_seq_id seq_id = batch.seq_id[i][0];
106
107
0
            pos[i] = p0[seq_id];
108
109
            // update the starting position for all sequences that are assigned to the this token
110
0
            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
111
0
                const llama_seq_id seq_id = batch.seq_id[i][s];
112
113
0
                p0[seq_id] = pos[i] + 1;
114
0
            }
115
0
        }
116
117
0
        batch.pos = pos.data();
118
0
    }
119
120
0
    if (!batch.logits) {
121
0
        if (output_all) {
122
            // return the output for all tokens
123
0
            output.resize(batch.n_tokens, true);
124
0
        } else {
125
            // return the output only for the last token
126
0
            output.resize(batch.n_tokens, false);
127
0
            output[output.size() - 1] = true;
128
0
        }
129
130
0
        batch.logits = output.data();
131
0
    } else if (output_all) {
132
0
        bool warn = false;
133
134
0
        for (int32_t i = 0; i < batch.n_tokens; ++i) {
135
0
            if (batch.logits[i] == 0) {
136
0
                warn = true;
137
0
            }
138
0
        }
139
140
0
        if (warn) {
141
0
            LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
142
143
0
            output.resize(batch.n_tokens, true);
144
0
            batch.logits = output.data();
145
0
        }
146
0
    }
147
148
    //
149
    // compute stats
150
    //
151
152
0
    this->n_embd    = n_embd;
153
0
    this->n_seq_max = n_seq_max;
154
155
    // count the outputs in this batch
156
0
    for (int32_t i = 0; i < batch.n_tokens; ++i) {
157
0
        n_outputs += batch.logits[i] != 0;
158
0
    }
159
160
0
    has_cpl = false;
161
162
    // determine coupled sequences
163
    // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
164
0
    for (int32_t i = 0; i < batch.n_tokens; ++i) {
165
0
        const llama_seq_id s0 = batch.seq_id[i][0];
166
167
0
        for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
168
0
            const llama_seq_id s1 = batch.seq_id[i][s];
169
170
0
            seq_pos[s1].insert(batch.pos[i]);
171
172
0
            if (s > 0) {
173
                // mark that sequence s1 is coupled to s0
174
0
                seq_cpl[s1][s0] = true;
175
176
                // note: tracking the other way around is not necessary for now
177
                //seq_cpl[s0][s1] = true;
178
179
0
                has_cpl = true;
180
0
            }
181
0
        }
182
0
    }
183
184
    // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
185
0
    {
186
0
        seq_set_t seq_set_unq;
187
188
0
        for (int32_t i = 0; i < batch.n_tokens; ++i) {
189
0
            seq_set_t cur;
190
0
            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
191
0
                const llama_seq_id seq_id = batch.seq_id[i][s];
192
193
0
                cur        .set(seq_id);
194
0
                seq_set_unq.set(seq_id);
195
0
            }
196
197
0
            seq_set.push_back(cur);
198
0
            seq_set_map[cur].push_back(i);
199
0
        }
200
201
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
202
0
            if (seq_set_unq.test(s)) {
203
0
                seq_idx[s] = seq_id_unq.size();
204
0
                seq_id_unq.push_back(s);
205
0
            }
206
0
        }
207
0
    }
208
209
0
    if (debug > 0) {
210
0
        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211
212
0
        llama_ubatch ubatch {
213
0
            /*.b_equal_seqs =*/ false,
214
0
            /*.n_tokens     =*/ (uint32_t) batch.n_tokens,
215
0
            /*.n_seq_tokens =*/ (uint32_t) 1,
216
0
            /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
217
0
            /*.n_seqs_unq   =*/ (uint32_t) this->seq_id_unq.size(),
218
0
            /*.n_pos        =*/ n_pos_per_embd,
219
0
            /*.token        =*/ batch.token,
220
0
            /*.embd         =*/ batch.embd,
221
0
            /*.pos          =*/ batch.pos,
222
0
            /*.n_seq_id     =*/ batch.n_seq_id,
223
0
            /*.seq_id       =*/ batch.seq_id,
224
0
            /*.seq_id_unq   =*/ this->seq_id_unq.data(),
225
0
            /*.seq_idx      =*/ this->seq_idx.data(),
226
0
            /*.output       =*/ batch.logits,
227
0
            /*.data         =*/ {},
228
0
        };
229
230
0
        ubatch_print(ubatch, debug);
231
232
0
        LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
233
0
        for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
234
0
            if (seq_pos[s0].empty()) {
235
0
                continue;
236
0
            }
237
238
0
            std::stringstream ss;
239
0
            for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
240
0
                if (seq_cpl[s0][s1]) {
241
0
                    ss << s1 << " ";
242
0
                }
243
0
            }
244
245
0
            LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
246
0
                    __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
247
0
        }
248
0
        LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
249
0
    }
250
251
    //
252
    // consistency checks
253
    //
254
255
0
    if (n_pos_per_embd > 1) {
256
        // M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
258
0
            if (seq_pos[s].empty()) {
259
0
                continue;
260
0
            }
261
262
0
            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263
264
0
            if (batch.token) {
265
0
                if (p0 >= 0 && p0 >= seq_pos_min(s)) {
266
0
                    LLAMA_LOG_ERROR(
267
0
                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
268
0
                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
269
0
                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
270
0
                            " for M-RoPE, it is required that the position satisfies: X < Y\n",
271
0
                            __func__, s, s, p0, s, seq_pos_min(s));
272
273
0
                    return false;
274
0
                }
275
0
            } else {
276
                // embedding inputs can have overlapping positions
277
0
                if (p0 >= 0 && p0 > seq_pos_min(s)) {
278
0
                    LLAMA_LOG_ERROR(
279
0
                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
280
0
                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
281
0
                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
282
0
                            " for M-RoPE, it is required that the position satisfies: X <= Y\n",
283
0
                            __func__, s, s, p0, s, seq_pos_min(s));
284
285
0
                    return false;
286
0
                }
287
0
            }
288
0
        }
289
0
    } else {
290
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
291
0
            if (seq_pos[s].empty()) {
292
0
                continue;
293
0
            }
294
295
0
            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
296
297
0
            if (p0 >= 0) {
298
0
                bool ok = true;
299
300
0
                if (seq_pos_min(s) != p0 + 1) {
301
0
                    ok = false;
302
0
                }
303
304
0
                if (!ok) {
305
0
                    LLAMA_LOG_ERROR(
306
0
                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307
0
                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308
0
                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309
0
                            " it is required that the sequence positions remain consecutive: Y = X + 1\n",
310
0
                            __func__, s, s, p0, s, seq_pos_min(s));
311
312
0
                    return false;
313
0
                }
314
0
            }
315
316
0
            if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
317
0
                LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
318
0
                return false;
319
0
            }
320
0
        }
321
0
    }
322
323
0
    if (memory) {
324
0
        for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
325
0
            for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
326
0
                if (seq_cpl[s0][s1]) {
327
0
                    if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
328
0
                        memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
329
0
                        LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
330
0
                        return false;
331
0
                    }
332
0
                }
333
0
            }
334
0
        }
335
0
    }
336
337
    // disallow partial sequence sub-sets:
338
    //
339
    // invalid:          x
340
    //            i: 0 1 2 ...
341
    // ---------------------------------------
342
    // seq_id[i][0]: 0 0 1
343
    // seq_id[i][1]: 1 1 2
344
    // seq_id[i][2]: 2
345
    //
346
    // disallow decreasing sequence positions:
347
    //
348
    // invalid:                  x
349
    //            i: 0 1 2 3 4 5 6 ...
350
    // ---------------------------------------
351
    //       pos[i]: 4 5 0 1 6 2 3
352
    // seq_id[i][0]: 0 0 1 1 0 1 0
353
    //
354
0
    {
355
0
        seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
356
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
357
0
            cur_seq_set[s].set();
358
0
        }
359
360
0
        llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
361
0
        for (uint32_t s = 0; s < n_seq_max; ++s) {
362
0
            cur_seq_pos[s] = -1;
363
0
        }
364
365
0
        for (int32_t i = 0; i < batch.n_tokens; ++i) {
366
0
            const llama_pos pos = batch.pos[i];
367
368
0
            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
369
0
                const llama_seq_id seq_id = batch.seq_id[i][s];
370
371
0
                cur_seq_set[seq_id] &= seq_set[i];
372
373
0
                if (cur_seq_set[seq_id].none()) {
374
0
                    LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
375
0
                    return false;
376
0
                }
377
378
0
                if (pos < cur_seq_pos[seq_id]) {
379
0
                    LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
380
0
                    return false;
381
0
                }
382
0
            }
383
0
        }
384
0
    }
385
386
0
    split_reset();
387
388
0
    return true;
389
0
}
390
391
0
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
392
0
    const uint32_t n_tokens = n_seq_tokens*n_seqs;
393
394
0
    clear();
395
0
    split_reset();
396
397
0
    auto udata = std::make_shared<llama_ubatch::data_t>();
398
399
0
    udata->token     .resize(n_tokens);
400
0
    udata->embd      .clear();
401
0
    udata->pos       .resize(n_tokens);
402
0
    udata->n_seq_id  .resize(n_tokens);
403
0
    udata->seq_id    .resize(n_tokens);
404
0
    udata->seq_id_unq.resize(0);
405
0
    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
406
0
    udata->output    .resize(n_tokens);
407
408
0
    for (uint32_t s = 0; s < n_seqs; ++s) {
409
0
        udata->seq_idx[s] = s;
410
0
        udata->seq_id_unq.push_back(s);
411
0
    }
412
413
0
    llama_ubatch res {
414
0
        /*.b_equal_seqs =*/ true,
415
0
        /*.n_tokens     =*/ n_tokens,
416
0
        /*.n_seq_tokens =*/ n_seq_tokens,
417
0
        /*.n_seqs       =*/ n_seqs,
418
0
        /*.n_seqs_unq   =*/ n_seqs,
419
0
        /*.n_pos        =*/ n_pos_per_embd,
420
421
        /*.token        =*/ udata->token.data(),
422
0
        /*.embd         =*/ nullptr,
423
0
        /*.pos          =*/ udata->pos.data(),
424
0
        /*.n_seq_id     =*/ udata->n_seq_id.data(),
425
0
        /*.seq_id       =*/ udata->seq_id.data(),
426
0
        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
427
0
        /*.seq_idx      =*/ udata->seq_idx.data(),
428
0
        /*.output       =*/ udata->output.data(),
429
0
        /*.data         =*/ std::move(udata),
430
0
    };
431
432
0
    return res;
433
0
}
434
435
0
const llama_batch & llama_batch_allocr::get_batch() const {
436
0
    return batch;
437
0
}
438
439
0
uint32_t llama_batch_allocr::get_n_tokens() const {
440
0
    return batch.n_tokens;
441
0
}
442
443
0
uint32_t llama_batch_allocr::get_n_outputs() const {
444
0
    return n_outputs;
445
0
}
446
447
0
uint32_t llama_batch_allocr::get_n_used() const {
448
0
    return n_used;
449
0
}
450
451
0
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
452
0
    return out_ids;
453
0
}
454
455
0
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
456
0
    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
457
0
}
458
459
0
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
460
0
    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
461
0
}
462
463
0
void llama_batch_allocr::split_reset() {
464
0
    out_ids.clear();
465
466
0
    n_used = 0;
467
468
0
    used.clear();
469
0
    used.resize(get_n_tokens(), false);
470
0
}
471
472
0
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
473
    // find the first unused token
474
0
    uint32_t cur_idx = 0;
475
0
    while (cur_idx < used.size() && used[cur_idx]) {
476
0
        ++cur_idx;
477
0
    }
478
479
    // we are done
480
0
    if (cur_idx >= used.size()) {
481
0
        return {};
482
0
    }
483
484
0
    std::vector<int32_t> idxs;
485
486
0
    while (true) {
487
0
        idxs.push_back(cur_idx);
488
489
0
        used[cur_idx] = true;
490
0
        ++n_used;
491
492
0
        ++cur_idx;
493
494
0
        if (cur_idx >= used.size()) {
495
0
            break;
496
0
        }
497
498
0
        if (idxs.size() >= n_ubatch) {
499
0
            break;
500
0
        }
501
0
    }
502
503
0
    return ubatch_add(idxs, idxs.size(), false);
504
0
}
505
506
0
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
507
0
    if (sequential && has_cpl) {
508
0
        LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
509
510
0
        return {};
511
0
    }
512
513
0
    std::vector<seq_set_t> cur_seq_set;
514
515
0
    llama_seq_id last_seq_id = -1;
516
517
    // determine the non-overlapping sequence sets participating in this ubatch
518
0
    for (int32_t i = 0; i < batch.n_tokens; ++i) {
519
0
        if (used[i]) {
520
0
            continue;
521
0
        }
522
523
0
        bool add = true;
524
525
0
        for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
526
            // no overlap with existing sequence sets:
527
0
            if (!(cur_seq_set[s] & seq_set[i]).none()) {
528
0
                add = false;
529
0
                break;
530
0
            }
531
0
        }
532
533
        // accept only increasing sequence ids
534
0
        if (sequential) {
535
0
            add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
536
0
        }
537
538
0
        if (add) {
539
0
            cur_seq_set.push_back(seq_set[i]);
540
541
0
            last_seq_id = batch.seq_id[i][0];
542
543
0
            if (cur_seq_set.size() > n_ubatch) {
544
0
                break;
545
0
            }
546
0
        }
547
0
    }
548
549
0
    const uint32_t n_seqs = cur_seq_set.size();
550
551
    // we are done
552
0
    if (n_seqs == 0) {
553
0
        return {};
554
0
    }
555
556
    // the current batch index of each sequence set
557
0
    std::vector<int32_t> cur_idx(n_seqs, 0);
558
559
0
    for (uint32_t s = 0; s < n_seqs; ++s) {
560
0
        while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
561
0
            ++cur_idx[s];
562
0
        }
563
0
    }
564
565
    // the list of batch indices for each sequence set
566
    // at the end we will concat these to get the final ubatch
567
0
    std::vector<idx_vec_t> idxs_per_seq(n_seqs);
568
569
0
    while (true) {
570
        // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
571
        //   if we haven't reached n_ubatch
572
0
        bool can_expand = true;
573
574
0
        for (uint32_t s = 0; s < n_seqs; ++s) {
575
0
            if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
576
0
                can_expand = false;
577
0
                break;
578
0
            }
579
0
        }
580
581
0
        if (!can_expand) {
582
0
            break;
583
0
        }
584
585
0
        for (uint32_t s = 0; s < n_seqs; ++s) {
586
0
            const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
587
588
0
            idxs_per_seq[s].push_back(idx);
589
590
0
            used[idx] = true;
591
0
            ++n_used;
592
593
0
            ++cur_idx[s];
594
0
        }
595
596
0
        if  ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
597
0
            break;
598
0
        }
599
0
    }
600
601
    // concat the per-sequence-set lists
602
0
    std::vector<int32_t> idxs;
603
604
0
    for (uint32_t s = 0; s < n_seqs; ++s) {
605
0
        idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
606
0
    }
607
608
0
    return ubatch_add(idxs, n_seqs, true);
609
0
}
610
611
0
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
612
    // find the first unused token
613
0
    uint32_t cur_idx = 0;
614
0
    while (cur_idx < used.size() && used[cur_idx]) {
615
0
        ++cur_idx;
616
0
    }
617
618
    // we are done
619
0
    if (cur_idx >= used.size()) {
620
0
        return {};
621
0
    }
622
623
    // this is the starting sequence set
624
    // we allow adding tokens only if their sequence set is a subset of the current sequence set
625
0
    auto cur_seq_set = seq_set[cur_idx];
626
627
0
    std::vector<int32_t> idxs;
628
629
0
    while (true) {
630
0
        idxs.push_back(cur_idx);
631
632
0
        used[cur_idx] = true;
633
0
        ++n_used;
634
635
0
        if (idxs.size() >= n_ubatch) {
636
0
            break;
637
0
        }
638
639
0
        do {
640
0
            ++cur_idx;
641
0
        } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
642
643
0
        if (cur_idx == get_n_tokens()) {
644
0
            break;
645
0
        }
646
647
0
        cur_seq_set = seq_set[cur_idx];
648
0
    }
649
650
0
    return ubatch_add(idxs, 1, true);
651
0
}
652
653
0
void llama_batch_allocr::clear() {
654
0
    n_outputs = 0;
655
656
0
    batch = {};
657
658
0
    pos       .clear();
659
0
    n_seq_id  .clear();
660
0
    seq_id    .clear();
661
0
    seq_id_unq.clear();
662
0
    output    .clear();
663
664
0
    for (auto & cur : seq_pos) {
665
0
        cur.clear();
666
0
    }
667
668
0
    for (auto & cur : seq_cpl) {
669
0
        std::fill(cur.begin(), cur.end(), false);
670
0
    }
671
672
0
    seq_set.clear();
673
674
0
    seq_set_map.clear();
675
676
0
    std::fill(seq_idx.begin(), seq_idx.end(), -1);
677
0
}
678
679
0
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
680
0
    const uint32_t n_tokens = idxs.size();
681
682
0
    assert(n_tokens%n_seqs == 0);
683
684
0
    auto udata = std::make_shared<llama_ubatch::data_t>();
685
686
0
    const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
687
0
    const int64_t n_pos_all  =              (int64_t) n_tokens*n_pos_per_embd;
688
689
0
    udata->token     .resize(n_tokens);
690
0
    udata->embd      .resize(n_embd_all);
691
0
    udata->pos       .resize(n_pos_all);
692
0
    udata->n_seq_id  .resize(n_tokens);
693
0
    udata->seq_id    .resize(n_tokens);
694
0
    udata->seq_id_unq.resize(0);
695
0
    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
696
0
    udata->output    .resize(n_tokens);
697
698
0
    udata->seq_id_data.reserve(n_tokens);
699
700
0
    seq_set_t seq_set_unq;
701
702
0
    for (size_t i = 0; i < idxs.size(); ++i) {
703
0
        if (batch.token) {
704
0
            udata->token[i] = batch.token[idxs[i]];
705
0
        }
706
707
0
        if (batch.embd) {
708
0
            memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
709
0
        }
710
711
0
        for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
712
            // if we are using M-RoPE
713
            //     if the current batch is text, we need to broadcast the same position across all RoPE sections
714
            //     otherwise, the input batch is image embeddings, we copy the positions as-is
715
            // if we are not using M-RoPE, there is only one position per token (this loop runs only once)
716
0
            size_t src_off = batch.token ? 0 : j*batch.n_tokens;
717
0
            udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
718
0
        }
719
720
0
        udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
721
0
        udata->output[i]   = batch.logits[idxs[i]];
722
723
0
        for (int s = 0; s < udata->n_seq_id[i]; ++s) {
724
0
            const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
725
726
0
            udata->seq_id_data.push_back(seq_id);
727
0
            seq_set_unq.set(seq_id);
728
0
        }
729
730
0
        if (udata->output[i]) {
731
0
            out_ids.push_back(idxs[i]);
732
0
        }
733
0
    }
734
735
0
    llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
736
0
    for (size_t i = 0; i < idxs.size(); ++i) {
737
0
        udata->seq_id[i] = seq_id_ptr;
738
0
        seq_id_ptr += udata->n_seq_id[i];
739
0
    }
740
741
0
    for (uint32_t s = 0; s < n_seq_max; ++s) {
742
0
        if (seq_set_unq.test(s)) {
743
0
            udata->seq_idx[s] = udata->seq_id_unq.size();
744
0
            udata->seq_id_unq.push_back(s);
745
0
        }
746
0
    }
747
748
0
    llama_ubatch res {
749
0
        /*.b_equal_seqs =*/ equal_seqs,
750
0
        /*.n_tokens     =*/ n_tokens,
751
0
        /*.n_seq_tokens =*/ n_tokens/n_seqs,
752
0
        /*.n_seqs       =*/ n_seqs,
753
0
        /*.n_seqs_unq   =*/ (uint32_t) udata->seq_id_unq.size(),
754
0
        /*.n_pos        =*/ n_pos_per_embd,
755
756
0
        /*.token        =*/ batch.token ? udata->token.data() : nullptr,
757
0
        /*.embd         =*/ batch.embd ? udata->embd.data() : nullptr,
758
0
        /*.pos          =*/ udata->pos.data(),
759
0
        /*.n_seq_id     =*/ udata->n_seq_id.data(),
760
0
        /*.seq_id       =*/ udata->seq_id.data(),
761
0
        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
762
0
        /*.seq_idx      =*/ udata->seq_idx.data(),
763
0
        /*.output       =*/ udata->output.data(),
764
0
        /*.data         =*/ std::move(udata),
765
0
    };
766
767
0
    if (debug > 0) {
768
0
        LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
769
770
0
        ubatch_print(res, debug);
771
0
    }
772
773
0
    return res;
774
0
}
775
776
0
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
777
0
    if (debug > 0) {
778
0
        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs());
779
0
        LLAMA_LOG_DEBUG("%s:   n_tokens     = %d\n", __func__, ubatch.n_tokens);
780
0
        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
781
0
        LLAMA_LOG_DEBUG("%s:   n_seqs       = %d\n", __func__, ubatch.n_seqs);
782
0
        LLAMA_LOG_DEBUG("%s:   n_seqs_unq   = %d\n", __func__, ubatch.n_seqs_unq);
783
784
0
        std::stringstream ss_seq_id_unq;
785
0
        std::stringstream ss_seq_idx;
786
787
0
        ss_seq_id_unq << "[ ";
788
0
        ss_seq_idx << "[";
789
790
0
        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
791
0
            ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
792
0
        }
793
794
0
        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
795
0
            if (ubatch.seq_idx[s] >= 0) {
796
0
                ss_seq_idx << ubatch.seq_idx[s]%10;
797
0
            } else {
798
0
                ss_seq_idx << ".";
799
0
            }
800
0
        }
801
802
0
        ss_seq_id_unq << "]";
803
0
        ss_seq_idx    << "]";
804
805
0
        LLAMA_LOG_DEBUG("%s:   token      = %p\n", __func__, (void *) ubatch.token);
806
0
        LLAMA_LOG_DEBUG("%s:   embd       = %p\n", __func__, (void *) ubatch.embd);
807
0
        LLAMA_LOG_DEBUG("%s:   pos        = %p\n", __func__, (void *) ubatch.pos);
808
0
        LLAMA_LOG_DEBUG("%s:   n_seq_id   = %p\n", __func__, (void *) ubatch.n_seq_id);
809
0
        LLAMA_LOG_DEBUG("%s:   seq_id     = %p\n", __func__, (void *) ubatch.seq_id);
810
0
        LLAMA_LOG_DEBUG("%s:   seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
811
0
        LLAMA_LOG_DEBUG("%s:   seq_idx    = %s\n", __func__, ss_seq_idx.str().c_str());
812
0
        LLAMA_LOG_DEBUG("%s:   output     = %p\n", __func__, (void *) ubatch.output);
813
0
        LLAMA_LOG_DEBUG("%s:   n_outputs  = %d\n", __func__, n_outputs);
814
815
0
        if (debug > 1) {
816
0
            int seq_id_max = 0;
817
0
            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
818
0
                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
819
0
                    for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
820
0
                        seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
821
0
                    }
822
0
                }
823
0
            }
824
0
            ++seq_id_max;
825
826
0
            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);
827
0
            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
828
0
                std::vector<int8_t> seq_id(seq_id_max);
829
830
0
                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
831
0
                    seq_id[ubatch.seq_id[i][s]] = 1;
832
0
                }
833
834
0
                std::stringstream ss;
835
0
                for (int s = 0; s < seq_id_max; ++s) {
836
0
                    if (seq_id[s]) {
837
0
                        ss << s%10;
838
0
                    } else {
839
0
                        ss << ".";
840
0
                    }
841
0
                }
842
843
0
                if (ubatch.token) {
844
0
                    LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
845
0
                            __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
846
0
                            ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
847
0
                } else {
848
0
                    LLAMA_LOG_DEBUG("%s:  %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
849
0
                            __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
850
0
                }
851
0
            }
852
0
            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
853
0
        }
854
0
    }
855
0
}
856
857
//
858
// interface implementation
859
//
860
861
struct llama_batch llama_batch_get_one(
862
             llama_token * tokens,
863
0
                 int32_t   n_tokens) {
864
0
    return {
865
0
        /*n_tokens =*/ n_tokens,
866
0
        /*tokens   =*/ tokens,
867
0
        /*embd     =*/ nullptr,
868
0
        /*pos      =*/ nullptr,
869
0
        /*n_seq_id =*/ nullptr,
870
0
        /*seq_id   =*/ nullptr,
871
0
        /*logits   =*/ nullptr,
872
0
    };
873
0
}
874
875
0
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
876
0
    llama_batch batch = {
877
0
        /*n_tokens =*/ 0,
878
0
        /*tokens   =*/ nullptr,
879
0
        /*embd     =*/ nullptr,
880
0
        /*pos      =*/ nullptr,
881
0
        /*n_seq_id =*/ nullptr,
882
0
        /*seq_id   =*/ nullptr,
883
0
        /*logits   =*/ nullptr,
884
0
    };
885
886
0
    if (embd) {
887
0
        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
888
0
    } else {
889
0
        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
890
0
    }
891
892
0
    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
893
0
    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
894
0
    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
895
0
    for (int i = 0; i < n_tokens_alloc; ++i) {
896
0
        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
897
0
    }
898
0
    batch.seq_id[n_tokens_alloc] = nullptr;
899
900
0
    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
901
902
0
    return batch;
903
0
}
904
905
0
void llama_batch_free(struct llama_batch batch) {
906
0
    if (batch.token)    free(batch.token);
907
0
    if (batch.embd)     free(batch.embd);
908
0
    if (batch.pos)      free(batch.pos);
909
0
    if (batch.n_seq_id) free(batch.n_seq_id);
910
0
    if (batch.seq_id) {
911
0
        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
912
0
            free(batch.seq_id[i]);
913
0
        }
914
0
        free(batch.seq_id);
915
0
    }
916
0
    if (batch.logits)   free(batch.logits);
917
0
}