/src/llama.cpp/src/llama-sampler.cpp
Line | Count | Source |
1 | | #include "llama-sampler.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-vocab.h" |
5 | | #include "llama-grammar.h" |
6 | | |
7 | | #include "ggml-cpp.h" |
8 | | |
9 | | #include <array> |
10 | | #include <algorithm> |
11 | | #include <cassert> |
12 | | #include <cfloat> |
13 | | #include <chrono> |
14 | | #include <cmath> |
15 | | #include <cstdlib> |
16 | | #include <cstring> |
17 | | #include <ctime> |
18 | | #include <numeric> |
19 | | #include <random> |
20 | | #include <unordered_map> |
21 | | #include <stdexcept> |
22 | | |
23 | | // the ring buffer works similarly to std::deque, but with a fixed capacity |
24 | | template<typename T> |
25 | | struct ring_buffer { |
26 | 0 | ring_buffer(size_t cap) : capacity(cap), data(cap) {} |
27 | | |
28 | 0 | T & front() { |
29 | 0 | if (sz == 0) { |
30 | 0 | throw std::runtime_error("ring buffer is empty"); |
31 | 0 | } |
32 | 0 | return data[first]; |
33 | 0 | } |
34 | | |
35 | | const T & front() const { |
36 | | if (sz == 0) { |
37 | | throw std::runtime_error("ring buffer is empty"); |
38 | | } |
39 | | return data[first]; |
40 | | } |
41 | | |
42 | | T & back() { |
43 | | if (sz == 0) { |
44 | | throw std::runtime_error("ring buffer is empty"); |
45 | | } |
46 | | return data[pos]; |
47 | | } |
48 | | |
49 | | const T & back() const { |
50 | | if (sz == 0) { |
51 | | throw std::runtime_error("ring buffer is empty"); |
52 | | } |
53 | | return data[pos]; |
54 | | } |
55 | | |
56 | 0 | void push_back(const T & value) { |
57 | 0 | if (capacity == 0) { |
58 | 0 | throw std::runtime_error("ring buffer: capacity is zero"); |
59 | 0 | } |
60 | | |
61 | 0 | if (sz == capacity) { |
62 | | // advance the start when buffer is full |
63 | 0 | first = (first + 1) % capacity; |
64 | 0 | } else { |
65 | 0 | sz++; |
66 | 0 | } |
67 | 0 | data[pos] = value; |
68 | 0 | pos = (pos + 1) % capacity; |
69 | 0 | } |
70 | | |
71 | | T pop_front() { |
72 | | if (sz == 0) { |
73 | | throw std::runtime_error("ring buffer is empty"); |
74 | | } |
75 | | T value = data[first]; |
76 | | first = (first + 1) % capacity; |
77 | | sz--; |
78 | | return value; |
79 | | } |
80 | | |
81 | | //T & operator[](size_t i) { |
82 | | // if (i >= sz) { |
83 | | // throw std::runtime_error("ring buffer: index out of bounds"); |
84 | | // } |
85 | | // return data[(first + i) % capacity]; |
86 | | //} |
87 | | |
88 | | //const T & at(size_t i) const { |
89 | | // if (i >= sz) { |
90 | | // throw std::runtime_error("ring buffer: index out of bounds"); |
91 | | // } |
92 | | // return data[(first + i) % capacity]; |
93 | | //} |
94 | | |
95 | 0 | const T & rat(size_t i) const { |
96 | 0 | if (i >= sz) { |
97 | 0 | throw std::runtime_error("ring buffer: index out of bounds"); |
98 | 0 | } |
99 | 0 | return data[(first + sz - i - 1) % capacity]; |
100 | 0 | } |
101 | | |
102 | | std::vector<T> to_vector() const { |
103 | | std::vector<T> result; |
104 | | result.reserve(sz); |
105 | | for (size_t i = 0; i < sz; i++) { |
106 | | result.push_back(data[(first + i) % capacity]); |
107 | | } |
108 | | return result; |
109 | | } |
110 | | |
111 | 0 | void clear() { |
112 | | // here only reset the status of the buffer |
113 | 0 | sz = 0; |
114 | 0 | first = 0; |
115 | 0 | pos = 0; |
116 | 0 | } |
117 | | |
118 | | bool empty() const { |
119 | | return sz == 0; |
120 | | } |
121 | | |
122 | 0 | size_t size() const { |
123 | 0 | return sz; |
124 | 0 | } |
125 | | |
126 | | size_t capacity = 0; |
127 | | size_t sz = 0; |
128 | | size_t first = 0; |
129 | | size_t pos = 0; |
130 | | |
131 | | std::vector<T> data; |
132 | | }; |
133 | | |
134 | | // writes result in res, does not mutate cur |
135 | 0 | static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) { |
136 | 0 | static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { |
137 | 0 | return a.logit > b.logit; |
138 | 0 | }; |
139 | |
|
140 | 0 | constexpr int nbuckets = 128; |
141 | 0 | constexpr float bucket_low = -10.0f; |
142 | 0 | constexpr float bucket_high = 10.0f; |
143 | 0 | constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); |
144 | 0 | constexpr float bucket_inter = -bucket_low * bucket_scale; |
145 | |
|
146 | 0 | std::vector<int> bucket_idx; |
147 | 0 | std::vector<int> histo(nbuckets, 0); |
148 | |
|
149 | 0 | std::vector<llama_token_data*> bucket_ptrs; |
150 | |
|
151 | 0 | bucket_idx.reserve(cur.size); |
152 | |
|
153 | 0 | for (int i = 0; i < (int)cur.size; ++i) { |
154 | 0 | const float val = cur.data[i].logit; |
155 | 0 | int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); |
156 | 0 | ib = std::max(0, std::min(nbuckets - 1, ib)); |
157 | 0 | bucket_idx.push_back(ib); |
158 | 0 | ++histo[ib]; |
159 | 0 | } |
160 | 0 | int nhave = 0; |
161 | 0 | int ib = nbuckets - 1; |
162 | 0 | for ( ; ib >= 0; --ib) { |
163 | 0 | nhave += histo[ib]; |
164 | 0 | if (nhave >= npartial) { |
165 | 0 | break; |
166 | 0 | } |
167 | 0 | } |
168 | 0 | res.resize(nhave); |
169 | 0 | auto * ptr = res.data(); |
170 | 0 | bucket_ptrs.reserve(nbuckets - ib); |
171 | 0 | for (int j = nbuckets - 1; j >= ib; --j) { |
172 | 0 | bucket_ptrs.push_back(ptr); |
173 | 0 | ptr += histo[j]; |
174 | 0 | } |
175 | 0 | for (int i = 0; i < (int)cur.size; ++i) { |
176 | 0 | int j = bucket_idx[i]; |
177 | 0 | if (j >= ib) { |
178 | 0 | *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i]; |
179 | 0 | } |
180 | 0 | } |
181 | |
|
182 | 0 | ptr = res.data(); |
183 | 0 | int ndone = 0; |
184 | 0 | for (int j = nbuckets - 1; j > ib; --j) { |
185 | 0 | std::sort(ptr, ptr + histo[j], comp); |
186 | 0 | ptr += histo[j]; |
187 | 0 | ndone += histo[j]; |
188 | 0 | } |
189 | 0 | std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp); |
190 | 0 | } |
191 | | |
192 | | // reduces the size of cur_p to npartial, keeping only the top npartial elements |
193 | 0 | static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) { |
194 | 0 | static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { |
195 | 0 | return a.logit > b.logit; |
196 | 0 | }; |
197 | |
|
198 | 0 | if (npartial <= 128) { |
199 | 0 | std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp); |
200 | |
|
201 | 0 | cur_p->size = npartial; |
202 | 0 | cur_p->sorted = true; |
203 | |
|
204 | 0 | return; |
205 | 0 | } |
206 | | |
207 | 0 | std::vector<llama_token_data> tmp; |
208 | |
|
209 | 0 | llama_token_data_array_partial_sort(*cur_p, npartial, tmp); |
210 | |
|
211 | 0 | std::copy(tmp.data(), tmp.data() + npartial, cur_p->data); |
212 | |
|
213 | 0 | cur_p->size = npartial; |
214 | 0 | cur_p->sorted = true; |
215 | 0 | } |
216 | | |
217 | 0 | static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { |
218 | | // iterator for the probabilities |
219 | 0 | #ifdef __GNUC__ |
220 | 0 | #pragma GCC diagnostic push |
221 | 0 | #pragma GCC diagnostic ignored "-Wunused-local-typedefs" |
222 | 0 | #endif |
223 | |
|
224 | 0 | struct probs_iterator { |
225 | 0 | typedef std::input_iterator_tag iterator_category; |
226 | 0 | typedef float value_type; |
227 | 0 | typedef float * pointer; |
228 | 0 | typedef float & reference; |
229 | 0 | typedef ptrdiff_t difference_type; |
230 | |
|
231 | 0 | const llama_token_data * data; |
232 | |
|
233 | 0 | bool operator==(const probs_iterator & other) const { return data == other.data; } |
234 | 0 | bool operator!=(const probs_iterator & other) const { return data != other.data; } |
235 | 0 | const float & operator*() const { return data->p; } |
236 | 0 | probs_iterator & operator++() { ++data; return *this; } |
237 | 0 | probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } |
238 | 0 | }; |
239 | |
|
240 | 0 | #ifdef __GNUC__ |
241 | 0 | #pragma GCC diagnostic pop |
242 | 0 | #endif |
243 | |
|
244 | 0 | std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); |
245 | |
|
246 | 0 | return dist(rng); |
247 | 0 | } |
248 | | |
249 | | /* |
250 | | static void llama_log_softmax(float * array, size_t size) { |
251 | | float max_l = *std::max_element(array, array + size); |
252 | | float sum = 0.f; |
253 | | for (size_t i = 0; i < size; ++i) { |
254 | | float p = expf(array[i] - max_l); |
255 | | sum += p; |
256 | | array[i] = p; |
257 | | } |
258 | | |
259 | | for (size_t i = 0; i < size; ++i) { |
260 | | array[i] = logf(array[i] / sum); |
261 | | } |
262 | | } |
263 | | */ |
264 | | |
265 | 0 | static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { |
266 | 0 | if (temp <= 0.0f) { |
267 | | // find the token with the highest logit and set the rest to -inf |
268 | 0 | size_t max_i = 0; |
269 | 0 | float max_l = cur_p->data[0].logit; |
270 | |
|
271 | 0 | for (size_t i = 1; i < cur_p->size; ++i) { |
272 | 0 | if (cur_p->data[i ].logit > max_l) { |
273 | 0 | cur_p->data[max_i].logit = -INFINITY; |
274 | 0 | max_i = i; |
275 | 0 | max_l = cur_p->data[i].logit; |
276 | 0 | } else { |
277 | 0 | cur_p->data[i].logit = -INFINITY; |
278 | 0 | } |
279 | 0 | } |
280 | |
|
281 | 0 | return; |
282 | 0 | } |
283 | | |
284 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
285 | 0 | cur_p->data[i].logit /= temp; |
286 | 0 | } |
287 | 0 | } |
288 | | |
289 | 0 | static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) { |
290 | 0 | GGML_ASSERT(cur_p->size > 0); |
291 | | |
292 | | // Sort the logits in descending order if requested |
293 | 0 | if (do_sort && !cur_p->sorted) { |
294 | 0 | llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); |
295 | 0 | } |
296 | |
|
297 | 0 | float max_l = cur_p->data[0].logit; |
298 | 0 | if (!cur_p->sorted) { |
299 | 0 | for (size_t i = 1; i < cur_p->size; ++i) { |
300 | 0 | max_l = std::max(max_l, cur_p->data[i].logit); |
301 | 0 | } |
302 | 0 | } |
303 | |
|
304 | 0 | float cum_sum = 0.0f; |
305 | |
|
306 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
307 | 0 | float p = expf(cur_p->data[i].logit - max_l); |
308 | 0 | cur_p->data[i].p = p; |
309 | 0 | cum_sum += p; |
310 | 0 | } |
311 | |
|
312 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
313 | 0 | cur_p->data[i].p /= cum_sum; |
314 | 0 | } |
315 | 0 | } |
316 | | |
317 | 0 | static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { |
318 | | // if (k >= (int32_t)cur_p->size) { |
319 | | // return; |
320 | | // } |
321 | |
|
322 | 0 | if (k <= 0) { |
323 | 0 | return; |
324 | 0 | } |
325 | | |
326 | 0 | k = std::min(k, (int) cur_p->size); |
327 | | |
328 | | // Sort scores in descending order |
329 | 0 | if (!cur_p->sorted) { |
330 | 0 | llama_token_data_array_partial_sort_inplace(cur_p, k); |
331 | 0 | } |
332 | |
|
333 | 0 | cur_p->size = k; |
334 | 0 | } |
335 | | |
336 | 0 | static uint32_t get_rng_seed(uint32_t seed) { |
337 | 0 | if (seed == LLAMA_DEFAULT_SEED) { |
338 | | // use system clock if std::random_device is not a true RNG |
339 | 0 | static bool is_rd_prng = std::random_device().entropy() == 0; |
340 | 0 | if (is_rd_prng) { |
341 | 0 | return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); |
342 | 0 | } |
343 | 0 | std::random_device rd; |
344 | 0 | return rd(); |
345 | 0 | } |
346 | 0 | return seed; |
347 | 0 | } |
348 | | |
349 | | // llama_sampler API |
350 | | |
351 | | struct llama_sampler * llama_sampler_init( |
352 | | struct llama_sampler_i * iface, |
353 | 0 | llama_sampler_context_t ctx) { |
354 | 0 | return new llama_sampler { |
355 | 0 | /* .iface = */ iface, |
356 | 0 | /* .ctx = */ ctx, |
357 | 0 | }; |
358 | 0 | } |
359 | | |
360 | 0 | const char * llama_sampler_name(const struct llama_sampler * smpl) { |
361 | 0 | if (!smpl->iface) { |
362 | 0 | return "(null)"; |
363 | 0 | } |
364 | | |
365 | 0 | return smpl->iface->name(smpl); |
366 | 0 | } |
367 | | |
368 | 0 | void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { |
369 | 0 | if (!smpl) { |
370 | 0 | return; |
371 | 0 | } |
372 | | |
373 | 0 | if (smpl->iface->accept) { |
374 | 0 | smpl->iface->accept(smpl, token); |
375 | 0 | } |
376 | 0 | } |
377 | | |
378 | 0 | void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { |
379 | 0 | if (!smpl) { |
380 | 0 | return; |
381 | 0 | } |
382 | | |
383 | 0 | GGML_ASSERT(smpl->iface->apply); |
384 | 0 | smpl->iface->apply(smpl, cur_p); |
385 | 0 | } |
386 | | |
387 | 0 | void llama_sampler_reset(struct llama_sampler * smpl) { |
388 | 0 | if (!smpl) { |
389 | 0 | return; |
390 | 0 | } |
391 | | |
392 | 0 | if (smpl->iface->reset) { |
393 | 0 | smpl->iface->reset(smpl); |
394 | 0 | } |
395 | 0 | } |
396 | | |
397 | 0 | struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { |
398 | 0 | if (!smpl) { |
399 | 0 | return nullptr; |
400 | 0 | } |
401 | | |
402 | 0 | if (smpl->iface->clone) { |
403 | 0 | return smpl->iface->clone(smpl); |
404 | 0 | } |
405 | | |
406 | 0 | if (smpl->ctx == nullptr) { |
407 | 0 | return llama_sampler_init( |
408 | 0 | /* .iface = */ smpl->iface, |
409 | 0 | /* .ctx = */ nullptr |
410 | 0 | ); |
411 | 0 | } |
412 | | |
413 | 0 | GGML_ABORT("the sampler does not support cloning"); |
414 | 0 | } |
415 | | |
416 | 0 | void llama_sampler_free(struct llama_sampler * smpl) { |
417 | 0 | if (smpl == nullptr) { |
418 | 0 | return; |
419 | 0 | } |
420 | | |
421 | 0 | if (smpl->iface->free) { |
422 | 0 | smpl->iface->free(smpl); |
423 | 0 | } |
424 | |
|
425 | 0 | delete smpl; |
426 | 0 | } |
427 | | |
428 | | // empty sampler |
429 | | |
430 | | struct llama_sampler_empty { |
431 | | const char * name; |
432 | | }; |
433 | | |
434 | | static struct llama_sampler * llama_sampler_init_empty(const char * name); |
435 | | |
436 | 0 | static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { |
437 | 0 | auto * ctx = (llama_sampler_empty *) smpl->ctx; |
438 | 0 | return ctx->name; |
439 | 0 | } |
440 | | |
441 | 0 | static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { |
442 | 0 | GGML_UNUSED(smpl); |
443 | 0 | GGML_UNUSED(token); |
444 | 0 | } |
445 | | |
446 | 0 | static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
447 | 0 | GGML_UNUSED(smpl); |
448 | 0 | GGML_UNUSED(cur_p); |
449 | 0 | } |
450 | | |
451 | 0 | static void llama_sampler_empty_reset(struct llama_sampler * smpl) { |
452 | 0 | GGML_UNUSED(smpl); |
453 | 0 | } |
454 | | |
455 | 0 | static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { |
456 | 0 | auto * ctx = (llama_sampler_empty *) smpl->ctx; |
457 | 0 | return llama_sampler_init_empty(ctx->name); |
458 | 0 | } |
459 | | |
460 | 0 | static void llama_sampler_empty_free(struct llama_sampler * smpl) { |
461 | 0 | delete (llama_sampler_empty *) smpl->ctx; |
462 | 0 | } |
463 | | |
464 | | static bool llama_sampler_empty_backend_init( |
465 | | struct llama_sampler * smpl, |
466 | 0 | ggml_backend_buffer_type_t buft) { |
467 | 0 | GGML_UNUSED(smpl); |
468 | 0 | GGML_UNUSED(buft); |
469 | |
|
470 | 0 | return true; |
471 | 0 | } |
472 | | |
473 | | static void llama_sampler_empty_backend_accept( |
474 | | struct llama_sampler * smpl, |
475 | | ggml_context * ctx, |
476 | | ggml_cgraph * gf, |
477 | 0 | struct ggml_tensor * selected_token) { |
478 | 0 | GGML_UNUSED(smpl); |
479 | 0 | GGML_UNUSED(ctx); |
480 | 0 | GGML_UNUSED(gf); |
481 | 0 | GGML_UNUSED(selected_token); |
482 | 0 | } |
483 | | |
484 | | static void llama_sampler_empty_backend_apply( |
485 | | struct llama_sampler * smpl, |
486 | | struct ggml_context * ctx, |
487 | | struct ggml_cgraph * gf, |
488 | 0 | struct llama_sampler_data * data) { |
489 | 0 | GGML_UNUSED(smpl); |
490 | 0 | GGML_UNUSED(ctx); |
491 | 0 | GGML_UNUSED(gf); |
492 | 0 | GGML_UNUSED(data); |
493 | 0 | } |
494 | | |
495 | 0 | static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { |
496 | 0 | GGML_UNUSED(smpl); |
497 | 0 | } |
498 | | |
499 | | static struct llama_sampler_i llama_sampler_empty_i = { |
500 | | /* .name = */ llama_sampler_empty_name, |
501 | | /* .accept = */ llama_sampler_empty_accept, |
502 | | /* .apply = */ llama_sampler_empty_apply, |
503 | | /* .reset = */ llama_sampler_empty_reset, |
504 | | /* .clone = */ llama_sampler_empty_clone, |
505 | | /* .free = */ llama_sampler_empty_free, |
506 | | /* .backend_init = */ llama_sampler_empty_backend_init, |
507 | | /* .backend_accept = */ llama_sampler_empty_backend_accept, |
508 | | /* .backend_apply = */ llama_sampler_empty_backend_apply, |
509 | | /* .backend_set_input = */ llama_sampler_empty_backend_set_input, |
510 | | }; |
511 | | |
512 | 0 | struct llama_sampler * llama_sampler_init_empty(const char * name) { |
513 | 0 | return llama_sampler_init( |
514 | 0 | /* .iface = */ &llama_sampler_empty_i, |
515 | 0 | /* .ctx = */ new llama_sampler_empty { |
516 | 0 | /* .name = */ name, |
517 | 0 | } |
518 | 0 | ); |
519 | 0 | } |
520 | | |
521 | | // common backend sampler functionality |
522 | | // |
523 | | // +name : means that the sampler is support and will run on the backend |
524 | | // -name : means that a ggml operator is not supported by the backend |
525 | | // |
526 | | struct llama_sampler_backend { |
527 | 0 | llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} |
528 | | |
529 | 0 | const char * get_name() { |
530 | 0 | if (!is_init) { |
531 | 0 | return name.c_str(); |
532 | 0 | } |
533 | | |
534 | 0 | if (support) { |
535 | 0 | name_ext = "+" + name; |
536 | 0 | } else { |
537 | 0 | name_ext = "-" + name; |
538 | 0 | } |
539 | |
|
540 | 0 | return name_ext.c_str(); |
541 | 0 | } |
542 | | |
543 | 0 | void init(bool support) { |
544 | 0 | GGML_ASSERT(this->is_init == false); |
545 | |
|
546 | 0 | this->is_init = true; |
547 | 0 | this->support = support; |
548 | 0 | } |
549 | | |
550 | | private: |
551 | | std::string name; |
552 | | std::string name_ext; |
553 | | |
554 | | bool is_init; |
555 | | bool support; |
556 | | }; |
557 | | |
558 | | // check if all ggml ops used by the sampler are supported by the backend |
559 | | static bool llama_sampler_backend_support( |
560 | | llama_sampler * smpl, |
561 | 0 | ggml_backend_buffer_type_t buft) { |
562 | 0 | auto * device = ggml_backend_buft_get_device(buft); |
563 | 0 | if (!device) { |
564 | | // CPU backend always supported |
565 | 0 | return true; |
566 | 0 | } |
567 | | |
568 | 0 | ggml_init_params params = { |
569 | 0 | /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), |
570 | 0 | /*.mem_buffer =*/ NULL, |
571 | 0 | /*.no_alloc =*/ true, |
572 | 0 | }; |
573 | |
|
574 | 0 | ggml_context_ptr ctx_ptr { ggml_init(params) }; |
575 | 0 | if (!ctx_ptr) { |
576 | 0 | throw std::runtime_error(format("failed to create ggml context")); |
577 | 0 | } |
578 | | |
579 | 0 | ggml_context * ctx = ctx_ptr.get(); |
580 | |
|
581 | 0 | const int64_t n = 1024*1024; |
582 | |
|
583 | 0 | llama_sampler_data data = { |
584 | 0 | /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), |
585 | 0 | /*.probs = */ nullptr, |
586 | 0 | /*.sampled = */ nullptr, |
587 | 0 | /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), |
588 | 0 | }; |
589 | |
|
590 | 0 | ggml_cgraph * gf = ggml_new_graph(ctx); |
591 | |
|
592 | 0 | smpl->iface->backend_apply(smpl, ctx, gf, &data); |
593 | |
|
594 | 0 | if (data.logits) { |
595 | 0 | ggml_build_forward_expand(gf, data.logits); |
596 | 0 | } |
597 | |
|
598 | 0 | if (data.probs) { |
599 | 0 | ggml_build_forward_expand(gf, data.probs); |
600 | 0 | } |
601 | |
|
602 | 0 | if (data.sampled) { |
603 | 0 | ggml_build_forward_expand(gf, data.sampled); |
604 | 0 | } |
605 | |
|
606 | 0 | if (data.candidates) { |
607 | 0 | ggml_build_forward_expand(gf, data.candidates); |
608 | 0 | } |
609 | |
|
610 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
611 | 0 | struct ggml_tensor * op = ggml_graph_node(gf, i); |
612 | |
|
613 | 0 | if (!ggml_backend_dev_supports_op(device, op)) { |
614 | 0 | LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", |
615 | 0 | __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); |
616 | |
|
617 | 0 | return false; |
618 | 0 | } |
619 | 0 | } |
620 | | |
621 | 0 | return true; |
622 | 0 | } |
623 | | |
624 | | // sampler chain |
625 | | |
626 | 0 | static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { |
627 | 0 | return "chain"; |
628 | 0 | } |
629 | | |
630 | 0 | static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { |
631 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
632 | |
|
633 | 0 | time_meas tm(chain->t_sample_us, chain->params.no_perf); |
634 | |
|
635 | 0 | for (auto & smpl : chain->samplers) { |
636 | 0 | llama_sampler_accept(smpl.ptr, token); |
637 | 0 | } |
638 | |
|
639 | 0 | chain->n_sample++; |
640 | 0 | } |
641 | | |
642 | 0 | static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
643 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
644 | |
|
645 | 0 | time_meas tm(chain->t_sample_us, chain->params.no_perf); |
646 | |
|
647 | 0 | bool is_backend = chain->is_init; |
648 | |
|
649 | 0 | for (auto & smpl : chain->samplers) { |
650 | 0 | if (is_backend && smpl.is_backend) { |
651 | 0 | continue; |
652 | 0 | } |
653 | | |
654 | 0 | is_backend = false; |
655 | |
|
656 | 0 | if (smpl.ptr->iface->apply == nullptr) { |
657 | 0 | continue; |
658 | 0 | } |
659 | | |
660 | 0 | llama_sampler_apply(smpl.ptr, cur_p); |
661 | 0 | } |
662 | 0 | } |
663 | | |
664 | 0 | static void llama_sampler_chain_reset(struct llama_sampler * smpl) { |
665 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
666 | |
|
667 | 0 | for (auto & smpl : chain->samplers) { |
668 | 0 | llama_sampler_reset(smpl.ptr); |
669 | 0 | } |
670 | 0 | } |
671 | | |
672 | 0 | static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { |
673 | 0 | const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; |
674 | |
|
675 | 0 | auto * result = llama_sampler_chain_init(chain_src->params); |
676 | |
|
677 | 0 | for (const auto & smpl : chain_src->samplers) { |
678 | 0 | llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); |
679 | 0 | } |
680 | |
|
681 | 0 | return result; |
682 | 0 | } |
683 | | |
684 | 0 | static void llama_sampler_chain_free(struct llama_sampler * smpl) { |
685 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
686 | |
|
687 | 0 | for (auto & smpl : chain->samplers) { |
688 | 0 | llama_sampler_free(smpl.ptr); |
689 | 0 | } |
690 | |
|
691 | 0 | delete chain; |
692 | 0 | } |
693 | | |
694 | | static bool llama_sampler_chain_backend_init( |
695 | | struct llama_sampler * smpl, |
696 | 0 | ggml_backend_buffer_type_t buft) { |
697 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
698 | |
|
699 | 0 | GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); |
700 | |
|
701 | 0 | chain->is_init = true; |
702 | |
|
703 | 0 | bool res = true; |
704 | |
|
705 | 0 | for (auto & smpl : chain->samplers) { |
706 | 0 | bool res_cur = true; |
707 | | |
708 | | // to be able to run a sampler on the backend, it has to: |
709 | | // - have the .backend_init() API implemented |
710 | | // - return true during .backend_init() |
711 | 0 | if (smpl.ptr->iface->backend_init) { |
712 | 0 | if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { |
713 | 0 | res_cur = false; |
714 | 0 | } |
715 | 0 | } else { |
716 | 0 | res_cur = false; |
717 | 0 | } |
718 | |
|
719 | 0 | smpl.is_backend = res_cur; |
720 | |
|
721 | 0 | res = res && res_cur; |
722 | 0 | } |
723 | |
|
724 | 0 | return res; |
725 | 0 | } |
726 | | |
727 | | static void llama_sampler_chain_backend_accept( |
728 | | struct llama_sampler * smpl, |
729 | | ggml_context * ctx, |
730 | | ggml_cgraph * gf, |
731 | 0 | struct ggml_tensor * selected_token) { |
732 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
733 | |
|
734 | 0 | for (auto & smpl : chain->samplers) { |
735 | 0 | if (!smpl.is_backend) { |
736 | 0 | break; |
737 | 0 | } |
738 | | |
739 | 0 | if (smpl.ptr->iface->backend_accept) { |
740 | 0 | smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); |
741 | 0 | } |
742 | 0 | } |
743 | 0 | } |
744 | | |
745 | | static void llama_sampler_chain_backend_apply( |
746 | | struct llama_sampler * smpl, |
747 | | struct ggml_context * ctx, |
748 | | struct ggml_cgraph * gf, |
749 | 0 | struct llama_sampler_data * data) { |
750 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
751 | |
|
752 | 0 | GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); |
753 | |
|
754 | 0 | for (auto & smpl : chain->samplers) { |
755 | 0 | if (!smpl.is_backend) { |
756 | 0 | break; |
757 | 0 | } |
758 | | |
759 | 0 | if (smpl.ptr->iface->backend_apply) { |
760 | 0 | smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); |
761 | 0 | } |
762 | 0 | } |
763 | 0 | } |
764 | | |
765 | 0 | static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { |
766 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
767 | |
|
768 | 0 | for (auto & smpl : chain->samplers) { |
769 | 0 | if (!smpl.is_backend) { |
770 | 0 | break; |
771 | 0 | } |
772 | | |
773 | 0 | if (smpl.ptr->iface->backend_set_input) { |
774 | 0 | smpl.ptr->iface->backend_set_input(smpl.ptr); |
775 | 0 | } |
776 | 0 | } |
777 | 0 | } |
778 | | |
779 | | static struct llama_sampler_i llama_sampler_chain_i = { |
780 | | /* .name = */ llama_sampler_chain_name, |
781 | | /* .accept = */ llama_sampler_chain_accept, |
782 | | /* .apply = */ llama_sampler_chain_apply, |
783 | | /* .reset = */ llama_sampler_chain_reset, |
784 | | /* .clone = */ llama_sampler_chain_clone, |
785 | | /* .free = */ llama_sampler_chain_free, |
786 | | /* .backend_init = */ llama_sampler_chain_backend_init, |
787 | | /* .backend_accept = */ llama_sampler_chain_backend_accept, |
788 | | /* .backend_apply = */ llama_sampler_chain_backend_apply, |
789 | | /* .backend_set_input = */ llama_sampler_chain_backend_set_input, |
790 | | }; |
791 | | |
792 | 0 | struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { |
793 | 0 | return llama_sampler_init( |
794 | 0 | /* .iface = */ &llama_sampler_chain_i, |
795 | 0 | /* .ctx = */ new llama_sampler_chain { |
796 | 0 | /* .params = */ params, |
797 | 0 | /* .is_init = */ false, |
798 | 0 | /* .samplers = */ {}, |
799 | 0 | /* .cur = */ {}, |
800 | 0 | /* .t_sample_us = */ 0, |
801 | 0 | /* .n_sample = */ 0, |
802 | 0 | } |
803 | 0 | ); |
804 | 0 | } |
805 | | |
806 | 0 | llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { |
807 | 0 | const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); |
808 | 0 | const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); |
809 | 0 | const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); |
810 | 0 | const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); |
811 | | |
812 | | // If a backend sampler has already sampled a token, return it. |
813 | 0 | if (sampled_token != LLAMA_TOKEN_NULL) { |
814 | 0 | LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); |
815 | 0 | return sampled_token; |
816 | 0 | } |
817 | | |
818 | 0 | const llama_model * model = llama_get_model(ctx); |
819 | 0 | const llama_vocab * vocab = llama_model_get_vocab(model); |
820 | |
|
821 | 0 | const int n_vocab = llama_vocab_n_tokens(vocab); |
822 | | |
823 | | // use pre-allocated buffer from chain if available, otherwise allocate locally |
824 | 0 | std::vector<llama_token_data> * cur_ptr; |
825 | 0 | std::vector<llama_token_data> cur_local; |
826 | |
|
827 | 0 | if (smpl->iface == &llama_sampler_chain_i) { |
828 | 0 | auto * chain = (llama_sampler_chain *) smpl->ctx; |
829 | 0 | cur_ptr = &chain->cur; |
830 | 0 | } else { |
831 | 0 | cur_ptr = &cur_local; |
832 | 0 | } |
833 | |
|
834 | 0 | auto & cur = *cur_ptr; |
835 | |
|
836 | 0 | if (sampled_probs) { |
837 | 0 | const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); |
838 | 0 | cur.resize(sampled_probs_count); |
839 | 0 | for (uint32_t i = 0; i < sampled_probs_count; ++i) { |
840 | 0 | cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; |
841 | 0 | } |
842 | 0 | } else if (sampled_logits) { |
843 | 0 | const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); |
844 | 0 | cur.resize(sampled_logits_count); |
845 | 0 | for (llama_token i = 0; i < (int)sampled_logits_count; i++) { |
846 | 0 | cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; |
847 | 0 | } |
848 | 0 | } else { |
849 | 0 | const auto * logits = llama_get_logits_ith(ctx, idx); |
850 | 0 | GGML_ASSERT(logits != nullptr); |
851 | 0 | cur.resize(n_vocab); |
852 | 0 | for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
853 | 0 | cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; |
854 | 0 | } |
855 | 0 | } |
856 | |
|
857 | 0 | llama_token_data_array cur_p = { |
858 | 0 | /* .data = */ cur.data(), |
859 | 0 | /* .size = */ cur.size(), |
860 | 0 | /* .selected = */ -1, |
861 | 0 | /* .sorted = */ false, |
862 | 0 | }; |
863 | |
|
864 | 0 | llama_sampler_apply(smpl, &cur_p); |
865 | |
|
866 | 0 | GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); |
867 | |
|
868 | 0 | auto token = cur_p.data[cur_p.selected].id; |
869 | |
|
870 | 0 | llama_sampler_accept(smpl, token); |
871 | |
|
872 | 0 | return token; |
873 | 0 | } |
874 | | |
875 | | |
876 | 0 | void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { |
877 | 0 | auto * p = (llama_sampler_chain *) chain->ctx; |
878 | 0 | p->samplers.push_back({ |
879 | 0 | /* .is_backend = */ false, |
880 | 0 | /* .ptr = */ smpl, |
881 | 0 | }); |
882 | 0 | } |
883 | | |
884 | 0 | struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { |
885 | 0 | if (chain == nullptr) { |
886 | 0 | return nullptr; |
887 | 0 | } |
888 | | |
889 | 0 | if (chain->iface != &llama_sampler_chain_i) { |
890 | 0 | return nullptr; |
891 | 0 | } |
892 | | |
893 | 0 | if (i == -1) { |
894 | 0 | return chain; |
895 | 0 | } |
896 | | |
897 | 0 | const auto * p = (const llama_sampler_chain *) chain->ctx; |
898 | |
|
899 | 0 | if (i < 0 || (size_t) i >= p->samplers.size()) { |
900 | 0 | return nullptr; |
901 | 0 | } |
902 | | |
903 | 0 | return p->samplers[i].ptr; |
904 | 0 | } |
905 | | |
906 | 0 | struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { |
907 | 0 | auto * p = (llama_sampler_chain *) chain->ctx; |
908 | |
|
909 | 0 | if (i < 0 || (size_t) i >= p->samplers.size()) { |
910 | 0 | return nullptr; |
911 | 0 | } |
912 | | |
913 | 0 | auto * result = p->samplers[i].ptr; |
914 | 0 | p->samplers.erase(p->samplers.begin() + i); |
915 | |
|
916 | 0 | return result; |
917 | 0 | } |
918 | | |
919 | 0 | int llama_sampler_chain_n(const struct llama_sampler * chain) { |
920 | 0 | const auto * p = (const llama_sampler_chain *) chain->ctx; |
921 | |
|
922 | 0 | return p->samplers.size(); |
923 | 0 | } |
924 | | |
925 | | // |
926 | | // samplers |
927 | | // |
928 | | |
929 | | // greedy |
930 | | |
931 | | struct llama_sampler_greedy : public llama_sampler_backend { |
932 | | }; |
933 | | |
934 | 0 | static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { |
935 | 0 | auto * sctx = (llama_sampler_greedy *) smpl->ctx; |
936 | 0 | return sctx->get_name(); |
937 | 0 | } |
938 | | |
939 | 0 | static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { |
940 | 0 | auto * ctx = (llama_sampler_greedy *) smpl->ctx; |
941 | 0 | GGML_UNUSED(ctx); |
942 | 0 | } |
943 | | |
944 | 0 | static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { |
945 | 0 | const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; |
946 | 0 | auto * result = llama_sampler_init_greedy(); |
947 | | |
948 | | // copy the state |
949 | 0 | { |
950 | 0 | auto * result_ctx = (llama_sampler_greedy *) result->ctx; |
951 | |
|
952 | 0 | GGML_UNUSED(ctx); |
953 | 0 | GGML_UNUSED(result_ctx); |
954 | 0 | } |
955 | |
|
956 | 0 | return result; |
957 | 0 | } |
958 | | |
959 | 0 | static void llama_sampler_greedy_free(struct llama_sampler * smpl) { |
960 | 0 | delete (llama_sampler_greedy *) smpl->ctx; |
961 | 0 | } |
962 | | |
963 | 0 | static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { |
964 | 0 | cur_p->selected = 0; |
965 | 0 | for (size_t i = 1; i < cur_p->size; ++i) { |
966 | 0 | if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { |
967 | 0 | cur_p->selected = i; |
968 | 0 | } |
969 | 0 | } |
970 | 0 | } |
971 | | |
972 | | static bool llama_sampler_greedy_backend_init( |
973 | | struct llama_sampler * smpl, |
974 | 0 | ggml_backend_buffer_type_t buft) { |
975 | 0 | auto * sctx = (llama_sampler_greedy *) smpl->ctx; |
976 | |
|
977 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
978 | |
|
979 | 0 | sctx->init(res); |
980 | |
|
981 | 0 | return res; |
982 | 0 | } |
983 | | |
984 | | static void llama_sampler_greedy_backend_apply( |
985 | | struct llama_sampler * smpl, |
986 | | struct ggml_context * ctx, |
987 | | struct ggml_cgraph * gf, |
988 | 0 | struct llama_sampler_data * data) { |
989 | 0 | GGML_UNUSED(gf); |
990 | 0 | GGML_UNUSED(smpl); |
991 | |
|
992 | 0 | struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); |
993 | 0 | ggml_set_name(curl, "greedy_argmax"); |
994 | |
|
995 | 0 | data->sampled = curl; |
996 | 0 | } |
997 | | |
998 | | static struct llama_sampler_i llama_sampler_greedy_i = { |
999 | | /* .name = */ llama_sampler_greedy_name, |
1000 | | /* .accept = */ nullptr, |
1001 | | /* .apply = */ llama_sampler_greedy_apply, |
1002 | | /* .reset = */ llama_sampler_greedy_reset, |
1003 | | /* .clone = */ llama_sampler_greedy_clone, |
1004 | | /* .free = */ llama_sampler_greedy_free, |
1005 | | /* .backend_init = */ llama_sampler_greedy_backend_init, |
1006 | | /* .backend_accept = */ nullptr, |
1007 | | /* .backend_apply = */ llama_sampler_greedy_backend_apply, |
1008 | | /* .backend_set_input = */ nullptr, |
1009 | | }; |
1010 | | |
1011 | 0 | struct llama_sampler * llama_sampler_init_greedy() { |
1012 | 0 | return llama_sampler_init( |
1013 | 0 | /* .iface = */ &llama_sampler_greedy_i, |
1014 | 0 | /* .ctx = */ new llama_sampler_greedy { |
1015 | 0 | ("greedy"), |
1016 | 0 | } |
1017 | 0 | ); |
1018 | 0 | } |
1019 | | |
1020 | | // dist |
1021 | | |
1022 | | struct llama_sampler_dist : public llama_sampler_backend { |
1023 | | const uint32_t seed; |
1024 | | uint32_t seed_cur; |
1025 | | |
1026 | | std::mt19937 rng; |
1027 | | |
1028 | | ggml_tensor * inp_uniform; |
1029 | | }; |
1030 | | |
1031 | 0 | static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { |
1032 | 0 | auto * sctx = (llama_sampler_dist *) smpl->ctx; |
1033 | 0 | return sctx->get_name(); |
1034 | 0 | } |
1035 | | |
1036 | 0 | static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1037 | 0 | auto * ctx = (llama_sampler_dist *) smpl->ctx; |
1038 | | |
1039 | | // edge cases |
1040 | 0 | if (cur_p->size == 0) { |
1041 | 0 | cur_p->selected = -1; |
1042 | 0 | return; |
1043 | 0 | } |
1044 | | |
1045 | 0 | cur_p->selected = 0; |
1046 | |
|
1047 | 0 | if (cur_p->size == 1) { |
1048 | 0 | cur_p->data[0].p = 1.0f; |
1049 | 0 | return; |
1050 | 0 | } |
1051 | | |
1052 | | // max logit for numerical stability |
1053 | 0 | float max_l = cur_p->data[0].logit; |
1054 | 0 | if (!cur_p->sorted) { |
1055 | 0 | for (size_t i = 1; i < cur_p->size; ++i) { |
1056 | 0 | max_l = std::max(max_l, cur_p->data[i].logit); |
1057 | 0 | } |
1058 | 0 | } |
1059 | | |
1060 | | // apply softmax to obtain the probabilities |
1061 | 0 | double sum_cum = 0.0f; |
1062 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1063 | 0 | float p = expf(cur_p->data[i].logit - max_l); |
1064 | 0 | cur_p->data[i].p = p; |
1065 | 0 | sum_cum += p; |
1066 | 0 | } |
1067 | |
|
1068 | 0 | #if 1 |
1069 | | // sample from the obtained probabilities and normalize the probs in a single pass |
1070 | | // this is ~3x faster on Mac with full gpt-oss vocab than the version below |
1071 | | // |
1072 | 0 | std::uniform_real_distribution<double> dist(0.0f, 1.0f); |
1073 | 0 | const double rnd = dist(ctx->rng); |
1074 | |
|
1075 | 0 | double sum_run = 0.0f; |
1076 | 0 | const double sum_tgt = sum_cum*rnd; |
1077 | |
|
1078 | 0 | bool found = false; |
1079 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1080 | 0 | if (!found) { |
1081 | | // accumulate probs until we reach the target sum |
1082 | 0 | sum_run += cur_p->data[i].p; |
1083 | 0 | if (sum_run >= sum_tgt) { |
1084 | 0 | cur_p->selected = i; |
1085 | 0 | found = true; |
1086 | 0 | } |
1087 | 0 | } |
1088 | | |
1089 | | // normalize probs |
1090 | 0 | cur_p->data[i].p /= sum_cum; |
1091 | 0 | } |
1092 | | |
1093 | | // fallback to the last token (don't think this can happen) |
1094 | 0 | assert(found); |
1095 | 0 | if (!found) { |
1096 | 0 | cur_p->selected = cur_p->size - 1; |
1097 | 0 | } |
1098 | | #else |
1099 | | // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling |
1100 | | for (size_t i = 0; i < cur_p->size; ++i) { |
1101 | | cur_p->data[i].p /= sum_cum; |
1102 | | } |
1103 | | |
1104 | | cur_p->selected = llama_sample_dist(cur_p, ctx->rng); |
1105 | | #endif |
1106 | 0 | } |
1107 | | |
1108 | 0 | static void llama_sampler_dist_reset(struct llama_sampler * smpl) { |
1109 | 0 | auto * ctx = (llama_sampler_dist *) smpl->ctx; |
1110 | 0 | ctx->seed_cur = get_rng_seed(ctx->seed); |
1111 | 0 | ctx->rng.seed(ctx->seed_cur); |
1112 | 0 | } |
1113 | | |
1114 | 0 | static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { |
1115 | 0 | const auto * ctx = (const llama_sampler_dist *) smpl->ctx; |
1116 | 0 | auto * result = llama_sampler_init_dist(ctx->seed); |
1117 | | |
1118 | | // copy the state |
1119 | 0 | { |
1120 | 0 | auto * result_ctx = (llama_sampler_dist *) result->ctx; |
1121 | |
|
1122 | 0 | result_ctx->rng = ctx->rng; |
1123 | 0 | } |
1124 | |
|
1125 | 0 | return result; |
1126 | 0 | } |
1127 | | |
1128 | 0 | static void llama_sampler_dist_free(struct llama_sampler * smpl) { |
1129 | 0 | delete (llama_sampler_dist *) smpl->ctx; |
1130 | 0 | } |
1131 | | |
1132 | | static bool llama_sampler_dist_backend_init( |
1133 | | struct llama_sampler * smpl, |
1134 | 0 | ggml_backend_buffer_type_t buft) { |
1135 | 0 | auto * sctx = (llama_sampler_dist *) smpl->ctx; |
1136 | |
|
1137 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1138 | |
|
1139 | 0 | sctx->init(res); |
1140 | |
|
1141 | 0 | return res; |
1142 | 0 | } |
1143 | | |
1144 | | static void llama_sampler_dist_backend_apply( |
1145 | | struct llama_sampler * smpl, |
1146 | | struct ggml_context * ctx, |
1147 | | struct ggml_cgraph * gf, |
1148 | 0 | struct llama_sampler_data * data) { |
1149 | 0 | GGML_UNUSED(gf); |
1150 | |
|
1151 | 0 | auto * sctx = (llama_sampler_dist *) smpl->ctx; |
1152 | |
|
1153 | 0 | sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); |
1154 | 0 | ggml_set_name (sctx->inp_uniform, "uniform"); |
1155 | 0 | ggml_set_input(sctx->inp_uniform); |
1156 | |
|
1157 | 0 | struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); |
1158 | 0 | ggml_set_name(probs, "dist_probs"); |
1159 | |
|
1160 | 0 | struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); |
1161 | 0 | ggml_set_name(cumsum, "dist_cumsum"); |
1162 | | |
1163 | | // The uniform tensor has a random value and we subtract this tensor with |
1164 | | // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). |
1165 | | // Recall that each entry in cumsum is the cumulative probability up to that |
1166 | | // index so values stay negative while the cumulative total is below the |
1167 | | // random value, and become zero/positive once the threshold is crossed. |
1168 | 0 | struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); |
1169 | 0 | ggml_set_name(diff, "dist_cumsum"); |
1170 | | |
1171 | | // The ggml_step function produces a tensor where entries are 1 if the |
1172 | | // corresponding entry in diff is > 0, and 0 otherwise. So all values up to |
1173 | | // the index where the cumulative probability exceeds the random value are 0, |
1174 | | // and all entries after that are 1. |
1175 | 0 | struct ggml_tensor * mask = ggml_step(ctx, diff); |
1176 | 0 | ggml_set_name(mask, "dist_mask"); |
1177 | | |
1178 | | // Taking the sum of the mask gives us the sum of elements after the threshold |
1179 | | // we are interested in. |
1180 | 0 | struct ggml_tensor * idxf = ggml_sum(ctx, mask); |
1181 | 0 | ggml_set_name(idxf, "dist_index_f32"); |
1182 | | |
1183 | | // Use ggml_scale_bias to scale the index value by -1 and then add the size |
1184 | | // of the mask to that value so we get the correct index ((-1 * idxf) + n). |
1185 | 0 | struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); |
1186 | 0 | ggml_set_name(idx, "dist_index_i32"); |
1187 | | |
1188 | | // Map back to original vocab ids if a candidates tensor is available. |
1189 | 0 | struct ggml_tensor * sampled_token = idx; |
1190 | 0 | if (data->candidates != nullptr) { |
1191 | 0 | struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); |
1192 | |
|
1193 | 0 | sampled_token = ggml_get_rows(ctx, candidates, idx); |
1194 | 0 | ggml_set_name(sampled_token, "dist_sampled_token"); |
1195 | 0 | } |
1196 | |
|
1197 | 0 | data->sampled = sampled_token; |
1198 | 0 | data->probs = probs; |
1199 | 0 | } |
1200 | | |
1201 | 0 | static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { |
1202 | 0 | auto * sctx = (llama_sampler_dist *) smpl->ctx; |
1203 | |
|
1204 | 0 | GGML_ASSERT(sctx->inp_uniform != nullptr); |
1205 | | |
1206 | | // We sample in double precision and cast to float to match rnd numbers of |
1207 | | // llama_dampler_dist which uses double precision (sampling from |
1208 | | // std::uniform_real_distribution<double> and |
1209 | | // std::uniform_real_distribution<float> with same rng will produce |
1210 | | // different sequences). |
1211 | 0 | std::uniform_real_distribution<double> dist(0.0f, 1.0f); |
1212 | 0 | const float rnd = dist(sctx->rng); |
1213 | |
|
1214 | 0 | ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); |
1215 | 0 | } |
1216 | | |
1217 | | static struct llama_sampler_i llama_sampler_dist_i = { |
1218 | | /* .name = */ llama_sampler_dist_name, |
1219 | | /* .accept = */ nullptr, |
1220 | | /* .apply = */ llama_sampler_dist_apply, |
1221 | | /* .reset = */ llama_sampler_dist_reset, |
1222 | | /* .clone = */ llama_sampler_dist_clone, |
1223 | | /* .free = */ llama_sampler_dist_free, |
1224 | | /* .backend_init = */ llama_sampler_dist_backend_init, |
1225 | | /* .backend_accept = */ nullptr, |
1226 | | /* .backend_apply = */ llama_sampler_dist_backend_apply, |
1227 | | /* .backend_set_input = */ llama_sampler_dist_backend_set_input, |
1228 | | }; |
1229 | | |
1230 | 0 | struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { |
1231 | 0 | auto seed_cur = get_rng_seed(seed); |
1232 | 0 | return llama_sampler_init( |
1233 | 0 | /* .iface = */ &llama_sampler_dist_i, |
1234 | 0 | /* .ctx = */ new llama_sampler_dist { |
1235 | 0 | ("dist"), |
1236 | 0 | /* .seed = */ seed, |
1237 | 0 | /* .seed_cur = */ seed_cur, |
1238 | 0 | /* .rng = */ std::mt19937(seed_cur), |
1239 | 0 | /* .inp_uniform = */ nullptr, |
1240 | 0 | } |
1241 | 0 | ); |
1242 | 0 | } |
1243 | | |
1244 | | // top-k |
1245 | | |
1246 | | struct llama_sampler_top_k : public llama_sampler_backend { |
1247 | | const int32_t k; |
1248 | | }; |
1249 | | |
1250 | 0 | static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { |
1251 | 0 | auto * sctx = (llama_sampler_top_k *) smpl->ctx; |
1252 | 0 | return sctx->get_name(); |
1253 | 0 | } |
1254 | | |
1255 | 0 | static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1256 | 0 | auto * ctx = (llama_sampler_top_k *) smpl->ctx; |
1257 | 0 | llama_sampler_top_k_impl(cur_p, ctx->k); |
1258 | 0 | } |
1259 | | |
1260 | 0 | static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { |
1261 | 0 | const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; |
1262 | 0 | return llama_sampler_init_top_k(ctx->k); |
1263 | 0 | } |
1264 | | |
1265 | 0 | static void llama_sampler_top_k_free(struct llama_sampler * smpl) { |
1266 | 0 | delete (llama_sampler_top_k *) smpl->ctx; |
1267 | 0 | } |
1268 | | |
1269 | | static bool llama_sampler_top_k_backend_init( |
1270 | | struct llama_sampler * smpl, |
1271 | 0 | ggml_backend_buffer_type_t buft) { |
1272 | 0 | auto * sctx = (llama_sampler_top_k *) smpl->ctx; |
1273 | |
|
1274 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1275 | |
|
1276 | 0 | sctx->init(res); |
1277 | |
|
1278 | 0 | return res; |
1279 | 0 | } |
1280 | | |
1281 | | static void llama_sampler_top_k_backend_apply( |
1282 | | struct llama_sampler * smpl, |
1283 | | struct ggml_context * ctx, |
1284 | | struct ggml_cgraph * gf, |
1285 | 0 | struct llama_sampler_data * data) { |
1286 | 0 | auto * sctx = (llama_sampler_top_k *) smpl->ctx; |
1287 | |
|
1288 | 0 | struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); |
1289 | 0 | ggml_set_name(top_k, "top_k"); |
1290 | |
|
1291 | 0 | if (data->candidates) { |
1292 | 0 | struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); |
1293 | 0 | data->candidates = ggml_get_rows(ctx, candidates_rows, top_k); |
1294 | 0 | data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k); |
1295 | 0 | ggml_set_name(data->candidates, "top_k_candidates"); |
1296 | 0 | } else { |
1297 | 0 | data->candidates = top_k; |
1298 | 0 | } |
1299 | |
|
1300 | 0 | struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); |
1301 | 0 | struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); |
1302 | 0 | data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); |
1303 | 0 | ggml_set_name(top_k_rows, "top_k_rows"); |
1304 | |
|
1305 | 0 | GGML_UNUSED(gf); |
1306 | 0 | } |
1307 | | |
1308 | | static struct llama_sampler_i llama_sampler_top_k_i = { |
1309 | | /* .name = */ llama_sampler_top_k_name, |
1310 | | /* .accept = */ nullptr, |
1311 | | /* .apply = */ llama_sampler_top_k_apply, |
1312 | | /* .reset = */ nullptr, |
1313 | | /* .clone = */ llama_sampler_top_k_clone, |
1314 | | /* .free = */ llama_sampler_top_k_free, |
1315 | | /* .backend_init = */ llama_sampler_top_k_backend_init, |
1316 | | /* .backend_accept = */ nullptr, |
1317 | | /* .backend_apply = */ llama_sampler_top_k_backend_apply, |
1318 | | /* .backend_set_input = */ nullptr, |
1319 | | }; |
1320 | | |
1321 | 0 | struct llama_sampler * llama_sampler_init_top_k(int32_t k) { |
1322 | 0 | const bool is_empty = (k <= 0); |
1323 | |
|
1324 | 0 | if (is_empty) { |
1325 | 0 | return llama_sampler_init_empty("?top-k"); |
1326 | 0 | } |
1327 | | |
1328 | 0 | return llama_sampler_init( |
1329 | 0 | /* .iface = */ &llama_sampler_top_k_i, |
1330 | 0 | /* .ctx = */ new llama_sampler_top_k { |
1331 | 0 | ("top-k"), |
1332 | 0 | /* .k = */ k, |
1333 | 0 | } |
1334 | 0 | ); |
1335 | 0 | } |
1336 | | |
1337 | | // top-p |
1338 | | |
1339 | | struct llama_sampler_top_p : public llama_sampler_backend { |
1340 | | const float p; |
1341 | | const size_t min_keep; |
1342 | | |
1343 | | std::vector<llama_token_data> buf_sort; |
1344 | | }; |
1345 | | |
1346 | 0 | static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { |
1347 | 0 | auto * sctx = (llama_sampler_top_p *) smpl->ctx; |
1348 | 0 | return sctx->get_name(); |
1349 | 0 | } |
1350 | | |
1351 | 0 | static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1352 | 0 | auto * ctx = (llama_sampler_top_p *) smpl->ctx; |
1353 | |
|
1354 | 0 | if (ctx->p >= 1.0f) { |
1355 | 0 | return; |
1356 | 0 | } |
1357 | | |
1358 | 0 | llama_sampler_softmax_impl(cur_p, false); |
1359 | |
|
1360 | 0 | size_t k = cur_p->size; |
1361 | 0 | auto * pdata = cur_p->data; |
1362 | |
|
1363 | 0 | auto & buf_sort = ctx->buf_sort; |
1364 | | |
1365 | | // if not sorted, try adaptive top-k sorting |
1366 | 0 | if (!cur_p->sorted && cur_p->size > 1024) { |
1367 | 0 | k = std::min<size_t>(256, cur_p->size); |
1368 | 0 | llama_token_data_array_partial_sort(*cur_p, k, buf_sort); |
1369 | 0 | pdata = buf_sort.data(); |
1370 | 0 | } else if (!cur_p->sorted) { |
1371 | | // small candidates -> sort inplace |
1372 | 0 | llama_token_data_array_partial_sort_inplace(cur_p, k); |
1373 | 0 | } |
1374 | | |
1375 | | // Compute the cumulative probabilities |
1376 | 0 | float cum_sum = 0.0f; |
1377 | 0 | size_t last_idx = cur_p->size; |
1378 | |
|
1379 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1380 | 0 | cum_sum += pdata[i].p; |
1381 | | |
1382 | | // Check if the running sum is at least p or if we have kept at least min_keep tokens |
1383 | | // we set the last index to i+1 to indicate that the current iterate should be included in the set |
1384 | 0 | if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) { |
1385 | 0 | last_idx = i + 1; |
1386 | 0 | break; |
1387 | 0 | } |
1388 | | |
1389 | | // we exceeded the current top-k heuristic -> increase k and continue |
1390 | 0 | if (!cur_p->sorted && i == k - 1) { |
1391 | 0 | k = cur_p->size; |
1392 | 0 | llama_token_data_array_partial_sort(*cur_p, k, buf_sort); |
1393 | 0 | pdata = buf_sort.data(); |
1394 | 0 | } |
1395 | 0 | } |
1396 | | |
1397 | | // Resize the output vector to keep only the top-p tokens |
1398 | 0 | if (!cur_p->sorted) { |
1399 | 0 | std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data); |
1400 | 0 | cur_p->sorted = true; |
1401 | 0 | } |
1402 | |
|
1403 | 0 | cur_p->size = last_idx; |
1404 | 0 | } |
1405 | | |
1406 | 0 | static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) { |
1407 | 0 | const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; |
1408 | 0 | return llama_sampler_init_top_p(ctx->p, ctx->min_keep); |
1409 | 0 | } |
1410 | | |
1411 | 0 | static void llama_sampler_top_p_free(struct llama_sampler * smpl) { |
1412 | 0 | delete (llama_sampler_top_p *) smpl->ctx; |
1413 | 0 | } |
1414 | | |
1415 | | static bool llama_sampler_top_p_backend_init( |
1416 | | struct llama_sampler * smpl, |
1417 | 0 | ggml_backend_buffer_type_t buft) { |
1418 | 0 | auto * sctx = (llama_sampler_top_p *) smpl->ctx; |
1419 | |
|
1420 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1421 | |
|
1422 | 0 | sctx->init(res); |
1423 | |
|
1424 | 0 | return res; |
1425 | 0 | } |
1426 | | |
1427 | | static void llama_sampler_top_p_backend_apply( |
1428 | | struct llama_sampler * smpl, |
1429 | | struct ggml_context * ctx, |
1430 | | struct ggml_cgraph * gf, |
1431 | 0 | struct llama_sampler_data * data) { |
1432 | 0 | auto * sctx = (llama_sampler_top_p *) smpl->ctx; |
1433 | |
|
1434 | 0 | auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) { |
1435 | 0 | GGML_ASSERT(ggml_nrows(a) == 1); |
1436 | 0 | struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); |
1437 | 0 | struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); |
1438 | 0 | return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); |
1439 | 0 | }; |
1440 | | |
1441 | | // Get the sorted logits in descending order. |
1442 | 0 | struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); |
1443 | 0 | ggml_set_name(sorted_idx, "top_p_sorted_idx"); |
1444 | | |
1445 | | // Do the sorting via reshape + get_rows |
1446 | 0 | struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); |
1447 | 0 | ggml_set_name(sorted_logits, "top_p_sorted_logits"); |
1448 | |
|
1449 | 0 | struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); |
1450 | 0 | ggml_set_name(softmax, "top_p_softmax"); |
1451 | | |
1452 | | // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. |
1453 | 0 | if (data->candidates) { |
1454 | 0 | data->candidates = ggml_sort(data->candidates, sorted_idx); |
1455 | 0 | } else { |
1456 | 0 | data->candidates = sorted_idx; |
1457 | 0 | } |
1458 | 0 | ggml_set_name(data->candidates, "top_p_candidates"); |
1459 | | |
1460 | | // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. |
1461 | 0 | struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); |
1462 | 0 | ggml_set_name(cdf, "top_p_cdf"); |
1463 | | |
1464 | | // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep |
1465 | 0 | struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); |
1466 | 0 | ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); |
1467 | |
|
1468 | 0 | struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); |
1469 | 0 | ggml_set_name(mask, "top_p_mask"); |
1470 | | |
1471 | | // Taking the sum of the mask gives us the sum of elements after the threshold |
1472 | | // we are interested in. |
1473 | 0 | struct ggml_tensor * idxf = ggml_sum(ctx, mask); |
1474 | 0 | ggml_set_name(idxf, "top_p_index_f32"); |
1475 | | |
1476 | | // prevent out-of-bounds access |
1477 | 0 | idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); |
1478 | | |
1479 | | // construct ones tensor to set the value in the mask |
1480 | 0 | struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); |
1481 | 0 | ggml_set_name(ones, "top_p_ones"); |
1482 | | |
1483 | | // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) |
1484 | 0 | struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); |
1485 | |
|
1486 | 0 | mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); |
1487 | 0 | mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); |
1488 | | |
1489 | | // Apply -INFINITY bias for masked-out tokens |
1490 | | // log(1) = 0 (keep), log(0) = -INF (discard) |
1491 | 0 | struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); |
1492 | 0 | ggml_set_name(top_p_bias, "top_p_bias"); |
1493 | |
|
1494 | 0 | data->logits = ggml_add(ctx, sorted_logits, top_p_bias); |
1495 | 0 | ggml_set_name(data->logits, "top_p_logits"); |
1496 | |
|
1497 | 0 | GGML_UNUSED(gf); |
1498 | 0 | } |
1499 | | |
1500 | | static struct llama_sampler_i llama_sampler_top_p_i = { |
1501 | | /* .name = */ llama_sampler_top_p_name, |
1502 | | /* .accept = */ nullptr, |
1503 | | /* .apply = */ llama_sampler_top_p_apply, |
1504 | | /* .reset = */ nullptr, |
1505 | | /* .clone = */ llama_sampler_top_p_clone, |
1506 | | /* .free = */ llama_sampler_top_p_free, |
1507 | | /* .backend_init = */ llama_sampler_top_p_backend_init, |
1508 | | /* .backend_accept = */ nullptr, |
1509 | | /* .backend_apply = */ llama_sampler_top_p_backend_apply, |
1510 | | /* .backend_set_input = */ nullptr, |
1511 | | }; |
1512 | | |
1513 | 0 | struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { |
1514 | 0 | const bool is_empty = p >= 1.0f; |
1515 | |
|
1516 | 0 | if (is_empty) { |
1517 | 0 | return llama_sampler_init_empty("?top-p"); |
1518 | 0 | } |
1519 | | |
1520 | 0 | return llama_sampler_init( |
1521 | 0 | /* .iface = */ &llama_sampler_top_p_i, |
1522 | 0 | /* .ctx = */ new llama_sampler_top_p { |
1523 | 0 | ("top-p"), |
1524 | 0 | /* .p = */ p, |
1525 | 0 | /* .min_keep = */ min_keep, |
1526 | 0 | /* .buf_sort = */ {}, |
1527 | 0 | } |
1528 | 0 | ); |
1529 | 0 | } |
1530 | | |
1531 | | // min-p |
1532 | | |
1533 | | struct llama_sampler_min_p : public llama_sampler_backend { |
1534 | | const float p; |
1535 | | const size_t min_keep; |
1536 | | }; |
1537 | | |
1538 | 0 | static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { |
1539 | 0 | auto * sctx = (llama_sampler_min_p *) smpl->ctx; |
1540 | 0 | return sctx->get_name(); |
1541 | 0 | } |
1542 | | |
1543 | 0 | static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1544 | 0 | auto * ctx = (llama_sampler_min_p *) smpl->ctx; |
1545 | |
|
1546 | 0 | if (ctx->p <= 0.0f || !cur_p->size) { |
1547 | 0 | return; |
1548 | 0 | } |
1549 | | |
1550 | 0 | bool min_p_applied = false; |
1551 | | |
1552 | | // if the cur_p aren't sorted, try the unsorted implementation first |
1553 | 0 | if (!cur_p->sorted) { |
1554 | 0 | std::vector<llama_token_data> filtered_tokens; |
1555 | |
|
1556 | 0 | float max_logit = -FLT_MAX; |
1557 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1558 | 0 | max_logit = std::max(max_logit, cur_p->data[i].logit); |
1559 | 0 | } |
1560 | 0 | const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max |
1561 | |
|
1562 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1563 | 0 | if (cur_p->data[i].logit >= min_logit) { |
1564 | 0 | filtered_tokens.push_back(cur_p->data[i]); |
1565 | 0 | } |
1566 | 0 | } |
1567 | | |
1568 | | // if we have enough values the operation was a success |
1569 | 0 | if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { |
1570 | 0 | std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data); |
1571 | 0 | cur_p->size = filtered_tokens.size(); |
1572 | 0 | min_p_applied = true; |
1573 | 0 | } |
1574 | 0 | } |
1575 | | |
1576 | | // if the cur_p are sorted or the unsorted implementation failed, use this implementation |
1577 | 0 | if (!min_p_applied) { |
1578 | | // Sort the logits in descending order |
1579 | 0 | if (!cur_p->sorted) { |
1580 | 0 | llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); |
1581 | 0 | } |
1582 | |
|
1583 | 0 | const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max |
1584 | 0 | size_t i = 1; // first token always matches |
1585 | |
|
1586 | 0 | for (; i < cur_p->size; ++i) { |
1587 | 0 | if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) { |
1588 | 0 | break; // prob too small |
1589 | 0 | } |
1590 | 0 | } |
1591 | | |
1592 | | // Resize the output vector to keep only the matching tokens |
1593 | 0 | cur_p->size = i; |
1594 | 0 | } |
1595 | 0 | } |
1596 | | |
1597 | 0 | static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) { |
1598 | 0 | const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; |
1599 | 0 | return llama_sampler_init_min_p(ctx->p, ctx->min_keep); |
1600 | 0 | } |
1601 | | |
1602 | 0 | static void llama_sampler_min_p_free(struct llama_sampler * smpl) { |
1603 | 0 | delete (llama_sampler_min_p *) smpl->ctx; |
1604 | 0 | } |
1605 | | |
1606 | | static bool llama_sampler_min_p_backend_init( |
1607 | | struct llama_sampler * smpl, |
1608 | 0 | ggml_backend_buffer_type_t buft) { |
1609 | 0 | auto * sctx = (llama_sampler_min_p *) smpl->ctx; |
1610 | |
|
1611 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1612 | |
|
1613 | 0 | sctx->init(res); |
1614 | |
|
1615 | 0 | return res; |
1616 | 0 | } |
1617 | | |
1618 | | static void llama_sampler_min_p_backend_apply( |
1619 | | struct llama_sampler * smpl, |
1620 | | struct ggml_context * ctx, |
1621 | | struct ggml_cgraph * gf, |
1622 | 0 | struct llama_sampler_data * data) { |
1623 | 0 | auto * sctx = (llama_sampler_min_p *) smpl->ctx; |
1624 | |
|
1625 | 0 | struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); |
1626 | 0 | ggml_set_name(max_idx, "max_idx"); |
1627 | |
|
1628 | 0 | struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); |
1629 | 0 | ggml_set_name(logits_rows, "logits_rows"); |
1630 | |
|
1631 | 0 | struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); |
1632 | 0 | ggml_set_name(max_logit, "max_logit"); |
1633 | | |
1634 | | // Calculate the threshold value. |
1635 | 0 | struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); |
1636 | 0 | ggml_set_name(threshold, "min_p_threshold"); |
1637 | | |
1638 | | // Subtract the threshold from logits. |
1639 | 0 | struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); |
1640 | | |
1641 | | // Create a mask where logits below the threshold are 0 (discard), |
1642 | | // and others are 1 (keep). |
1643 | 0 | struct ggml_tensor * mask = ggml_step(ctx, sub); |
1644 | 0 | ggml_set_name(mask, "min_p_mask"); |
1645 | | |
1646 | | // Apply -INFINITY bias for masked-out tokens |
1647 | | // log(1) = 0 (keep), log(0) = -INF (discard) |
1648 | 0 | struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); |
1649 | 0 | ggml_set_name(min_p_bias, "min_p_bias"); |
1650 | |
|
1651 | 0 | data->logits = ggml_add(ctx, data->logits, min_p_bias); |
1652 | 0 | ggml_set_name(data->logits, "min_p_logits"); |
1653 | |
|
1654 | 0 | GGML_UNUSED(gf); |
1655 | 0 | } |
1656 | | |
1657 | | static struct llama_sampler_i llama_sampler_min_p_i = { |
1658 | | /* .name = */ llama_sampler_min_p_name, |
1659 | | /* .accept = */ nullptr, |
1660 | | /* .apply = */ llama_sampler_min_p_apply, |
1661 | | /* .reset = */ nullptr, |
1662 | | /* .clone = */ llama_sampler_min_p_clone, |
1663 | | /* .free = */ llama_sampler_min_p_free, |
1664 | | /* .backend_init = */ llama_sampler_min_p_backend_init, |
1665 | | /* .backend_accept = */ nullptr, |
1666 | | /* .backend_apply = */ llama_sampler_min_p_backend_apply, |
1667 | | /* .backend_set_input = */ nullptr, |
1668 | | }; |
1669 | | |
1670 | 0 | struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { |
1671 | 0 | const bool is_empty = (p <= 0.0f); |
1672 | |
|
1673 | 0 | if (is_empty) { |
1674 | 0 | return llama_sampler_init_empty("?min-p"); |
1675 | 0 | } |
1676 | | |
1677 | 0 | return llama_sampler_init( |
1678 | 0 | /* .iface = */ &llama_sampler_min_p_i, |
1679 | 0 | /* .ctx = */ new llama_sampler_min_p { |
1680 | 0 | ("min-p"), |
1681 | 0 | /* .p = */ p, |
1682 | 0 | /* .min_keep = */ min_keep, |
1683 | 0 | } |
1684 | 0 | ); |
1685 | 0 | } |
1686 | | |
1687 | | // typical |
1688 | | |
1689 | | struct llama_sampler_typical { |
1690 | | const float p; |
1691 | | const size_t min_keep; |
1692 | | }; |
1693 | | |
1694 | 0 | static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) { |
1695 | 0 | return "typical"; |
1696 | 0 | } |
1697 | | |
1698 | 0 | static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1699 | 0 | auto * ctx = (llama_sampler_typical *) smpl->ctx; |
1700 | | |
1701 | | // Reference implementation: |
1702 | | // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr |
1703 | 0 | if (ctx->p >= 1.0f) { |
1704 | 0 | return; |
1705 | 0 | } |
1706 | | |
1707 | | // Compute the softmax of logits and calculate entropy |
1708 | 0 | llama_sampler_softmax_impl(cur_p, true); |
1709 | |
|
1710 | 0 | float entropy = 0.0f; |
1711 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1712 | 0 | entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); |
1713 | 0 | } |
1714 | | |
1715 | | // Compute the absolute difference between negative log probability and entropy for each candidate |
1716 | 0 | std::vector<float> shifted_scores; |
1717 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1718 | 0 | float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); |
1719 | 0 | shifted_scores.push_back(shifted_score); |
1720 | 0 | } |
1721 | | |
1722 | | // Sort tokens based on the shifted_scores and their corresponding indices |
1723 | 0 | std::vector<size_t> indices(cur_p->size); |
1724 | 0 | std::iota(indices.begin(), indices.end(), 0); |
1725 | |
|
1726 | 0 | std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { |
1727 | 0 | return shifted_scores[a] < shifted_scores[b]; |
1728 | 0 | }); |
1729 | | |
1730 | | // Compute the cumulative probabilities |
1731 | 0 | float cum_sum = 0.0f; |
1732 | 0 | size_t last_idx = indices.size(); |
1733 | |
|
1734 | 0 | for (size_t i = 0; i < indices.size(); ++i) { |
1735 | 0 | size_t idx = indices[i]; |
1736 | 0 | cum_sum += cur_p->data[idx].p; |
1737 | | |
1738 | | // Check if the running sum is greater than typical or if we have kept at least min_keep tokens |
1739 | 0 | if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { |
1740 | 0 | last_idx = i + 1; |
1741 | 0 | break; |
1742 | 0 | } |
1743 | 0 | } |
1744 | | |
1745 | | // Resize the output vector to keep only the locally typical tokens |
1746 | 0 | std::vector<llama_token_data> cur_p_new; |
1747 | 0 | for (size_t i = 0; i < last_idx; ++i) { |
1748 | 0 | size_t idx = indices[i]; |
1749 | 0 | cur_p_new.push_back(cur_p->data[idx]); |
1750 | 0 | } |
1751 | | |
1752 | | // Replace the data in cur_p with the cur_p_new data |
1753 | 0 | std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); |
1754 | 0 | cur_p->size = cur_p_new.size(); |
1755 | 0 | cur_p->sorted = false; |
1756 | 0 | } |
1757 | | |
1758 | 0 | static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) { |
1759 | 0 | const auto * ctx = (const llama_sampler_typical *) smpl->ctx; |
1760 | 0 | return llama_sampler_init_typical(ctx->p, ctx->min_keep); |
1761 | 0 | } |
1762 | | |
1763 | 0 | static void llama_sampler_typical_free(struct llama_sampler * smpl) { |
1764 | 0 | delete (llama_sampler_typical *) smpl->ctx; |
1765 | 0 | } |
1766 | | |
1767 | | static struct llama_sampler_i llama_sampler_typical_i = { |
1768 | | /* .name = */ llama_sampler_typical_name, |
1769 | | /* .accept = */ nullptr, |
1770 | | /* .apply = */ llama_sampler_typical_apply, |
1771 | | /* .reset = */ nullptr, |
1772 | | /* .clone = */ llama_sampler_typical_clone, |
1773 | | /* .free = */ llama_sampler_typical_free, |
1774 | | /* .backend_init = */ nullptr, |
1775 | | /* .backend_accept = */ nullptr, |
1776 | | /* .backend_apply = */ nullptr, |
1777 | | /* .backend_set_input = */ nullptr, |
1778 | | }; |
1779 | | |
1780 | 0 | struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { |
1781 | 0 | const bool is_empty = (p >= 1.0f); |
1782 | |
|
1783 | 0 | if (is_empty) { |
1784 | 0 | return llama_sampler_init_empty("?typical"); |
1785 | 0 | } |
1786 | | |
1787 | 0 | return llama_sampler_init( |
1788 | 0 | /* .iface = */ &llama_sampler_typical_i, |
1789 | 0 | /* .ctx = */ new llama_sampler_typical { |
1790 | 0 | /* .p = */ p, |
1791 | 0 | /* .min_keep = */ min_keep, |
1792 | 0 | } |
1793 | 0 | ); |
1794 | 0 | } |
1795 | | |
1796 | | // temp |
1797 | | |
1798 | | struct llama_sampler_temp : public llama_sampler_backend { |
1799 | | const float temp; |
1800 | | }; |
1801 | | |
1802 | 0 | static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { |
1803 | 0 | auto * sctx = (llama_sampler_temp *) smpl->ctx; |
1804 | 0 | return sctx->get_name(); |
1805 | 0 | } |
1806 | | |
1807 | 0 | static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1808 | 0 | const auto * ctx = (llama_sampler_temp *) smpl->ctx; |
1809 | |
|
1810 | 0 | llama_sampler_temp_impl(cur_p, ctx->temp); |
1811 | 0 | } |
1812 | | |
1813 | 0 | static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { |
1814 | 0 | const auto * ctx = (const llama_sampler_temp *) smpl->ctx; |
1815 | 0 | return llama_sampler_init_temp(ctx->temp); |
1816 | 0 | } |
1817 | | |
1818 | 0 | static void llama_sampler_temp_free(struct llama_sampler * smpl) { |
1819 | 0 | delete (llama_sampler_temp *) smpl->ctx; |
1820 | 0 | } |
1821 | | |
1822 | | static void llama_sampler_backend_temp_sampling( |
1823 | | struct ggml_context * ctx, |
1824 | | struct ggml_cgraph * gf, |
1825 | | struct llama_sampler_data * data, |
1826 | 0 | float temp) { |
1827 | 0 | if (temp <= 0.0f) { |
1828 | | // Find the most probable token index. |
1829 | 0 | struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); |
1830 | 0 | ggml_set_name(max_idx, "temp_max_idx"); |
1831 | |
|
1832 | 0 | if (data->candidates) { |
1833 | 0 | struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); |
1834 | 0 | data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx); |
1835 | 0 | } else { |
1836 | 0 | data->candidates = max_idx; |
1837 | 0 | } |
1838 | |
|
1839 | 0 | struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); |
1840 | 0 | data->logits = ggml_get_rows(ctx, logits_rows, max_idx); |
1841 | |
|
1842 | 0 | return; |
1843 | 0 | } |
1844 | | |
1845 | 0 | data->logits = ggml_scale(ctx, data->logits, 1.0f / temp); |
1846 | |
|
1847 | 0 | GGML_UNUSED(gf); |
1848 | 0 | } |
1849 | | |
1850 | | static bool llama_sampler_temp_backend_init( |
1851 | | struct llama_sampler * smpl, |
1852 | 0 | ggml_backend_buffer_type_t buft) { |
1853 | 0 | auto * sctx = (llama_sampler_temp *) smpl->ctx; |
1854 | |
|
1855 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1856 | |
|
1857 | 0 | sctx->init(res); |
1858 | |
|
1859 | 0 | return res; |
1860 | 0 | } |
1861 | | |
1862 | | static void llama_sampler_temp_backend_apply( |
1863 | | struct llama_sampler * smpl, |
1864 | | struct ggml_context * ctx, |
1865 | | struct ggml_cgraph * gf, |
1866 | 0 | struct llama_sampler_data * data) { |
1867 | 0 | auto * sctx = (llama_sampler_temp *) smpl->ctx; |
1868 | 0 | llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); |
1869 | 0 | } |
1870 | | |
1871 | | static struct llama_sampler_i llama_sampler_temp_i = { |
1872 | | /* .name = */ llama_sampler_temp_name, |
1873 | | /* .accept = */ nullptr, |
1874 | | /* .apply = */ llama_sampler_temp_apply, |
1875 | | /* .reset = */ nullptr, |
1876 | | /* .clone = */ llama_sampler_temp_clone, |
1877 | | /* .free = */ llama_sampler_temp_free, |
1878 | | /* .backend_init = */ llama_sampler_temp_backend_init, |
1879 | | /* .backend_accept = */ nullptr, |
1880 | | /* .backend_apply = */ llama_sampler_temp_backend_apply, |
1881 | | /* .backend_set_input = */ nullptr, |
1882 | | }; |
1883 | | |
1884 | 0 | struct llama_sampler * llama_sampler_init_temp(float temp) { |
1885 | 0 | const bool is_empty = temp == 1.0f; |
1886 | |
|
1887 | 0 | if (is_empty) { |
1888 | 0 | return llama_sampler_init_empty("?temp"); |
1889 | 0 | } |
1890 | | |
1891 | 0 | return llama_sampler_init( |
1892 | 0 | /* .iface = */ &llama_sampler_temp_i, |
1893 | 0 | /* .ctx = */ new llama_sampler_temp { |
1894 | 0 | ("temp"), |
1895 | 0 | /*.temp = */ temp, |
1896 | 0 | } |
1897 | 0 | ); |
1898 | 0 | } |
1899 | | |
1900 | | // temp-ext |
1901 | | |
1902 | | struct llama_sampler_temp_ext : public llama_sampler_backend { |
1903 | | const float temp; |
1904 | | const float delta; |
1905 | | const float exponent; |
1906 | | }; |
1907 | | |
1908 | 0 | static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { |
1909 | 0 | auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; |
1910 | 0 | return sctx->get_name(); |
1911 | 0 | } |
1912 | | |
1913 | 0 | static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
1914 | 0 | auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; |
1915 | 0 | if (ctx->delta > 0) { |
1916 | 0 | const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); |
1917 | 0 | const float max_temp = ctx->temp + ctx->delta; |
1918 | |
|
1919 | 0 | float exponent_val = ctx->exponent; |
1920 | | |
1921 | | // no need to do anything if there is only one (or zero) candidates |
1922 | 0 | if (cur_p->size <= 1) { |
1923 | 0 | return; |
1924 | 0 | } |
1925 | | |
1926 | | // Calculate maximum possible entropy |
1927 | 0 | float max_entropy = -logf(1.0f / cur_p->size); |
1928 | |
|
1929 | 0 | llama_sampler_softmax_impl(cur_p, true); |
1930 | | |
1931 | | // Calculate entropy of the softmax probabilities |
1932 | 0 | float entropy = 0.0f; |
1933 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1934 | 0 | float prob = cur_p->data[i].p; |
1935 | 0 | if (prob > 0.0f) { // Ensure no log(0) |
1936 | 0 | entropy -= prob * logf(prob); |
1937 | 0 | } |
1938 | 0 | } |
1939 | | |
1940 | | // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) |
1941 | 0 | float normalized_entropy = entropy / max_entropy; |
1942 | | |
1943 | | // Map the normalized entropy to the desired temperature range using the power function |
1944 | 0 | float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); |
1945 | |
|
1946 | | #ifdef DEBUG |
1947 | | LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); |
1948 | | LLAMA_LOG_INFO("Entropy: %f\n", entropy); |
1949 | | LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); |
1950 | | LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); |
1951 | | LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); |
1952 | | LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); |
1953 | | #endif |
1954 | | |
1955 | | // Apply the dynamically calculated temperature scaling |
1956 | 0 | llama_sampler_temp_impl(cur_p, dyn_temp); |
1957 | | |
1958 | | // Re-compute softmax probabilities after scaling logits with dynamic temperature |
1959 | 0 | const double max_l_double = cur_p->data[0].logit; |
1960 | |
|
1961 | 0 | double cum_sum_double = 0.0; |
1962 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1963 | 0 | double p = exp(cur_p->data[i].logit - max_l_double); |
1964 | 0 | cur_p->data[i].p = p; // Store the scaled probability |
1965 | 0 | cum_sum_double += p; |
1966 | 0 | } |
1967 | |
|
1968 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
1969 | 0 | cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities |
1970 | 0 | } |
1971 | |
|
1972 | | #ifdef DEBUG |
1973 | | // Print the updated top 25 probabilities after temperature scaling |
1974 | | LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); |
1975 | | for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { |
1976 | | LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); |
1977 | | } |
1978 | | #endif |
1979 | 0 | } else { |
1980 | 0 | llama_sampler_temp_impl(cur_p, ctx->temp); |
1981 | 0 | } |
1982 | 0 | } |
1983 | | |
1984 | 0 | static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) { |
1985 | 0 | const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; |
1986 | 0 | return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); |
1987 | 0 | } |
1988 | | |
1989 | 0 | static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { |
1990 | 0 | delete (llama_sampler_temp_ext *) smpl->ctx; |
1991 | 0 | } |
1992 | | |
1993 | | static bool llama_sampler_temp_ext_backend_init( |
1994 | | struct llama_sampler * smpl, |
1995 | 0 | ggml_backend_buffer_type_t buft) { |
1996 | 0 | auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; |
1997 | |
|
1998 | 0 | const bool res = llama_sampler_backend_support(smpl, buft); |
1999 | |
|
2000 | 0 | sctx->init(res); |
2001 | |
|
2002 | 0 | return res; |
2003 | 0 | } |
2004 | | |
2005 | | static void llama_sampler_temp_ext_backend_apply( |
2006 | | struct llama_sampler * smpl, |
2007 | | struct ggml_context * ctx, |
2008 | | struct ggml_cgraph * gf, |
2009 | 0 | struct llama_sampler_data * data) { |
2010 | 0 | auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; |
2011 | | |
2012 | | // Revert to standard temperature scaling if delta or temp are non-positive. |
2013 | 0 | if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { |
2014 | 0 | llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); |
2015 | 0 | return; |
2016 | 0 | } |
2017 | | |
2018 | | // Calculate min_temp, max_temp, and max_entropy. |
2019 | 0 | const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); |
2020 | 0 | const float max_temp = sctx->temp + sctx->delta; |
2021 | 0 | const float max_entropy = logf(data->logits->ne[0]); |
2022 | | |
2023 | | // Calculate the probabilities. |
2024 | 0 | struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); |
2025 | 0 | ggml_set_name(probs, "temp_ext_softmax_probs"); |
2026 | | |
2027 | | // Clamp probabilities to avoid log(0) which would give -inf |
2028 | 0 | struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); |
2029 | 0 | ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); |
2030 | | |
2031 | | // Calculate the entropy, entropy = -Σ(p * log(p)). |
2032 | 0 | struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); |
2033 | 0 | struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); |
2034 | 0 | struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); |
2035 | 0 | struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); |
2036 | 0 | ggml_set_name(log_probs, "temp_ext_log_probs"); |
2037 | 0 | ggml_set_name(p_log_p, "temp_ext_p_log_p"); |
2038 | 0 | ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); |
2039 | 0 | ggml_set_name(entropy, "temp_ext_entropy"); |
2040 | | |
2041 | | // Normalize the entropy, norm_entropy = entropy / max_entropy |
2042 | 0 | struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); |
2043 | 0 | ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); |
2044 | | |
2045 | | // Calculate the dynamic temperature: |
2046 | | // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); |
2047 | | // |
2048 | | // Calculate powf(normalized_entropy, exponent) as |
2049 | | // norm_entropy^exponent = exp(exponent * log(norm_entropy)) |
2050 | 0 | struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); |
2051 | 0 | struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); |
2052 | 0 | struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); |
2053 | | // With pow_entropy computed we can now compute dyn_temp, scaling by |
2054 | | // (max_temp - min_temp) and then adding min_temp. |
2055 | 0 | struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); |
2056 | 0 | ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); |
2057 | 0 | ggml_set_name(scaled_log, "temp_ext_scaled_log"); |
2058 | 0 | ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); |
2059 | 0 | ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); |
2060 | | |
2061 | | // Scale the logits by the dynamic temperature |
2062 | 0 | struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); |
2063 | 0 | ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); |
2064 | |
|
2065 | 0 | data->logits = scaled_logits; |
2066 | 0 | } |
2067 | | |
2068 | | static struct llama_sampler_i llama_sampler_temp_ext_i = { |
2069 | | /* .name = */ llama_sampler_temp_ext_name, |
2070 | | /* .accept = */ nullptr, |
2071 | | /* .apply = */ llama_sampler_temp_ext_apply, |
2072 | | /* .reset = */ nullptr, |
2073 | | /* .clone = */ llama_sampler_temp_ext_clone, |
2074 | | /* .free = */ llama_sampler_temp_ext_free, |
2075 | | /* .backend_init = */ llama_sampler_temp_ext_backend_init, |
2076 | | /* .backend_accept = */ nullptr, |
2077 | | /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, |
2078 | | /* .backend_set_input = */ nullptr, |
2079 | | }; |
2080 | | |
2081 | 0 | struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { |
2082 | 0 | const bool is_empty = temp == 1.0f && delta <= 0.0f; |
2083 | |
|
2084 | 0 | if (is_empty) { |
2085 | 0 | return llama_sampler_init_empty("?temp-ext"); |
2086 | 0 | } |
2087 | | |
2088 | 0 | auto * res = llama_sampler_init( |
2089 | 0 | /* .iface = */ &llama_sampler_temp_ext_i, |
2090 | 0 | /* .ctx = */ new llama_sampler_temp_ext { |
2091 | 0 | ("temp-ext"), |
2092 | 0 | /* .temp = */ temp, |
2093 | 0 | /* .delta = */ delta, |
2094 | 0 | /* .exponent = */ exponent, |
2095 | 0 | } |
2096 | 0 | ); |
2097 | |
|
2098 | 0 | return res; |
2099 | 0 | } |
2100 | | |
2101 | | // xtc |
2102 | | |
2103 | | struct llama_sampler_xtc { |
2104 | | const float probability; |
2105 | | const float threshold; |
2106 | | const size_t min_keep; |
2107 | | |
2108 | | const uint32_t seed; |
2109 | | uint32_t seed_cur; |
2110 | | |
2111 | | std::mt19937 rng; |
2112 | | }; |
2113 | | |
2114 | 0 | static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { |
2115 | 0 | return "xtc"; |
2116 | 0 | } |
2117 | | |
2118 | 0 | static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2119 | 0 | auto * ctx = (llama_sampler_xtc *) smpl->ctx; |
2120 | |
|
2121 | 0 | if (ctx->probability <= 0.0f |
2122 | 0 | || ctx->threshold > 0.5f |
2123 | 0 | || cur_p->size < 2) { |
2124 | 0 | return; |
2125 | 0 | } |
2126 | | |
2127 | 0 | std::uniform_real_distribution<float> distribution(0.0f, 1.0f); |
2128 | 0 | float chance = distribution(ctx->rng); |
2129 | 0 | if (chance > ctx->probability) { |
2130 | 0 | return; |
2131 | 0 | } |
2132 | | |
2133 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2134 | |
|
2135 | 0 | int pos_last = 0; |
2136 | |
|
2137 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
2138 | 0 | if (cur_p->data[i].p >= ctx->threshold) { |
2139 | 0 | pos_last = i; |
2140 | 0 | } else { |
2141 | 0 | break; |
2142 | 0 | } |
2143 | 0 | } |
2144 | |
|
2145 | 0 | if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { |
2146 | 0 | cur_p->data += pos_last; |
2147 | 0 | cur_p->size -= pos_last; |
2148 | 0 | } |
2149 | 0 | } |
2150 | | |
2151 | 0 | static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { |
2152 | 0 | const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; |
2153 | 0 | auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); |
2154 | | |
2155 | | // copy the state |
2156 | 0 | { |
2157 | 0 | auto * result_ctx = (llama_sampler_xtc *) result->ctx; |
2158 | |
|
2159 | 0 | result_ctx->rng = ctx->rng; |
2160 | 0 | } |
2161 | |
|
2162 | 0 | return result; |
2163 | 0 | } |
2164 | | |
2165 | 0 | static void llama_sampler_xtc_free(struct llama_sampler * smpl) { |
2166 | 0 | delete (llama_sampler_xtc *) smpl->ctx; |
2167 | 0 | } |
2168 | | |
2169 | 0 | static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { |
2170 | 0 | auto * ctx = (llama_sampler_xtc *) smpl->ctx; |
2171 | 0 | ctx->seed_cur = get_rng_seed(ctx->seed); |
2172 | 0 | ctx->rng.seed(ctx->seed_cur); |
2173 | 0 | } |
2174 | | |
2175 | | static struct llama_sampler_i llama_sampler_xtc_i = { |
2176 | | /* .name = */ llama_sampler_xtc_name, |
2177 | | /* .accept = */ nullptr, |
2178 | | /* .apply = */ llama_sample_xtc_apply, |
2179 | | /* .reset = */ llama_sampler_xtc_reset, |
2180 | | /* .clone = */ llama_sampler_xtc_clone, |
2181 | | /* .free = */ llama_sampler_xtc_free, |
2182 | | /* .backend_init = */ nullptr, |
2183 | | /* .backend_accept = */ nullptr, |
2184 | | /* .backend_apply = */ nullptr, |
2185 | | /* .backend_set_input = */ nullptr, |
2186 | | }; |
2187 | | |
2188 | 0 | struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { |
2189 | 0 | const bool is_empty = (p <= 0.0f || t > 0.5f); |
2190 | |
|
2191 | 0 | if (is_empty) { |
2192 | 0 | return llama_sampler_init_empty("?xtc"); |
2193 | 0 | } |
2194 | | |
2195 | 0 | const auto seed_cur = get_rng_seed(seed); |
2196 | |
|
2197 | 0 | return llama_sampler_init( |
2198 | 0 | /* .iface = */ &llama_sampler_xtc_i, |
2199 | 0 | /* .ctx = */ new llama_sampler_xtc { |
2200 | 0 | /* .probability = */ p, |
2201 | 0 | /* .threshold = */ t, |
2202 | 0 | /* .min_keep = */ min_keep, |
2203 | 0 | /* .seed = */ seed, |
2204 | 0 | /* .seed_cur = */ seed_cur, |
2205 | 0 | /* .rng = */ std::mt19937(seed_cur), |
2206 | 0 | } |
2207 | 0 | ); |
2208 | 0 | } |
2209 | | |
2210 | | // mirostat |
2211 | | |
2212 | | struct llama_sampler_mirostat { |
2213 | | const int32_t n_vocab; |
2214 | | |
2215 | | const uint32_t seed; |
2216 | | uint32_t seed_cur; |
2217 | | |
2218 | | const float tau; |
2219 | | const float eta; |
2220 | | |
2221 | | const int32_t m; |
2222 | | |
2223 | | float mu; |
2224 | | |
2225 | | std::mt19937 rng; |
2226 | | }; |
2227 | | |
2228 | 0 | static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { |
2229 | 0 | return "mirostat"; |
2230 | 0 | } |
2231 | | |
2232 | 0 | static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2233 | 0 | auto * ctx = (llama_sampler_mirostat *) smpl->ctx; |
2234 | |
|
2235 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2236 | | |
2237 | | // Estimate s_hat using the most probable m tokens |
2238 | 0 | float s_hat = 0.0; |
2239 | 0 | float sum_ti_bi = 0.0; |
2240 | 0 | float sum_ti_sq = 0.0; |
2241 | 0 | for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { |
2242 | 0 | float t_i = logf(float(i + 2) / float(i + 1)); |
2243 | 0 | float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); |
2244 | 0 | sum_ti_bi += t_i * b_i; |
2245 | 0 | sum_ti_sq += t_i * t_i; |
2246 | 0 | } |
2247 | 0 | s_hat = sum_ti_bi / sum_ti_sq; |
2248 | | |
2249 | | // Compute k from the estimated s_hat and target surprise value |
2250 | 0 | float epsilon_hat = s_hat - 1; |
2251 | 0 | float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); |
2252 | |
|
2253 | 0 | llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); |
2254 | |
|
2255 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2256 | |
|
2257 | 0 | const int idx = llama_sample_dist(cur_p, ctx->rng); |
2258 | |
|
2259 | 0 | cur_p->selected = idx; |
2260 | |
|
2261 | 0 | float observed_surprise = -log2f(cur_p->data[idx].p); |
2262 | 0 | float e = observed_surprise - ctx->tau; |
2263 | | |
2264 | | // Update mu using the learning rate and error |
2265 | 0 | ctx->mu = ctx->mu - ctx->eta * e; |
2266 | 0 | } |
2267 | | |
2268 | 0 | static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) { |
2269 | 0 | const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; |
2270 | 0 | auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); |
2271 | | |
2272 | | // copy the state |
2273 | 0 | { |
2274 | 0 | auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; |
2275 | |
|
2276 | 0 | result_ctx->mu = ctx->mu; |
2277 | 0 | result_ctx->rng = ctx->rng; |
2278 | 0 | } |
2279 | |
|
2280 | 0 | return result; |
2281 | 0 | } |
2282 | | |
2283 | 0 | static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { |
2284 | 0 | auto * ctx = (llama_sampler_mirostat *) smpl->ctx; |
2285 | 0 | ctx->mu = 2.0f*ctx->tau; |
2286 | 0 | ctx->seed_cur = get_rng_seed(ctx->seed); |
2287 | 0 | ctx->rng.seed(ctx->seed_cur); |
2288 | 0 | } |
2289 | | |
2290 | 0 | static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { |
2291 | 0 | delete (llama_sampler_mirostat *) smpl->ctx; |
2292 | 0 | } |
2293 | | |
2294 | | static struct llama_sampler_i llama_sampler_mirostat_i = { |
2295 | | /* .name = */ llama_sampler_mirostat_name, |
2296 | | /* .accept = */ nullptr, |
2297 | | /* .apply = */ llama_sampler_mirostat_apply, |
2298 | | /* .reset = */ llama_sampler_mirostat_reset, |
2299 | | /* .clone = */ llama_sampler_mirostat_clone, |
2300 | | /* .free = */ llama_sampler_mirostat_free, |
2301 | | /* .backend_init = */ nullptr, |
2302 | | /* .backend_accept = */ nullptr, |
2303 | | /* .backend_apply = */ nullptr, |
2304 | | /* .backend_set_input = */ nullptr, |
2305 | | }; |
2306 | | |
2307 | 0 | struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { |
2308 | 0 | const auto seed_cur = get_rng_seed(seed); |
2309 | |
|
2310 | 0 | return llama_sampler_init( |
2311 | 0 | /* .iface = */ &llama_sampler_mirostat_i, |
2312 | 0 | /* .ctx = */ new llama_sampler_mirostat { |
2313 | 0 | /* .n_vocab = */ n_vocab, |
2314 | 0 | /* .seed = */ seed, |
2315 | 0 | /* .seed_cur = */ seed_cur, |
2316 | 0 | /* .tau = */ tau, |
2317 | 0 | /* .eta = */ eta, |
2318 | 0 | /* .m = */ m, |
2319 | 0 | /* .mu = */ 2.0f*tau, |
2320 | 0 | /* .rng = */ std::mt19937(seed_cur), |
2321 | 0 | } |
2322 | 0 | ); |
2323 | 0 | } |
2324 | | |
2325 | | // mirostat v2 |
2326 | | |
2327 | | struct llama_sampler_mirostat_v2 { |
2328 | | const uint32_t seed; |
2329 | | uint32_t seed_cur; |
2330 | | |
2331 | | const float tau; |
2332 | | const float eta; |
2333 | | |
2334 | | float mu; |
2335 | | |
2336 | | std::mt19937 rng; |
2337 | | }; |
2338 | | |
2339 | 0 | static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { |
2340 | 0 | return "mirostat-v2"; |
2341 | 0 | } |
2342 | | |
2343 | 0 | static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2344 | 0 | auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; |
2345 | |
|
2346 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2347 | | |
2348 | | // Truncate the words with surprise values greater than mu |
2349 | 0 | cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { |
2350 | 0 | return -log2f(candidate.p) > ctx->mu; |
2351 | 0 | })); |
2352 | |
|
2353 | 0 | if (cur_p->size == 0) { |
2354 | 0 | cur_p->size = 1; |
2355 | 0 | } |
2356 | | |
2357 | | // Normalize the probabilities of the remaining words |
2358 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2359 | |
|
2360 | 0 | const int idx = llama_sample_dist(cur_p, ctx->rng); |
2361 | |
|
2362 | 0 | cur_p->selected = idx; |
2363 | |
|
2364 | 0 | float observed_surprise = -log2f(cur_p->data[idx].p); |
2365 | 0 | float e = observed_surprise - ctx->tau; |
2366 | | |
2367 | | // Update mu using the learning rate and error |
2368 | 0 | ctx->mu = ctx->mu - ctx->eta * e; |
2369 | 0 | } |
2370 | | |
2371 | 0 | static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { |
2372 | 0 | auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; |
2373 | 0 | ctx->mu = 2.0f*ctx->tau; |
2374 | 0 | ctx->seed_cur = get_rng_seed(ctx->seed); |
2375 | 0 | ctx->rng.seed(ctx->seed_cur); |
2376 | 0 | } |
2377 | | |
2378 | 0 | static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { |
2379 | 0 | const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; |
2380 | |
|
2381 | 0 | auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); |
2382 | | |
2383 | | // copy the state |
2384 | 0 | { |
2385 | 0 | auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; |
2386 | |
|
2387 | 0 | result_ctx->mu = ctx->mu; |
2388 | 0 | result_ctx->rng = ctx->rng; |
2389 | 0 | } |
2390 | |
|
2391 | 0 | return result; |
2392 | 0 | } |
2393 | | |
2394 | 0 | static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { |
2395 | 0 | delete (llama_sampler_mirostat_v2 *) smpl->ctx; |
2396 | 0 | } |
2397 | | |
2398 | | static struct llama_sampler_i llama_sampler_mirostat_v2_i = { |
2399 | | /* .name = */ llama_sampler_mirostat_v2_name, |
2400 | | /* .accept = */ nullptr, |
2401 | | /* .apply = */ llama_sampler_mirostat_v2_apply, |
2402 | | /* .reset = */ llama_sampler_mirostat_v2_reset, |
2403 | | /* .clone = */ llama_sampler_mirostat_v2_clone, |
2404 | | /* .free = */ llama_sampler_mirostat_v2_free, |
2405 | | /* .backend_init = */ nullptr, |
2406 | | /* .backend_accept = */ nullptr, |
2407 | | /* .backend_apply = */ nullptr, |
2408 | | /* .backend_set_input = */ nullptr, |
2409 | | }; |
2410 | | |
2411 | 0 | struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { |
2412 | 0 | auto seed_cur = get_rng_seed(seed); |
2413 | 0 | return llama_sampler_init( |
2414 | 0 | /* .iface = */ &llama_sampler_mirostat_v2_i, |
2415 | 0 | /* .ctx = */ new llama_sampler_mirostat_v2 { |
2416 | 0 | /* .seed = */ seed, |
2417 | 0 | /* .seed_cur = */ seed_cur, |
2418 | 0 | /* .tau = */ tau, |
2419 | 0 | /* .eta = */ eta, |
2420 | 0 | /* .mu = */ 2.0f*tau, |
2421 | 0 | /* .rng = */ std::mt19937(seed_cur), |
2422 | 0 | } |
2423 | 0 | ); |
2424 | 0 | } |
2425 | | |
2426 | | // grammar |
2427 | | |
2428 | | struct llama_sampler_grammar { |
2429 | | const struct llama_vocab * vocab; |
2430 | | |
2431 | | std::string grammar_str; |
2432 | | std::string grammar_root; |
2433 | | |
2434 | | struct llama_grammar * grammar; |
2435 | | }; |
2436 | | |
2437 | 0 | static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) { |
2438 | 0 | return "grammar"; |
2439 | 0 | } |
2440 | | |
2441 | 0 | static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) { |
2442 | 0 | auto * ctx = (llama_sampler_grammar *) smpl->ctx; |
2443 | 0 | if (ctx->grammar) { |
2444 | 0 | llama_grammar_accept_impl(*ctx->grammar, token); |
2445 | 0 | } |
2446 | 0 | } |
2447 | | |
2448 | 0 | static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2449 | 0 | auto * ctx = (llama_sampler_grammar *) smpl->ctx; |
2450 | 0 | if (ctx->grammar) { |
2451 | 0 | llama_grammar_apply_impl(*ctx->grammar, cur_p); |
2452 | 0 | } |
2453 | 0 | } |
2454 | | |
2455 | | // Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. |
2456 | | static struct llama_sampler * llama_sampler_init_grammar_impl( |
2457 | | const struct llama_vocab * vocab, |
2458 | | const char * grammar_str, |
2459 | | const char * grammar_root, |
2460 | | bool lazy, |
2461 | | const char ** trigger_words, |
2462 | | size_t num_trigger_words, |
2463 | | const llama_token * trigger_tokens, |
2464 | | size_t num_trigger_tokens, |
2465 | | const char ** trigger_patterns, |
2466 | | size_t num_trigger_patterns); |
2467 | | |
2468 | 0 | static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { |
2469 | 0 | auto * ctx = (llama_sampler_grammar *) smpl->ctx; |
2470 | 0 | if (!ctx->grammar) { |
2471 | 0 | return; |
2472 | 0 | } |
2473 | | |
2474 | 0 | std::vector<const char *> trigger_patterns_c; |
2475 | 0 | trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); |
2476 | 0 | for (auto & trigger_pattern : ctx->grammar->trigger_patterns) { |
2477 | 0 | trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); |
2478 | 0 | } |
2479 | |
|
2480 | 0 | auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), |
2481 | 0 | ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), |
2482 | 0 | ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); |
2483 | |
|
2484 | 0 | llama_grammar_free_impl(ctx->grammar); |
2485 | 0 | ctx->grammar = grammar_new; |
2486 | 0 | } |
2487 | | |
2488 | 0 | static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { |
2489 | 0 | const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; |
2490 | |
|
2491 | 0 | auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); |
2492 | 0 | GGML_ASSERT(result); |
2493 | | |
2494 | | // copy the state |
2495 | 0 | { |
2496 | 0 | auto * result_ctx = (llama_sampler_grammar *) result->ctx; |
2497 | |
|
2498 | 0 | if (ctx->grammar) { |
2499 | 0 | result_ctx->grammar_str = ctx->grammar_str; |
2500 | 0 | result_ctx->grammar_root = ctx->grammar_root; |
2501 | |
|
2502 | 0 | result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); |
2503 | 0 | } |
2504 | 0 | } |
2505 | |
|
2506 | 0 | return result; |
2507 | 0 | } |
2508 | | |
2509 | 0 | static void llama_sampler_grammar_free(struct llama_sampler * smpl) { |
2510 | 0 | const auto * ctx = (llama_sampler_grammar *) smpl->ctx; |
2511 | |
|
2512 | 0 | if (ctx->grammar) { |
2513 | 0 | llama_grammar_free_impl(ctx->grammar); |
2514 | 0 | } |
2515 | |
|
2516 | 0 | delete ctx; |
2517 | 0 | } |
2518 | | |
2519 | | static struct llama_sampler_i llama_sampler_grammar_i = { |
2520 | | /* .name = */ llama_sampler_grammar_name, |
2521 | | /* .accept = */ llama_sampler_grammar_accept_impl, |
2522 | | /* .apply = */ llama_sampler_grammar_apply, |
2523 | | /* .reset = */ llama_sampler_grammar_reset, |
2524 | | /* .clone = */ llama_sampler_grammar_clone, |
2525 | | /* .free = */ llama_sampler_grammar_free, |
2526 | | /* .backend_init = */ nullptr, |
2527 | | /* .backend_accept = */ nullptr, |
2528 | | /* .backend_apply = */ nullptr, |
2529 | | /* .backend_set_input = */ nullptr, |
2530 | | }; |
2531 | | |
2532 | | static struct llama_sampler * llama_sampler_init_grammar_impl( |
2533 | | const struct llama_vocab * vocab, |
2534 | | const char * grammar_str, |
2535 | | const char * grammar_root, |
2536 | | bool lazy, |
2537 | | const char ** trigger_words, |
2538 | | size_t num_trigger_words, |
2539 | | const llama_token * trigger_tokens, |
2540 | | size_t num_trigger_tokens, |
2541 | | const char ** trigger_patterns, |
2542 | 0 | size_t num_trigger_patterns) { |
2543 | 0 | auto * ctx = new llama_sampler_grammar; |
2544 | |
|
2545 | 0 | if (grammar_str != nullptr && grammar_str[0] != '\0') { |
2546 | 0 | std::string trigger_pattern; |
2547 | 0 | llama_grammar * grammar = nullptr; |
2548 | | // TODO: remove trigger_words support. |
2549 | 0 | if (trigger_words != nullptr && num_trigger_words > 0) { |
2550 | 0 | GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); |
2551 | 0 | trigger_pattern = "[\\s\\S]*?("; |
2552 | 0 | for (size_t i = 0; i < num_trigger_words; ++i) { |
2553 | 0 | static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); |
2554 | 0 | if (i > 0) { |
2555 | 0 | trigger_pattern += "|"; |
2556 | 0 | } |
2557 | 0 | trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); |
2558 | 0 | } |
2559 | 0 | trigger_pattern += ")[\\s\\S]*"; |
2560 | |
|
2561 | 0 | std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() }; |
2562 | 0 | grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens); |
2563 | 0 | } else { |
2564 | 0 | grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens); |
2565 | 0 | } |
2566 | 0 | *ctx = { |
2567 | 0 | /* .vocab = */ vocab, |
2568 | 0 | /* .grammar_str = */ grammar_str, |
2569 | 0 | /* .grammar_root = */ grammar_root, |
2570 | 0 | /* .grammar = */ grammar, |
2571 | 0 | }; |
2572 | 0 | if (!ctx->grammar) { |
2573 | 0 | delete ctx; |
2574 | 0 | return nullptr; |
2575 | 0 | } |
2576 | 0 | } else { |
2577 | 0 | *ctx = { |
2578 | 0 | /* .vocab = */ vocab, |
2579 | 0 | /* .grammar_str = */ {}, |
2580 | 0 | /* .grammar_root = */ {}, |
2581 | 0 | /* .grammar = */ nullptr, |
2582 | 0 | }; |
2583 | 0 | } |
2584 | | |
2585 | 0 | return llama_sampler_init( |
2586 | 0 | /* .iface = */ &llama_sampler_grammar_i, |
2587 | 0 | /* .ctx = */ ctx |
2588 | 0 | ); |
2589 | 0 | } |
2590 | | |
2591 | | struct llama_sampler * llama_sampler_init_grammar( |
2592 | | const struct llama_vocab * vocab, |
2593 | | const char * grammar_str, |
2594 | 0 | const char * grammar_root) { |
2595 | 0 | return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); |
2596 | 0 | } |
2597 | | |
2598 | | struct llama_sampler * llama_sampler_init_grammar_lazy( |
2599 | | const struct llama_vocab * vocab, |
2600 | | const char * grammar_str, |
2601 | | const char * grammar_root, |
2602 | | const char ** trigger_words, |
2603 | | size_t num_trigger_words, |
2604 | | const llama_token * trigger_tokens, |
2605 | 0 | size_t num_trigger_tokens) { |
2606 | 0 | return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); |
2607 | 0 | } |
2608 | | |
2609 | | struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( |
2610 | | const struct llama_vocab * vocab, |
2611 | | const char * grammar_str, |
2612 | | const char * grammar_root, |
2613 | | const char ** trigger_patterns, |
2614 | | size_t num_trigger_patterns, |
2615 | | const llama_token * trigger_tokens, |
2616 | 0 | size_t num_trigger_tokens) { |
2617 | 0 | return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); |
2618 | 0 | } |
2619 | | |
2620 | | // penalties |
2621 | | |
2622 | | struct llama_sampler_penalties { |
2623 | | const int32_t penalty_last_n; |
2624 | | const float penalty_repeat; |
2625 | | const float penalty_freq; |
2626 | | const float penalty_present; |
2627 | | |
2628 | | ring_buffer<llama_token> prev; |
2629 | | |
2630 | | // a frequency map to count token occurrences |
2631 | | std::unordered_map<llama_token, int> token_count; |
2632 | | }; |
2633 | | |
2634 | 0 | static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { |
2635 | 0 | return "penalties"; |
2636 | 0 | } |
2637 | | |
2638 | 0 | static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) { |
2639 | 0 | auto * ctx = (llama_sampler_penalties *) smpl->ctx; |
2640 | 0 | if (ctx->penalty_last_n == 0) { |
2641 | 0 | return; |
2642 | 0 | } |
2643 | | |
2644 | 0 | ctx->token_count[token]++; |
2645 | | |
2646 | | // if the ring buffer is full, remove the oldest token |
2647 | 0 | if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { |
2648 | 0 | const auto old = ctx->prev.front(); |
2649 | |
|
2650 | 0 | ctx->token_count[old]--; |
2651 | 0 | if (ctx->token_count[old] == 0) { |
2652 | 0 | ctx->token_count.erase(old); |
2653 | 0 | } |
2654 | 0 | } |
2655 | |
|
2656 | 0 | ctx->prev.push_back(token); |
2657 | |
|
2658 | | #if 0 |
2659 | | // sanity check |
2660 | | std::unordered_map<llama_token, int> tmp; |
2661 | | for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) { |
2662 | | tmp[ctx->prev.rat(i)]++; |
2663 | | } |
2664 | | |
2665 | | assert(ctx->token_count == tmp); |
2666 | | #endif |
2667 | 0 | } |
2668 | | |
2669 | 0 | static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2670 | 0 | auto * ctx = (llama_sampler_penalties *) smpl->ctx; |
2671 | |
|
2672 | 0 | if ((ctx->penalty_last_n == 0) || |
2673 | 0 | (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { |
2674 | 0 | return; |
2675 | 0 | } |
2676 | | |
2677 | | // Apply frequency and presence penalties to the cur_p |
2678 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
2679 | 0 | const auto token_iter = ctx->token_count.find(cur_p->data[i].id); |
2680 | 0 | if (token_iter == ctx->token_count.end()) { |
2681 | 0 | continue; |
2682 | 0 | } |
2683 | | |
2684 | 0 | const int count = token_iter->second; |
2685 | |
|
2686 | 0 | assert(count > 0 && count <= ctx->penalty_last_n); |
2687 | | |
2688 | | // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. |
2689 | | // This is common fix for this problem, which is to multiply by the penalty instead of dividing. |
2690 | 0 | if (cur_p->data[i].logit <= 0) { |
2691 | 0 | cur_p->data[i].logit *= ctx->penalty_repeat; |
2692 | 0 | } else { |
2693 | 0 | cur_p->data[i].logit /= ctx->penalty_repeat; |
2694 | 0 | } |
2695 | |
|
2696 | 0 | cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; |
2697 | 0 | } |
2698 | |
|
2699 | 0 | cur_p->sorted = false; |
2700 | 0 | } |
2701 | | |
2702 | 0 | static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { |
2703 | 0 | auto * ctx = (llama_sampler_penalties *) smpl->ctx; |
2704 | 0 | ctx->prev.clear(); |
2705 | 0 | ctx->token_count.clear(); |
2706 | 0 | } |
2707 | | |
2708 | 0 | static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { |
2709 | 0 | const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; |
2710 | 0 | auto * result = llama_sampler_init_penalties( |
2711 | 0 | ctx->penalty_last_n, |
2712 | 0 | ctx->penalty_repeat, |
2713 | 0 | ctx->penalty_freq, |
2714 | 0 | ctx->penalty_present); |
2715 | | |
2716 | | // copy the state |
2717 | 0 | { |
2718 | 0 | auto * result_ctx = (llama_sampler_penalties *) result->ctx; |
2719 | |
|
2720 | 0 | result_ctx->prev = ctx->prev; |
2721 | 0 | } |
2722 | |
|
2723 | 0 | return result; |
2724 | 0 | } |
2725 | | |
2726 | 0 | static void llama_sampler_penalties_free(struct llama_sampler * smpl) { |
2727 | 0 | delete (llama_sampler_penalties *) smpl->ctx; |
2728 | 0 | } |
2729 | | |
2730 | | static struct llama_sampler_i llama_sampler_penalties_i = { |
2731 | | /* .name = */ llama_sampler_penalties_name, |
2732 | | /* .accept = */ llama_sampler_penalties_accept, |
2733 | | /* .apply = */ llama_sampler_penalties_apply, |
2734 | | /* .reset = */ llama_sampler_penalties_reset, |
2735 | | /* .clone = */ llama_sampler_penalties_clone, |
2736 | | /* .free = */ llama_sampler_penalties_free, |
2737 | | /* .backend_init = */ nullptr, |
2738 | | /* .backend_accept = */ nullptr, |
2739 | | /* .backend_apply = */ nullptr, |
2740 | | /* .backend_set_input = */ nullptr, |
2741 | | }; |
2742 | | |
2743 | | struct llama_sampler * llama_sampler_init_penalties( |
2744 | | int32_t penalty_last_n, |
2745 | | float penalty_repeat, |
2746 | | float penalty_freq, |
2747 | 0 | float penalty_present) { |
2748 | 0 | penalty_last_n = std::max(penalty_last_n, 0); |
2749 | |
|
2750 | 0 | const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); |
2751 | |
|
2752 | 0 | if (is_empty) { |
2753 | 0 | return llama_sampler_init_empty("?penalties"); |
2754 | 0 | } |
2755 | | |
2756 | 0 | return llama_sampler_init( |
2757 | 0 | /* .iface = */ &llama_sampler_penalties_i, |
2758 | 0 | /* .ctx = */ new llama_sampler_penalties { |
2759 | 0 | /* .penalty_last_n = */ penalty_last_n, |
2760 | 0 | /* .penalty_repeat = */ penalty_repeat, |
2761 | 0 | /* .penalty_freq = */ penalty_freq, |
2762 | 0 | /* .penalty_present = */ penalty_present, |
2763 | 0 | /* .prev = */ ring_buffer<llama_token>(penalty_last_n), |
2764 | 0 | /* .token_count = */ {}, |
2765 | 0 | } |
2766 | 0 | ); |
2767 | 0 | } |
2768 | | |
2769 | | // top-n-sigma |
2770 | | |
2771 | | struct llama_sampler_top_n_sigma { |
2772 | | const float n; |
2773 | | }; |
2774 | | |
2775 | 0 | static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { |
2776 | 0 | return "top-n-sigma"; |
2777 | 0 | } |
2778 | | |
2779 | 0 | static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2780 | 0 | auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; |
2781 | |
|
2782 | 0 | if (ctx->n <= 0.0f || cur_p->size <= 1) { |
2783 | 0 | return; |
2784 | 0 | } |
2785 | | |
2786 | | // find max logit and calculate mean |
2787 | 0 | float max = cur_p->data[0].logit; |
2788 | 0 | float logits_sum = 0; |
2789 | 0 | size_t valid_count = 0; |
2790 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
2791 | | // Only count non-negative infinity values |
2792 | 0 | if (cur_p->data[i].logit != -INFINITY) { |
2793 | 0 | max = std::max(max, cur_p->data[i].logit); |
2794 | 0 | logits_sum += cur_p->data[i].logit; |
2795 | 0 | valid_count++; |
2796 | 0 | } |
2797 | 0 | } |
2798 | 0 | float mean = valid_count > 0 ? logits_sum/valid_count : 0; |
2799 | | |
2800 | | // calculate standard deviation |
2801 | 0 | float acc = 0; |
2802 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
2803 | | // Skip -infinity in std calculation |
2804 | 0 | if (cur_p->data[i].logit != -INFINITY) { |
2805 | 0 | acc += pow(cur_p->data[i].logit - mean, 2); |
2806 | 0 | } |
2807 | 0 | } |
2808 | 0 | float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; |
2809 | | |
2810 | | // apply mask |
2811 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
2812 | 0 | if (cur_p->data[i].logit < max - (ctx->n * std)) { |
2813 | 0 | cur_p->data[i].logit = -INFINITY; |
2814 | 0 | } |
2815 | 0 | } |
2816 | |
|
2817 | 0 | llama_sampler_softmax_impl(cur_p, true); |
2818 | 0 | } |
2819 | | |
2820 | 0 | static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { |
2821 | 0 | const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; |
2822 | 0 | return llama_sampler_init_top_n_sigma(ctx->n); |
2823 | 0 | } |
2824 | | |
2825 | 0 | static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { |
2826 | 0 | delete (llama_sampler_top_n_sigma *) smpl->ctx; |
2827 | 0 | } |
2828 | | |
2829 | | static struct llama_sampler_i llama_sampler_top_n_sigma_i = { |
2830 | | /* .name = */ llama_sampler_top_n_sigma_name, |
2831 | | /* .accept = */ nullptr, |
2832 | | /* .apply = */ llama_sampler_top_n_sigma_apply, |
2833 | | /* .reset = */ nullptr, |
2834 | | /* .clone = */ llama_sampler_top_n_sigma_clone, |
2835 | | /* .free = */ llama_sampler_top_n_sigma_free, |
2836 | | /* .backend_init = */ nullptr, |
2837 | | /* .backend_accept = */ nullptr, |
2838 | | /* .backend_apply = */ nullptr, |
2839 | | /* .backend_set_input = */ nullptr, |
2840 | | }; |
2841 | | |
2842 | 0 | struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { |
2843 | 0 | const bool is_empty = (n <= 0.0f); |
2844 | |
|
2845 | 0 | if (is_empty) { |
2846 | 0 | return llama_sampler_init_empty("?top-n-sigma"); |
2847 | 0 | } |
2848 | | |
2849 | 0 | return llama_sampler_init( |
2850 | 0 | /* .iface = */ &llama_sampler_top_n_sigma_i, |
2851 | 0 | /* .ctx = */ new llama_sampler_top_n_sigma { |
2852 | 0 | /* .n = */ n, |
2853 | 0 | } |
2854 | 0 | ); |
2855 | 0 | } |
2856 | | |
2857 | | // DRY |
2858 | | |
2859 | | struct llama_sampler_dry { |
2860 | | int32_t total_context_size; |
2861 | | |
2862 | | const float dry_multiplier; |
2863 | | const float dry_base; |
2864 | | const int32_t dry_allowed_length; |
2865 | | const int32_t dry_penalty_last_n; |
2866 | | |
2867 | | std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers; |
2868 | | std::vector<int> dry_repeat_count; |
2869 | | std::unordered_map<llama_token, int> dry_max_token_repeat; |
2870 | | ring_buffer<llama_token> last_tokens; |
2871 | | }; |
2872 | | |
2873 | | // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) |
2874 | 0 | static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { |
2875 | 0 | for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) { |
2876 | 0 | std::string word = vocab.detokenize({token_id}, true); |
2877 | 0 | if (word.find(str) != std::string::npos) { |
2878 | 0 | token_sequences.emplace(token_id, std::vector<llama_token>()); |
2879 | 0 | } else { |
2880 | 0 | size_t word_len = word.size(); |
2881 | 0 | size_t str_len = str.size(); |
2882 | 0 | size_t pos = -1; |
2883 | 0 | while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { |
2884 | 0 | bool match = true; |
2885 | 0 | size_t i; |
2886 | 0 | for (i = 1; i < str_len && i + pos < word_len; ++i) { |
2887 | 0 | if (word[pos + i] != str[i]) { |
2888 | 0 | match = false; |
2889 | 0 | break; |
2890 | 0 | } |
2891 | 0 | } |
2892 | 0 | if (match) { |
2893 | 0 | std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false); |
2894 | 0 | if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { |
2895 | 0 | tokenization.resize(max_tail_len); |
2896 | 0 | } |
2897 | | |
2898 | | // Ensure we don't already have a duplicate matching tokenization |
2899 | 0 | auto its = token_sequences.equal_range(token_id); |
2900 | 0 | bool found = false; |
2901 | 0 | for (auto it = its.first; it != its.second; ++it) { |
2902 | 0 | if (tokenization == it->second) { |
2903 | 0 | found = true; |
2904 | 0 | break; |
2905 | 0 | } |
2906 | 0 | } |
2907 | 0 | if (!found) { |
2908 | 0 | token_sequences.emplace(token_id, tokenization); |
2909 | 0 | } |
2910 | 0 | } |
2911 | 0 | } |
2912 | 0 | } |
2913 | 0 | } |
2914 | 0 | } |
2915 | | |
2916 | 0 | static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { |
2917 | 0 | return "dry"; |
2918 | 0 | } |
2919 | | |
2920 | 0 | static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { |
2921 | 0 | auto * ctx = (llama_sampler_dry *) smpl->ctx; |
2922 | 0 | if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { |
2923 | 0 | return; |
2924 | 0 | } |
2925 | | |
2926 | 0 | ctx->last_tokens.push_back(token); |
2927 | 0 | } |
2928 | | |
2929 | | // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) |
2930 | 0 | static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
2931 | 0 | auto * ctx = (llama_sampler_dry *) smpl->ctx; |
2932 | |
|
2933 | 0 | if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { |
2934 | 0 | return; |
2935 | 0 | } |
2936 | | |
2937 | 0 | int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); |
2938 | 0 | int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); |
2939 | |
|
2940 | 0 | if (last_n_repeat <= ctx->dry_allowed_length) { |
2941 | 0 | return; |
2942 | 0 | } |
2943 | | |
2944 | 0 | ctx->dry_repeat_count.assign(last_n_repeat, 0); |
2945 | 0 | ctx->dry_max_token_repeat.clear(); |
2946 | | |
2947 | | // Step 1: Look for restart sequences to limit the maximum repetition length. |
2948 | | // Work backwards through the context looking for any token that begins a restart sequence. |
2949 | | // |
2950 | | // The collection `restart_sequences` is a mapping from a "head" token to all "tail" |
2951 | | // sequences that together comprise a restart sequence. This allows us to quickly check |
2952 | | // whether each token is the head of a complete sequence. Most restart sequences are actually |
2953 | | // a single token, and for these the "tail" is an empty vector. |
2954 | | // |
2955 | | // If the token is a "head", test all restart sequences that begin with this token |
2956 | | // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and |
2957 | | // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The |
2958 | | // longest matching sequence (if any) is used to limit the maximum repetition length. |
2959 | | // |
2960 | | // Note that in the case case of a short sequence contained in a longer one, this might fail to |
2961 | | // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as |
2962 | | // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress |
2963 | | // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. |
2964 | | // |
2965 | | // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we |
2966 | | // have already clamped the maximum tail sequence length when generating `restart_sequences`. |
2967 | | // With clamping, this scan is O(N) in the context length. |
2968 | |
|
2969 | 0 | int rep_limit = last_n_repeat; |
2970 | 0 | for (int i = 0; i < last_n_repeat; ++i) { |
2971 | 0 | llama_token token = ctx->last_tokens.rat(i); |
2972 | 0 | auto its = ctx->dry_processed_breakers.equal_range(token); |
2973 | 0 | if (its.first == ctx->dry_processed_breakers.end()) { |
2974 | 0 | continue; |
2975 | 0 | } |
2976 | 0 | int longest_match = -1; |
2977 | 0 | for (auto it = its.first; it != its.second; ++it) { |
2978 | | // Note that (*it) does not contain the head character, so seq_len will be |
2979 | | // the restart sequence length minus 1. |
2980 | | // In the common case of a single-token restart sequence, (*it) will be empty |
2981 | | // and we will trivially match. |
2982 | 0 | int seq_len = (int)it->second.size(); |
2983 | 0 | if (seq_len > longest_match && seq_len <= (int)i) { |
2984 | 0 | bool match = true; |
2985 | 0 | for (int offset = 0; offset < seq_len; ++offset) { |
2986 | | // The -1 when indexing `last_tokens` is because we already matched the head. |
2987 | 0 | if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { |
2988 | 0 | match = false; |
2989 | 0 | break; |
2990 | 0 | } |
2991 | 0 | } |
2992 | 0 | if (match) { |
2993 | 0 | longest_match = seq_len; |
2994 | 0 | } |
2995 | 0 | } |
2996 | 0 | } |
2997 | 0 | if (longest_match >= 0) { |
2998 | | // We found a restart sequence starting `i` tokens from the end and continuing for |
2999 | | // `longest_match` tokens. |
3000 | 0 | rep_limit = i - longest_match; |
3001 | 0 | break; |
3002 | 0 | } |
3003 | 0 | } |
3004 | 0 | if (rep_limit < ctx->dry_allowed_length) { |
3005 | 0 | return; |
3006 | 0 | } |
3007 | | |
3008 | | // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in |
3009 | | // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing |
3010 | | // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. |
3011 | | // |
3012 | | // This algorithm is not currently documented on Wikipedia, but there is a clear description here: |
3013 | | // https://ivanyu.me/blog/2014/10/15/z-algorithm/ |
3014 | | // |
3015 | | // The code below is adapted from the public domain implementation by the same author here: |
3016 | | // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py |
3017 | | // |
3018 | | // Example: |
3019 | | // Last N tokens: a b c c b c y a b c |
3020 | | // Repeat counts: 0 0 3 1 0 2 0 0 0 0 |
3021 | | // ^ |
3022 | | // This `3` means that the last three tokens of the context (a b c) also appear here. |
3023 | | // |
3024 | | // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested |
3025 | | // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each |
3026 | | // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables |
3027 | | // ensure that the inner while loops only examine each token in the context once as the outer |
3028 | | // for loop iterates over the context. |
3029 | | |
3030 | 0 | { |
3031 | 0 | const int last = last_n_repeat - 1; |
3032 | |
|
3033 | 0 | int rt = 0; |
3034 | 0 | int lt = 0; |
3035 | |
|
3036 | 0 | for (int k = 1; k < last_n_repeat; ++k) { |
3037 | 0 | if (k > rt) { |
3038 | | // If k is outside the current Z-box, do naive computation. |
3039 | 0 | int n = 0; |
3040 | 0 | while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { |
3041 | 0 | ++n; |
3042 | 0 | } |
3043 | 0 | ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); |
3044 | 0 | if (n > 0) { |
3045 | 0 | lt = k; |
3046 | 0 | rt = k + n - 1; |
3047 | 0 | } |
3048 | 0 | } else { |
3049 | | // If k is inside the current Z-box, consider two cases. |
3050 | |
|
3051 | 0 | int p = k - lt; // Pair index. |
3052 | 0 | int right_part_len = rt - k + 1; |
3053 | |
|
3054 | 0 | if (ctx->dry_repeat_count[last - p] < right_part_len) { |
3055 | 0 | int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); |
3056 | 0 | ctx->dry_repeat_count[last - k] = n; |
3057 | 0 | } else { |
3058 | 0 | int i = rt + 1; |
3059 | 0 | while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { |
3060 | 0 | i += 1; |
3061 | 0 | } |
3062 | |
|
3063 | 0 | int n = std::min(i - k, rep_limit); |
3064 | 0 | ctx->dry_repeat_count[last - k] = n; |
3065 | 0 | lt = k; |
3066 | 0 | rt = i - 1; |
3067 | 0 | } |
3068 | 0 | } |
3069 | 0 | } |
3070 | 0 | } |
3071 | | |
3072 | | // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length |
3073 | | // that would be generated by emitting each new token that would extend a sequence. |
3074 | | // |
3075 | | // Following the same example as above: |
3076 | | // Last N tokens: a b c c b c y a b c |
3077 | | // Repeat counts: 0 0 3 1 0 2 0 0 0 0 |
3078 | | // |
3079 | | // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. |
3080 | | // c: 3 -> 4 (from `a b c` to `a b c c`) |
3081 | | // b: 1 -> 2 (from `c` to `c b`) |
3082 | | // y: 2 -> 3 (from `b c` to `b c y`) |
3083 | |
|
3084 | 0 | for (int i = 0; i < last_n_repeat - 1; ++i) { |
3085 | 0 | int repeat_len = ctx->dry_repeat_count[i]; |
3086 | 0 | if (repeat_len >= ctx->dry_allowed_length) { |
3087 | | // This token ends a repeat, so the next token would continue one. |
3088 | | // By convention, the value of `repeat_len` only includes the tokens currently |
3089 | | // in the context, not the new token that would be added. |
3090 | 0 | llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); |
3091 | | // Track the maximum sequence ending in this token. |
3092 | 0 | const auto& it = ctx->dry_max_token_repeat.find(token); |
3093 | 0 | if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { |
3094 | 0 | ctx->dry_max_token_repeat[token] = repeat_len; |
3095 | 0 | } |
3096 | 0 | } |
3097 | 0 | } |
3098 | | |
3099 | | // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. |
3100 | | |
3101 | | // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. |
3102 | | // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()` |
3103 | 0 | const float FLOAT_MAX_LOG = 88.7228391f; |
3104 | 0 | int max_exponent = 0; |
3105 | 0 | if (ctx->dry_base > 1.000001f) { |
3106 | 0 | max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); |
3107 | 0 | } |
3108 | |
|
3109 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3110 | 0 | const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); |
3111 | 0 | if (af_kvp != ctx->dry_max_token_repeat.end()) { |
3112 | | // Check all sequence breakers starting with this token |
3113 | 0 | auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); |
3114 | 0 | bool is_single_token_breaker = false; |
3115 | |
|
3116 | 0 | for (auto it = range.first; it != range.second; ++it) { |
3117 | 0 | if (it->second.empty()) { |
3118 | 0 | is_single_token_breaker = true; |
3119 | 0 | break; |
3120 | 0 | } |
3121 | 0 | } |
3122 | | |
3123 | | // Apply penalty only if it's not a single-token sequence breaker |
3124 | 0 | if (!is_single_token_breaker) { |
3125 | 0 | int repeat_exp = af_kvp->second - ctx->dry_allowed_length; |
3126 | 0 | if (max_exponent > 0 && repeat_exp > max_exponent) { |
3127 | 0 | repeat_exp = max_exponent; |
3128 | 0 | } |
3129 | 0 | float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); |
3130 | 0 | cur_p->data[i].logit -= penalty; |
3131 | 0 | } |
3132 | 0 | } |
3133 | 0 | } |
3134 | |
|
3135 | 0 | cur_p->sorted = false; |
3136 | 0 | } |
3137 | | |
3138 | 0 | static void llama_sampler_dry_reset(struct llama_sampler * smpl) { |
3139 | 0 | auto * ctx = (llama_sampler_dry *) smpl->ctx; |
3140 | 0 | ctx->last_tokens.clear(); |
3141 | 0 | ctx->dry_repeat_count.clear(); |
3142 | 0 | ctx->dry_max_token_repeat.clear(); |
3143 | 0 | } |
3144 | | |
3145 | 0 | static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { |
3146 | 0 | const auto * ctx = (llama_sampler_dry *) smpl->ctx; |
3147 | |
|
3148 | 0 | llama_vocab dummy_vocab; |
3149 | | |
3150 | | // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying |
3151 | 0 | auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); |
3152 | | |
3153 | | // Copy the state, including the processed breakers |
3154 | 0 | { |
3155 | 0 | auto * result_ctx = (llama_sampler_dry *) result->ctx; |
3156 | 0 | result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; |
3157 | 0 | result_ctx->dry_repeat_count = ctx->dry_repeat_count; |
3158 | 0 | result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; |
3159 | 0 | result_ctx->last_tokens = ctx->last_tokens; |
3160 | 0 | } |
3161 | |
|
3162 | 0 | return result; |
3163 | 0 | } |
3164 | | |
3165 | 0 | static void llama_sampler_dry_free(struct llama_sampler * smpl) { |
3166 | 0 | delete (llama_sampler_dry *) smpl->ctx; |
3167 | 0 | } |
3168 | | |
3169 | | static struct llama_sampler_i llama_sampler_dry_i = { |
3170 | | /* .name = */ llama_sampler_dry_name, |
3171 | | /* .accept = */ llama_sampler_dry_accept, |
3172 | | /* .apply = */ llama_sampler_dry_apply, |
3173 | | /* .reset = */ llama_sampler_dry_reset, |
3174 | | /* .clone = */ llama_sampler_dry_clone, |
3175 | | /* .free = */ llama_sampler_dry_free, |
3176 | | /* .backend_init = */ nullptr, |
3177 | | /* .backend_accept = */ nullptr, |
3178 | | /* .backend_apply = */ nullptr, |
3179 | | /* .backend_set_input = */ nullptr, |
3180 | | }; |
3181 | | |
3182 | 0 | struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { |
3183 | 0 | int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0); |
3184 | 0 | std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; |
3185 | 0 | const int MAX_CHAR_LEN = 40; |
3186 | 0 | const int MAX_SEQ_LEN = 20; |
3187 | |
|
3188 | 0 | const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); |
3189 | |
|
3190 | 0 | if (!dry_enabled) { |
3191 | 0 | return llama_sampler_init_empty("?dry"); |
3192 | 0 | } |
3193 | | |
3194 | 0 | if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { |
3195 | | // Process sequence breakers |
3196 | 0 | for (size_t i = 0; i < num_breakers; ++i) { |
3197 | 0 | if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { |
3198 | 0 | LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); |
3199 | 0 | continue; |
3200 | 0 | } |
3201 | | |
3202 | 0 | std::string sequence_break(seq_breakers[i]); |
3203 | 0 | if (sequence_break.empty()) { |
3204 | 0 | LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); |
3205 | 0 | continue; |
3206 | 0 | } |
3207 | | |
3208 | 0 | if (sequence_break.size() > MAX_CHAR_LEN) { |
3209 | 0 | LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); |
3210 | 0 | sequence_break.resize(MAX_CHAR_LEN); |
3211 | 0 | } |
3212 | |
|
3213 | 0 | get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); |
3214 | 0 | } |
3215 | 0 | } |
3216 | |
|
3217 | 0 | return llama_sampler_init( |
3218 | 0 | /* .iface = */ &llama_sampler_dry_i, |
3219 | 0 | /* .ctx = */ new llama_sampler_dry { |
3220 | 0 | /* .total_context_size = */ n_ctx_train, |
3221 | 0 | /* .dry_multiplier = */ dry_multiplier, |
3222 | 0 | /* .dry_base = */ dry_base, |
3223 | 0 | /* .dry_allowed_length = */ dry_allowed_length, |
3224 | 0 | /* .dry_penalty_last_n = */ dry_penalty_last_n, |
3225 | 0 | /* .dry_processed_breakers = */ std::move(processed_breakers), |
3226 | 0 | /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, |
3227 | 0 | /* .dry_max_token_repeat = */ {}, |
3228 | 0 | /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), |
3229 | 0 | } |
3230 | 0 | ); |
3231 | 0 | } |
3232 | | |
3233 | | // wrapper for test-sampling.cpp |
3234 | 0 | struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) { |
3235 | 0 | llama_vocab dummy_vocab; |
3236 | 0 | auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); |
3237 | 0 | auto * ctx = (llama_sampler_dry *) result->ctx; |
3238 | | |
3239 | | // Process the token-based sequence breakers |
3240 | 0 | ctx->dry_processed_breakers.clear(); |
3241 | 0 | if (seq_breakers.empty()) { |
3242 | 0 | LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); |
3243 | 0 | } else { |
3244 | 0 | for (const auto& breaker : seq_breakers) { |
3245 | 0 | if (breaker.empty()) { |
3246 | 0 | LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); |
3247 | 0 | continue; |
3248 | 0 | } |
3249 | 0 | llama_token head_token = breaker[0]; |
3250 | 0 | std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end()); |
3251 | 0 | ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); |
3252 | 0 | } |
3253 | |
|
3254 | 0 | if (ctx->dry_processed_breakers.empty()) { |
3255 | 0 | LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); |
3256 | 0 | } |
3257 | 0 | } |
3258 | |
|
3259 | 0 | return result; |
3260 | 0 | } |
3261 | | |
3262 | | // adaptive-p sampler state |
3263 | | // |
3264 | | // maintains an exponential moving average of the *ORIGINAL* probabilities |
3265 | | // of selected tokens, used to compute an adapted target at each sampling step. |
3266 | | // |
3267 | | // see llama.h for a full description of the sampler |
3268 | | // |
3269 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17927 |
3270 | | // |
3271 | | struct llama_sampler_adaptive_p { |
3272 | | const float target; // target probability (0.0 - 1.0; negative = disabled) |
3273 | | const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) |
3274 | | const uint32_t seed; // original RNG seed |
3275 | | uint32_t seed_cur; // actual RNG seed |
3276 | | std::mt19937 rng; // RNG state |
3277 | | float weighted_sum; // sum(p_i * decay^i) |
3278 | | float total_weight; // sum(decay^i), converges to 1/(1-decay) |
3279 | | std::vector<float> original_probs; // pre-transform probs, cached for EMA update |
3280 | | llama_token pending_token_id; // token ID of selected token |
3281 | | int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs |
3282 | | }; |
3283 | | |
3284 | | // adaptive probability transformation constants |
3285 | | static constexpr float DISTRIBUTION_WIDTH = 0.3f; |
3286 | | static constexpr float PEAK_LOGIT_VALUE = 5.0f; |
3287 | | static constexpr float SHARPNESS = 10.0f; |
3288 | | static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; |
3289 | | |
3290 | 0 | static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { |
3291 | 0 | return "adaptive-p"; |
3292 | 0 | } |
3293 | | |
3294 | 0 | static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
3295 | 0 | auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; |
3296 | |
|
3297 | 0 | llama_sampler_softmax_impl(cur_p, false); |
3298 | |
|
3299 | 0 | if (ctx->target < 0.0f) { |
3300 | | // at negative target values, adaptive-p is no-op |
3301 | | // we simply sample from the existing distribution |
3302 | 0 | cur_p->selected = llama_sample_dist(cur_p, ctx->rng); |
3303 | 0 | return; |
3304 | 0 | } |
3305 | | |
3306 | | // store the original probabilities |
3307 | 0 | ctx->original_probs.resize(cur_p->size); |
3308 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3309 | 0 | ctx->original_probs[i] = cur_p->data[i].p; |
3310 | 0 | } |
3311 | | |
3312 | | // using the EMA, compute the adapted target probability for the current sampling step |
3313 | 0 | auto target = std::clamp(ctx->target, 0.0f, 1.0f); |
3314 | 0 | float adapted_target = std::clamp( |
3315 | 0 | ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), |
3316 | 0 | 0.0f, 1.0f |
3317 | 0 | ); |
3318 | | |
3319 | | // adaptive probability transform |
3320 | | // |
3321 | | // quadratic near target for fine differentiation, transitioning to linear decay in the |
3322 | | // tails. unbounded negative logits ensure proper suppression of far-from-target tokens |
3323 | | // after the softmax. |
3324 | | // |
3325 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3326 | 0 | if (cur_p->data[i].logit == -INFINITY) { |
3327 | | // don't transform logits that are -INFINITY |
3328 | | // (as masked out by e.g. min-p and top-p when using backend sampling) |
3329 | 0 | continue; |
3330 | 0 | } |
3331 | 0 | float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); |
3332 | 0 | cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); |
3333 | 0 | } |
3334 | | |
3335 | | // softmax and sample from the transformed distribution |
3336 | 0 | llama_sampler_softmax_impl(cur_p, false); |
3337 | 0 | const int idx = llama_sample_dist(cur_p, ctx->rng); |
3338 | 0 | cur_p->selected = idx; |
3339 | | |
3340 | | // store the selected token ID for acceptance later |
3341 | 0 | ctx->pending_token_id = cur_p->data[idx].id; |
3342 | 0 | ctx->pending_token_idx = idx; |
3343 | 0 | } |
3344 | | |
3345 | 0 | static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { |
3346 | 0 | auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; |
3347 | 0 | if (ctx->pending_token_id == token) { |
3348 | 0 | GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); |
3349 | 0 | GGML_ASSERT(ctx->pending_token_idx != -1); |
3350 | | // update EMA with the original probability of the selected token |
3351 | 0 | ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; |
3352 | 0 | ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; |
3353 | 0 | } |
3354 | 0 | ctx->pending_token_id = LLAMA_TOKEN_NULL; |
3355 | 0 | ctx->pending_token_idx = -1; |
3356 | 0 | } |
3357 | | |
3358 | 0 | static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { |
3359 | 0 | auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; |
3360 | | // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. |
3361 | | // original_probs is completely overwritten on every call to _apply. |
3362 | | // so we only need to reset the EMA state and pending token. |
3363 | 0 | ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); |
3364 | 0 | ctx->total_weight = 1.0f / (1.0f - ctx->decay); |
3365 | 0 | ctx->pending_token_id = LLAMA_TOKEN_NULL; |
3366 | 0 | ctx->pending_token_idx = -1; |
3367 | 0 | ctx->seed_cur = get_rng_seed(ctx->seed); |
3368 | 0 | ctx->rng.seed(ctx->seed_cur); |
3369 | 0 | } |
3370 | | |
3371 | 0 | static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { |
3372 | 0 | const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; |
3373 | 0 | auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); |
3374 | 0 | auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; |
3375 | | |
3376 | | // copy everything (target, decay, seed, and RNG are already set) |
3377 | 0 | result_ctx->weighted_sum = ctx->weighted_sum; |
3378 | 0 | result_ctx->total_weight = ctx->total_weight; |
3379 | 0 | result_ctx->pending_token_id = ctx->pending_token_id; |
3380 | 0 | result_ctx->pending_token_idx = ctx->pending_token_idx; |
3381 | |
|
3382 | 0 | return result; |
3383 | 0 | } |
3384 | | |
3385 | 0 | static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { |
3386 | 0 | delete (llama_sampler_adaptive_p *) smpl->ctx; |
3387 | 0 | } |
3388 | | |
3389 | | static struct llama_sampler_i llama_sampler_adaptive_p_i = { |
3390 | | /* .name = */ llama_sampler_adaptive_p_name, |
3391 | | /* .accept = */ llama_sampler_adaptive_p_accept, |
3392 | | /* .apply = */ llama_sampler_adaptive_p_apply, |
3393 | | /* .reset = */ llama_sampler_adaptive_p_reset, |
3394 | | /* .clone = */ llama_sampler_adaptive_p_clone, |
3395 | | /* .free = */ llama_sampler_adaptive_p_free, |
3396 | | /* .backend_init = */ nullptr, |
3397 | | /* .backend_accept = */ nullptr, |
3398 | | /* .backend_apply = */ nullptr, |
3399 | | /* .backend_set_input = */ nullptr, |
3400 | | }; |
3401 | | |
3402 | | struct llama_sampler * llama_sampler_init_adaptive_p( |
3403 | | float target, |
3404 | | float decay, |
3405 | | uint32_t seed |
3406 | 0 | ) { |
3407 | 0 | auto seed_cur = get_rng_seed(seed); |
3408 | 0 | float clamped_decay = std::clamp(decay, 0.0f, 0.99f); |
3409 | 0 | return llama_sampler_init( |
3410 | 0 | /* .iface = */ &llama_sampler_adaptive_p_i, |
3411 | 0 | /* .ctx = */ new llama_sampler_adaptive_p { |
3412 | 0 | /* .target = */ target, |
3413 | 0 | /* .decay = */ clamped_decay, |
3414 | 0 | /* .seed = */ seed, |
3415 | 0 | /* .seed_cur = */ seed_cur, |
3416 | 0 | /* .rng = */ std::mt19937(seed_cur), |
3417 | 0 | /* .weighted_sum = */ target / (1.0f - clamped_decay), |
3418 | 0 | /* .total_weight = */ 1.0f / (1.0f - clamped_decay), |
3419 | 0 | /* .original_probs = */ {}, |
3420 | 0 | /* .pending_token_id = */ LLAMA_TOKEN_NULL, |
3421 | 0 | /* .pending_token_idx = */ -1 |
3422 | 0 | } |
3423 | 0 | ); |
3424 | 0 | } |
3425 | | |
3426 | | // logit-bias |
3427 | | |
3428 | | struct llama_sampler_logit_bias : public llama_sampler_backend { |
3429 | | const int32_t n_vocab; |
3430 | | |
3431 | | const std::vector<llama_logit_bias> logit_bias; |
3432 | | |
3433 | | std::vector<llama_logit_bias> to_search; |
3434 | | |
3435 | | struct ggml_tensor * inp_logit_bias; |
3436 | | struct ggml_tensor * inp_logit_idxs; |
3437 | | }; |
3438 | | |
3439 | 0 | static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { |
3440 | 0 | auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; |
3441 | 0 | return ctx->get_name(); |
3442 | 0 | } |
3443 | | |
3444 | 0 | static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
3445 | 0 | auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; |
3446 | |
|
3447 | 0 | if (ctx->logit_bias.empty()) { |
3448 | 0 | return; |
3449 | 0 | } |
3450 | | |
3451 | 0 | ctx->to_search.clear(); |
3452 | | |
3453 | | // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) |
3454 | 0 | for (const auto & lb : ctx->logit_bias) { |
3455 | 0 | if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { |
3456 | 0 | cur_p->data[lb.token].logit += lb.bias; |
3457 | 0 | } else { |
3458 | 0 | ctx->to_search.push_back(lb); |
3459 | 0 | } |
3460 | 0 | } |
3461 | |
|
3462 | 0 | if (ctx->to_search.empty()) { |
3463 | 0 | return; |
3464 | 0 | } |
3465 | | |
3466 | | // search for the remaining candidates that were not found in the previous step |
3467 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3468 | 0 | for (const auto & lb : ctx->to_search) { |
3469 | 0 | if (cur_p->data[i].id == lb.token) { |
3470 | 0 | cur_p->data[i].logit += lb.bias; |
3471 | 0 | break; |
3472 | 0 | } |
3473 | 0 | } |
3474 | 0 | } |
3475 | 0 | } |
3476 | | |
3477 | 0 | static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { |
3478 | 0 | const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; |
3479 | 0 | return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); |
3480 | 0 | } |
3481 | | |
3482 | 0 | static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { |
3483 | 0 | delete (llama_sampler_logit_bias *) smpl->ctx; |
3484 | 0 | } |
3485 | | |
3486 | | static void llama_sampler_logit_bias_backend_apply( |
3487 | | struct llama_sampler * smpl, |
3488 | | struct ggml_context * ctx, |
3489 | | struct ggml_cgraph * gf, |
3490 | 0 | struct llama_sampler_data * data) { |
3491 | 0 | GGML_UNUSED(gf); |
3492 | 0 | GGML_UNUSED(ctx); |
3493 | |
|
3494 | 0 | auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; |
3495 | 0 | if (sctx->logit_bias.empty()) { |
3496 | 0 | return; |
3497 | 0 | } |
3498 | | |
3499 | 0 | const size_t n = sctx->logit_bias.size(); |
3500 | |
|
3501 | 0 | sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); |
3502 | 0 | ggml_set_name(sctx->inp_logit_bias, "logit_bias"); |
3503 | 0 | ggml_set_input(sctx->inp_logit_bias); |
3504 | |
|
3505 | 0 | sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); |
3506 | 0 | ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); |
3507 | 0 | ggml_set_input(sctx->inp_logit_idxs); |
3508 | |
|
3509 | 0 | ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); |
3510 | |
|
3511 | 0 | cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); |
3512 | 0 | cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); |
3513 | 0 | cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); |
3514 | |
|
3515 | 0 | data->logits = ggml_add(ctx, data->logits, cur); |
3516 | 0 | } |
3517 | | |
3518 | 0 | static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { |
3519 | 0 | auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; |
3520 | 0 | if (sctx->logit_bias.empty()) { |
3521 | 0 | return; |
3522 | 0 | } |
3523 | | |
3524 | 0 | GGML_ASSERT(sctx->inp_logit_bias != nullptr); |
3525 | 0 | GGML_ASSERT(sctx->inp_logit_idxs != nullptr); |
3526 | |
|
3527 | 0 | const size_t n = sctx->logit_bias.size(); |
3528 | |
|
3529 | 0 | std::vector<float> data_logit_bias(n, 0.0f); |
3530 | 0 | std::vector<int32_t> data_logit_idxs(n, 0); |
3531 | 0 | for (size_t i = 0; i < n; ++i) { |
3532 | 0 | const auto & lb = sctx->logit_bias[i]; |
3533 | 0 | GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); |
3534 | 0 | data_logit_bias[i] = lb.bias; |
3535 | 0 | data_logit_idxs[i] = lb.token; |
3536 | 0 | } |
3537 | |
|
3538 | 0 | ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); |
3539 | 0 | ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); |
3540 | 0 | } |
3541 | | |
3542 | | static bool llama_sampler_logit_bias_backend_init( |
3543 | | struct llama_sampler * smpl, |
3544 | 0 | ggml_backend_buffer_type_t buft) { |
3545 | 0 | GGML_UNUSED(buft); |
3546 | |
|
3547 | 0 | auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; |
3548 | |
|
3549 | 0 | sctx->init(true); |
3550 | |
|
3551 | 0 | if (sctx->logit_bias.empty()) { |
3552 | 0 | return true; |
3553 | 0 | } |
3554 | | |
3555 | 0 | return true; |
3556 | 0 | } |
3557 | | |
3558 | | static struct llama_sampler_i llama_sampler_logit_bias_i = { |
3559 | | /* .name = */ llama_sampler_logit_bias_name, |
3560 | | /* .accept = */ nullptr, |
3561 | | /* .apply = */ llama_sampler_logit_bias_apply, |
3562 | | /* .reset = */ nullptr, |
3563 | | /* .clone = */ llama_sampler_logit_bias_clone, |
3564 | | /* .free = */ llama_sampler_logit_bias_free, |
3565 | | /* .backend_init = */ llama_sampler_logit_bias_backend_init, |
3566 | | /* .backend_accept = */ nullptr, |
3567 | | /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, |
3568 | | /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, |
3569 | | }; |
3570 | | |
3571 | | struct llama_sampler * llama_sampler_init_logit_bias( |
3572 | | int32_t n_vocab, |
3573 | | int32_t n_logit_bias, |
3574 | 0 | const llama_logit_bias * logit_bias) { |
3575 | 0 | const bool is_empty = n_logit_bias <= 0; |
3576 | |
|
3577 | 0 | if (is_empty) { |
3578 | 0 | return llama_sampler_init_empty("?logit-bias"); |
3579 | 0 | } |
3580 | | |
3581 | 0 | return llama_sampler_init( |
3582 | 0 | /* .iface = */ &llama_sampler_logit_bias_i, |
3583 | 0 | /* .ctx = */ new llama_sampler_logit_bias { |
3584 | 0 | ("logit-bias"), |
3585 | 0 | /* .n_vocab = */ n_vocab, |
3586 | 0 | /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), |
3587 | 0 | /* .to_search = */ {}, |
3588 | 0 | /* .inp_logit_bias = */ nullptr, |
3589 | 0 | /* .inp_logit_idxs = */ nullptr, |
3590 | 0 | } |
3591 | 0 | ); |
3592 | 0 | } |
3593 | | |
3594 | | // infill |
3595 | | |
3596 | | //#define GGML_DEBUG_SAMPLER_INFILL |
3597 | | |
3598 | | struct llama_sampler_infill { |
3599 | | const struct llama_vocab * vocab; |
3600 | | |
3601 | | std::vector<char> buf0; |
3602 | | std::vector<char> buf1; |
3603 | | }; |
3604 | | |
3605 | 0 | static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { |
3606 | 0 | return "infill"; |
3607 | 0 | } |
3608 | | |
3609 | 0 | static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { |
3610 | 0 | auto * ctx = (llama_sampler_infill *) smpl->ctx; |
3611 | |
|
3612 | 0 | llama_sampler_softmax_impl(cur_p, true); |
3613 | |
|
3614 | | #if defined(GGML_DEBUG_SAMPLER_INFILL) |
3615 | | #define LOG_DBG_CUR LLAMA_LOG_DEBUG |
3616 | | #else |
3617 | 0 | #define LOG_DBG_CUR(...) |
3618 | 0 | #endif |
3619 | |
|
3620 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3621 | 0 | LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); |
3622 | 0 | } |
3623 | |
|
3624 | 0 | float p_txt_sum = 0.0f; |
3625 | 0 | float p_eog_sum = 0.0f; |
3626 | |
|
3627 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3628 | 0 | if (ctx->vocab->is_eog(cur_p->data[i].id)) { |
3629 | 0 | p_eog_sum += cur_p->data[i].p; |
3630 | 0 | } else { |
3631 | 0 | p_txt_sum += cur_p->data[i].p; |
3632 | 0 | } |
3633 | 0 | } |
3634 | |
|
3635 | 0 | const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); |
3636 | |
|
3637 | 0 | LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); |
3638 | |
|
3639 | 0 | if (3*p_eog_sum*cur_p->size > p_txt_sum) { |
3640 | 0 | LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); |
3641 | | |
3642 | | // keep just the EOG tokens |
3643 | 0 | const auto size_org = cur_p->size; |
3644 | |
|
3645 | 0 | cur_p->size = 0; |
3646 | |
|
3647 | 0 | float p_sum = 0.0f; |
3648 | |
|
3649 | 0 | for (size_t i = 0; i < size_org; ++i) { |
3650 | 0 | if (ctx->vocab->is_eog(cur_p->data[i].id)) { |
3651 | 0 | p_sum += cur_p->data[i].p; |
3652 | |
|
3653 | 0 | cur_p->data[cur_p->size++] = cur_p->data[i]; |
3654 | 0 | } |
3655 | 0 | } |
3656 | | |
3657 | | // normalize probs |
3658 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3659 | 0 | cur_p->data[i].p /= p_sum; |
3660 | 0 | } |
3661 | |
|
3662 | 0 | return; |
3663 | 0 | } |
3664 | | |
3665 | 0 | size_t n_combined = 0; GGML_UNUSED(n_combined); |
3666 | | |
3667 | | // combine tokens with common prefix |
3668 | 0 | for (size_t i0 = 0; i0 < cur_p->size; ++i0) { |
3669 | 0 | for (size_t i1 = 0; i1 < cur_p->size; ++i1) { |
3670 | 0 | if (cur_p->data[i0].logit == -INFINITY) { |
3671 | 0 | break; |
3672 | 0 | } |
3673 | | |
3674 | 0 | if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { |
3675 | 0 | continue; |
3676 | 0 | } |
3677 | | |
3678 | 0 | int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); |
3679 | 0 | if (len0 < 0) { |
3680 | 0 | ctx->buf0.resize(len0); |
3681 | 0 | len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); |
3682 | 0 | assert(len0 > 0); |
3683 | 0 | } |
3684 | |
|
3685 | 0 | int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); |
3686 | 0 | if (len1 < 0) { |
3687 | 0 | ctx->buf1.resize(len1); |
3688 | 0 | len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); |
3689 | 0 | assert(len1 > 0); |
3690 | 0 | } |
3691 | | |
3692 | | // token i0 is a prefix of token i1 |
3693 | 0 | if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { |
3694 | 0 | int dst = i0; |
3695 | 0 | int src = i1; |
3696 | | |
3697 | | // merge into the token with higher probability |
3698 | 0 | if (cur_p->data[i1].p > cur_p->data[i0].p) { |
3699 | 0 | std::swap(dst, src); |
3700 | 0 | } |
3701 | |
|
3702 | 0 | cur_p->data[dst].p += cur_p->data[src].p; |
3703 | 0 | cur_p->data[src].logit = -INFINITY; |
3704 | 0 | cur_p->data[src].p = 0.0f; |
3705 | |
|
3706 | 0 | n_combined++; |
3707 | 0 | } |
3708 | 0 | } |
3709 | 0 | } |
3710 | |
|
3711 | 0 | size_t n_non_eog = 0; |
3712 | |
|
3713 | 0 | size_t size_org = cur_p->size; |
3714 | |
|
3715 | 0 | float p_sum = 0.0f; |
3716 | 0 | float thold = 0.2f; |
3717 | |
|
3718 | 0 | cur_p->size = 0; |
3719 | |
|
3720 | 0 | LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); |
3721 | |
|
3722 | 0 | for (size_t i = 0; i < size_org; ++i) { |
3723 | 0 | const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); |
3724 | |
|
3725 | 0 | if (cur_p->data[i].p < thold && !is_eog) { |
3726 | 0 | continue; |
3727 | 0 | } |
3728 | | |
3729 | 0 | if (!is_eog) { |
3730 | 0 | ++n_non_eog; |
3731 | 0 | } |
3732 | |
|
3733 | 0 | p_sum += cur_p->data[i].p; |
3734 | | |
3735 | | // keep this token |
3736 | 0 | cur_p->data[cur_p->size++] = cur_p->data[i]; |
3737 | 0 | } |
3738 | |
|
3739 | 0 | LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); |
3740 | | |
3741 | | // if no non-EOG tokens are left -> reduce cur_p to single EOT token |
3742 | 0 | if (n_non_eog == 0) { |
3743 | 0 | cur_p->size = 1; |
3744 | 0 | cur_p->data[0].id = ctx->vocab->token_eot(); |
3745 | 0 | if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { |
3746 | 0 | cur_p->data[0].id = ctx->vocab->token_eos(); |
3747 | 0 | } |
3748 | 0 | cur_p->data[0].logit = 1.0f; |
3749 | |
|
3750 | 0 | GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); |
3751 | |
|
3752 | 0 | return; |
3753 | 0 | } |
3754 | | |
3755 | | // normalize probs |
3756 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3757 | 0 | cur_p->data[i].p /= p_sum; |
3758 | |
|
3759 | 0 | LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); |
3760 | 0 | } |
3761 | |
|
3762 | 0 | size_org = cur_p->size; |
3763 | 0 | p_sum = 0.0f; |
3764 | 0 | thold = 1.0/(n_non_eog + 1); |
3765 | |
|
3766 | 0 | cur_p->size = 0; |
3767 | |
|
3768 | 0 | LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); |
3769 | |
|
3770 | 0 | for (size_t i = 0; i < size_org; ++i) { |
3771 | 0 | const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); |
3772 | |
|
3773 | 0 | if (cur_p->data[i].p < thold && !is_eog) { |
3774 | 0 | continue; |
3775 | 0 | } |
3776 | | |
3777 | 0 | p_sum += cur_p->data[i].p; |
3778 | |
|
3779 | 0 | cur_p->data[cur_p->size++] = cur_p->data[i]; |
3780 | 0 | } |
3781 | | |
3782 | | // normalize probs |
3783 | 0 | for (size_t i = 0; i < cur_p->size; ++i) { |
3784 | 0 | cur_p->data[i].p /= p_sum; |
3785 | |
|
3786 | 0 | LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); |
3787 | 0 | } |
3788 | |
|
3789 | 0 | #undef LOG_DBG_CUR |
3790 | 0 | } |
3791 | | |
3792 | 0 | static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { |
3793 | 0 | const auto * ctx = (const llama_sampler_infill *) smpl->ctx; |
3794 | 0 | return llama_sampler_init_infill(ctx->vocab); |
3795 | 0 | } |
3796 | | |
3797 | 0 | static void llama_sampler_infill_free(struct llama_sampler * smpl) { |
3798 | 0 | delete (llama_sampler_infill *) smpl->ctx; |
3799 | 0 | } |
3800 | | |
3801 | | static struct llama_sampler_i llama_sampler_infill_i = { |
3802 | | /* .name = */ llama_sampler_infill_name, |
3803 | | /* .accept = */ nullptr, |
3804 | | /* .apply = */ llama_sampler_infill_apply, |
3805 | | /* .reset = */ nullptr, |
3806 | | /* .clone = */ llama_sampler_infill_clone, |
3807 | | /* .free = */ llama_sampler_infill_free, |
3808 | | /* .backend_apply = */ nullptr, |
3809 | | /* .backend_accept = */ nullptr, |
3810 | | /* .backend_set_input = */ nullptr, |
3811 | | /* .backend_init = */ nullptr, |
3812 | | }; |
3813 | | |
3814 | 0 | struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { |
3815 | 0 | return llama_sampler_init( |
3816 | 0 | /* .iface = */ &llama_sampler_infill_i, |
3817 | 0 | /* .ctx = */ new llama_sampler_infill { |
3818 | 0 | /* .vocab = */ vocab, |
3819 | 0 | /* .buf0 = */ std::vector<char>(512), |
3820 | 0 | /* .buf1 = */ std::vector<char>(512), |
3821 | 0 | } |
3822 | 0 | ); |
3823 | 0 | } |
3824 | | |
3825 | | // utils |
3826 | | |
3827 | 0 | uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { |
3828 | 0 | if (smpl->iface == &llama_sampler_dist_i) { |
3829 | 0 | return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; |
3830 | 0 | } |
3831 | | |
3832 | 0 | if (smpl->iface == &llama_sampler_mirostat_i) { |
3833 | 0 | return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; |
3834 | 0 | } |
3835 | | |
3836 | 0 | if (smpl->iface == &llama_sampler_mirostat_v2_i) { |
3837 | 0 | return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; |
3838 | 0 | } |
3839 | | |
3840 | 0 | if (smpl->iface == &llama_sampler_chain_i) { |
3841 | 0 | const auto * ctx = (const llama_sampler_chain *) smpl->ctx; |
3842 | 0 | for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { |
3843 | 0 | const uint32_t seed = llama_sampler_get_seed(it->ptr); |
3844 | 0 | if (seed != LLAMA_DEFAULT_SEED) { |
3845 | 0 | return seed; |
3846 | 0 | } |
3847 | 0 | } |
3848 | 0 | } |
3849 | | |
3850 | 0 | return LLAMA_DEFAULT_SEED; |
3851 | 0 | } |
3852 | | |
3853 | | // perf |
3854 | | |
3855 | 0 | struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) { |
3856 | 0 | struct llama_perf_sampler_data data = {}; |
3857 | |
|
3858 | 0 | if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { |
3859 | 0 | GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); |
3860 | 0 | } |
3861 | |
|
3862 | 0 | const auto * ctx = (const struct llama_sampler_chain *) chain->ctx; |
3863 | |
|
3864 | 0 | data.t_sample_ms = 1e-3 * ctx->t_sample_us; |
3865 | 0 | data.n_sample = std::max(0, ctx->n_sample); |
3866 | |
|
3867 | 0 | return data; |
3868 | 0 | } |
3869 | | |
3870 | 0 | void llama_perf_sampler_print(const struct llama_sampler * chain) { |
3871 | 0 | const auto data = llama_perf_sampler(chain); |
3872 | |
|
3873 | 0 | LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample); |
3874 | 0 | } |
3875 | | |
3876 | 0 | void llama_perf_sampler_reset(struct llama_sampler * chain) { |
3877 | 0 | if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { |
3878 | 0 | GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); |
3879 | 0 | } |
3880 | |
|
3881 | 0 | auto * ctx = (struct llama_sampler_chain *) chain->ctx; |
3882 | |
|
3883 | 0 | ctx->t_sample_us = 0; |
3884 | 0 | ctx->n_sample = 0; |
3885 | 0 | } |