Coverage Report

Created: 2025-12-14 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/llama-sampling.cpp
Line
Count
Source
1
#include "llama-sampling.h"
2
3
#include "llama-impl.h"
4
#include "llama-vocab.h"
5
#include "llama-grammar.h"
6
7
#include <array>
8
#include <algorithm>
9
#include <cassert>
10
#include <cfloat>
11
#include <chrono>
12
#include <cmath>
13
#include <cstdlib>
14
#include <cstring>
15
#include <ctime>
16
#include <numeric>
17
#include <random>
18
#include <unordered_map>
19
#include <stdexcept>
20
21
// the ring buffer works similarly to std::deque, but with a fixed capacity
22
template<typename T>
23
struct ring_buffer {
24
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
25
26
0
    T & front() {
27
0
        if (sz == 0) {
28
0
            throw std::runtime_error("ring buffer is empty");
29
0
        }
30
0
        return data[first];
31
0
    }
32
33
    const T & front() const {
34
        if (sz == 0) {
35
            throw std::runtime_error("ring buffer is empty");
36
        }
37
        return data[first];
38
    }
39
40
    T & back() {
41
        if (sz == 0) {
42
            throw std::runtime_error("ring buffer is empty");
43
        }
44
        return data[pos];
45
    }
46
47
    const T & back() const {
48
        if (sz == 0) {
49
            throw std::runtime_error("ring buffer is empty");
50
        }
51
        return data[pos];
52
    }
53
54
    void push_back(const T & value) {
55
        if (capacity == 0) {
56
            throw std::runtime_error("ring buffer: capacity is zero");
57
        }
58
59
        if (sz == capacity) {
60
            // advance the start when buffer is full
61
            first = (first + 1) % capacity;
62
        } else {
63
            sz++;
64
        }
65
        data[pos] = value;
66
        pos = (pos + 1) % capacity;
67
    }
68
69
    T pop_front() {
70
        if (sz == 0) {
71
            throw std::runtime_error("ring buffer is empty");
72
        }
73
        T value = data[first];
74
        first = (first + 1) % capacity;
75
        sz--;
76
        return value;
77
    }
78
79
    //T & operator[](size_t i) {
80
    //    if (i >= sz) {
81
    //        throw std::runtime_error("ring buffer: index out of bounds");
82
    //    }
83
    //    return data[(first + i) % capacity];
84
    //}
85
86
    //const T & at(size_t i) const {
87
    //    if (i >= sz) {
88
    //        throw std::runtime_error("ring buffer: index out of bounds");
89
    //    }
90
    //    return data[(first + i) % capacity];
91
    //}
92
93
    const T & rat(size_t i) const {
94
        if (i >= sz) {
95
            throw std::runtime_error("ring buffer: index out of bounds");
96
        }
97
        return data[(first + sz - i - 1) % capacity];
98
    }
99
100
    std::vector<T> to_vector() const {
101
        std::vector<T> result;
102
        result.reserve(sz);
103
        for (size_t i = 0; i < sz; i++) {
104
            result.push_back(data[(first + i) % capacity]);
105
        }
106
        return result;
107
    }
108
109
    void clear() {
110
        // here only reset the status of the buffer
111
        sz = 0;
112
        first = 0;
113
        pos = 0;
114
    }
115
116
    bool empty() const {
117
        return sz == 0;
118
    }
119
120
    size_t size() const {
121
        return sz;
122
    }
123
124
    size_t capacity = 0;
125
    size_t sz = 0;
126
    size_t first = 0;
127
    size_t pos = 0;
128
129
    std::vector<T> data;
130
};
131
132
// writes result in res, does not mutate cur
133
0
static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
134
0
    static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
135
0
        return a.logit > b.logit;
136
0
    };
137
138
0
    constexpr int   nbuckets     = 128;
139
0
    constexpr float bucket_low   = -10.0f;
140
0
    constexpr float bucket_high  =  10.0f;
141
0
    constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
142
0
    constexpr float bucket_inter = -bucket_low * bucket_scale;
143
144
0
    std::vector<int> bucket_idx;
145
0
    std::vector<int> histo(nbuckets, 0);
146
147
0
    std::vector<llama_token_data*> bucket_ptrs;
148
149
0
    bucket_idx.reserve(cur.size);
150
151
0
    for (int i = 0; i < (int)cur.size; ++i) {
152
0
        const float val = cur.data[i].logit;
153
0
        int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
154
0
        ib = std::max(0, std::min(nbuckets - 1, ib));
155
0
        bucket_idx.push_back(ib);
156
0
        ++histo[ib];
157
0
    }
158
0
    int nhave = 0;
159
0
    int ib = nbuckets - 1;
160
0
    for ( ; ib >= 0; --ib) {
161
0
        nhave += histo[ib];
162
0
        if (nhave >= npartial) {
163
0
            break;
164
0
        }
165
0
    }
166
0
    res.resize(nhave);
167
0
    auto * ptr = res.data();
168
0
    bucket_ptrs.reserve(nbuckets - ib);
169
0
    for (int j = nbuckets - 1; j >= ib; --j) {
170
0
        bucket_ptrs.push_back(ptr);
171
0
        ptr += histo[j];
172
0
    }
173
0
    for (int i = 0; i < (int)cur.size; ++i) {
174
0
        int j = bucket_idx[i];
175
0
        if (j >= ib) {
176
0
            *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
177
0
        }
178
0
    }
179
180
0
    ptr = res.data();
181
0
    int ndone = 0;
182
0
    for (int j = nbuckets - 1; j > ib; --j) {
183
0
        std::sort(ptr, ptr + histo[j], comp);
184
0
        ptr += histo[j];
185
0
        ndone += histo[j];
186
0
    }
187
0
    std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
188
0
}
189
190
// reduces the size of cur_p to npartial, keeping only the top npartial elements
191
0
static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
192
0
    static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
193
0
        return a.logit > b.logit;
194
0
    };
195
196
0
    if (npartial <= 128) {
197
0
        std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
198
199
0
        cur_p->size = npartial;
200
0
        cur_p->sorted = true;
201
202
0
        return;
203
0
    }
204
205
0
    std::vector<llama_token_data> tmp;
206
207
0
    llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
208
209
0
    std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
210
211
0
    cur_p->size = npartial;
212
0
    cur_p->sorted = true;
213
0
}
214
215
0
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
216
    // iterator for the probabilities
217
0
#ifdef __GNUC__
218
0
    #pragma GCC diagnostic push
219
0
    #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
220
0
#endif
221
222
0
    struct probs_iterator {
223
0
        typedef std::input_iterator_tag iterator_category;
224
0
        typedef float value_type;
225
0
        typedef float * pointer;
226
0
        typedef float & reference;
227
0
        typedef ptrdiff_t difference_type;
228
229
0
        const llama_token_data * data;
230
231
0
        bool operator==(const probs_iterator & other) const { return data == other.data; }
232
0
        bool operator!=(const probs_iterator & other) const { return data != other.data; }
233
0
        const float & operator*() const { return data->p; }
234
0
        probs_iterator & operator++() { ++data; return *this; }
235
0
        probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
236
0
    };
237
238
0
#ifdef __GNUC__
239
0
    #pragma GCC diagnostic pop
240
0
#endif
241
242
0
    std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
243
244
0
    return dist(rng);
245
0
}
246
247
/*
248
static void llama_log_softmax(float * array, size_t size) {
249
    float max_l = *std::max_element(array, array + size);
250
    float sum = 0.f;
251
    for (size_t i = 0; i < size; ++i) {
252
        float p = expf(array[i] - max_l);
253
        sum += p;
254
        array[i] = p;
255
    }
256
257
    for (size_t i = 0; i < size; ++i) {
258
        array[i] = logf(array[i] / sum);
259
    }
260
}
261
*/
262
263
0
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
264
0
    if (temp <= 0.0f) {
265
        // find the token with the highest logit and set the rest to -inf
266
0
        size_t max_i = 0;
267
0
        float  max_l = cur_p->data[0].logit;
268
269
0
        for (size_t i = 1; i < cur_p->size; ++i) {
270
0
            if (cur_p->data[i    ].logit > max_l) {
271
0
                cur_p->data[max_i].logit = -INFINITY;
272
0
                max_i = i;
273
0
                max_l = cur_p->data[i].logit;
274
0
            } else {
275
0
                cur_p->data[i].logit = -INFINITY;
276
0
            }
277
0
        }
278
279
0
        return;
280
0
    }
281
282
0
    for (size_t i = 0; i < cur_p->size; ++i) {
283
0
        cur_p->data[i].logit /= temp;
284
0
    }
285
0
}
286
287
0
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
288
0
    GGML_ASSERT(cur_p->size > 0);
289
290
    // Sort the logits in descending order if requested
291
0
    if (do_sort && !cur_p->sorted) {
292
0
        llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
293
0
    }
294
295
0
    float max_l = cur_p->data[0].logit;
296
0
    if (!cur_p->sorted) {
297
0
        for (size_t i = 1; i < cur_p->size; ++i) {
298
0
            max_l = std::max(max_l, cur_p->data[i].logit);
299
0
        }
300
0
    }
301
302
0
    float cum_sum = 0.0f;
303
304
0
    for (size_t i = 0; i < cur_p->size; ++i) {
305
0
        float p = expf(cur_p->data[i].logit - max_l);
306
0
        cur_p->data[i].p = p;
307
0
        cum_sum += p;
308
0
    }
309
310
0
    for (size_t i = 0; i < cur_p->size; ++i) {
311
0
        cur_p->data[i].p /= cum_sum;
312
0
    }
313
0
}
314
315
0
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
316
    // if (k >= (int32_t)cur_p->size) {
317
    //     return;
318
    // }
319
320
0
    if (k <= 0) {
321
0
        return;
322
0
    }
323
324
0
    k = std::min(k, (int) cur_p->size);
325
326
    // Sort scores in descending order
327
0
    if (!cur_p->sorted) {
328
0
        llama_token_data_array_partial_sort_inplace(cur_p, k);
329
0
    }
330
331
0
    cur_p->size = k;
332
0
}
333
334
0
static uint32_t get_rng_seed(uint32_t seed) {
335
0
    if (seed == LLAMA_DEFAULT_SEED) {
336
        // use system clock if std::random_device is not a true RNG
337
0
        static bool is_rd_prng = std::random_device().entropy() == 0;
338
0
        if (is_rd_prng) {
339
0
            return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
340
0
        }
341
0
        std::random_device rd;
342
0
        return rd();
343
0
    }
344
0
    return seed;
345
0
}
346
347
// llama_sampler API
348
349
0
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
350
0
    return new llama_sampler {
351
0
        /* .iface = */ iface,
352
0
        /* .ctx   = */ ctx,
353
0
    };
354
0
}
355
356
0
const char * llama_sampler_name(const struct llama_sampler * smpl) {
357
0
    if (!smpl->iface) {
358
0
        return "(null)";
359
0
    }
360
361
0
    return smpl->iface->name(smpl);
362
0
}
363
364
0
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
365
0
    if (smpl->iface->accept) {
366
0
        smpl->iface->accept(smpl, token);
367
0
    }
368
0
}
369
370
0
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
371
0
    GGML_ASSERT(smpl->iface->apply);
372
0
    smpl->iface->apply(smpl, cur_p);
373
0
}
374
375
0
void llama_sampler_reset(struct llama_sampler * smpl) {
376
0
    if (smpl->iface->reset) {
377
0
        smpl->iface->reset(smpl);
378
0
    }
379
0
}
380
381
0
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
382
0
    if (smpl->iface->clone) {
383
0
        return smpl->iface->clone(smpl);
384
0
    }
385
386
0
    if (smpl->ctx == nullptr) {
387
0
        return llama_sampler_init(
388
0
            /* .iface = */ smpl->iface,
389
0
            /* .ctx   = */ nullptr
390
0
        );
391
0
    }
392
393
0
    GGML_ABORT("the sampler does not support cloning");
394
0
}
395
396
0
void llama_sampler_free(struct llama_sampler * smpl) {
397
0
    if (smpl == nullptr) {
398
0
        return;
399
0
    }
400
401
0
    if (smpl->iface->free) {
402
0
        smpl->iface->free(smpl);
403
0
    }
404
405
0
    delete smpl;
406
0
}
407
408
0
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
409
0
    const auto * logits = llama_get_logits_ith(ctx, idx);
410
411
0
    const llama_model * model = llama_get_model(ctx);
412
0
    const llama_vocab * vocab = llama_model_get_vocab(model);
413
414
0
    const int n_vocab = llama_vocab_n_tokens(vocab);
415
416
    // TODO: do not allocate each time
417
0
    std::vector<llama_token_data> cur;
418
0
    cur.reserve(n_vocab);
419
0
    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
420
0
        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
421
0
    }
422
423
0
    llama_token_data_array cur_p = {
424
0
        /* .data       = */ cur.data(),
425
0
        /* .size       = */ cur.size(),
426
0
        /* .selected   = */ -1,
427
0
        /* .sorted     = */ false,
428
0
    };
429
430
0
    llama_sampler_apply(smpl, &cur_p);
431
432
0
    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
433
434
0
    auto token = cur_p.data[cur_p.selected].id;
435
436
0
    llama_sampler_accept(smpl, token);
437
438
0
    return token;
439
0
}
440
441
// sampler chain
442
443
0
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
444
0
    return "chain";
445
0
}
446
447
0
static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
448
0
    auto * chain = (llama_sampler_chain *) smpl->ctx;
449
450
0
    time_meas tm(chain->t_sample_us, chain->params.no_perf);
451
452
0
    for (auto * smpl : chain->samplers) {
453
0
        llama_sampler_accept(smpl, token);
454
0
    }
455
456
0
    chain->n_sample++;
457
0
}
458
459
0
static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
460
0
    auto * chain = (llama_sampler_chain *) smpl->ctx;
461
462
0
    time_meas tm(chain->t_sample_us, chain->params.no_perf);
463
464
0
    for (auto * smpl : chain->samplers) {
465
0
        llama_sampler_apply(smpl, cur_p);
466
0
    }
467
0
}
468
469
0
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
470
0
    auto * chain = (llama_sampler_chain *) smpl->ctx;
471
472
0
    for (auto * smpl : chain->samplers) {
473
0
        llama_sampler_reset(smpl);
474
0
    }
475
0
}
476
477
0
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
478
0
    const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
479
480
0
    auto * result = llama_sampler_chain_init(chain_src->params);
481
482
0
    for (auto * smpl : chain_src->samplers) {
483
0
        llama_sampler_chain_add(result, llama_sampler_clone(smpl));
484
0
    }
485
486
0
    return result;
487
0
}
488
489
0
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
490
0
    auto * chain = (llama_sampler_chain *) smpl->ctx;
491
492
0
    for (auto * smpl : chain->samplers) {
493
0
        llama_sampler_free(smpl);
494
0
    }
495
496
0
    delete chain;
497
0
}
498
499
static struct llama_sampler_i llama_sampler_chain_i = {
500
    /* .name   = */ llama_sampler_chain_name,
501
    /* .accept = */ llama_sampler_chain_accept,
502
    /* .apply  = */ llama_sampler_chain_apply,
503
    /* .reset  = */ llama_sampler_chain_reset,
504
    /* .clone  = */ llama_sampler_chain_clone,
505
    /* .free   = */ llama_sampler_chain_free,
506
};
507
508
0
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
509
0
    return llama_sampler_init(
510
0
        /* .iface = */ &llama_sampler_chain_i,
511
0
        /* .ctx   = */ new llama_sampler_chain {
512
0
            /* .params      = */ params,
513
0
            /* .samplers    = */ {},
514
0
            /* .t_sample_us = */ 0,
515
0
            /* .n_sample    = */ 0,
516
0
        }
517
0
    );
518
0
}
519
520
0
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
521
0
    auto * p = (llama_sampler_chain *) chain->ctx;
522
0
    p->samplers.push_back(smpl);
523
0
}
524
525
0
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
526
0
    const auto * p = (const llama_sampler_chain *) chain->ctx;
527
528
0
    if (i < 0 || (size_t) i >= p->samplers.size()) {
529
0
        return nullptr;
530
0
    }
531
532
0
    return p->samplers[i];
533
0
}
534
535
0
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
536
0
    auto * p = (llama_sampler_chain *) chain->ctx;
537
538
0
    if (i < 0 || (size_t) i >= p->samplers.size()) {
539
0
        return nullptr;
540
0
    }
541
542
0
    auto * result = p->samplers[i];
543
0
    p->samplers.erase(p->samplers.begin() + i);
544
545
0
    return result;
546
0
}
547
548
0
int llama_sampler_chain_n(const struct llama_sampler * chain) {
549
0
    const auto * p = (const llama_sampler_chain *) chain->ctx;
550
551
0
    return p->samplers.size();
552
0
}
553
554
//
555
// samplers
556
//
557
558
// greedy
559
560
0
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
561
0
    return "greedy";
562
0
}
563
564
0
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
565
0
    cur_p->selected = 0;
566
0
    for (size_t i = 1; i < cur_p->size; ++i) {
567
0
        if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
568
0
            cur_p->selected = i;
569
0
        }
570
0
    }
571
0
}
572
573
static struct llama_sampler_i llama_sampler_greedy_i = {
574
    /* .name   = */ llama_sampler_greedy_name,
575
    /* .accept = */ nullptr,
576
    /* .apply  = */ llama_sampler_greedy_apply,
577
    /* .reset  = */ nullptr,
578
    /* .clone  = */ nullptr,
579
    /* .free   = */ nullptr,
580
};
581
582
0
struct llama_sampler * llama_sampler_init_greedy() {
583
0
    return llama_sampler_init(
584
0
        /* .iface = */ &llama_sampler_greedy_i,
585
0
        /* .ctx   = */ nullptr
586
0
    );
587
0
}
588
589
// dist
590
591
struct llama_sampler_dist {
592
    const uint32_t seed;
593
          uint32_t seed_cur;
594
595
    std::mt19937 rng;
596
};
597
598
0
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
599
0
    return "dist";
600
0
}
601
602
0
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
603
0
    auto * ctx = (llama_sampler_dist *) smpl->ctx;
604
605
    // edge cases
606
0
    if (cur_p->size == 0) {
607
0
        cur_p->selected = -1;
608
0
        return;
609
0
    }
610
611
0
    cur_p->selected = 0;
612
613
0
    if (cur_p->size == 1) {
614
0
        cur_p->data[0].p = 1.0f;
615
0
        return;
616
0
    }
617
618
    // max logit for numerical stability
619
0
    float max_l = cur_p->data[0].logit;
620
0
    if (!cur_p->sorted) {
621
0
        for (size_t i = 1; i < cur_p->size; ++i) {
622
0
            max_l = std::max(max_l, cur_p->data[i].logit);
623
0
        }
624
0
    }
625
626
    // apply softmax to obtain the probabilities
627
0
    double sum_cum = 0.0f;
628
0
    for (size_t i = 0; i < cur_p->size; ++i) {
629
0
        float p = expf(cur_p->data[i].logit - max_l);
630
0
        cur_p->data[i].p = p;
631
0
        sum_cum += p;
632
0
    }
633
634
0
#if 1
635
    // sample from the obtained probabilities and normalize the probs in a single pass
636
    // this is ~3x faster on Mac with full gpt-oss vocab than the version below
637
    //
638
0
    std::uniform_real_distribution<double> dist(0.0f, 1.0f);
639
0
    const double rnd = dist(ctx->rng);
640
641
0
          double sum_run = 0.0f;
642
0
    const double sum_tgt = sum_cum*rnd;
643
644
0
    bool found = false;
645
0
    for (size_t i = 0; i < cur_p->size; ++i) {
646
0
        if (!found) {
647
            // accumulate probs until we reach the target sum
648
0
            sum_run += cur_p->data[i].p;
649
0
            if (sum_run >= sum_tgt) {
650
0
                cur_p->selected = i;
651
0
                found = true;
652
0
            }
653
0
        }
654
655
        // normalize probs
656
0
        cur_p->data[i].p /= sum_cum;
657
0
    }
658
659
    // fallback to the last token (don't think this can happen)
660
0
    assert(found);
661
0
    if (!found) {
662
0
        cur_p->selected = cur_p->size - 1;
663
0
    }
664
#else
665
    // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
666
    for (size_t i = 0; i < cur_p->size; ++i) {
667
        cur_p->data[i].p /= sum_cum;
668
    }
669
670
    cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
671
#endif
672
0
}
673
674
0
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
675
0
    const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
676
0
    auto * result = llama_sampler_init_dist(ctx->seed);
677
678
    // copy the state
679
0
    {
680
0
        auto * result_ctx = (llama_sampler_dist *) result->ctx;
681
682
0
        result_ctx->rng = ctx->rng;
683
0
    }
684
685
0
    return result;
686
0
}
687
688
0
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
689
0
    auto * ctx = (llama_sampler_dist *) smpl->ctx;
690
0
    ctx->seed_cur = get_rng_seed(ctx->seed);
691
0
    ctx->rng.seed(ctx->seed_cur);
692
0
}
693
694
0
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
695
0
    delete (llama_sampler_dist *) smpl->ctx;
696
0
}
697
698
static struct llama_sampler_i llama_sampler_dist_i = {
699
    /* .name   = */ llama_sampler_dist_name,
700
    /* .accept = */ nullptr,
701
    /* .apply  = */ llama_sampler_dist_apply,
702
    /* .reset  = */ llama_sampler_dist_reset,
703
    /* .clone  = */ llama_sampler_dist_clone,
704
    /* .free   = */ llama_sampler_dist_free,
705
};
706
707
0
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
708
0
    auto seed_cur = get_rng_seed(seed);
709
0
    return llama_sampler_init(
710
0
        /* .iface = */ &llama_sampler_dist_i,
711
0
        /* .ctx   = */ new llama_sampler_dist {
712
0
            /* .seed     = */ seed,
713
0
            /* .seed_cur = */ seed_cur,
714
0
            /* .rng      = */ std::mt19937(seed_cur),
715
0
        }
716
0
    );
717
0
}
718
719
// top-k
720
721
struct llama_sampler_top_k {
722
    const int32_t k;
723
};
724
725
0
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
726
0
    return "top-k";
727
0
}
728
729
0
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
730
0
    auto * ctx = (llama_sampler_top_k *) smpl->ctx;
731
0
    llama_sampler_top_k_impl(cur_p, ctx->k);
732
0
}
733
734
0
static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
735
0
    const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
736
0
    return llama_sampler_init_top_k(ctx->k);
737
0
}
738
739
0
static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
740
0
    delete (llama_sampler_top_k *) smpl->ctx;
741
0
}
742
743
static struct llama_sampler_i llama_sampler_top_k_i = {
744
    /* .name   = */ llama_sampler_top_k_name,
745
    /* .accept = */ nullptr,
746
    /* .apply  = */ llama_sampler_top_k_apply,
747
    /* .reset  = */ nullptr,
748
    /* .clone  = */ llama_sampler_top_k_clone,
749
    /* .free   = */ llama_sampler_top_k_free,
750
};
751
752
0
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
753
0
    return llama_sampler_init(
754
0
        /* .iface = */ &llama_sampler_top_k_i,
755
0
        /* .ctx   = */ new llama_sampler_top_k {
756
0
            /* .k = */ k,
757
0
        }
758
0
    );
759
0
}
760
761
// top-p
762
763
struct llama_sampler_top_p {
764
    const float  p;
765
    const size_t min_keep;
766
767
    std::vector<llama_token_data> buf_sort;
768
};
769
770
0
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
771
0
    return "top-p";
772
0
}
773
774
0
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
775
0
    auto * ctx = (llama_sampler_top_p *) smpl->ctx;
776
777
0
    if (ctx->p >= 1.0f) {
778
0
        return;
779
0
    }
780
781
0
    llama_sampler_softmax_impl(cur_p, false);
782
783
0
    size_t k = cur_p->size;
784
0
    auto * pdata = cur_p->data;
785
786
0
    auto & buf_sort = ctx->buf_sort;
787
788
    // if not sorted, try adaptive top-k sorting
789
0
    if (!cur_p->sorted && cur_p->size > 1024) {
790
0
        k = std::min<size_t>(256, cur_p->size);
791
0
        llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
792
0
        pdata = buf_sort.data();
793
0
    } else if (!cur_p->sorted) {
794
        // small candidates -> sort inplace
795
0
        llama_token_data_array_partial_sort_inplace(cur_p, k);
796
0
    }
797
798
    // Compute the cumulative probabilities
799
0
    float cum_sum = 0.0f;
800
0
    size_t last_idx = cur_p->size;
801
802
0
    for (size_t i = 0; i < cur_p->size; ++i) {
803
0
        cum_sum += pdata[i].p;
804
805
        // Check if the running sum is at least p or if we have kept at least min_keep tokens
806
        // we set the last index to i+1 to indicate that the current iterate should be included in the set
807
0
        if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
808
0
            last_idx = i + 1;
809
0
            break;
810
0
        }
811
812
        // we exceeded the current top-k heuristic -> increase k and continue
813
0
        if (!cur_p->sorted && i == k - 1) {
814
0
            k = cur_p->size;
815
0
            llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
816
0
            pdata = buf_sort.data();
817
0
        }
818
0
    }
819
820
    // Resize the output vector to keep only the top-p tokens
821
0
    if (!cur_p->sorted) {
822
0
        std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
823
0
        cur_p->sorted = true;
824
0
    }
825
826
0
    cur_p->size = last_idx;
827
0
}
828
829
0
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
830
0
    const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
831
0
    return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
832
0
}
833
834
0
static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
835
0
    delete (llama_sampler_top_p *) smpl->ctx;
836
0
}
837
838
static struct llama_sampler_i llama_sampler_top_p_i = {
839
    /* .name   = */ llama_sampler_top_p_name,
840
    /* .accept = */ nullptr,
841
    /* .apply  = */ llama_sampler_top_p_apply,
842
    /* .reset  = */ nullptr,
843
    /* .clone  = */ llama_sampler_top_p_clone,
844
    /* .free   = */ llama_sampler_top_p_free,
845
};
846
847
0
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
848
0
    return llama_sampler_init(
849
0
        /* .iface = */ &llama_sampler_top_p_i,
850
0
        /* .ctx   = */ new llama_sampler_top_p {
851
0
            /* .p        = */ p,
852
0
            /* .min_keep = */ min_keep,
853
0
            /* .buf_sort = */ {},
854
0
        }
855
0
    );
856
0
}
857
858
// min-p
859
860
struct llama_sampler_min_p {
861
    const float  p;
862
    const size_t min_keep;
863
};
864
865
0
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
866
0
    return "min-p";
867
0
}
868
869
0
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
870
0
    auto * ctx = (llama_sampler_min_p *) smpl->ctx;
871
872
0
    if (ctx->p <= 0.0f || !cur_p->size) {
873
0
        return;
874
0
    }
875
876
0
    bool min_p_applied = false;
877
878
    // if the cur_p aren't sorted, try the unsorted implementation first
879
0
    if (!cur_p->sorted) {
880
0
        std::vector<llama_token_data> filtered_tokens;
881
882
0
        float max_logit = -FLT_MAX;
883
0
        for (size_t i = 0; i < cur_p->size; ++i) {
884
0
            max_logit = std::max(max_logit, cur_p->data[i].logit);
885
0
        }
886
0
        const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
887
888
0
        for (size_t i = 0; i < cur_p->size; ++i) {
889
0
            if (cur_p->data[i].logit >= min_logit) {
890
0
                filtered_tokens.push_back(cur_p->data[i]);
891
0
            }
892
0
        }
893
894
        // if we have enough values the operation was a success
895
0
        if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
896
0
            std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
897
0
            cur_p->size = filtered_tokens.size();
898
0
            min_p_applied = true;
899
0
        }
900
0
    }
901
902
    // if the cur_p are sorted or the unsorted implementation failed, use this implementation
903
0
    if (!min_p_applied) {
904
        // Sort the logits in descending order
905
0
        if (!cur_p->sorted) {
906
0
            llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
907
0
        }
908
909
0
        const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
910
0
        size_t i = 1; // first token always matches
911
912
0
        for (; i < cur_p->size; ++i) {
913
0
            if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
914
0
                break; // prob too small
915
0
            }
916
0
        }
917
918
        // Resize the output vector to keep only the matching tokens
919
0
        cur_p->size = i;
920
0
    }
921
0
}
922
923
0
static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
924
0
    const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
925
0
    return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
926
0
}
927
928
0
static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
929
0
    delete (llama_sampler_min_p *) smpl->ctx;
930
0
}
931
932
static struct llama_sampler_i llama_sampler_min_p_i = {
933
    /* .name   = */ llama_sampler_min_p_name,
934
    /* .accept = */ nullptr,
935
    /* .apply  = */ llama_sampler_min_p_apply,
936
    /* .reset  = */ nullptr,
937
    /* .clone  = */ llama_sampler_min_p_clone,
938
    /* .free   = */ llama_sampler_min_p_free,
939
};
940
941
0
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
942
0
    return llama_sampler_init(
943
0
        /* .iface = */ &llama_sampler_min_p_i,
944
0
        /* .ctx   = */ new llama_sampler_min_p {
945
0
            /* .p        = */ p,
946
0
            /* .min_keep = */ min_keep,
947
0
        }
948
0
    );
949
0
}
950
951
// typical
952
953
struct llama_sampler_typical {
954
    const float  p;
955
    const size_t min_keep;
956
};
957
958
0
static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
959
0
    return "typical";
960
0
}
961
962
0
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
963
0
    auto * ctx = (llama_sampler_typical *) smpl->ctx;
964
965
    // Reference implementation:
966
    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
967
0
    if (ctx->p >= 1.0f) {
968
0
        return;
969
0
    }
970
971
    // Compute the softmax of logits and calculate entropy
972
0
    llama_sampler_softmax_impl(cur_p, true);
973
974
0
    float entropy = 0.0f;
975
0
    for (size_t i = 0; i < cur_p->size; ++i) {
976
0
        entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
977
0
    }
978
979
    // Compute the absolute difference between negative log probability and entropy for each candidate
980
0
    std::vector<float> shifted_scores;
981
0
    for (size_t i = 0; i < cur_p->size; ++i) {
982
0
        float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
983
0
        shifted_scores.push_back(shifted_score);
984
0
    }
985
986
    // Sort tokens based on the shifted_scores and their corresponding indices
987
0
    std::vector<size_t> indices(cur_p->size);
988
0
    std::iota(indices.begin(), indices.end(), 0);
989
990
0
    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
991
0
        return shifted_scores[a] < shifted_scores[b];
992
0
    });
993
994
    // Compute the cumulative probabilities
995
0
    float cum_sum = 0.0f;
996
0
    size_t last_idx = indices.size();
997
998
0
    for (size_t i = 0; i < indices.size(); ++i) {
999
0
        size_t idx = indices[i];
1000
0
        cum_sum += cur_p->data[idx].p;
1001
1002
        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
1003
0
        if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
1004
0
            last_idx = i + 1;
1005
0
            break;
1006
0
        }
1007
0
    }
1008
1009
    // Resize the output vector to keep only the locally typical tokens
1010
0
    std::vector<llama_token_data> cur_p_new;
1011
0
    for (size_t i = 0; i < last_idx; ++i) {
1012
0
        size_t idx = indices[i];
1013
0
        cur_p_new.push_back(cur_p->data[idx]);
1014
0
    }
1015
1016
    // Replace the data in cur_p with the cur_p_new data
1017
0
    std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
1018
0
    cur_p->size = cur_p_new.size();
1019
0
    cur_p->sorted = false;
1020
0
}
1021
1022
0
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
1023
0
    const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
1024
0
    return llama_sampler_init_typical(ctx->p, ctx->min_keep);
1025
0
}
1026
1027
0
static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1028
0
    delete (llama_sampler_typical *) smpl->ctx;
1029
0
}
1030
1031
static struct llama_sampler_i llama_sampler_typical_i = {
1032
    /* .name   = */ llama_sampler_typical_name,
1033
    /* .accept = */ nullptr,
1034
    /* .apply  = */ llama_sampler_typical_apply,
1035
    /* .reset  = */ nullptr,
1036
    /* .clone  = */ llama_sampler_typical_clone,
1037
    /* .free   = */ llama_sampler_typical_free,
1038
};
1039
1040
0
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1041
0
    return llama_sampler_init(
1042
0
        /* .iface = */ &llama_sampler_typical_i,
1043
0
        /* .ctx   = */ new llama_sampler_typical {
1044
0
            /* .p        = */ p,
1045
0
            /* .min_keep = */ min_keep,
1046
0
        }
1047
0
    );
1048
0
}
1049
1050
// temp
1051
1052
struct llama_sampler_temp {
1053
    const float temp;
1054
};
1055
1056
0
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1057
0
    return "temp";
1058
0
}
1059
1060
0
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1061
0
    const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1062
1063
0
    llama_sampler_temp_impl(cur_p, ctx->temp);
1064
0
}
1065
1066
0
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1067
0
    const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1068
0
    return llama_sampler_init_temp(ctx->temp);
1069
0
}
1070
1071
0
static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1072
0
    delete (llama_sampler_temp *) smpl->ctx;
1073
0
}
1074
1075
static struct llama_sampler_i llama_sampler_temp_i = {
1076
    /* .name   = */ llama_sampler_temp_name,
1077
    /* .accept = */ nullptr,
1078
    /* .apply  = */ llama_sampler_temp_apply,
1079
    /* .reset  = */ nullptr,
1080
    /* .clone  = */ llama_sampler_temp_clone,
1081
    /* .free   = */ llama_sampler_temp_free,
1082
};
1083
1084
0
struct llama_sampler * llama_sampler_init_temp(float temp) {
1085
0
    return llama_sampler_init(
1086
0
        /* .iface = */ &llama_sampler_temp_i,
1087
0
        /* .ctx   = */ new llama_sampler_temp {
1088
0
            /*.temp = */ temp,
1089
0
        }
1090
0
    );
1091
0
}
1092
1093
// temp-ext
1094
1095
struct llama_sampler_temp_ext {
1096
    const float temp;
1097
    const float delta;
1098
    const float exponent;
1099
};
1100
1101
0
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1102
0
    return "temp-ext";
1103
0
}
1104
1105
0
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1106
0
    auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1107
0
    if (ctx->delta > 0) {
1108
0
        const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1109
0
        const float max_temp = ctx->temp + ctx->delta;
1110
1111
0
        float exponent_val = ctx->exponent;
1112
1113
        // no need to do anything if there is only one (or zero) candidates
1114
0
        if (cur_p->size <= 1) {
1115
0
            return;
1116
0
        }
1117
1118
        // Calculate maximum possible entropy
1119
0
        float max_entropy = -logf(1.0f / cur_p->size);
1120
1121
0
        llama_sampler_softmax_impl(cur_p, true);
1122
1123
        // Calculate entropy of the softmax probabilities
1124
0
        float entropy = 0.0f;
1125
0
        for (size_t i = 0; i < cur_p->size; ++i) {
1126
0
            float prob = cur_p->data[i].p;
1127
0
            if (prob > 0.0f) { // Ensure no log(0)
1128
0
                entropy -= prob * logf(prob);
1129
0
            }
1130
0
        }
1131
1132
        // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1133
0
        float normalized_entropy = entropy / max_entropy;
1134
1135
        // Map the normalized entropy to the desired temperature range using the power function
1136
0
        float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1137
1138
    #ifdef DEBUG
1139
        LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1140
        LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1141
        LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1142
        LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1143
        LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1144
        LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1145
    #endif
1146
1147
        // Apply the dynamically calculated temperature scaling
1148
0
        llama_sampler_temp_impl(cur_p, dyn_temp);
1149
1150
        // Re-compute softmax probabilities after scaling logits with dynamic temperature
1151
0
        const double max_l_double = cur_p->data[0].logit;
1152
1153
0
        double cum_sum_double = 0.0;
1154
0
        for (size_t i = 0; i < cur_p->size; ++i) {
1155
0
            double p = exp(cur_p->data[i].logit - max_l_double);
1156
0
            cur_p->data[i].p = p; // Store the scaled probability
1157
0
            cum_sum_double += p;
1158
0
        }
1159
1160
0
        for (size_t i = 0; i < cur_p->size; ++i) {
1161
0
            cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1162
0
        }
1163
1164
    #ifdef DEBUG
1165
        // Print the updated top 25 probabilities after temperature scaling
1166
        LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1167
        for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1168
            LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1169
        }
1170
    #endif
1171
0
    } else {
1172
0
        llama_sampler_temp_impl(cur_p, ctx->temp);
1173
0
    }
1174
0
}
1175
1176
0
static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1177
0
    const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1178
0
    return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1179
0
}
1180
1181
0
static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1182
0
    delete (llama_sampler_temp_ext *) smpl->ctx;
1183
0
}
1184
1185
static struct llama_sampler_i llama_sampler_temp_ext_i = {
1186
    /* .name   = */ llama_sampler_temp_ext_name,
1187
    /* .accept = */ nullptr,
1188
    /* .apply  = */ llama_sampler_temp_ext_apply,
1189
    /* .reset  = */ nullptr,
1190
    /* .clone  = */ llama_sampler_temp_ext_clone,
1191
    /* .free   = */ llama_sampler_temp_ext_free,
1192
};
1193
1194
0
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1195
0
    return llama_sampler_init(
1196
0
        /* .iface = */ &llama_sampler_temp_ext_i,
1197
0
        /* .ctx   = */ new llama_sampler_temp_ext {
1198
0
            /* .temp     = */ temp,
1199
0
            /* .delta    = */ delta,
1200
0
            /* .exponent = */ exponent,
1201
0
        }
1202
0
    );
1203
0
}
1204
1205
// xtc
1206
1207
struct llama_sampler_xtc {
1208
    const float    probability;
1209
    const float    threshold;
1210
    const size_t   min_keep;
1211
1212
    const uint32_t seed;
1213
    uint32_t       seed_cur;
1214
1215
    std::mt19937    rng;
1216
};
1217
1218
0
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1219
0
    return "xtc";
1220
0
}
1221
1222
0
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1223
0
    auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1224
1225
0
    if (ctx->probability <= 0.0f
1226
0
        || ctx->threshold > 0.5f
1227
0
        || cur_p->size < 2) {
1228
0
        return;
1229
0
    }
1230
1231
0
    std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1232
0
    float chance = distribution(ctx->rng);
1233
0
    if (chance > ctx->probability) {
1234
0
        return;
1235
0
    }
1236
1237
0
    llama_sampler_softmax_impl(cur_p, true);
1238
1239
0
    int pos_last = 0;
1240
1241
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1242
0
        if (cur_p->data[i].p >= ctx->threshold) {
1243
0
            pos_last = i;
1244
0
        } else {
1245
0
            break;
1246
0
        }
1247
0
    }
1248
1249
0
    if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1250
0
        cur_p->data += pos_last;
1251
0
        cur_p->size -= pos_last;
1252
0
    }
1253
0
}
1254
1255
0
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1256
0
    const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1257
0
    auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1258
1259
    // copy the state
1260
0
    {
1261
0
        auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1262
1263
0
        result_ctx->rng = ctx->rng;
1264
0
    }
1265
1266
0
    return result;
1267
0
}
1268
1269
0
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1270
0
    delete (llama_sampler_xtc *) smpl->ctx;
1271
0
}
1272
1273
0
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1274
0
    auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1275
0
    ctx->seed_cur = get_rng_seed(ctx->seed);
1276
0
    ctx->rng.seed(ctx->seed_cur);
1277
0
}
1278
1279
static struct llama_sampler_i llama_sampler_xtc_i = {
1280
    /* .name   = */ llama_sampler_xtc_name,
1281
    /* .accept = */ nullptr,
1282
    /* .apply  = */ llama_sample_xtc_apply,
1283
    /* .reset  = */ llama_sampler_xtc_reset,
1284
    /* .clone  = */ llama_sampler_xtc_clone,
1285
    /* .free   = */ llama_sampler_xtc_free,
1286
};
1287
1288
0
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1289
0
    auto seed_cur = get_rng_seed(seed);
1290
0
    return llama_sampler_init(
1291
0
        /* .iface = */ &llama_sampler_xtc_i,
1292
0
        /* .ctx   = */ new llama_sampler_xtc {
1293
0
            /* .probability   = */ p,
1294
0
            /* .threshold     = */ t,
1295
0
            /* .min_keep      = */ min_keep,
1296
0
            /* .seed          = */ seed,
1297
0
            /* .seed_cur      = */ seed_cur,
1298
0
            /* .rng           = */ std::mt19937(seed_cur),
1299
0
        }
1300
0
    );
1301
0
}
1302
1303
// mirostat
1304
1305
struct llama_sampler_mirostat {
1306
    const int32_t n_vocab;
1307
1308
    const uint32_t seed;
1309
          uint32_t seed_cur;
1310
1311
    const float tau;
1312
    const float eta;
1313
1314
    const int32_t m;
1315
1316
    float mu;
1317
1318
    std::mt19937    rng;
1319
};
1320
1321
0
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1322
0
    return "mirostat";
1323
0
}
1324
1325
0
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1326
0
    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1327
1328
0
    llama_sampler_softmax_impl(cur_p, true);
1329
1330
    // Estimate s_hat using the most probable m tokens
1331
0
    float s_hat = 0.0;
1332
0
    float sum_ti_bi = 0.0;
1333
0
    float sum_ti_sq = 0.0;
1334
0
    for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1335
0
        float t_i = logf(float(i + 2) / float(i + 1));
1336
0
        float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1337
0
        sum_ti_bi += t_i * b_i;
1338
0
        sum_ti_sq += t_i * t_i;
1339
0
    }
1340
0
    s_hat = sum_ti_bi / sum_ti_sq;
1341
1342
    // Compute k from the estimated s_hat and target surprise value
1343
0
    float epsilon_hat = s_hat - 1;
1344
0
    float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1345
1346
0
    llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1347
1348
0
    llama_sampler_softmax_impl(cur_p, true);
1349
1350
0
    const int idx = llama_sample_dist(cur_p, ctx->rng);
1351
1352
0
    cur_p->selected = idx;
1353
1354
0
    float observed_surprise = -log2f(cur_p->data[idx].p);
1355
0
    float e = observed_surprise - ctx->tau;
1356
1357
    // Update mu using the learning rate and error
1358
0
    ctx->mu = ctx->mu - ctx->eta * e;
1359
0
}
1360
1361
0
static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1362
0
    const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1363
0
    auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1364
1365
    // copy the state
1366
0
    {
1367
0
        auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1368
1369
0
        result_ctx->mu  = ctx->mu;
1370
0
        result_ctx->rng = ctx->rng;
1371
0
    }
1372
1373
0
    return result;
1374
0
}
1375
1376
0
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1377
0
    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1378
0
    ctx->mu = 2.0f*ctx->tau;
1379
0
    ctx->seed_cur = get_rng_seed(ctx->seed);
1380
0
    ctx->rng.seed(ctx->seed_cur);
1381
0
}
1382
1383
0
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1384
0
    delete (llama_sampler_mirostat *) smpl->ctx;
1385
0
}
1386
1387
static struct llama_sampler_i llama_sampler_mirostat_i = {
1388
    /* .name   = */ llama_sampler_mirostat_name,
1389
    /* .accept = */ nullptr,
1390
    /* .apply  = */ llama_sampler_mirostat_apply,
1391
    /* .reset  = */ llama_sampler_mirostat_reset,
1392
    /* .clone  = */ llama_sampler_mirostat_clone,
1393
    /* .free   = */ llama_sampler_mirostat_free,
1394
};
1395
1396
0
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1397
0
    auto seed_cur = get_rng_seed(seed);
1398
0
    return llama_sampler_init(
1399
0
        /* .iface = */ &llama_sampler_mirostat_i,
1400
0
        /* .ctx   = */ new llama_sampler_mirostat {
1401
0
            /* .n_vocab  = */ n_vocab,
1402
0
            /* .seed     = */ seed,
1403
0
            /* .seed_cur = */ seed_cur,
1404
0
            /* .tau      = */ tau,
1405
0
            /* .eta      = */ eta,
1406
0
            /* .m        = */ m,
1407
0
            /* .mu       = */ 2.0f*tau,
1408
0
            /* .rng      = */ std::mt19937(seed_cur),
1409
0
        }
1410
0
    );
1411
0
}
1412
1413
// mirostat v2
1414
1415
struct llama_sampler_mirostat_v2 {
1416
    const uint32_t seed;
1417
          uint32_t seed_cur;
1418
1419
    const float tau;
1420
    const float eta;
1421
1422
    float mu;
1423
1424
    std::mt19937 rng;
1425
};
1426
1427
0
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1428
0
    return "mirostat-v2";
1429
0
}
1430
1431
0
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1432
0
    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1433
1434
0
    llama_sampler_softmax_impl(cur_p, true);
1435
1436
    // Truncate the words with surprise values greater than mu
1437
0
    cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1438
0
        return -log2f(candidate.p) > ctx->mu;
1439
0
    }));
1440
1441
0
    if (cur_p->size == 0) {
1442
0
        cur_p->size = 1;
1443
0
    }
1444
1445
    // Normalize the probabilities of the remaining words
1446
0
    llama_sampler_softmax_impl(cur_p, true);
1447
1448
0
    const int idx = llama_sample_dist(cur_p, ctx->rng);
1449
1450
0
    cur_p->selected = idx;
1451
1452
0
    float observed_surprise = -log2f(cur_p->data[idx].p);
1453
0
    float e = observed_surprise - ctx->tau;
1454
1455
    // Update mu using the learning rate and error
1456
0
    ctx->mu = ctx->mu - ctx->eta * e;
1457
0
}
1458
1459
0
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1460
0
    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1461
0
    ctx->mu = 2.0f*ctx->tau;
1462
0
    ctx->seed_cur = get_rng_seed(ctx->seed);
1463
0
    ctx->rng.seed(ctx->seed_cur);
1464
0
}
1465
1466
0
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1467
0
    const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1468
1469
0
    auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1470
1471
    // copy the state
1472
0
    {
1473
0
        auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1474
1475
0
        result_ctx->mu  = ctx->mu;
1476
0
        result_ctx->rng = ctx->rng;
1477
0
    }
1478
1479
0
    return result;
1480
0
}
1481
1482
0
static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1483
0
    delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1484
0
}
1485
1486
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1487
    /* .name   = */ llama_sampler_mirostat_v2_name,
1488
    /* .accept = */ nullptr,
1489
    /* .apply  = */ llama_sampler_mirostat_v2_apply,
1490
    /* .reset  = */ llama_sampler_mirostat_v2_reset,
1491
    /* .clone  = */ llama_sampler_mirostat_v2_clone,
1492
    /* .free   = */ llama_sampler_mirostat_v2_free,
1493
};
1494
1495
0
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1496
0
    auto seed_cur = get_rng_seed(seed);
1497
0
    return llama_sampler_init(
1498
0
        /* .iface = */ &llama_sampler_mirostat_v2_i,
1499
0
        /* .ctx   = */ new llama_sampler_mirostat_v2 {
1500
0
            /* .seed     = */ seed,
1501
0
            /* .seed_cur = */ seed_cur,
1502
0
            /* .tau      = */ tau,
1503
0
            /* .eta      = */ eta,
1504
0
            /* .mu       = */ 2.0f*tau,
1505
0
            /* .rng      = */ std::mt19937(seed_cur),
1506
0
        }
1507
0
    );
1508
0
}
1509
1510
// grammar
1511
1512
struct llama_sampler_grammar {
1513
    const struct llama_vocab * vocab;
1514
1515
    std::string grammar_str;
1516
    std::string grammar_root;
1517
1518
    struct llama_grammar * grammar;
1519
};
1520
1521
0
static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1522
0
    return "grammar";
1523
0
}
1524
1525
0
static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1526
0
    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1527
0
    if (ctx->grammar) {
1528
0
        llama_grammar_accept_impl(*ctx->grammar, token);
1529
0
    }
1530
0
}
1531
1532
0
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1533
0
    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1534
0
    if (ctx->grammar) {
1535
0
        llama_grammar_apply_impl(*ctx->grammar, cur_p);
1536
0
    }
1537
0
}
1538
1539
// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
1540
static struct llama_sampler * llama_sampler_init_grammar_impl(
1541
        const struct llama_vocab * vocab,
1542
                      const char * grammar_str,
1543
                      const char * grammar_root,
1544
                              bool lazy,
1545
                     const char ** trigger_words,
1546
                            size_t num_trigger_words,
1547
               const llama_token * trigger_tokens,
1548
                            size_t num_trigger_tokens,
1549
                     const char ** trigger_patterns,
1550
                            size_t num_trigger_patterns);
1551
1552
0
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1553
0
    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1554
0
    if (!ctx->grammar) {
1555
0
        return;
1556
0
    }
1557
1558
0
    std::vector<const char *>  trigger_patterns_c;
1559
0
    trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
1560
0
    for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
1561
0
        trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
1562
0
    }
1563
1564
0
    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
1565
0
                                                 ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
1566
0
                                                 ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
1567
1568
0
    llama_grammar_free_impl(ctx->grammar);
1569
0
    ctx->grammar = grammar_new;
1570
0
}
1571
1572
0
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1573
0
    const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1574
1575
0
    auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
1576
0
    GGML_ASSERT(result);
1577
1578
    // copy the state
1579
0
    {
1580
0
        auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1581
1582
0
        if (ctx->grammar) {
1583
0
            result_ctx->grammar_str  = ctx->grammar_str;
1584
0
            result_ctx->grammar_root = ctx->grammar_root;
1585
1586
0
            result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1587
0
        }
1588
0
    }
1589
1590
0
    return result;
1591
0
}
1592
1593
0
static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1594
0
    const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1595
1596
0
    if (ctx->grammar) {
1597
0
        llama_grammar_free_impl(ctx->grammar);
1598
0
    }
1599
1600
0
    delete ctx;
1601
0
}
1602
1603
static struct llama_sampler_i llama_sampler_grammar_i = {
1604
    /* .name   = */ llama_sampler_grammar_name,
1605
    /* .accept = */ llama_sampler_grammar_accept_impl,
1606
    /* .apply  = */ llama_sampler_grammar_apply,
1607
    /* .reset  = */ llama_sampler_grammar_reset,
1608
    /* .clone  = */ llama_sampler_grammar_clone,
1609
    /* .free   = */ llama_sampler_grammar_free,
1610
};
1611
1612
static struct llama_sampler * llama_sampler_init_grammar_impl(
1613
        const struct llama_vocab * vocab,
1614
                      const char * grammar_str,
1615
                      const char * grammar_root,
1616
                              bool lazy,
1617
                     const char ** trigger_words,
1618
                            size_t num_trigger_words,
1619
               const llama_token * trigger_tokens,
1620
                            size_t num_trigger_tokens,
1621
                     const char ** trigger_patterns,
1622
0
                            size_t num_trigger_patterns) {
1623
0
    auto * ctx = new llama_sampler_grammar;
1624
1625
0
    if (grammar_str != nullptr && grammar_str[0] != '\0') {
1626
0
        std::string trigger_pattern;
1627
0
        llama_grammar * grammar = nullptr;
1628
        // TODO: remove trigger_words support.
1629
0
        if (trigger_words != nullptr && num_trigger_words > 0) {
1630
0
            GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1631
0
            trigger_pattern = "[\\s\\S]*?(";
1632
0
            for (size_t i = 0; i < num_trigger_words; ++i) {
1633
0
                static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1634
0
                if (i > 0) {
1635
0
                    trigger_pattern += "|";
1636
0
                }
1637
0
                trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1638
0
            }
1639
0
            trigger_pattern += ")[\\s\\S]*";
1640
1641
0
            std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
1642
0
            grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
1643
0
        } else {
1644
0
            grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
1645
0
        }
1646
0
        *ctx = {
1647
0
            /* .vocab        = */ vocab,
1648
0
            /* .grammar_str  = */ grammar_str,
1649
0
            /* .grammar_root = */ grammar_root,
1650
0
            /* .grammar      = */ grammar,
1651
0
        };
1652
0
        if (!ctx->grammar) {
1653
0
            delete ctx;
1654
0
            return nullptr;
1655
0
        }
1656
0
    } else {
1657
0
        *ctx = {
1658
0
            /* .vocab        = */ vocab,
1659
0
            /* .grammar_str  = */ {},
1660
0
            /* .grammar_root = */ {},
1661
0
            /* .grammar      = */ nullptr,
1662
0
        };
1663
0
    }
1664
1665
0
    return llama_sampler_init(
1666
0
        /* .iface = */ &llama_sampler_grammar_i,
1667
0
        /* .ctx   = */ ctx
1668
0
    );
1669
0
}
1670
1671
struct llama_sampler * llama_sampler_init_grammar(
1672
        const struct llama_vocab * vocab,
1673
                      const char * grammar_str,
1674
0
                      const char * grammar_root) {
1675
0
    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
1676
0
}
1677
1678
struct llama_sampler * llama_sampler_init_grammar_lazy(
1679
        const struct llama_vocab * vocab,
1680
                      const char * grammar_str,
1681
                      const char * grammar_root,
1682
                     const char ** trigger_words,
1683
                            size_t num_trigger_words,
1684
               const llama_token * trigger_tokens,
1685
0
                            size_t num_trigger_tokens) {
1686
0
    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
1687
0
}
1688
1689
struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
1690
        const struct llama_vocab * vocab,
1691
                      const char * grammar_str,
1692
                      const char * grammar_root,
1693
                     const char ** trigger_patterns,
1694
                            size_t num_trigger_patterns,
1695
               const llama_token * trigger_tokens,
1696
0
                            size_t num_trigger_tokens) {
1697
0
    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
1698
0
}
1699
1700
// penalties
1701
1702
struct llama_sampler_penalties {
1703
    const int32_t penalty_last_n;
1704
    const float   penalty_repeat;
1705
    const float   penalty_freq;
1706
    const float   penalty_present;
1707
1708
    ring_buffer<llama_token> prev;
1709
1710
    // a frequency map to count token occurrences
1711
    std::unordered_map<llama_token, int> token_count;
1712
};
1713
1714
0
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1715
0
    return "penalties";
1716
0
}
1717
1718
0
static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1719
0
    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1720
0
    if (ctx->penalty_last_n == 0) {
1721
0
        return;
1722
0
    }
1723
1724
0
    ctx->token_count[token]++;
1725
1726
    // if the ring buffer is full, remove the oldest token
1727
0
    if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1728
0
        const auto old = ctx->prev.front();
1729
1730
0
        ctx->token_count[old]--;
1731
0
        if (ctx->token_count[old] == 0) {
1732
0
            ctx->token_count.erase(old);
1733
0
        }
1734
0
    }
1735
1736
0
    ctx->prev.push_back(token);
1737
1738
#if 0
1739
    // sanity check
1740
    std::unordered_map<llama_token, int> tmp;
1741
    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1742
        tmp[ctx->prev.rat(i)]++;
1743
    }
1744
1745
    assert(ctx->token_count == tmp);
1746
#endif
1747
0
}
1748
1749
0
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1750
0
    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1751
1752
0
    if ((ctx->penalty_last_n == 0) ||
1753
0
        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1754
0
        return;
1755
0
    }
1756
1757
    // Apply frequency and presence penalties to the cur_p
1758
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1759
0
        const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1760
0
        if (token_iter == ctx->token_count.end()) {
1761
0
            continue;
1762
0
        }
1763
1764
0
        const int count = token_iter->second;
1765
1766
0
        assert(count > 0 && count <= ctx->penalty_last_n);
1767
1768
        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1769
        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1770
0
        if (cur_p->data[i].logit <= 0) {
1771
0
            cur_p->data[i].logit *= ctx->penalty_repeat;
1772
0
        } else {
1773
0
            cur_p->data[i].logit /= ctx->penalty_repeat;
1774
0
        }
1775
1776
0
        cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1777
0
    }
1778
1779
0
    cur_p->sorted = false;
1780
0
}
1781
1782
0
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1783
0
    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1784
0
    ctx->prev.clear();
1785
0
    ctx->token_count.clear();
1786
0
}
1787
1788
0
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1789
0
    const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1790
0
    auto * result = llama_sampler_init_penalties(
1791
0
            ctx->penalty_last_n,
1792
0
            ctx->penalty_repeat,
1793
0
            ctx->penalty_freq,
1794
0
            ctx->penalty_present);
1795
1796
    // copy the state
1797
0
    {
1798
0
        auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1799
1800
0
        result_ctx->prev = ctx->prev;
1801
0
    }
1802
1803
0
    return result;
1804
0
}
1805
1806
0
static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1807
0
    delete (llama_sampler_penalties *) smpl->ctx;
1808
0
}
1809
1810
static struct llama_sampler_i llama_sampler_penalties_i = {
1811
    /* .name   = */ llama_sampler_penalties_name,
1812
    /* .accept = */ llama_sampler_penalties_accept,
1813
    /* .apply  = */ llama_sampler_penalties_apply,
1814
    /* .reset  = */ llama_sampler_penalties_reset,
1815
    /* .clone  = */ llama_sampler_penalties_clone,
1816
    /* .free   = */ llama_sampler_penalties_free,
1817
};
1818
1819
struct llama_sampler * llama_sampler_init_penalties(
1820
        int32_t penalty_last_n,
1821
        float penalty_repeat,
1822
        float penalty_freq,
1823
0
        float penalty_present) {
1824
0
    penalty_last_n = std::max(penalty_last_n, 0);
1825
1826
0
    return llama_sampler_init(
1827
0
        /* .iface = */ &llama_sampler_penalties_i,
1828
0
        /* .ctx   = */ new llama_sampler_penalties {
1829
0
            /* .penalty_last_n  = */ penalty_last_n,
1830
0
            /* .penalty_repeat  = */ penalty_repeat,
1831
0
            /* .penalty_freq    = */ penalty_freq,
1832
0
            /* .penalty_present = */ penalty_present,
1833
0
            /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
1834
0
            /* .token_count     = */ {},
1835
0
        }
1836
0
    );
1837
0
}
1838
1839
// top-n-sigma
1840
1841
struct llama_sampler_top_n_sigma {
1842
    const float n;
1843
};
1844
1845
0
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1846
0
    return "top-n-sigma";
1847
0
}
1848
1849
0
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1850
0
    auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1851
1852
0
    if (ctx->n <= 0.0f || cur_p->size <= 1) {
1853
0
        return;
1854
0
    }
1855
1856
    // find max logit and calculate mean
1857
0
    float max = cur_p->data[0].logit;
1858
0
    float logits_sum = 0;
1859
0
    size_t valid_count = 0;
1860
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1861
        // Only count non-negative infinity values
1862
0
        if (cur_p->data[i].logit != -INFINITY) {
1863
0
            if (cur_p->data[i].logit > max) {
1864
0
                max = cur_p->data[i].logit;
1865
0
            }
1866
0
            logits_sum += cur_p->data[i].logit;
1867
0
            valid_count++;
1868
0
        }
1869
0
    }
1870
0
    float mean = valid_count > 0 ? logits_sum/valid_count : 0;
1871
1872
    // calculate standard deviation
1873
0
    float acc = 0;
1874
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1875
        // Skip -infinity in std calculation
1876
0
        if (cur_p->data[i].logit != -INFINITY) {
1877
0
            acc += pow(cur_p->data[i].logit - mean, 2);
1878
0
        }
1879
0
    }
1880
0
    float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
1881
1882
    // apply mask
1883
0
    for (size_t i = 0; i < cur_p->size; ++i) {
1884
0
        if (cur_p->data[i].logit < max - (ctx->n * std)) {
1885
0
            cur_p->data[i].logit = -INFINITY;
1886
0
        }
1887
0
    }
1888
1889
0
    llama_sampler_softmax_impl(cur_p, true);
1890
0
}
1891
1892
0
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1893
0
    const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1894
0
    return llama_sampler_init_top_n_sigma(ctx->n);
1895
0
}
1896
1897
0
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1898
0
    delete (llama_sampler_top_n_sigma *) smpl->ctx;
1899
0
}
1900
1901
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1902
    /* .name   = */ llama_sampler_top_n_sigma_name,
1903
    /* .accept = */ nullptr,
1904
    /* .apply  = */ llama_sampler_top_n_sigma_apply,
1905
    /* .reset  = */ nullptr,
1906
    /* .clone  = */ llama_sampler_top_n_sigma_clone,
1907
    /* .free   = */ llama_sampler_top_n_sigma_free,
1908
};
1909
1910
0
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1911
0
    return llama_sampler_init(
1912
0
        /* .iface = */ &llama_sampler_top_n_sigma_i,
1913
0
        /* .ctx   = */ new llama_sampler_top_n_sigma {
1914
0
            /* .n = */ n,
1915
0
        }
1916
0
    );
1917
0
}
1918
1919
// DRY
1920
1921
struct llama_sampler_dry {
1922
    int32_t total_context_size;
1923
1924
    const float   dry_multiplier;
1925
    const float   dry_base;
1926
    const int32_t dry_allowed_length;
1927
    const int32_t dry_penalty_last_n;
1928
1929
    std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1930
    std::vector<int> dry_repeat_count;
1931
    std::unordered_map<llama_token, int> dry_max_token_repeat;
1932
    ring_buffer<llama_token> last_tokens;
1933
};
1934
1935
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1936
0
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1937
0
    for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
1938
0
        std::string word = vocab.detokenize({token_id}, true);
1939
0
        if (word.find(str) != std::string::npos) {
1940
0
            token_sequences.emplace(token_id, std::vector<llama_token>());
1941
0
        } else {
1942
0
            size_t word_len = word.size();
1943
0
            size_t str_len = str.size();
1944
0
            size_t pos = -1;
1945
0
            while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1946
0
                bool match = true;
1947
0
                size_t i;
1948
0
                for (i = 1; i < str_len && i + pos < word_len; ++i) {
1949
0
                    if (word[pos + i] != str[i]) {
1950
0
                        match = false;
1951
0
                        break;
1952
0
                    }
1953
0
                }
1954
0
                if (match) {
1955
0
                    std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
1956
0
                    if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1957
0
                        tokenization.resize(max_tail_len);
1958
0
                    }
1959
1960
                    // Ensure we don't already have a duplicate matching tokenization
1961
0
                    auto its = token_sequences.equal_range(token_id);
1962
0
                    bool found = false;
1963
0
                    for (auto it = its.first; it != its.second; ++it) {
1964
0
                        if (tokenization == it->second) {
1965
0
                            found = true;
1966
0
                            break;
1967
0
                        }
1968
0
                    }
1969
0
                    if (!found) {
1970
0
                        token_sequences.emplace(token_id, tokenization);
1971
0
                    }
1972
0
                }
1973
0
            }
1974
0
        }
1975
0
    }
1976
0
}
1977
1978
0
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1979
0
    return "dry";
1980
0
}
1981
1982
0
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1983
0
    auto * ctx = (llama_sampler_dry *) smpl->ctx;
1984
0
    if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1985
0
        return;
1986
0
    }
1987
1988
0
    ctx->last_tokens.push_back(token);
1989
0
}
1990
1991
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1992
0
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1993
0
    auto * ctx = (llama_sampler_dry *) smpl->ctx;
1994
1995
0
    if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1996
0
        return;
1997
0
    }
1998
1999
0
    int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
2000
0
    int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
2001
2002
0
    if (last_n_repeat <= ctx->dry_allowed_length) {
2003
0
        return;
2004
0
    }
2005
2006
0
    ctx->dry_repeat_count.assign(last_n_repeat, 0);
2007
0
    ctx->dry_max_token_repeat.clear();
2008
2009
    // Step 1: Look for restart sequences to limit the maximum repetition length.
2010
    // Work backwards through the context looking for any token that begins a restart sequence.
2011
    //
2012
    // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
2013
    // sequences that together comprise a restart sequence. This allows us to quickly check
2014
    // whether each token is the head of a complete sequence. Most restart sequences are actually
2015
    // a single token, and for these the "tail" is an empty vector.
2016
    //
2017
    // If the token is a "head", test all restart sequences that begin with this token
2018
    // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
2019
    // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
2020
    // longest matching sequence (if any) is used to limit the maximum repetition length.
2021
    //
2022
    // Note that in the case case of a short sequence contained in a longer one, this might fail to
2023
    // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
2024
    // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
2025
    // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
2026
    //
2027
    // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
2028
    // have already clamped the maximum tail sequence length when generating `restart_sequences`.
2029
    // With clamping, this scan is O(N) in the context length.
2030
2031
0
    int rep_limit = last_n_repeat;
2032
0
    for (int i = 0; i < last_n_repeat; ++i) {
2033
0
        llama_token token = ctx->last_tokens.rat(i);
2034
0
        auto its = ctx->dry_processed_breakers.equal_range(token);
2035
0
        if (its.first == ctx->dry_processed_breakers.end()) {
2036
0
            continue;
2037
0
        }
2038
0
        int longest_match = -1;
2039
0
        for (auto it = its.first; it != its.second; ++it) {
2040
            // Note that (*it) does not contain the head character, so seq_len will be
2041
            // the restart sequence length minus 1.
2042
            // In the common case of a single-token restart sequence, (*it) will be empty
2043
            // and we will trivially match.
2044
0
            int seq_len = (int)it->second.size();
2045
0
            if (seq_len > longest_match && seq_len <= (int)i) {
2046
0
                bool match = true;
2047
0
                for (int offset = 0; offset < seq_len; ++offset) {
2048
                    // The -1 when indexing `last_tokens` is because we already matched the head.
2049
0
                    if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
2050
0
                        match = false;
2051
0
                        break;
2052
0
                    }
2053
0
                }
2054
0
                if (match) {
2055
0
                    longest_match = seq_len;
2056
0
                }
2057
0
            }
2058
0
        }
2059
0
        if (longest_match >= 0) {
2060
            // We found a restart sequence starting `i` tokens from the end and continuing for
2061
            // `longest_match` tokens.
2062
0
            rep_limit = i - longest_match;
2063
0
            break;
2064
0
        }
2065
0
    }
2066
0
    if (rep_limit < ctx->dry_allowed_length) {
2067
0
        return;
2068
0
    }
2069
2070
    // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
2071
    // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
2072
    // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
2073
    //
2074
    // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
2075
    // https://ivanyu.me/blog/2014/10/15/z-algorithm/
2076
    //
2077
    // The code below is adapted from the public domain implementation by the same author here:
2078
    // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
2079
    //
2080
    // Example:
2081
    // Last N tokens: a b c c b c y a b c
2082
    // Repeat counts: 0 0 3 1 0 2 0 0 0 0
2083
    //                    ^
2084
    //   This `3` means that the last three tokens of the context (a b c) also appear here.
2085
    //
2086
    // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
2087
    // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
2088
    // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
2089
    // ensure that the inner while loops only examine each token in the context once as the outer
2090
    // for loop iterates over the context.
2091
2092
0
    {
2093
0
        const int last = last_n_repeat - 1;
2094
2095
0
        int rt = 0;
2096
0
        int lt = 0;
2097
2098
0
        for (int k = 1; k < last_n_repeat; ++k) {
2099
0
            if (k > rt) {
2100
                // If k is outside the current Z-box, do naive computation.
2101
0
                int n = 0;
2102
0
                while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
2103
0
                    ++n;
2104
0
                }
2105
0
                ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
2106
0
                if (n > 0) {
2107
0
                    lt = k;
2108
0
                    rt = k + n - 1;
2109
0
                }
2110
0
            } else {
2111
                // If k is inside the current Z-box, consider two cases.
2112
2113
0
                int p = k - lt; // Pair index.
2114
0
                int right_part_len = rt - k + 1;
2115
2116
0
                if (ctx->dry_repeat_count[last - p] < right_part_len) {
2117
0
                    int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
2118
0
                    ctx->dry_repeat_count[last - k] = n;
2119
0
                } else {
2120
0
                    int i = rt + 1;
2121
0
                    while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
2122
0
                        i += 1;
2123
0
                    }
2124
2125
0
                    int n = std::min(i - k, rep_limit);
2126
0
                    ctx->dry_repeat_count[last - k] = n;
2127
0
                    lt = k;
2128
0
                    rt = i - 1;
2129
0
                }
2130
0
            }
2131
0
        }
2132
0
    }
2133
2134
    // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
2135
    // that would be generated by emitting each new token that would extend a sequence.
2136
    //
2137
    // Following the same example as above:
2138
    // Last N tokens: a b c c b c y a b c
2139
    // Repeat counts: 0 0 3 1 0 2 0 0 0 0
2140
    //
2141
    // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
2142
    // c: 3 -> 4 (from `a b c` to `a b c c`)
2143
    // b: 1 -> 2 (from `c` to `c b`)
2144
    // y: 2 -> 3 (from `b c` to `b c y`)
2145
2146
0
    for (int i = 0; i < last_n_repeat - 1; ++i) {
2147
0
        int repeat_len = ctx->dry_repeat_count[i];
2148
0
        if (repeat_len >= ctx->dry_allowed_length) {
2149
            // This token ends a repeat, so the next token would continue one.
2150
            // By convention, the value of `repeat_len` only includes the tokens currently
2151
            // in the context, not the new token that would be added.
2152
0
            llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
2153
            // Track the maximum sequence ending in this token.
2154
0
            const auto& it = ctx->dry_max_token_repeat.find(token);
2155
0
            if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
2156
0
                ctx->dry_max_token_repeat[token] = repeat_len;
2157
0
            }
2158
0
        }
2159
0
    }
2160
2161
    // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
2162
2163
    // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
2164
    // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
2165
0
    const float FLOAT_MAX_LOG = 88.7228391f;
2166
0
    int max_exponent = 0;
2167
0
    if (ctx->dry_base > 1.000001f) {
2168
0
        max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
2169
0
    }
2170
2171
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2172
0
        const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
2173
0
        if (af_kvp != ctx->dry_max_token_repeat.end()) {
2174
            // Check all sequence breakers starting with this token
2175
0
            auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
2176
0
            bool is_single_token_breaker = false;
2177
2178
0
            for (auto it = range.first; it != range.second; ++it) {
2179
0
                if (it->second.empty()) {
2180
0
                    is_single_token_breaker = true;
2181
0
                    break;
2182
0
                }
2183
0
            }
2184
2185
            // Apply penalty only if it's not a single-token sequence breaker
2186
0
            if (!is_single_token_breaker) {
2187
0
                int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
2188
0
                if (max_exponent > 0 && repeat_exp > max_exponent) {
2189
0
                    repeat_exp = max_exponent;
2190
0
                }
2191
0
                float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
2192
0
                cur_p->data[i].logit -= penalty;
2193
0
            }
2194
0
        }
2195
0
    }
2196
2197
0
    cur_p->sorted = false;
2198
0
}
2199
2200
0
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
2201
0
    auto * ctx = (llama_sampler_dry *) smpl->ctx;
2202
0
    ctx->last_tokens.clear();
2203
0
    ctx->dry_repeat_count.clear();
2204
0
    ctx->dry_max_token_repeat.clear();
2205
0
}
2206
2207
0
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
2208
0
    const auto * ctx = (llama_sampler_dry *) smpl->ctx;
2209
2210
0
    llama_vocab dummy_vocab;
2211
2212
    // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
2213
0
    auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
2214
2215
    // Copy the state, including the processed breakers
2216
0
    {
2217
0
        auto * result_ctx = (llama_sampler_dry *) result->ctx;
2218
0
        result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
2219
0
        result_ctx->dry_repeat_count = ctx->dry_repeat_count;
2220
0
        result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
2221
0
        result_ctx->last_tokens = ctx->last_tokens;
2222
0
    }
2223
2224
0
    return result;
2225
0
}
2226
2227
0
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2228
0
    delete (llama_sampler_dry *) smpl->ctx;
2229
0
}
2230
2231
static struct llama_sampler_i llama_sampler_dry_i = {
2232
    /* .name   = */ llama_sampler_dry_name,
2233
    /* .accept = */ llama_sampler_dry_accept,
2234
    /* .apply  = */ llama_sampler_dry_apply,
2235
    /* .reset  = */ llama_sampler_dry_reset,
2236
    /* .clone  = */ llama_sampler_dry_clone,
2237
    /* .free   = */ llama_sampler_dry_free,
2238
};
2239
2240
0
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2241
0
    int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
2242
0
    std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2243
0
    const int MAX_CHAR_LEN = 40;
2244
0
    const int MAX_SEQ_LEN = 20;
2245
2246
0
    const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2247
2248
0
    if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2249
        // Process sequence breakers
2250
0
        for (size_t i = 0; i < num_breakers; ++i) {
2251
0
            if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
2252
0
                LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
2253
0
                continue;
2254
0
            }
2255
2256
0
            std::string sequence_break(seq_breakers[i]);
2257
0
            if (sequence_break.empty()) {
2258
0
                LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
2259
0
                continue;
2260
0
            }
2261
2262
0
            if (sequence_break.size() > MAX_CHAR_LEN) {
2263
0
                LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
2264
0
                sequence_break.resize(MAX_CHAR_LEN);
2265
0
            }
2266
2267
0
            get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
2268
0
        }
2269
0
    }
2270
2271
0
    return llama_sampler_init(
2272
0
        /* .iface = */ &llama_sampler_dry_i,
2273
0
        /* .ctx   = */ new llama_sampler_dry {
2274
0
            /* .total_context_size     = */ n_ctx_train,
2275
0
            /* .dry_multiplier         = */ dry_multiplier,
2276
0
            /* .dry_base               = */ dry_base,
2277
0
            /* .dry_allowed_length     = */ dry_allowed_length,
2278
0
            /* .dry_penalty_last_n     = */ dry_penalty_last_n,
2279
0
            /* .dry_processed_breakers = */ std::move(processed_breakers),
2280
0
            /* .dry_repeat_count       = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2281
0
            /* .dry_max_token_repeat   = */ {},
2282
0
            /* .last_tokens            = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2283
0
        }
2284
0
    );
2285
0
}
2286
2287
// wrapper for test-sampling.cpp
2288
0
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2289
0
    llama_vocab dummy_vocab;
2290
0
    auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2291
0
    auto * ctx = (llama_sampler_dry *) result->ctx;
2292
2293
    // Process the token-based sequence breakers
2294
0
    ctx->dry_processed_breakers.clear();
2295
0
    if (seq_breakers.empty()) {
2296
0
        LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
2297
0
    } else {
2298
0
        for (const auto& breaker : seq_breakers) {
2299
0
            if (breaker.empty()) {
2300
0
                LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
2301
0
                continue;
2302
0
            }
2303
0
            llama_token head_token = breaker[0];
2304
0
            std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2305
0
            ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
2306
0
        }
2307
2308
0
        if (ctx->dry_processed_breakers.empty()) {
2309
0
            LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
2310
0
        }
2311
0
    }
2312
2313
0
    return result;
2314
0
}
2315
2316
// logit-bias
2317
2318
struct llama_sampler_logit_bias {
2319
    const int32_t n_vocab;
2320
2321
    const std::vector<llama_logit_bias> logit_bias;
2322
2323
    std::vector<llama_logit_bias> to_search;
2324
};
2325
2326
0
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2327
0
    return "logit-bias";
2328
0
}
2329
2330
0
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2331
0
    auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
2332
2333
0
    if (ctx->logit_bias.empty()) {
2334
0
        return;
2335
0
    }
2336
2337
0
    ctx->to_search.clear();
2338
2339
    // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2340
0
    for (const auto & lb : ctx->logit_bias) {
2341
0
        if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
2342
0
            cur_p->data[lb.token].logit += lb.bias;
2343
0
        } else {
2344
0
            ctx->to_search.push_back(lb);
2345
0
        }
2346
0
    }
2347
2348
0
    if (ctx->to_search.empty()) {
2349
0
        return;
2350
0
    }
2351
2352
    // search for the remaining candidates that were not found in the previous step
2353
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2354
0
        for (const auto & lb : ctx->to_search) {
2355
0
            if (cur_p->data[i].id == lb.token) {
2356
0
                cur_p->data[i].logit += lb.bias;
2357
0
                break;
2358
0
            }
2359
0
        }
2360
0
    }
2361
0
}
2362
2363
0
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
2364
0
    const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
2365
0
    return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
2366
0
}
2367
2368
0
static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2369
0
    delete (llama_sampler_logit_bias *) smpl->ctx;
2370
0
}
2371
2372
static struct llama_sampler_i llama_sampler_logit_bias_i = {
2373
    /* .name   = */ llama_sampler_logit_bias_name,
2374
    /* .accept = */ nullptr,
2375
    /* .apply  = */ llama_sampler_logit_bias_apply,
2376
    /* .reset  = */ nullptr,
2377
    /* .clone  = */ llama_sampler_logit_bias_clone,
2378
    /* .free   = */ llama_sampler_logit_bias_free,
2379
};
2380
2381
struct llama_sampler * llama_sampler_init_logit_bias(
2382
                         int32_t   n_vocab,
2383
                         int32_t   n_logit_bias,
2384
0
          const llama_logit_bias * logit_bias) {
2385
0
    return llama_sampler_init(
2386
0
        /* .iface = */ &llama_sampler_logit_bias_i,
2387
0
        /* .ctx   = */ new llama_sampler_logit_bias {
2388
0
            /* .n_vocab    = */ n_vocab,
2389
0
            /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2390
0
            /* .to_search  = */ {},
2391
0
        }
2392
0
    );
2393
0
}
2394
2395
// infill
2396
2397
//#define GGML_DEBUG_SAMPLER_INFILL
2398
2399
struct llama_sampler_infill {
2400
    const struct llama_vocab * vocab;
2401
2402
    std::vector<char> buf0;
2403
    std::vector<char> buf1;
2404
};
2405
2406
0
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2407
0
    return "infill";
2408
0
}
2409
2410
0
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2411
0
    auto * ctx = (llama_sampler_infill *) smpl->ctx;
2412
2413
0
    llama_sampler_softmax_impl(cur_p, true);
2414
2415
#if defined(GGML_DEBUG_SAMPLER_INFILL)
2416
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
2417
#else
2418
0
#define LOG_DBG_CUR(...)
2419
0
#endif
2420
2421
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2422
0
        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2423
0
    }
2424
2425
0
    float p_txt_sum = 0.0f;
2426
0
    float p_eog_sum = 0.0f;
2427
2428
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2429
0
        if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2430
0
            p_eog_sum += cur_p->data[i].p;
2431
0
        } else {
2432
0
            p_txt_sum += cur_p->data[i].p;
2433
0
        }
2434
0
    }
2435
2436
0
    const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
2437
2438
0
    LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2439
2440
0
    if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2441
0
        LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2442
2443
        // keep just the EOG tokens
2444
0
        const auto size_org = cur_p->size;
2445
2446
0
        cur_p->size = 0;
2447
2448
0
        float p_sum = 0.0f;
2449
2450
0
        for (size_t i = 0; i < size_org; ++i) {
2451
0
            if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2452
0
                p_sum += cur_p->data[i].p;
2453
2454
0
                cur_p->data[cur_p->size++] = cur_p->data[i];
2455
0
            }
2456
0
        }
2457
2458
        // normalize probs
2459
0
        for (size_t i = 0; i < cur_p->size; ++i) {
2460
0
            cur_p->data[i].p /= p_sum;
2461
0
        }
2462
2463
0
        return;
2464
0
    }
2465
2466
0
    size_t n_combined = 0; GGML_UNUSED(n_combined);
2467
2468
    // combine tokens with common prefix
2469
0
    for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2470
0
        for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2471
0
            if (cur_p->data[i0].logit == -INFINITY) {
2472
0
                break;
2473
0
            }
2474
2475
0
            if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2476
0
                continue;
2477
0
            }
2478
2479
0
            int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2480
0
            if (len0 < 0) {
2481
0
                ctx->buf0.resize(len0);
2482
0
                len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2483
0
                assert(len0 > 0);
2484
0
            }
2485
2486
0
            int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2487
0
            if (len1 < 0) {
2488
0
                ctx->buf1.resize(len1);
2489
0
                len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2490
0
                assert(len1 > 0);
2491
0
            }
2492
2493
            // token i0 is a prefix of token i1
2494
0
            if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2495
0
                int dst = i0;
2496
0
                int src = i1;
2497
2498
                // merge into the token with higher probability
2499
0
                if (cur_p->data[i1].p > cur_p->data[i0].p) {
2500
0
                    std::swap(dst, src);
2501
0
                }
2502
2503
0
                cur_p->data[dst].p += cur_p->data[src].p;
2504
0
                cur_p->data[src].logit = -INFINITY;
2505
0
                cur_p->data[src].p     = 0.0f;
2506
2507
0
                n_combined++;
2508
0
            }
2509
0
        }
2510
0
    }
2511
2512
0
    size_t n_non_eog = 0;
2513
2514
0
    size_t size_org = cur_p->size;
2515
2516
0
    float p_sum = 0.0f;
2517
0
    float thold = 0.2f;
2518
2519
0
    cur_p->size = 0;
2520
2521
0
    LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2522
2523
0
    for (size_t i = 0; i < size_org; ++i) {
2524
0
        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2525
2526
0
        if (cur_p->data[i].p < thold && !is_eog) {
2527
0
            continue;
2528
0
        }
2529
2530
0
        if (!is_eog) {
2531
0
            ++n_non_eog;
2532
0
        }
2533
2534
0
        p_sum += cur_p->data[i].p;
2535
2536
        // keep this token
2537
0
        cur_p->data[cur_p->size++] = cur_p->data[i];
2538
0
    }
2539
2540
0
    LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2541
2542
    // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2543
0
    if (n_non_eog == 0) {
2544
0
        cur_p->size = 1;
2545
0
        cur_p->data[0].id = ctx->vocab->token_eot();
2546
0
        if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
2547
0
            cur_p->data[0].id = ctx->vocab->token_eos();
2548
0
        }
2549
0
        cur_p->data[0].logit = 1.0f;
2550
2551
0
        GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
2552
2553
0
        return;
2554
0
    }
2555
2556
    // normalize probs
2557
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2558
0
        cur_p->data[i].p /= p_sum;
2559
2560
0
        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2561
0
    }
2562
2563
0
    size_org = cur_p->size;
2564
0
    p_sum = 0.0f;
2565
0
    thold = 1.0/(n_non_eog + 1);
2566
2567
0
    cur_p->size = 0;
2568
2569
0
    LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2570
2571
0
    for (size_t i = 0; i < size_org; ++i) {
2572
0
        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2573
2574
0
        if (cur_p->data[i].p < thold && !is_eog) {
2575
0
            continue;
2576
0
        }
2577
2578
0
        p_sum += cur_p->data[i].p;
2579
2580
0
        cur_p->data[cur_p->size++] = cur_p->data[i];
2581
0
    }
2582
2583
    // normalize probs
2584
0
    for (size_t i = 0; i < cur_p->size; ++i) {
2585
0
        cur_p->data[i].p /= p_sum;
2586
2587
0
        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2588
0
    }
2589
2590
0
#undef LOG_DBG_CUR
2591
0
}
2592
2593
0
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2594
0
    const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2595
0
    return llama_sampler_init_infill(ctx->vocab);
2596
0
}
2597
2598
0
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2599
0
    delete (llama_sampler_infill *) smpl->ctx;
2600
0
}
2601
2602
static struct llama_sampler_i llama_sampler_infill_i = {
2603
    /* .name   = */ llama_sampler_infill_name,
2604
    /* .accept = */ nullptr,
2605
    /* .apply  = */ llama_sampler_infill_apply,
2606
    /* .reset  = */ nullptr,
2607
    /* .clone  = */ llama_sampler_infill_clone,
2608
    /* .free   = */ llama_sampler_infill_free,
2609
};
2610
2611
0
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2612
0
    return llama_sampler_init(
2613
0
        /* .iface = */ &llama_sampler_infill_i,
2614
0
        /* .ctx   = */ new llama_sampler_infill {
2615
0
            /* .vocab = */ vocab,
2616
0
            /* .buf0  = */ std::vector<char>(512),
2617
0
            /* .buf1  = */ std::vector<char>(512),
2618
0
        }
2619
0
    );
2620
0
}
2621
2622
// utils
2623
2624
0
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2625
0
    if (smpl->iface == &llama_sampler_dist_i) {
2626
0
        return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
2627
0
    }
2628
2629
0
    if (smpl->iface == &llama_sampler_mirostat_i) {
2630
0
        return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
2631
0
    }
2632
2633
0
    if (smpl->iface == &llama_sampler_mirostat_v2_i) {
2634
0
        return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
2635
0
    }
2636
2637
0
    if (smpl->iface == &llama_sampler_chain_i) {
2638
0
        const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2639
0
        for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2640
0
            const uint32_t seed = llama_sampler_get_seed(*it);
2641
0
            if (seed != LLAMA_DEFAULT_SEED) {
2642
0
                return seed;
2643
0
            }
2644
0
        }
2645
0
    }
2646
2647
0
    return LLAMA_DEFAULT_SEED;
2648
0
}
2649
2650
// perf
2651
2652
0
struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
2653
0
    struct llama_perf_sampler_data data = {};
2654
2655
0
    if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2656
0
        GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2657
0
    }
2658
2659
0
    const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
2660
2661
0
    data.t_sample_ms = 1e-3 * ctx->t_sample_us;
2662
0
    data.n_sample    = std::max(0, ctx->n_sample);
2663
2664
0
    return data;
2665
0
}
2666
2667
0
void llama_perf_sampler_print(const struct llama_sampler * chain) {
2668
0
    const auto data = llama_perf_sampler(chain);
2669
2670
0
    LLAMA_LOG_INFO("%s:    samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
2671
0
}
2672
2673
0
void llama_perf_sampler_reset(struct llama_sampler * chain) {
2674
0
    if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2675
0
        GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2676
0
    }
2677
2678
0
    auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2679
2680
0
    ctx->t_sample_us = 0;
2681
0
    ctx->n_sample    = 0;
2682
0
}