/src/llama.cpp/src/llama-context.cpp
Line | Count | Source |
1 | | #include "llama-context.h" |
2 | | |
3 | | #include "llama-arch.h" |
4 | | #include "llama-impl.h" |
5 | | #include "llama-batch.h" |
6 | | #include "llama-io.h" |
7 | | #include "llama-memory.h" |
8 | | #include "llama-mmap.h" |
9 | | #include "llama-model.h" |
10 | | #include "llama-ext.h" |
11 | | |
12 | | #include <cinttypes> |
13 | | #include <cmath> |
14 | | #include <cstring> |
15 | | #include <limits> |
16 | | #include <stdexcept> |
17 | | |
18 | | // |
19 | | // llama_context |
20 | | // |
21 | | |
22 | | llama_context::llama_context( |
23 | | const llama_model & model, |
24 | | llama_context_params params) : |
25 | 0 | model(model), |
26 | 0 | cvec(std::make_unique<llama_adapter_cvec>()), |
27 | 0 | loras(std::make_unique<llama_adapter_loras>()), |
28 | 0 | balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { |
29 | | // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, |
30 | | // may need to be backend-dependent |
31 | 0 | LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |
32 | |
|
33 | 0 | t_start_us = model.t_start_us; |
34 | 0 | t_load_us = model.t_load_us; |
35 | |
|
36 | 0 | const auto & hparams = model.hparams; |
37 | |
|
38 | 0 | cparams.n_seq_max = std::max(1u, params.n_seq_max); |
39 | 0 | if (cparams.n_seq_max > LLAMA_MAX_SEQ) { |
40 | 0 | throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); |
41 | 0 | } |
42 | | |
43 | 0 | cparams.n_threads = params.n_threads; |
44 | 0 | cparams.n_threads_batch = params.n_threads_batch; |
45 | 0 | cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; |
46 | 0 | cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; |
47 | 0 | cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; |
48 | 0 | cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; |
49 | 0 | cparams.embeddings = params.embeddings; |
50 | 0 | cparams.offload_kqv = params.offload_kqv; |
51 | 0 | cparams.no_perf = params.no_perf; |
52 | 0 | cparams.pooling_type = params.pooling_type; |
53 | 0 | cparams.warmup = false; |
54 | |
|
55 | 0 | cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; |
56 | 0 | cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; |
57 | 0 | cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; |
58 | |
|
59 | 0 | cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : |
60 | 0 | hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : |
61 | 0 | hparams.n_ctx_train; |
62 | |
|
63 | 0 | cparams.cb_eval = params.cb_eval; |
64 | 0 | cparams.cb_eval_user_data = params.cb_eval_user_data; |
65 | | |
66 | | // Initialize backend samplers here so they are part of the sampling graph |
67 | | // before the reserve passes run later in this function. This avoids a later |
68 | | // re-reserve when graph nodes change. |
69 | 0 | if (params.samplers != nullptr && params.n_samplers > 0) { |
70 | 0 | for (size_t i = 0; i < params.n_samplers; ++i) { |
71 | 0 | const auto & config = params.samplers[i]; |
72 | |
|
73 | 0 | if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { |
74 | 0 | throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); |
75 | 0 | } |
76 | | |
77 | 0 | if (set_sampler(config.seq_id, config.sampler)) { |
78 | 0 | const int n_samplers = llama_sampler_chain_n(config.sampler); |
79 | |
|
80 | 0 | LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); |
81 | 0 | } |
82 | 0 | } |
83 | 0 | } |
84 | | |
85 | 0 | auto rope_scaling_type = params.rope_scaling_type; |
86 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { |
87 | 0 | rope_scaling_type = hparams.rope_scaling_type_train; |
88 | 0 | } |
89 | |
|
90 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { |
91 | 0 | cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none |
92 | 0 | } |
93 | |
|
94 | 0 | if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' |
95 | 0 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
96 | 0 | } |
97 | |
|
98 | 0 | if (cparams.yarn_ext_factor != 0) { |
99 | 0 | static auto get_mscale = [](float scale, float mscale) { |
100 | 0 | return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); |
101 | 0 | }; |
102 | |
|
103 | 0 | const float factor = 1.0f / cparams.rope_freq_scale; |
104 | | |
105 | | // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 |
106 | 0 | if (hparams.rope_yarn_log_mul != 0.0f) { |
107 | | // note: here we assume `mscale == 1.0f` |
108 | | // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f |
109 | 0 | float mscale = 1.0f; |
110 | 0 | const float mscale_all_dims = hparams.rope_yarn_log_mul; |
111 | | |
112 | | // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] |
113 | | // special-case DEEPSEEK v2: |
114 | | // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 |
115 | 0 | if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { |
116 | 0 | mscale = mscale_all_dims; |
117 | 0 | } |
118 | |
|
119 | 0 | cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
120 | |
|
121 | 0 | LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", |
122 | 0 | __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); |
123 | 0 | } else { |
124 | 0 | cparams.yarn_attn_factor = get_mscale(factor, 1.0f); |
125 | 0 | } |
126 | | |
127 | | // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: |
128 | | // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 |
129 | | // |
130 | | // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 |
131 | | // https://github.com/ggml-org/llama.cpp/pull/17945 |
132 | 0 | cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); |
133 | 0 | } |
134 | |
|
135 | 0 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
136 | |
|
137 | 0 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
138 | 0 | if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
139 | 0 | cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; |
140 | 0 | } else { |
141 | 0 | cparams.pooling_type = hparams.pooling_type; |
142 | 0 | } |
143 | 0 | } |
144 | |
|
145 | 0 | if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { |
146 | 0 | cparams.causal_attn = hparams.causal_attn; |
147 | 0 | } else { |
148 | 0 | cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; |
149 | 0 | } |
150 | |
|
151 | 0 | cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; |
152 | 0 | cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; |
153 | |
|
154 | 0 | cparams.fused_gdn_ar = true; |
155 | 0 | cparams.fused_gdn_ch = true; |
156 | 0 | cparams.auto_fgdn = true; |
157 | | |
158 | | // with causal attention, the batch size is limited by the context size |
159 | 0 | cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; |
160 | |
|
161 | 0 | cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); |
162 | |
|
163 | 0 | cparams.op_offload = params.op_offload; |
164 | 0 | cparams.kv_unified = params.kv_unified; |
165 | | |
166 | | // initialized later |
167 | 0 | cparams.pipeline_parallel = false; |
168 | |
|
169 | 0 | { |
170 | 0 | const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); |
171 | 0 | graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; |
172 | |
|
173 | 0 | if (graph_reuse_disable) { |
174 | 0 | LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); |
175 | 0 | } |
176 | 0 | } |
177 | | |
178 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 |
179 | 0 | cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); |
180 | |
|
181 | 0 | if (cparams.kv_unified) { |
182 | 0 | cparams.n_ctx_seq = cparams.n_ctx; |
183 | 0 | } else { |
184 | 0 | cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; |
185 | 0 | cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); |
186 | |
|
187 | 0 | if (cparams.n_ctx_seq == 0) { |
188 | 0 | throw std::runtime_error("n_ctx_seq == 0"); |
189 | 0 | } |
190 | | |
191 | 0 | if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { |
192 | 0 | cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; |
193 | 0 | LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); |
194 | 0 | } |
195 | 0 | } |
196 | | |
197 | 0 | LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); |
198 | 0 | LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); |
199 | 0 | LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); |
200 | 0 | LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); |
201 | 0 | LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |
202 | 0 | LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); |
203 | 0 | LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); |
204 | 0 | LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); |
205 | 0 | LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); |
206 | 0 | LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |
207 | |
|
208 | 0 | if (cparams.n_ctx_seq < hparams.n_ctx_train) { |
209 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", |
210 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
211 | 0 | } |
212 | |
|
213 | 0 | if (cparams.n_ctx_seq > hparams.n_ctx_train) { |
214 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", |
215 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
216 | 0 | } |
217 | |
|
218 | 0 | if (!hparams.vocab_only) { |
219 | | // GPU backends |
220 | 0 | for (auto * dev : model.devices) { |
221 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
222 | 0 | if (backend == nullptr) { |
223 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
224 | 0 | } |
225 | 0 | backends.emplace_back(backend); |
226 | 0 | } |
227 | | |
228 | | // add ACCEL backends (such as BLAS) |
229 | 0 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
230 | 0 | ggml_backend_dev_t dev = ggml_backend_dev_get(i); |
231 | 0 | if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { |
232 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
233 | 0 | if (backend == nullptr) { |
234 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
235 | 0 | } |
236 | 0 | backends.emplace_back(backend); |
237 | 0 | } |
238 | 0 | } |
239 | | |
240 | | // add CPU backend |
241 | 0 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
242 | 0 | if (backend_cpu == nullptr) { |
243 | 0 | throw std::runtime_error("failed to initialize CPU backend"); |
244 | 0 | } |
245 | 0 | backends.emplace_back(backend_cpu); |
246 | | |
247 | | // create a list of the set_n_threads functions in the backends |
248 | 0 | for (auto & backend : backends) { |
249 | 0 | ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); |
250 | 0 | ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; |
251 | 0 | if (reg) { |
252 | 0 | auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); |
253 | 0 | if (ggml_backend_set_n_threads_fn) { |
254 | 0 | set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); |
255 | 0 | } |
256 | 0 | } |
257 | 0 | } |
258 | |
|
259 | 0 | llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); |
260 | | |
261 | | // graph outputs buffer |
262 | 0 | { |
263 | 0 | if (output_reserve(params.n_seq_max) < params.n_seq_max) { |
264 | 0 | throw std::runtime_error("failed to reserve initial output buffer"); |
265 | 0 | } |
266 | | |
267 | 0 | LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, |
268 | 0 | ggml_backend_buffer_name (buf_output.get()), |
269 | 0 | ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); |
270 | 0 | } |
271 | 0 | } |
272 | | |
273 | | // init the memory module |
274 | 0 | if (!hparams.vocab_only) { |
275 | 0 | llama_memory_params params_mem = { |
276 | 0 | /*.type_k =*/ params.type_k, |
277 | 0 | /*.type_v =*/ params.type_v, |
278 | 0 | /*.swa_full =*/ params.swa_full, |
279 | 0 | }; |
280 | |
|
281 | 0 | memory.reset(model.create_memory(params_mem, cparams)); |
282 | 0 | } |
283 | | |
284 | | // init backends |
285 | 0 | if (!hparams.vocab_only) { |
286 | 0 | LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); |
287 | |
|
288 | 0 | backend_buft.clear(); |
289 | 0 | backend_ptrs.clear(); |
290 | 0 | backend_buf_exp_size.clear(); |
291 | |
|
292 | 0 | for (auto & backend : backends) { |
293 | 0 | auto * buft = ggml_backend_get_default_buffer_type(backend.get()); |
294 | 0 | auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
295 | |
|
296 | 0 | if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { |
297 | | // use the host buffer of the first device CPU for faster transfer of the intermediate state |
298 | 0 | auto * dev = model.devices[0]; |
299 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(dev); |
300 | 0 | if (host_buft) { |
301 | 0 | buft = host_buft; |
302 | 0 | } |
303 | 0 | } |
304 | |
|
305 | 0 | backend_buft.push_back(buft); |
306 | 0 | backend_ptrs.push_back(backend.get()); |
307 | 0 | backend_buf_exp_size.push_back(0); |
308 | 0 | } |
309 | |
|
310 | 0 | LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); |
311 | | |
312 | | // TODO: move these checks to ggml_backend_sched |
313 | | // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary |
314 | 0 | bool pipeline_parallel = |
315 | 0 | model.n_devices() > 1 && |
316 | 0 | model.n_gpu_layers() > model.hparams.n_layer && |
317 | 0 | model.split_mode() == LLAMA_SPLIT_MODE_LAYER && |
318 | 0 | cparams.offload_kqv && |
319 | 0 | !model.has_tensor_overrides(); |
320 | | |
321 | | // pipeline parallelism requires support for async compute and events in all devices |
322 | 0 | if (pipeline_parallel) { |
323 | 0 | for (auto & backend : backends) { |
324 | 0 | auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
325 | 0 | if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { |
326 | | // ignore CPU backend |
327 | | // TODO: should we ignore ACCEL types too? |
328 | 0 | continue; |
329 | 0 | } |
330 | 0 | auto * dev = ggml_backend_get_device(backend.get()); |
331 | 0 | ggml_backend_dev_props props; |
332 | 0 | ggml_backend_dev_get_props(dev, &props); |
333 | 0 | if (!props.caps.async || !props.caps.events) { |
334 | | // device does not support async compute or events |
335 | 0 | pipeline_parallel = false; |
336 | 0 | break; |
337 | 0 | } |
338 | 0 | } |
339 | 0 | } |
340 | |
|
341 | 0 | cparams.pipeline_parallel = pipeline_parallel; |
342 | |
|
343 | 0 | if (cparams.pipeline_parallel) { |
344 | 0 | LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); |
345 | |
|
346 | 0 | if (!graph_reuse_disable) { |
347 | | // TODO: figure out a way to make graph reuse work with pipeline parallelism |
348 | | // ref: https://github.com/ggml-org/llama.cpp/pull/20463 |
349 | 0 | LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); |
350 | |
|
351 | 0 | graph_reuse_disable = true; |
352 | 0 | } |
353 | 0 | } |
354 | |
|
355 | 0 | sched_reserve(); |
356 | |
|
357 | 0 | if (!cparams.flash_attn) { |
358 | 0 | if (ggml_is_quantized(params.type_v)) { |
359 | 0 | throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |
360 | 0 | } |
361 | 0 | } |
362 | 0 | } |
363 | | |
364 | | // Initialize the full vocabulary token ids for backend samplers. |
365 | 0 | { |
366 | 0 | const int n_vocab = model.vocab.n_tokens(); |
367 | |
|
368 | 0 | sampling.token_ids_full_vocab.resize(n_vocab); |
369 | 0 | for (int i = 0; i < n_vocab; ++i) { |
370 | 0 | sampling.token_ids_full_vocab[i] = i; |
371 | 0 | } |
372 | 0 | } |
373 | 0 | } |
374 | | |
375 | 0 | llama_context::~llama_context() { |
376 | 0 | if (!model.hparams.no_alloc) { |
377 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
378 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
379 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
380 | |
|
381 | 0 | const size_t size_exp = backend_buf_exp_size[i]; |
382 | 0 | const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
383 | 0 | if (size_exp == size_act) { |
384 | 0 | LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", |
385 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
386 | 0 | } else { |
387 | 0 | LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", |
388 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
389 | 0 | } |
390 | 0 | } |
391 | 0 | } |
392 | 0 | ggml_opt_free(opt_ctx); |
393 | 0 | } |
394 | | |
395 | 0 | void llama_context::sched_reserve() { |
396 | 0 | if (!sched_need_reserve) { |
397 | 0 | return; |
398 | 0 | } |
399 | | |
400 | 0 | sched_need_reserve = false; |
401 | |
|
402 | 0 | LLAMA_LOG_INFO("%s: reserving ...\n", __func__); |
403 | |
|
404 | 0 | synchronize(); |
405 | |
|
406 | 0 | const int64_t t_start_us = ggml_time_us(); |
407 | |
|
408 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
409 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
410 | |
|
411 | 0 | const size_t max_nodes = this->graph_max_nodes(n_tokens); |
412 | |
|
413 | 0 | LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); |
414 | |
|
415 | 0 | gf_res_prev.reset(new llm_graph_result(max_nodes)); |
416 | 0 | gf_res_reserve.reset(new llm_graph_result(max_nodes)); |
417 | |
|
418 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); |
419 | |
|
420 | 0 | llama_memory_context_ptr mctx; |
421 | 0 | if (memory) { |
422 | 0 | LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |
423 | 0 | mctx = memory->init_full(); |
424 | 0 | if (!mctx) { |
425 | 0 | throw std::runtime_error("failed to initialize memory module"); |
426 | 0 | } |
427 | 0 | } |
428 | | |
429 | | // avoid reserving graphs with zero outputs - assume one output per sequence |
430 | 0 | const int n_outputs = n_seqs; |
431 | |
|
432 | 0 | LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |
433 | | |
434 | | // resolve automatic Flash Attention use |
435 | 0 | if (cparams.auto_fa) { |
436 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
437 | 0 | if (!gf) { |
438 | 0 | throw std::runtime_error("failed to reserve graph for Flash Attention check"); |
439 | 0 | } |
440 | | |
441 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |
442 | 0 | bool fa_device_mismatch = false; |
443 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
444 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
445 | 0 | if (n->op != GGML_OP_FLASH_ATTN_EXT) { |
446 | 0 | continue; |
447 | 0 | } |
448 | 0 | ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
449 | | |
450 | | // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer |
451 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); |
452 | 0 | const int il = std::stoi(n->name + prefix_len); |
453 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
454 | 0 | if (device_fa != device_kv) { |
455 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " |
456 | 0 | "is assigned to device %s (usually due to missing support)\n", |
457 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); |
458 | | // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways |
459 | 0 | fa_device_mismatch = true; |
460 | 0 | break; |
461 | 0 | } |
462 | 0 | } |
463 | |
|
464 | 0 | if (fa_device_mismatch) { |
465 | 0 | cparams.flash_attn = false; |
466 | 0 | LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); |
467 | 0 | } else { |
468 | 0 | cparams.flash_attn = true; |
469 | 0 | LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |
470 | 0 | } |
471 | |
|
472 | 0 | cparams.auto_fa = false; |
473 | 0 | } |
474 | | |
475 | 0 | if (cparams.auto_fgdn) { |
476 | 0 | LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__); |
477 | |
|
478 | 0 | if (cparams.fused_gdn_ar) { |
479 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
480 | 0 | if (!gf) { |
481 | 0 | throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); |
482 | 0 | } |
483 | | |
484 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; |
485 | 0 | bool gdn_device_mismatch = false; |
486 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
487 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
488 | 0 | if (n->op != GGML_OP_GATED_DELTA_NET) { |
489 | 0 | continue; |
490 | 0 | } |
491 | 0 | ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
492 | |
|
493 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); |
494 | 0 | const int il = std::stoi(n->name + prefix_len); |
495 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
496 | 0 | if (device_gdn != device_kv) { |
497 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " |
498 | 0 | "is assigned to device %s (usually due to missing support)\n", |
499 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); |
500 | 0 | gdn_device_mismatch = true; |
501 | 0 | break; |
502 | 0 | } |
503 | 0 | } |
504 | |
|
505 | 0 | if (gdn_device_mismatch) { |
506 | 0 | cparams.fused_gdn_ar = false; |
507 | 0 | LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); |
508 | 0 | } else { |
509 | 0 | LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); |
510 | 0 | } |
511 | 0 | } |
512 | | |
513 | 0 | if (cparams.fused_gdn_ch) { |
514 | | // more than one token in the batch per sequence in order to take the chunked path |
515 | | // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, |
516 | | // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies |
517 | | // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, |
518 | | // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). |
519 | 0 | const uint32_t n_tokens_ch = 16*n_seqs; |
520 | 0 | auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); |
521 | 0 | if (!gf) { |
522 | 0 | throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); |
523 | 0 | } |
524 | | |
525 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; |
526 | 0 | bool gdn_device_mismatch = false; |
527 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
528 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
529 | 0 | if (n->op != GGML_OP_GATED_DELTA_NET) { |
530 | 0 | continue; |
531 | 0 | } |
532 | 0 | ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
533 | |
|
534 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); |
535 | 0 | const int il = std::stoi(n->name + prefix_len); |
536 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
537 | 0 | if (device_gdn != device_kv) { |
538 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " |
539 | 0 | "is assigned to device %s (usually due to missing support)\n", |
540 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); |
541 | 0 | gdn_device_mismatch = true; |
542 | 0 | break; |
543 | 0 | } |
544 | 0 | } |
545 | |
|
546 | 0 | if (gdn_device_mismatch) { |
547 | 0 | cparams.fused_gdn_ch = false; |
548 | 0 | LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); |
549 | 0 | } else { |
550 | 0 | LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); |
551 | 0 | } |
552 | 0 | } |
553 | | |
554 | 0 | cparams.auto_fgdn = false; |
555 | 0 | } |
556 | | |
557 | | // reserve worst-case graph |
558 | 0 | int n_splits_pp = -1; |
559 | 0 | int n_nodes_pp = -1; |
560 | |
|
561 | 0 | int n_splits_tg = -1; |
562 | 0 | int n_nodes_tg = -1; |
563 | | |
564 | | // reserve pp (prompt processing) graph first so that buffers are only allocated once |
565 | 0 | { |
566 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), |
567 | 0 | model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); |
568 | 0 | if (!gf) { |
569 | 0 | if (cparams.pipeline_parallel) { |
570 | 0 | LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); |
571 | 0 | cparams.pipeline_parallel = false; |
572 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); |
573 | 0 | gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
574 | 0 | } |
575 | 0 | if (!gf) { |
576 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
577 | 0 | } |
578 | 0 | } |
579 | | |
580 | 0 | n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |
581 | 0 | n_nodes_pp = ggml_graph_n_nodes(gf); |
582 | 0 | } |
583 | | |
584 | | // reserve with tg (token generation) graph to get the number of splits and nodes |
585 | 0 | { |
586 | 0 | auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); |
587 | 0 | if (!gf) { |
588 | 0 | throw std::runtime_error("failed to allocate compute tg buffers"); |
589 | 0 | } |
590 | | |
591 | 0 | n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); |
592 | 0 | n_nodes_tg = ggml_graph_n_nodes(gf); |
593 | 0 | } |
594 | | |
595 | | // reserve again with pp graph to avoid ggml-alloc reallocations during inference |
596 | 0 | { |
597 | | // TODO: not sure if the following graph would be worster case for multi-stream KV caches: |
598 | | // |
599 | | // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); |
600 | | // |
601 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); |
602 | 0 | if (!gf) { |
603 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
604 | 0 | } |
605 | 0 | } |
606 | | |
607 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
608 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
609 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
610 | 0 | if (!model.hparams.no_alloc) { |
611 | 0 | backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
612 | 0 | } |
613 | 0 | if (backend_buf_exp_size[i] > 1) { |
614 | 0 | LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, |
615 | 0 | ggml_backend_buft_name(buft), |
616 | 0 | backend_buf_exp_size[i] / 1024.0 / 1024.0); |
617 | 0 | } |
618 | 0 | } |
619 | |
|
620 | 0 | if (n_nodes_pp == n_nodes_tg) { |
621 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); |
622 | 0 | } else { |
623 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); |
624 | 0 | } |
625 | |
|
626 | 0 | if (n_splits_pp == n_splits_tg) { |
627 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); |
628 | 0 | } else { |
629 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); |
630 | 0 | } |
631 | |
|
632 | 0 | const int64_t t_end_us = ggml_time_us(); |
633 | |
|
634 | 0 | LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", |
635 | 0 | __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); |
636 | 0 | } |
637 | | |
638 | 0 | void llama_context::synchronize() { |
639 | 0 | if (!sched) { |
640 | 0 | return; |
641 | 0 | } |
642 | | |
643 | 0 | ggml_backend_sched_synchronize(sched.get()); |
644 | | |
645 | | // FIXME: if multiple single tokens are evaluated without a synchronization, |
646 | | // the stats will be added to the prompt evaluation stats |
647 | | // this should only happen when using batch size 1 to evaluate a batch |
648 | | |
649 | | // add the evaluation to the stats |
650 | 0 | if (n_queued_tokens == 1) { |
651 | 0 | if (!cparams.no_perf) { |
652 | 0 | t_eval_us += ggml_time_us() - t_compute_start_us; |
653 | 0 | } |
654 | 0 | n_eval++; |
655 | 0 | } else if (n_queued_tokens > 1) { |
656 | 0 | if (!cparams.no_perf) { |
657 | 0 | t_p_eval_us += ggml_time_us() - t_compute_start_us; |
658 | 0 | } |
659 | 0 | n_p_eval += n_queued_tokens; |
660 | 0 | } |
661 | | |
662 | | // get a more accurate load time, upon first eval |
663 | 0 | if (n_queued_tokens > 0 && !has_evaluated_once) { |
664 | 0 | t_load_us = ggml_time_us() - t_start_us; |
665 | 0 | has_evaluated_once = true; |
666 | 0 | } |
667 | |
|
668 | 0 | n_queued_tokens = 0; |
669 | 0 | t_compute_start_us = 0; |
670 | 0 | } |
671 | | |
672 | 0 | const llama_model & llama_context::get_model() const { |
673 | 0 | return model; |
674 | 0 | } |
675 | | |
676 | 0 | const llama_cparams & llama_context::get_cparams() const { |
677 | 0 | return cparams; |
678 | 0 | } |
679 | | |
680 | 0 | ggml_backend_sched_t llama_context::get_sched() const { |
681 | 0 | return sched.get(); |
682 | 0 | } |
683 | | |
684 | 0 | uint32_t llama_context::n_ctx() const { |
685 | 0 | return cparams.n_ctx; |
686 | 0 | } |
687 | | |
688 | 0 | uint32_t llama_context::n_ctx_seq() const { |
689 | 0 | return cparams.n_ctx_seq; |
690 | 0 | } |
691 | | |
692 | 0 | uint32_t llama_context::n_batch() const { |
693 | 0 | return cparams.n_batch; |
694 | 0 | } |
695 | | |
696 | 0 | uint32_t llama_context::n_ubatch() const { |
697 | 0 | return cparams.n_ubatch; |
698 | 0 | } |
699 | | |
700 | 0 | uint32_t llama_context::n_seq_max() const { |
701 | 0 | return cparams.n_seq_max; |
702 | 0 | } |
703 | | |
704 | 0 | uint32_t llama_context::n_threads() const { |
705 | 0 | return cparams.n_threads; |
706 | 0 | } |
707 | | |
708 | 0 | uint32_t llama_context::n_threads_batch() const { |
709 | 0 | return cparams.n_threads_batch; |
710 | 0 | } |
711 | | |
712 | 0 | llama_memory_t llama_context::get_memory() const { |
713 | 0 | return memory.get(); |
714 | 0 | } |
715 | | |
716 | 0 | bool llama_context::memory_update(bool optimize) { |
717 | 0 | if (!memory) { |
718 | 0 | return false; |
719 | 0 | } |
720 | | |
721 | 0 | { |
722 | 0 | const auto mctx = memory->init_update(this, optimize); |
723 | 0 | switch (mctx->get_status()) { |
724 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
725 | 0 | { |
726 | | // noop |
727 | 0 | } break; |
728 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
729 | 0 | { |
730 | | // no updates need to be performed |
731 | 0 | return false; |
732 | 0 | } |
733 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
734 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
735 | 0 | { |
736 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); |
737 | 0 | return false; |
738 | 0 | } |
739 | 0 | } |
740 | | |
741 | | // reset the previous graph result to make sure that it won't be reused |
742 | | // TODO: change the mctx->apply() to return information if a graph reserve is needed |
743 | | // reset the graph result only if the memory module did reset the scheduler |
744 | 0 | gf_res_prev->reset(); |
745 | |
|
746 | 0 | if (!mctx->apply()) { |
747 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); |
748 | 0 | } |
749 | 0 | } |
750 | | |
751 | | // if the memory module did any computation, we have to reserve a new worst-case graph |
752 | 0 | { |
753 | 0 | const auto mctx = memory->init_full(); |
754 | 0 | if (!mctx) { |
755 | 0 | throw std::runtime_error("failed to initialize memory context"); |
756 | 0 | } |
757 | | |
758 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
759 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
760 | |
|
761 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
762 | 0 | if (!gf) { |
763 | 0 | LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); |
764 | 0 | } |
765 | 0 | } |
766 | | |
767 | 0 | return true; |
768 | 0 | } |
769 | | |
770 | 0 | enum llama_pooling_type llama_context::pooling_type() const { |
771 | 0 | return cparams.pooling_type; |
772 | 0 | } |
773 | | |
774 | 0 | float * llama_context::get_logits() { |
775 | 0 | output_reorder(); |
776 | |
|
777 | 0 | return logits.data; |
778 | 0 | } |
779 | | |
780 | 0 | int64_t llama_context::output_resolve_row(int32_t i) const { |
781 | 0 | int64_t j = -1; |
782 | | |
783 | | // support negative indices (last output row) |
784 | 0 | if (i < 0) { |
785 | 0 | j = n_outputs + i; |
786 | 0 | if (j < 0) { |
787 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
788 | 0 | } |
789 | 0 | } else if ((size_t) i >= output_ids.size()) { |
790 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
791 | 0 | } else { |
792 | | // use output_ids to translate the batch token index into a row number |
793 | | // that holds this token's data. |
794 | 0 | j = output_ids[i]; |
795 | 0 | } |
796 | | |
797 | 0 | if (j < 0) { |
798 | | // the batch token was not configured to output anything |
799 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
800 | 0 | } |
801 | | |
802 | 0 | if (j >= n_outputs) { |
803 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
804 | 0 | } |
805 | | |
806 | 0 | return j; |
807 | 0 | } |
808 | | |
809 | 0 | float * llama_context::get_logits_ith(int32_t i) { |
810 | 0 | output_reorder(); |
811 | |
|
812 | 0 | try { |
813 | 0 | if (logits.data == nullptr) { |
814 | 0 | throw std::runtime_error("no logits"); |
815 | 0 | } |
816 | | |
817 | 0 | const int64_t j = output_resolve_row(i); |
818 | 0 | return logits.data + j*model.vocab.n_tokens(); |
819 | 0 | } catch (const std::exception & err) { |
820 | 0 | LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
821 | | #ifndef NDEBUG |
822 | | GGML_ABORT("fatal error"); |
823 | | #else |
824 | 0 | return nullptr; |
825 | 0 | #endif |
826 | 0 | } |
827 | 0 | } |
828 | | |
829 | 0 | float * llama_context::get_embeddings() { |
830 | 0 | output_reorder(); |
831 | |
|
832 | 0 | return embd.data; |
833 | 0 | } |
834 | | |
835 | 0 | llama_token * llama_context::get_sampled_tokens() const{ |
836 | 0 | return sampling.sampled.data; |
837 | 0 | } |
838 | | |
839 | 0 | float * llama_context::get_embeddings_ith(int32_t i) { |
840 | 0 | output_reorder(); |
841 | |
|
842 | 0 | try { |
843 | 0 | if (embd.data == nullptr) { |
844 | 0 | throw std::runtime_error("no embeddings"); |
845 | 0 | } |
846 | | |
847 | 0 | const int64_t j = output_resolve_row(i); |
848 | 0 | const uint32_t n_embd_out = model.hparams.n_embd_out(); |
849 | 0 | return embd.data + j*n_embd_out; |
850 | 0 | } catch (const std::exception & err) { |
851 | 0 | LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
852 | | #ifndef NDEBUG |
853 | | GGML_ABORT("fatal error"); |
854 | | #else |
855 | 0 | return nullptr; |
856 | 0 | #endif |
857 | 0 | } |
858 | 0 | } |
859 | | |
860 | 0 | float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { |
861 | 0 | auto it = embd_seq.find(seq_id); |
862 | 0 | if (it == embd_seq.end()) { |
863 | 0 | return nullptr; |
864 | 0 | } |
865 | | |
866 | 0 | return it->second.data(); |
867 | 0 | } |
868 | | |
869 | 0 | llama_token llama_context::get_sampled_token_ith(int32_t idx) { |
870 | 0 | output_reorder(); |
871 | |
|
872 | 0 | if (!sampling.sampled.has_data()) { |
873 | 0 | return LLAMA_TOKEN_NULL; |
874 | 0 | } |
875 | | |
876 | 0 | try { |
877 | 0 | const int64_t row = output_resolve_row(idx); |
878 | 0 | GGML_ASSERT(row < (int64_t) sampling.sampled.size); |
879 | 0 | return sampling.sampled.data[row]; |
880 | 0 | } catch (const std::exception & err) { |
881 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); |
882 | 0 | return LLAMA_TOKEN_NULL; |
883 | 0 | } |
884 | 0 | } |
885 | | |
886 | 0 | float * llama_context::get_sampled_probs_ith(int32_t idx) { |
887 | 0 | output_reorder(); |
888 | |
|
889 | 0 | if (!sampling.probs.has_data()) { |
890 | 0 | return nullptr; |
891 | 0 | } |
892 | | |
893 | 0 | try { |
894 | 0 | const int64_t row = output_resolve_row(idx); |
895 | 0 | if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { |
896 | 0 | return nullptr; |
897 | 0 | } |
898 | 0 | return sampling.probs.data + row*model.vocab.n_tokens(); |
899 | 0 | } catch (const std::exception & err) { |
900 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); |
901 | 0 | return nullptr; |
902 | 0 | } |
903 | 0 | } |
904 | | |
905 | 0 | float * llama_context::get_sampled_logits_ith(int32_t idx) { |
906 | 0 | output_reorder(); |
907 | |
|
908 | 0 | if (!sampling.logits.has_data()) { |
909 | 0 | return nullptr; |
910 | 0 | } |
911 | | |
912 | 0 | try { |
913 | 0 | const int64_t row = output_resolve_row(idx); |
914 | 0 | if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { |
915 | 0 | return nullptr; |
916 | 0 | } |
917 | 0 | return sampling.logits.data + row*model.vocab.n_tokens(); |
918 | 0 | } catch (const std::exception & err) { |
919 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); |
920 | 0 | return nullptr; |
921 | 0 | } |
922 | 0 | } |
923 | | |
924 | 0 | const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { |
925 | 0 | output_reorder(); |
926 | |
|
927 | 0 | try { |
928 | 0 | const int64_t row = output_resolve_row(idx); |
929 | 0 | if (sampling.candidates.has_data() && |
930 | 0 | (size_t) row < sampling.candidates_count.size() && |
931 | 0 | sampling.candidates_count[row] > 0) { |
932 | 0 | return sampling.candidates.data + row*model.vocab.n_tokens(); |
933 | 0 | } |
934 | 0 | } catch (const std::exception & err) { |
935 | | // fallback to full vocab list |
936 | 0 | GGML_UNUSED(err); |
937 | 0 | } |
938 | | |
939 | 0 | return sampling.token_ids_full_vocab.data(); |
940 | 0 | } |
941 | | |
942 | 0 | size_t llama_context::get_sampled_candidates_count(int32_t idx) { |
943 | 0 | output_reorder(); |
944 | |
|
945 | 0 | if (!sampling.candidates.has_data()) { |
946 | 0 | return 0; |
947 | 0 | } |
948 | | |
949 | 0 | try { |
950 | 0 | const int64_t row = output_resolve_row(idx); |
951 | 0 | if ((size_t) row >= sampling.candidates_count.size()) { |
952 | 0 | return 0; |
953 | 0 | } |
954 | 0 | return sampling.candidates_count[row]; |
955 | 0 | } catch (const std::exception & err) { |
956 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); |
957 | 0 | return 0; |
958 | 0 | } |
959 | 0 | } |
960 | | |
961 | 0 | size_t llama_context::get_sampled_logits_count(int32_t idx) { |
962 | 0 | output_reorder(); |
963 | |
|
964 | 0 | if (!sampling.logits.has_data()) { |
965 | 0 | return model.vocab.n_tokens(); |
966 | 0 | } |
967 | | |
968 | 0 | try { |
969 | 0 | const int64_t row = output_resolve_row(idx); |
970 | 0 | if ((size_t) row >= sampling.logits_count.size()) { |
971 | 0 | return 0; |
972 | 0 | } |
973 | 0 | return sampling.logits_count[row]; |
974 | 0 | } catch (const std::exception & err) { |
975 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); |
976 | 0 | return 0; |
977 | 0 | } |
978 | 0 | } |
979 | | |
980 | 0 | size_t llama_context::get_sampled_probs_count(int32_t idx) { |
981 | 0 | output_reorder(); |
982 | |
|
983 | 0 | if (!sampling.probs.has_data()) { |
984 | 0 | return 0; |
985 | 0 | } |
986 | | |
987 | 0 | try { |
988 | 0 | const int64_t row = output_resolve_row(idx); |
989 | 0 | if ((size_t) row >= sampling.probs_count.size()) { |
990 | 0 | return 0; |
991 | 0 | } |
992 | 0 | return sampling.probs_count[row]; |
993 | 0 | } catch (const std::exception & err) { |
994 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); |
995 | 0 | return 0; |
996 | 0 | } |
997 | 0 | } |
998 | | |
999 | | |
1000 | | void llama_context::attach_threadpool( |
1001 | | ggml_threadpool_t threadpool, |
1002 | 0 | ggml_threadpool_t threadpool_batch) { |
1003 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
1004 | |
|
1005 | 0 | this->threadpool = threadpool; |
1006 | 0 | this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
1007 | 0 | } |
1008 | | |
1009 | 0 | void llama_context::detach_threadpool() { |
1010 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
1011 | |
|
1012 | 0 | this->threadpool = nullptr; |
1013 | 0 | this->threadpool_batch = nullptr; |
1014 | 0 | } |
1015 | | |
1016 | 0 | void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { |
1017 | 0 | LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); |
1018 | |
|
1019 | 0 | cparams.n_threads = n_threads; |
1020 | 0 | cparams.n_threads_batch = n_threads_batch; |
1021 | 0 | } |
1022 | | |
1023 | 0 | void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { |
1024 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
1025 | |
|
1026 | 0 | this->abort_callback = abort_callback; |
1027 | 0 | this->abort_callback_data = abort_callback_data; |
1028 | |
|
1029 | 0 | for (auto & backend : backends) { |
1030 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
1031 | 0 | auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
1032 | 0 | if (set_abort_callback_fn) { |
1033 | 0 | set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); |
1034 | 0 | } |
1035 | 0 | } |
1036 | 0 | } |
1037 | | |
1038 | 0 | void llama_context::set_embeddings(bool value) { |
1039 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
1040 | |
|
1041 | 0 | cparams.embeddings = value; |
1042 | | |
1043 | | // TODO: not sure yet if we want to reserve here |
1044 | | //sched_need_reserve = true; |
1045 | 0 | } |
1046 | | |
1047 | 0 | void llama_context::set_causal_attn(bool value) { |
1048 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
1049 | |
|
1050 | 0 | if (cparams.causal_attn == value) { |
1051 | 0 | return; |
1052 | 0 | } |
1053 | | |
1054 | 0 | cparams.causal_attn = value; |
1055 | |
|
1056 | 0 | sched_need_reserve = true; |
1057 | 0 | } |
1058 | | |
1059 | 0 | void llama_context::set_warmup(bool value) { |
1060 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
1061 | |
|
1062 | 0 | if (cparams.warmup == value) { |
1063 | 0 | return; |
1064 | 0 | } |
1065 | | |
1066 | 0 | cparams.warmup = value; |
1067 | | |
1068 | | // warmups are usually with small batches, so no need to reserve |
1069 | | //sched_need_reserve = true; |
1070 | 0 | } |
1071 | | |
1072 | 0 | bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { |
1073 | 0 | if (!sampler && sampling.samplers.count(seq_id) == 0) { |
1074 | 0 | return true; |
1075 | 0 | } |
1076 | | |
1077 | 0 | LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); |
1078 | |
|
1079 | 0 | const bool can_offload = |
1080 | 0 | sampler && |
1081 | 0 | sampler->iface->backend_init && |
1082 | 0 | sampler->iface->backend_apply && |
1083 | 0 | llama_sampler_chain_n(sampler) > 0; |
1084 | |
|
1085 | 0 | if (sampler && can_offload) { |
1086 | 0 | auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); |
1087 | |
|
1088 | 0 | sampler->iface->backend_init(sampler, buft); |
1089 | |
|
1090 | 0 | sampling.samplers[seq_id] = sampler; |
1091 | |
|
1092 | 0 | sched_need_reserve = true; |
1093 | |
|
1094 | 0 | return true; |
1095 | 0 | } |
1096 | | |
1097 | 0 | if (sampler && !can_offload) { |
1098 | 0 | LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); |
1099 | |
|
1100 | 0 | if (sampling.samplers.count(seq_id) > 0) { |
1101 | 0 | sched_need_reserve = true; |
1102 | 0 | } |
1103 | |
|
1104 | 0 | sampling.samplers.erase(seq_id); |
1105 | |
|
1106 | 0 | return false; |
1107 | 0 | } |
1108 | | |
1109 | 0 | sampling.samplers.erase(seq_id); |
1110 | |
|
1111 | 0 | sched_need_reserve = true; |
1112 | |
|
1113 | 0 | return true; |
1114 | 0 | } |
1115 | | |
1116 | 0 | void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { |
1117 | 0 | LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); |
1118 | |
|
1119 | 0 | if (adapters_lora_are_same(adapters, n_adapters, scales)) { |
1120 | 0 | return; |
1121 | 0 | } |
1122 | | |
1123 | 0 | loras.reset(new llama_adapter_loras()); |
1124 | |
|
1125 | 0 | for (size_t i = 0; i < n_adapters; i ++) { |
1126 | 0 | if (scales[i] != 0.0f) { |
1127 | 0 | loras->insert({adapters[i], scales[i]}); |
1128 | 0 | } |
1129 | 0 | } |
1130 | |
|
1131 | 0 | sched_need_reserve = true; |
1132 | 0 | } |
1133 | | |
1134 | 0 | bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { |
1135 | 0 | LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); |
1136 | | |
1137 | | // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison. |
1138 | 0 | size_t n_non_zero = 0; |
1139 | |
|
1140 | 0 | for (size_t i = 0; i < n_adapters; i ++) { |
1141 | 0 | if (scales[i] == 0.0f) { |
1142 | 0 | continue; |
1143 | 0 | } |
1144 | 0 | n_non_zero++; |
1145 | |
|
1146 | 0 | auto it = loras->find(adapters[i]); |
1147 | |
|
1148 | 0 | if (it == loras->end() || it->second != scales[i]) { |
1149 | 0 | return false; |
1150 | 0 | } |
1151 | 0 | } |
1152 | | |
1153 | 0 | if (n_non_zero != loras->size()) { |
1154 | 0 | return false; |
1155 | 0 | } |
1156 | | |
1157 | 0 | return true; |
1158 | 0 | } |
1159 | | |
1160 | | bool llama_context::set_adapter_cvec( |
1161 | | const float * data, |
1162 | | size_t len, |
1163 | | int32_t n_embd, |
1164 | | int32_t il_start, |
1165 | 0 | int32_t il_end) { |
1166 | 0 | LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); |
1167 | |
|
1168 | 0 | bool res = cvec->apply(model, data, len, n_embd, il_start, il_end); |
1169 | |
|
1170 | 0 | sched_need_reserve = true; |
1171 | |
|
1172 | 0 | return res; |
1173 | 0 | } |
1174 | | |
1175 | 0 | llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { |
1176 | 0 | if (mctx && !mctx->apply()) { |
1177 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); |
1178 | 0 | ret = GGML_STATUS_FAILED; |
1179 | 0 | return nullptr; |
1180 | 0 | } |
1181 | | |
1182 | 0 | auto * res = gf_res_prev.get(); |
1183 | 0 | auto * gf = res->get_gf(); |
1184 | | |
1185 | | // the new graph parameters |
1186 | | // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters |
1187 | 0 | const auto gparams = graph_params(res, ubatch, mctx, gtype); |
1188 | |
|
1189 | 0 | if (!graph_reuse_disable && res->can_reuse(gparams)) { |
1190 | | //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); |
1191 | |
|
1192 | 0 | n_reused++; |
1193 | 0 | } else { |
1194 | 0 | res->reset(); |
1195 | |
|
1196 | 0 | ggml_backend_sched_reset(sched.get()); |
1197 | 0 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
1198 | | |
1199 | | //const auto t_start_us = ggml_time_us(); |
1200 | |
|
1201 | 0 | gf = model.build_graph(gparams); |
1202 | | |
1203 | | //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1204 | |
|
1205 | 0 | if (!gf) { |
1206 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
1207 | 0 | ret = GGML_STATUS_FAILED; |
1208 | 0 | return nullptr; |
1209 | 0 | } |
1210 | | |
1211 | 0 | if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
1212 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
1213 | 0 | ret = GGML_STATUS_ALLOC_FAILED; |
1214 | 0 | return nullptr; |
1215 | 0 | } |
1216 | 0 | } |
1217 | | |
1218 | | // set the input data for the input tensors |
1219 | 0 | { |
1220 | | //const auto t_start_us = ggml_time_us(); |
1221 | | |
1222 | | // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated |
1223 | 0 | res->set_inputs(&ubatch); |
1224 | | |
1225 | | //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1226 | 0 | } |
1227 | |
|
1228 | 0 | const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); |
1229 | 0 | if (status != GGML_STATUS_SUCCESS) { |
1230 | 0 | LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
1231 | 0 | ret = status; |
1232 | 0 | return nullptr; |
1233 | 0 | } |
1234 | | |
1235 | 0 | ret = GGML_STATUS_SUCCESS; |
1236 | |
|
1237 | 0 | return res; |
1238 | 0 | } |
1239 | | |
1240 | 0 | int llama_context::encode(const llama_batch & batch_inp) { |
1241 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1242 | |
|
1243 | 0 | if (batch_inp.n_tokens == 0) { |
1244 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1245 | 0 | return -1; |
1246 | 0 | } |
1247 | | |
1248 | 0 | const auto & hparams = model.hparams; |
1249 | |
|
1250 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1251 | 0 | const int64_t n_vocab = model.vocab.n_tokens(); |
1252 | | |
1253 | | // note: during encode, we always pass the full sequence starting from pos = 0 |
1254 | 0 | if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
1255 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1256 | 0 | return -1; |
1257 | 0 | } |
1258 | | |
1259 | 0 | const uint32_t n_tokens = balloc->get_n_tokens(); |
1260 | | |
1261 | | // [TAG_NO_CACHE_PAD] |
1262 | | // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true |
1263 | 0 | const llama_ubatch ubatch = balloc->split_simple(n_tokens); |
1264 | | |
1265 | | // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |
1266 | 0 | GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); |
1267 | |
|
1268 | 0 | if (t_compute_start_us == 0) { |
1269 | 0 | t_compute_start_us = ggml_time_us(); |
1270 | 0 | } |
1271 | | |
1272 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1273 | 0 | embd_seq.clear(); |
1274 | |
|
1275 | 0 | sched_reserve(); |
1276 | |
|
1277 | 0 | n_queued_tokens += n_tokens; |
1278 | | |
1279 | | // reserve output buffer |
1280 | 0 | if (output_reserve(n_tokens) < n_tokens) { |
1281 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); |
1282 | 0 | return -2; |
1283 | 0 | }; |
1284 | |
|
1285 | 0 | for (uint32_t i = 0; i < n_tokens; ++i) { |
1286 | 0 | output_ids[i] = i; |
1287 | 0 | } |
1288 | |
|
1289 | 0 | n_outputs = n_tokens; |
1290 | |
|
1291 | 0 | const auto causal_attn_org = cparams.causal_attn; |
1292 | | |
1293 | | // always use non-causal attention for encoder graphs |
1294 | | // TODO: this is a tmp solution until we have a proper way to support enc-dec models |
1295 | | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
1296 | 0 | cparams.causal_attn = false; |
1297 | |
|
1298 | 0 | ggml_status status; |
1299 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); |
1300 | |
|
1301 | 0 | cparams.causal_attn = causal_attn_org; |
1302 | |
|
1303 | 0 | if (!res) { |
1304 | 0 | switch (status) { |
1305 | 0 | case GGML_STATUS_ABORTED: return 2; |
1306 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1307 | 0 | case GGML_STATUS_FAILED: return -3; |
1308 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1309 | 0 | } |
1310 | 0 | } |
1311 | | |
1312 | 0 | auto * t_logits = res->get_logits(); |
1313 | 0 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
1314 | | |
1315 | | // extract logits |
1316 | 0 | if (logits.data && t_logits) { |
1317 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1318 | 0 | GGML_ASSERT(backend_res != nullptr); |
1319 | 0 | GGML_ASSERT(logits.data != nullptr); |
1320 | |
|
1321 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); |
1322 | 0 | } |
1323 | | |
1324 | | // extract embeddings |
1325 | 0 | if (embd.data && t_embd) { |
1326 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1327 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1328 | |
|
1329 | 0 | switch (cparams.pooling_type) { |
1330 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1331 | 0 | { |
1332 | | // extract token embeddings |
1333 | 0 | GGML_ASSERT(embd.data != nullptr); |
1334 | 0 | const uint32_t n_embd_out = hparams.n_embd_out(); |
1335 | |
|
1336 | 0 | GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); |
1337 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); |
1338 | 0 | } break; |
1339 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1340 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1341 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1342 | 0 | { |
1343 | | // extract sequence embeddings |
1344 | 0 | auto & embd_seq_out = embd_seq; |
1345 | |
|
1346 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1347 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1348 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1349 | |
|
1350 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1351 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1352 | 0 | } |
1353 | 0 | } break; |
1354 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1355 | 0 | { |
1356 | | // extract the rerank score - n_cls_out floats per sequence |
1357 | 0 | auto & embd_seq_out = embd_seq; |
1358 | |
|
1359 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1360 | |
|
1361 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1362 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1363 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1364 | |
|
1365 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1366 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1367 | 0 | } |
1368 | 0 | } break; |
1369 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1370 | 0 | { |
1371 | 0 | GGML_ABORT("unknown pooling type"); |
1372 | 0 | } |
1373 | 0 | } |
1374 | 0 | } |
1375 | | |
1376 | | // TODO: hacky solution |
1377 | 0 | if (model.arch == LLM_ARCH_T5 && t_embd) { |
1378 | | //cross.t_embd = t_embd; |
1379 | |
|
1380 | 0 | synchronize(); |
1381 | |
|
1382 | 0 | cross.n_embd = t_embd->ne[0]; |
1383 | 0 | cross.n_enc = t_embd->ne[1]; |
1384 | 0 | cross.v_embd.resize(cross.n_embd*cross.n_enc); |
1385 | 0 | memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); |
1386 | |
|
1387 | 0 | const auto & batch = balloc->get_batch(); |
1388 | | |
1389 | | // remember the sequence ids used during the encoding - needed for cross attention later |
1390 | 0 | cross.seq_ids_enc.resize(n_tokens); |
1391 | 0 | for (uint32_t i = 0; i < n_tokens; i++) { |
1392 | 0 | cross.seq_ids_enc[i].clear(); |
1393 | |
|
1394 | 0 | for (int s = 0; s < batch.n_seq_id[i]; s++) { |
1395 | 0 | const llama_seq_id seq_id = batch.seq_id[i][s]; |
1396 | |
|
1397 | 0 | cross.seq_ids_enc[i].insert(seq_id); |
1398 | 0 | } |
1399 | 0 | } |
1400 | 0 | } |
1401 | |
|
1402 | 0 | return 0; |
1403 | 0 | } |
1404 | | |
1405 | 0 | static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { |
1406 | 0 | std::map<llama_seq_id, uint32_t> seq_to_row; |
1407 | | // how many output tokens we have seen so far for this ubatch. |
1408 | 0 | uint32_t local = 0; |
1409 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1410 | | // skip tokens that are not output. |
1411 | 0 | if (!ubatch.output[i]) { |
1412 | 0 | continue; |
1413 | 0 | } |
1414 | | |
1415 | 0 | const llama_seq_id seq_id = ubatch.seq_id[i][0]; |
1416 | | // row_offset is the number of output tokens before this ubatch. |
1417 | 0 | seq_to_row[seq_id] = row_offset + local; |
1418 | 0 | ++local; |
1419 | 0 | } |
1420 | 0 | return seq_to_row; |
1421 | 0 | } |
1422 | | |
1423 | | static void copy_tensor_async_ints( |
1424 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1425 | | const buffer_view<llama_token> & sampled, |
1426 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1427 | 0 | ggml_backend_sched_t sched) { |
1428 | 0 | if (!sampled.has_data()) { |
1429 | 0 | return; |
1430 | 0 | } |
1431 | | |
1432 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1433 | 0 | auto it = seq_to_row.find(seq_id); |
1434 | 0 | if (it == seq_to_row.end()) { |
1435 | 0 | continue; |
1436 | 0 | } |
1437 | | |
1438 | 0 | const uint32_t row = it->second; |
1439 | 0 | GGML_ASSERT(row < sampled.size); |
1440 | |
|
1441 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); |
1442 | |
|
1443 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1444 | 0 | ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); |
1445 | 0 | } |
1446 | 0 | } |
1447 | | |
1448 | | static void copy_tensor_async_floats( |
1449 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1450 | | const buffer_view<float> & dst, |
1451 | | size_t stride, |
1452 | | std::vector<uint32_t> & counts, |
1453 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1454 | 0 | ggml_backend_sched_t sched) { |
1455 | 0 | if (!dst.has_data()) { |
1456 | 0 | return; |
1457 | 0 | } |
1458 | | |
1459 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1460 | 0 | auto it = seq_to_row.find(seq_id); |
1461 | 0 | if (it == seq_to_row.end()) { |
1462 | 0 | continue; |
1463 | 0 | } |
1464 | | |
1465 | 0 | const uint32_t row = it->second; |
1466 | 0 | GGML_ASSERT(row < counts.size()); |
1467 | |
|
1468 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); |
1469 | |
|
1470 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1471 | 0 | float * row_ptr = dst.data + (size_t) row * stride; |
1472 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1473 | | |
1474 | | // Update the actual number of logits/probabilities that were written for this row. |
1475 | 0 | counts[row] = ggml_nelements(tensor); |
1476 | 0 | } |
1477 | 0 | } |
1478 | | |
1479 | | static void copy_tensor_async_candidates( |
1480 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1481 | | const buffer_view<llama_token> & dst, |
1482 | | size_t stride, |
1483 | | std::vector<uint32_t> & counts, |
1484 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1485 | 0 | ggml_backend_sched_t sched) { |
1486 | 0 | if (!dst.has_data()) { |
1487 | 0 | return; |
1488 | 0 | } |
1489 | | |
1490 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1491 | 0 | auto it = seq_to_row.find(seq_id); |
1492 | 0 | if (it == seq_to_row.end()) { |
1493 | 0 | continue; |
1494 | 0 | } |
1495 | | |
1496 | 0 | const uint32_t row = it->second; |
1497 | 0 | GGML_ASSERT(row < counts.size()); |
1498 | |
|
1499 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); |
1500 | |
|
1501 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1502 | 0 | llama_token * row_ptr = dst.data + (size_t) row * stride; |
1503 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1504 | | |
1505 | | // Update the actual number of candidates that were written. |
1506 | 0 | counts[row] = ggml_nelements(tensor); |
1507 | 0 | } |
1508 | 0 | } |
1509 | | |
1510 | 0 | static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) { |
1511 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1512 | 0 | if (!ubatch.output[i]) { |
1513 | 0 | continue; |
1514 | 0 | } |
1515 | | |
1516 | | // Check if the output token has at least one sequence without a backend sampler. |
1517 | 0 | for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { |
1518 | 0 | llama_seq_id seq_id = ubatch.seq_id[i][j]; |
1519 | 0 | if (samplers.find(seq_id) == samplers.end()) { |
1520 | 0 | return true; |
1521 | 0 | } |
1522 | 0 | } |
1523 | 0 | } |
1524 | 0 | return false; // all sequences use backend sampling |
1525 | 0 | } |
1526 | | |
1527 | 0 | int llama_context::decode(const llama_batch & batch_inp) { |
1528 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1529 | |
|
1530 | 0 | if (!memory) { |
1531 | 0 | LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |
1532 | 0 | return encode(batch_inp); |
1533 | 0 | } |
1534 | | |
1535 | 0 | if (batch_inp.n_tokens == 0) { |
1536 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1537 | 0 | return -1; |
1538 | 0 | } |
1539 | | |
1540 | 0 | const auto & vocab = model.vocab; |
1541 | 0 | const auto & hparams = model.hparams; |
1542 | |
|
1543 | 0 | const int64_t n_vocab = vocab.n_tokens(); |
1544 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1545 | | |
1546 | | // when computing embeddings, all tokens are output |
1547 | 0 | const bool output_all = cparams.embeddings; |
1548 | 0 | const bool has_samplers = !sampling.samplers.empty(); |
1549 | |
|
1550 | 0 | const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; |
1551 | | |
1552 | | // TODO: avoid this workaround in the future |
1553 | 0 | if (has_samplers && batch_inp.logits) { |
1554 | 0 | std::vector<int32_t> seq_output_count(n_seq_max, 0); |
1555 | |
|
1556 | 0 | for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { |
1557 | 0 | if (batch_inp.logits[i] == 0) { |
1558 | 0 | continue; |
1559 | 0 | } |
1560 | | |
1561 | 0 | const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; |
1562 | |
|
1563 | 0 | for (int32_t s = 0; s < ns; ++s) { |
1564 | 0 | const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; |
1565 | |
|
1566 | 0 | seq_output_count[seq_id]++; |
1567 | 0 | if (seq_output_count[seq_id] > 1) { |
1568 | 0 | LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", |
1569 | 0 | __func__, seq_id, seq_output_count[seq_id]); |
1570 | 0 | return -1; |
1571 | 0 | } |
1572 | 0 | } |
1573 | 0 | } |
1574 | 0 | } |
1575 | | |
1576 | 0 | if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { |
1577 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1578 | 0 | return -1; |
1579 | 0 | } |
1580 | | |
1581 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
1582 | 0 | const uint32_t n_outputs_all = balloc->get_n_outputs(); |
1583 | |
|
1584 | 0 | if (output_all) { |
1585 | | // require that all tokens are output |
1586 | 0 | if (n_outputs_all != n_tokens_all) { |
1587 | 0 | LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", |
1588 | 0 | __func__, n_outputs_all, n_tokens_all); |
1589 | 0 | return -1; |
1590 | 0 | } |
1591 | 0 | } |
1592 | | |
1593 | 0 | GGML_ASSERT(n_tokens_all <= cparams.n_batch); |
1594 | |
|
1595 | 0 | GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |
1596 | |
|
1597 | 0 | if (t_compute_start_us == 0) { |
1598 | 0 | t_compute_start_us = ggml_time_us(); |
1599 | 0 | } |
1600 | 0 | n_queued_tokens += n_tokens_all; |
1601 | | |
1602 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1603 | 0 | embd_seq.clear(); |
1604 | 0 | output_swaps.clear(); |
1605 | |
|
1606 | 0 | sched_reserve(); |
1607 | |
|
1608 | 0 | bool did_optimize = false; |
1609 | | |
1610 | | // handle any pending shifts/copies |
1611 | 0 | memory_update(false); |
1612 | |
|
1613 | 0 | llama_memory_context_ptr mctx; |
1614 | |
|
1615 | 0 | while (true) { |
1616 | 0 | mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); |
1617 | 0 | if (!mctx) { |
1618 | 0 | return -2; |
1619 | 0 | } |
1620 | | |
1621 | 0 | switch (mctx->get_status()) { |
1622 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
1623 | 0 | { |
1624 | 0 | } break; |
1625 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
1626 | 0 | { |
1627 | 0 | LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); |
1628 | |
|
1629 | 0 | return -2; |
1630 | 0 | } |
1631 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
1632 | 0 | { |
1633 | 0 | if (!did_optimize) { |
1634 | 0 | did_optimize = true; |
1635 | |
|
1636 | 0 | if (memory_update(true)) { |
1637 | 0 | LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); |
1638 | |
|
1639 | 0 | continue; |
1640 | 0 | } |
1641 | 0 | } |
1642 | | |
1643 | 0 | LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); |
1644 | |
|
1645 | 0 | return 1; |
1646 | 0 | } |
1647 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
1648 | 0 | { |
1649 | 0 | LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); |
1650 | |
|
1651 | 0 | return -2; |
1652 | 0 | } |
1653 | 0 | } |
1654 | | |
1655 | 0 | break; |
1656 | 0 | } |
1657 | | |
1658 | | // reserve output buffer |
1659 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
1660 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
1661 | 0 | return -2; |
1662 | 0 | }; |
1663 | |
|
1664 | 0 | int64_t n_outputs_prev = 0; |
1665 | |
|
1666 | 0 | do { |
1667 | 0 | const auto & ubatch = mctx->get_ubatch(); |
1668 | | |
1669 | | // count the outputs in this ubatch |
1670 | 0 | { |
1671 | 0 | int32_t n_outputs_new = 0; |
1672 | |
|
1673 | 0 | if (n_outputs_all == n_tokens_all) { |
1674 | 0 | n_outputs_new = ubatch.n_tokens; |
1675 | 0 | } else { |
1676 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1677 | 0 | n_outputs_new += (int32_t) (ubatch.output[i] != 0); |
1678 | 0 | } |
1679 | 0 | } |
1680 | | |
1681 | | // needs to happen before the graph is built |
1682 | 0 | n_outputs = n_outputs_new; |
1683 | 0 | } |
1684 | |
|
1685 | 0 | ggml_status status; |
1686 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); |
1687 | |
|
1688 | 0 | if (!res) { |
1689 | | // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module |
1690 | 0 | llama_pos pos_min[LLAMA_MAX_SEQ]; |
1691 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1692 | 0 | pos_min[s] = std::numeric_limits<llama_pos>::max(); |
1693 | 0 | } |
1694 | |
|
1695 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1696 | 0 | const auto & seq_id = ubatch.seq_id[i][0]; |
1697 | |
|
1698 | 0 | pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
1699 | 0 | } |
1700 | |
|
1701 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1702 | 0 | if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
1703 | 0 | continue; |
1704 | 0 | } |
1705 | | |
1706 | 0 | LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
1707 | |
|
1708 | 0 | memory->seq_rm(s, pos_min[s], -1); |
1709 | 0 | } |
1710 | |
|
1711 | 0 | switch (status) { |
1712 | 0 | case GGML_STATUS_ABORTED: return 2; |
1713 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1714 | 0 | case GGML_STATUS_FAILED: return -3; |
1715 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1716 | 0 | } |
1717 | 0 | } |
1718 | | |
1719 | | // plot the computation graph in dot format (for debugging purposes) |
1720 | | //if (n_past%100 == 0) { |
1721 | | // ggml_graph_dump_dot(gf, NULL, "llama.dot"); |
1722 | | //} |
1723 | | |
1724 | 0 | auto * t_logits = res->get_logits(); |
1725 | 0 | auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; |
1726 | |
|
1727 | 0 | if (t_embd && res->get_embd_pooled()) { |
1728 | 0 | t_embd = res->get_embd_pooled(); |
1729 | 0 | } |
1730 | | |
1731 | | // extract logits |
1732 | 0 | if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { |
1733 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1734 | 0 | GGML_ASSERT(backend_res != nullptr); |
1735 | 0 | GGML_ASSERT(logits.data != nullptr); |
1736 | |
|
1737 | 0 | float * logits_out = logits.data + n_outputs_prev*n_vocab; |
1738 | |
|
1739 | 0 | if (n_outputs) { |
1740 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1741 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); |
1742 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); |
1743 | 0 | } |
1744 | 0 | } |
1745 | | |
1746 | | // extract embeddings |
1747 | 0 | if (embd.data && t_embd && n_outputs > 0) { |
1748 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1749 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1750 | |
|
1751 | 0 | switch (cparams.pooling_type) { |
1752 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1753 | 0 | { |
1754 | | // extract token embeddings |
1755 | 0 | GGML_ASSERT(embd.data != nullptr); |
1756 | 0 | const uint32_t n_embd_out = hparams.n_embd_out(); |
1757 | 0 | float * embd_out = embd.data + n_outputs_prev*n_embd_out; |
1758 | |
|
1759 | 0 | if (n_outputs) { |
1760 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1761 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); |
1762 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); |
1763 | 0 | } |
1764 | 0 | } break; |
1765 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1766 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1767 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1768 | 0 | { |
1769 | | // extract sequence embeddings (cleared before processing each batch) |
1770 | 0 | auto & embd_seq_out = embd_seq; |
1771 | |
|
1772 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1773 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1774 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1775 | |
|
1776 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1777 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1778 | 0 | } |
1779 | 0 | } break; |
1780 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1781 | 0 | { |
1782 | | // extract the rerank score - n_cls_out floats per sequence |
1783 | 0 | auto & embd_seq_out = embd_seq; |
1784 | |
|
1785 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1786 | |
|
1787 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1788 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1789 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1790 | |
|
1791 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1792 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1793 | 0 | } |
1794 | 0 | } break; |
1795 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1796 | 0 | { |
1797 | 0 | GGML_ABORT("unknown pooling type"); |
1798 | 0 | } |
1799 | 0 | } |
1800 | 0 | } |
1801 | | |
1802 | | // Copy backend sampling output if this ubatch produced any sampling tensors. |
1803 | 0 | if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { |
1804 | 0 | const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); |
1805 | 0 | const auto stride = n_vocab; |
1806 | | |
1807 | | // async copy the sampling data from the backend to the host |
1808 | 0 | copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); |
1809 | |
|
1810 | 0 | copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); |
1811 | 0 | copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); |
1812 | 0 | copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); |
1813 | 0 | } |
1814 | |
|
1815 | 0 | n_outputs_prev += n_outputs; |
1816 | 0 | } while (mctx->next()); |
1817 | | |
1818 | | // set to total number of outputs in the batch, for use in llama_get_logits_ith |
1819 | 0 | n_outputs = n_outputs_all; |
1820 | | |
1821 | | // set output mappings |
1822 | 0 | if (n_outputs > 0) { |
1823 | 0 | bool sorted_output = true; |
1824 | |
|
1825 | 0 | auto & out_ids = balloc->get_out_ids(); |
1826 | |
|
1827 | 0 | GGML_ASSERT(out_ids.size() == (size_t) n_outputs); |
1828 | |
|
1829 | 0 | for (int64_t i = 0; i < n_outputs; ++i) { |
1830 | 0 | int64_t out_id = out_ids[i]; |
1831 | 0 | output_ids[out_id] = i; |
1832 | 0 | if (out_id != i) { |
1833 | 0 | sorted_output = false; |
1834 | 0 | } |
1835 | 0 | } |
1836 | | |
1837 | | // make the outputs have the same order they had in the user-provided batch |
1838 | | // note: this is mostly relevant for recurrent models atm |
1839 | 0 | if (!sorted_output && n_outputs > 1) { |
1840 | 0 | GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
1841 | | |
1842 | | // TODO: is there something more efficient which also minimizes swaps? |
1843 | | // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
1844 | 0 | for (uint32_t i = 0; i < n_outputs - 1; ++i) { |
1845 | 0 | uint32_t j_min = i; |
1846 | 0 | for (uint32_t j = i + 1; j < n_outputs; ++j) { |
1847 | 0 | if (out_ids[j] < out_ids[j_min]) { |
1848 | 0 | j_min = j; |
1849 | 0 | } |
1850 | 0 | } |
1851 | 0 | if (j_min == i) { |
1852 | 0 | continue; |
1853 | 0 | } |
1854 | 0 | std::swap(out_ids[i], out_ids[j_min]); |
1855 | | |
1856 | | // remember the swaps and apply them lazily upon logits/embeddings access |
1857 | 0 | output_swaps.push_back({ i, j_min }); |
1858 | 0 | } |
1859 | |
|
1860 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1861 | |
|
1862 | 0 | for (uint32_t i = 0; i < n_outputs; ++i) { |
1863 | 0 | output_ids[out_ids[i]] = i; |
1864 | 0 | } |
1865 | 0 | } |
1866 | 0 | } |
1867 | | |
1868 | | // wait for the computation to finish (automatically done when obtaining the model output) |
1869 | | //synchronize(); |
1870 | |
|
1871 | 0 | return 0; |
1872 | 0 | } |
1873 | | |
1874 | | // |
1875 | | // output |
1876 | | // |
1877 | | |
1878 | 0 | uint32_t llama_context::output_reserve(int32_t n_outputs) { |
1879 | 0 | const auto & hparams = model.hparams; |
1880 | 0 | const auto & vocab = model.vocab; |
1881 | |
|
1882 | 0 | const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); |
1883 | |
|
1884 | 0 | const auto n_batch = cparams.n_batch; |
1885 | 0 | const auto n_vocab = vocab.n_tokens(); |
1886 | 0 | const auto n_embd_out = hparams.n_embd_out(); |
1887 | |
|
1888 | 0 | bool has_logits = true; |
1889 | 0 | bool has_embd = cparams.embeddings; |
1890 | | |
1891 | | // TODO: hacky enc-dec support |
1892 | 0 | if (model.arch == LLM_ARCH_T5) { |
1893 | 0 | has_logits = true; |
1894 | 0 | has_embd = true; |
1895 | 0 | } |
1896 | | |
1897 | |
|
1898 | 0 | size_t backend_float_count = 0; |
1899 | 0 | size_t backend_token_count = 0; |
1900 | |
|
1901 | 0 | logits.size = has_logits ? n_vocab*n_outputs_max : 0; |
1902 | 0 | embd.size = has_embd ? n_embd_out*n_outputs_max : 0; |
1903 | | |
1904 | | // Allocate backend sampling output buffers if there are backend samplers configured. |
1905 | 0 | const bool has_sampling = !sampling.samplers.empty(); |
1906 | 0 | if (has_sampling) { |
1907 | 0 | backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs |
1908 | 0 | backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates |
1909 | 0 | } |
1910 | |
|
1911 | 0 | if (output_ids.empty()) { |
1912 | | // init, never resized afterwards |
1913 | 0 | output_ids.resize(n_batch); |
1914 | 0 | } |
1915 | |
|
1916 | 0 | const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; |
1917 | 0 | const size_t new_size = |
1918 | 0 | (logits.size + embd.size + backend_float_count) * sizeof(float) + |
1919 | 0 | ( backend_token_count) * sizeof(llama_token); |
1920 | | |
1921 | | // alloc only when more than the current capacity is required |
1922 | | // TODO: also consider shrinking the buffer |
1923 | 0 | if (!buf_output || prev_size < new_size) { |
1924 | 0 | if (buf_output) { |
1925 | | #ifndef NDEBUG |
1926 | | // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
1927 | | LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
1928 | | #endif |
1929 | 0 | synchronize(); |
1930 | | |
1931 | | // TODO: not needed? |
1932 | 0 | buf_output = nullptr; |
1933 | 0 | logits.data = nullptr; |
1934 | 0 | embd.data = nullptr; |
1935 | 0 | } |
1936 | |
|
1937 | 0 | auto * buft = ggml_backend_cpu_buffer_type(); |
1938 | | // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
1939 | 0 | auto * output_dev = model.dev_output(); |
1940 | 0 | auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
1941 | 0 | if (output_dev_host_buft) { |
1942 | 0 | buft = output_dev_host_buft; |
1943 | 0 | } |
1944 | 0 | buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
1945 | 0 | if (buf_output == nullptr) { |
1946 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
1947 | 0 | return 0; |
1948 | 0 | } |
1949 | 0 | ggml_backend_buffer_clear(buf_output.get(), 0); |
1950 | 0 | } |
1951 | | |
1952 | 0 | float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); |
1953 | |
|
1954 | 0 | size_t offset = 0; |
1955 | 0 | uint8_t * base = (uint8_t *) output_base; |
1956 | |
|
1957 | 0 | logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0}; |
1958 | 0 | offset += logits.size * sizeof(float); |
1959 | |
|
1960 | 0 | embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; |
1961 | 0 | offset += embd.size * sizeof(float); |
1962 | |
|
1963 | 0 | if (has_sampling) { |
1964 | 0 | sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1965 | 0 | offset += sampling.logits.size * sizeof(float); |
1966 | |
|
1967 | 0 | sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1968 | 0 | offset += sampling.probs.size * sizeof(float); |
1969 | |
|
1970 | 0 | sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; |
1971 | 0 | offset += sampling.sampled.size * sizeof(llama_token); |
1972 | |
|
1973 | 0 | sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1974 | 0 | offset += sampling.candidates.size * sizeof(llama_token); |
1975 | | |
1976 | | // The count vectors keep track of the actual number of logits/probs/candidates |
1977 | | // copied from the backend for each output row. |
1978 | |
|
1979 | 0 | sampling.logits_count.resize(n_outputs_max); |
1980 | 0 | sampling.probs_count.resize(n_outputs_max); |
1981 | 0 | sampling.candidates_count.resize(n_outputs_max); |
1982 | |
|
1983 | 0 | std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); |
1984 | 0 | std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); |
1985 | 0 | std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); |
1986 | |
|
1987 | 0 | std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); |
1988 | 0 | } else { |
1989 | 0 | sampling.logits = {nullptr, 0}; |
1990 | 0 | sampling.probs = {nullptr, 0}; |
1991 | 0 | sampling.sampled = {nullptr, 0}; |
1992 | 0 | sampling.candidates = {nullptr, 0}; |
1993 | |
|
1994 | 0 | sampling.logits_count.clear(); |
1995 | 0 | sampling.probs_count.clear(); |
1996 | 0 | sampling.candidates_count.clear(); |
1997 | 0 | } |
1998 | | |
1999 | | // set all ids as invalid (negative) |
2000 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
2001 | |
|
2002 | 0 | this->n_outputs = 0; |
2003 | |
|
2004 | 0 | return n_outputs_max; |
2005 | 0 | } |
2006 | | |
2007 | 0 | void llama_context::output_reorder() { |
2008 | 0 | const uint64_t n_vocab = model.vocab.n_tokens(); |
2009 | 0 | const uint64_t n_embd = model.hparams.n_embd; |
2010 | |
|
2011 | 0 | for (size_t s = 0; s < output_swaps.size(); ++s) { |
2012 | 0 | const uint64_t i0 = output_swaps[s].i0; |
2013 | 0 | const uint64_t i1 = output_swaps[s].i1; |
2014 | |
|
2015 | 0 | if (logits.size > 0) { |
2016 | 0 | for (uint64_t k = 0; k < n_vocab; k++) { |
2017 | 0 | std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); |
2018 | 0 | } |
2019 | 0 | } |
2020 | |
|
2021 | 0 | if (embd.size > 0) { |
2022 | 0 | for (uint64_t k = 0; k < n_embd; k++) { |
2023 | 0 | std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); |
2024 | 0 | } |
2025 | 0 | } |
2026 | |
|
2027 | 0 | if (!sampling.samplers.empty()) { |
2028 | 0 | assert(sampling.logits.size > 0); |
2029 | 0 | assert(sampling.probs.size > 0); |
2030 | 0 | assert(sampling.candidates.size > 0); |
2031 | 0 | assert(sampling.sampled.size > 0); |
2032 | 0 | assert(sampling.logits_count.size() > 0); |
2033 | 0 | assert(sampling.probs_count.size() > 0); |
2034 | 0 | assert(sampling.candidates_count.size() > 0); |
2035 | |
|
2036 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
2037 | 0 | std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); |
2038 | 0 | } |
2039 | |
|
2040 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
2041 | 0 | std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); |
2042 | 0 | } |
2043 | |
|
2044 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
2045 | 0 | std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); |
2046 | 0 | } |
2047 | |
|
2048 | 0 | std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); |
2049 | 0 | std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); |
2050 | 0 | std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); |
2051 | 0 | std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); |
2052 | 0 | } |
2053 | 0 | } |
2054 | |
|
2055 | 0 | output_swaps.clear(); |
2056 | 0 | } |
2057 | | |
2058 | | // |
2059 | | // graph |
2060 | | // |
2061 | | |
2062 | 0 | uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { |
2063 | 0 | if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { |
2064 | 0 | return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors()); |
2065 | 0 | } |
2066 | 0 | uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors()); |
2067 | 0 | for (const auto & lora : model.loras) { |
2068 | 0 | res += lora->get_n_nodes(); |
2069 | 0 | } |
2070 | 0 | return res; |
2071 | 0 | } |
2072 | | |
2073 | 0 | llm_graph_result * llama_context::get_gf_res_reserve() const { |
2074 | 0 | return static_cast<llm_graph_result *>(gf_res_reserve.get()); |
2075 | 0 | } |
2076 | | |
2077 | | ggml_cgraph * llama_context::graph_reserve( |
2078 | 0 | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) { |
2079 | 0 | LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |
2080 | 0 | GGML_ASSERT(n_outputs >= 1); |
2081 | |
|
2082 | 0 | if (n_tokens % n_seqs != 0) { |
2083 | 0 | n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs |
2084 | 0 | n_outputs = std::max(n_outputs, n_tokens); |
2085 | |
|
2086 | 0 | LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); |
2087 | 0 | } |
2088 | |
|
2089 | 0 | ggml_backend_sched_reset(sched.get()); |
2090 | | |
2091 | | // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that |
2092 | 0 | gf_res_prev->reset(); |
2093 | | |
2094 | | // store the n_outputs as it is, and restore it afterwards |
2095 | | // TODO: not sure if needed, might simplify in the future by removing this |
2096 | 0 | const auto save_n_outputs = this->n_outputs; |
2097 | |
|
2098 | 0 | this->n_outputs = n_outputs; |
2099 | |
|
2100 | 0 | llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); |
2101 | 0 | llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); |
2102 | | |
2103 | | // set one output token per sequence in order to activate all backend samplers |
2104 | 0 | std::vector<llama_seq_id> seq_ids(n_seqs); |
2105 | 0 | for (uint32_t i = 0; i < n_seqs; ++i) { |
2106 | 0 | seq_ids[i] = i; |
2107 | 0 | ubatch.n_seq_id[i] = 1; |
2108 | 0 | ubatch.seq_id[i] = &seq_ids[i]; |
2109 | 0 | ubatch.output[i] = true; |
2110 | 0 | } |
2111 | |
|
2112 | 0 | auto * res = gf_res_reserve.get(); |
2113 | |
|
2114 | 0 | const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); |
2115 | |
|
2116 | 0 | res->reset(); |
2117 | |
|
2118 | 0 | auto * gf = model.build_graph(gparams); |
2119 | |
|
2120 | 0 | this->n_outputs = save_n_outputs; |
2121 | | |
2122 | | // initialize scheduler with the specified graph |
2123 | 0 | if (split_only) { |
2124 | 0 | if (sizes) { |
2125 | 0 | ggml_backend_sched_reserve_size(sched.get(), gf, sizes); |
2126 | 0 | } else { |
2127 | 0 | ggml_backend_sched_split_graph(sched.get(), gf); |
2128 | 0 | } |
2129 | 0 | } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { |
2130 | 0 | GGML_ASSERT(!sizes); |
2131 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |
2132 | 0 | return nullptr; |
2133 | 0 | } |
2134 | | |
2135 | 0 | return gf; |
2136 | 0 | } |
2137 | | |
2138 | | llm_graph_params llama_context::graph_params( |
2139 | | llm_graph_result * res, |
2140 | | const llama_ubatch & ubatch, |
2141 | | const llama_memory_context_i * mctx, |
2142 | 0 | llm_graph_type gtype) const { |
2143 | 0 | return { |
2144 | 0 | /*.arch =*/ model.arch, |
2145 | 0 | /*.hparams =*/ model.hparams, |
2146 | 0 | /*.cparams =*/ cparams, |
2147 | 0 | /*.ubatch =*/ ubatch, |
2148 | 0 | /*.gtype =*/ gtype, |
2149 | 0 | /*.sched =*/ sched.get(), |
2150 | 0 | /*.backend_cpu =*/ backend_cpu, |
2151 | 0 | /*.cvec =*/ cvec.get(), |
2152 | 0 | /*.loras =*/ loras.get(), |
2153 | 0 | /*.mctx =*/ mctx, |
2154 | 0 | /*.cross =*/ &cross, |
2155 | 0 | /*.samplers =*/ sampling.samplers, |
2156 | 0 | /*.n_outputs =*/ n_outputs, |
2157 | 0 | /*.cb =*/ graph_get_cb(), |
2158 | 0 | /*.res =*/ res, |
2159 | 0 | }; |
2160 | 0 | } |
2161 | | |
2162 | | ggml_status llama_context::graph_compute( |
2163 | | ggml_cgraph * gf, |
2164 | 0 | bool batched) { |
2165 | 0 | int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; |
2166 | 0 | ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; |
2167 | |
|
2168 | 0 | if (backend_cpu != nullptr) { |
2169 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); |
2170 | 0 | auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); |
2171 | 0 | if (set_threadpool_fn) { |
2172 | 0 | set_threadpool_fn(backend_cpu, tp); |
2173 | 0 | } |
2174 | 0 | } |
2175 | | |
2176 | | // set the number of threads for all the backends |
2177 | 0 | for (const auto & set_n_threads_fn : set_n_threads_fns) { |
2178 | 0 | set_n_threads_fn.second(set_n_threads_fn.first, n_threads); |
2179 | 0 | } |
2180 | |
|
2181 | 0 | auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); |
2182 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2183 | 0 | LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); |
2184 | 0 | } |
2185 | | |
2186 | | // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); |
2187 | |
|
2188 | 0 | return status; |
2189 | 0 | } |
2190 | | |
2191 | 0 | llm_graph_cb llama_context::graph_get_cb() const { |
2192 | 0 | return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { |
2193 | 0 | if (il >= 0) { |
2194 | 0 | ggml_format_name(cur, "%s-%d", name, il); |
2195 | 0 | } else { |
2196 | 0 | ggml_set_name(cur, name); |
2197 | 0 | } |
2198 | | |
2199 | | // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends |
2200 | | // FIXME: fix in ggml_backend_sched |
2201 | 0 | const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; |
2202 | 0 | if (ubatch.n_tokens < 32 || full_offload) { |
2203 | 0 | if (il != -1 && strcmp(name, "norm") == 0) { |
2204 | 0 | const auto & dev_layer = model.dev_layer(il); |
2205 | 0 | for (const auto & backend : backends) { |
2206 | 0 | if (ggml_backend_get_device(backend.get()) == dev_layer) { |
2207 | 0 | if (ggml_backend_supports_op(backend.get(), cur)) { |
2208 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); |
2209 | 0 | } |
2210 | 0 | } |
2211 | 0 | } |
2212 | 0 | } |
2213 | 0 | } |
2214 | 0 | }; |
2215 | 0 | } |
2216 | | |
2217 | | // |
2218 | | // state save/load |
2219 | | // |
2220 | | |
2221 | | class llama_io_write_dummy : public llama_io_write_i { |
2222 | | public: |
2223 | 0 | llama_io_write_dummy() = default; |
2224 | | |
2225 | 0 | void write(const void * /* src */, size_t size) override { |
2226 | 0 | size_written += size; |
2227 | 0 | } |
2228 | | |
2229 | 0 | void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { |
2230 | 0 | size_written += size; |
2231 | 0 | } |
2232 | | |
2233 | 0 | size_t n_bytes() override { |
2234 | 0 | return size_written; |
2235 | 0 | } |
2236 | | |
2237 | | private: |
2238 | | size_t size_written = 0; |
2239 | | }; |
2240 | | |
2241 | | class llama_io_write_buffer : public llama_io_write_i { |
2242 | | public: |
2243 | | llama_io_write_buffer( |
2244 | 0 | uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2245 | | |
2246 | 0 | void write(const void * src, size_t size) override { |
2247 | 0 | if (size > buf_size) { |
2248 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2249 | 0 | } |
2250 | 0 | memcpy(ptr, src, size); |
2251 | 0 | ptr += size; |
2252 | 0 | size_written += size; |
2253 | 0 | buf_size -= size; |
2254 | 0 | } |
2255 | | |
2256 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2257 | 0 | if (size > buf_size) { |
2258 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2259 | 0 | } |
2260 | 0 | ggml_backend_tensor_get(tensor, ptr, offset, size); |
2261 | 0 | ptr += size; |
2262 | 0 | size_written += size; |
2263 | 0 | buf_size -= size; |
2264 | 0 | } |
2265 | | |
2266 | 0 | size_t n_bytes() override { |
2267 | 0 | return size_written; |
2268 | 0 | } |
2269 | | |
2270 | | private: |
2271 | | uint8_t * ptr; |
2272 | | size_t buf_size = 0; |
2273 | | size_t size_written = 0; |
2274 | | }; |
2275 | | |
2276 | | class llama_io_read_buffer : public llama_io_read_i { |
2277 | | public: |
2278 | 0 | llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2279 | | |
2280 | 0 | const uint8_t * read(size_t size) override { |
2281 | 0 | const uint8_t * base_ptr = ptr; |
2282 | 0 | if (size > buf_size) { |
2283 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2284 | 0 | } |
2285 | 0 | ptr += size; |
2286 | 0 | size_read += size; |
2287 | 0 | buf_size -= size; |
2288 | 0 | return base_ptr; |
2289 | 0 | } |
2290 | | |
2291 | 0 | void read_to(void * dst, size_t size) override { |
2292 | 0 | memcpy(dst, read(size), size); |
2293 | 0 | } |
2294 | | |
2295 | 0 | size_t n_bytes() override { |
2296 | 0 | return size_read; |
2297 | 0 | } |
2298 | | |
2299 | | private: |
2300 | | const uint8_t * ptr; |
2301 | | size_t buf_size = 0; |
2302 | | size_t size_read = 0; |
2303 | | }; |
2304 | | |
2305 | | class llama_io_write_file : public llama_io_write_i { |
2306 | | public: |
2307 | 0 | llama_io_write_file(llama_file * f) : file(f) {} |
2308 | | |
2309 | 0 | void write(const void * src, size_t size) override { |
2310 | 0 | file->write_raw(src, size); |
2311 | 0 | size_written += size; |
2312 | 0 | } |
2313 | | |
2314 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2315 | 0 | temp_buffer.resize(size); |
2316 | 0 | ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); |
2317 | 0 | write(temp_buffer.data(), temp_buffer.size()); |
2318 | 0 | } |
2319 | | |
2320 | 0 | size_t n_bytes() override { |
2321 | 0 | return size_written; |
2322 | 0 | } |
2323 | | |
2324 | | private: |
2325 | | llama_file * file; |
2326 | | size_t size_written = 0; |
2327 | | std::vector<uint8_t> temp_buffer; |
2328 | | }; |
2329 | | |
2330 | | class llama_io_read_file : public llama_io_read_i { |
2331 | | public: |
2332 | 0 | llama_io_read_file(llama_file * f) : file(f) {} |
2333 | | |
2334 | 0 | void read_to(void * dst, size_t size) override { |
2335 | 0 | file->read_raw(dst, size); |
2336 | 0 | size_read += size; |
2337 | 0 | } |
2338 | | |
2339 | 0 | const uint8_t * read(size_t size) override { |
2340 | 0 | temp_buffer.resize(size); |
2341 | 0 | read_to(temp_buffer.data(), size); |
2342 | 0 | return temp_buffer.data(); |
2343 | 0 | } |
2344 | | |
2345 | 0 | size_t n_bytes() override { |
2346 | 0 | return size_read; |
2347 | 0 | } |
2348 | | |
2349 | | private: |
2350 | | llama_file * file; |
2351 | | size_t size_read = 0; |
2352 | | std::vector<uint8_t> temp_buffer; |
2353 | | }; |
2354 | | |
2355 | 0 | size_t llama_context::state_get_size() { |
2356 | 0 | llama_io_write_dummy io; |
2357 | 0 | try { |
2358 | 0 | return state_write_data(io); |
2359 | 0 | } catch (const std::exception & err) { |
2360 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2361 | 0 | return 0; |
2362 | 0 | } |
2363 | 0 | } |
2364 | | |
2365 | 0 | size_t llama_context::state_get_data(uint8_t * dst, size_t size) { |
2366 | 0 | llama_io_write_buffer io(dst, size); |
2367 | 0 | try { |
2368 | 0 | return state_write_data(io); |
2369 | 0 | } catch (const std::exception & err) { |
2370 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2371 | 0 | return 0; |
2372 | 0 | } |
2373 | 0 | } |
2374 | | |
2375 | 0 | size_t llama_context::state_set_data(const uint8_t * src, size_t size) { |
2376 | 0 | llama_io_read_buffer io(src, size); |
2377 | 0 | try { |
2378 | 0 | return state_read_data(io); |
2379 | 0 | } catch (const std::exception & err) { |
2380 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2381 | 0 | return 0; |
2382 | 0 | } |
2383 | 0 | } |
2384 | | |
2385 | 0 | size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { |
2386 | 0 | llama_io_write_dummy io; |
2387 | 0 | try { |
2388 | 0 | return state_seq_write_data(io, seq_id, flags); |
2389 | 0 | } catch (const std::exception & err) { |
2390 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2391 | 0 | return 0; |
2392 | 0 | } |
2393 | 0 | } |
2394 | | |
2395 | 0 | size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { |
2396 | 0 | llama_io_write_buffer io(dst, size); |
2397 | 0 | try { |
2398 | 0 | return state_seq_write_data(io, seq_id, flags); |
2399 | 0 | } catch (const std::exception & err) { |
2400 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2401 | 0 | return 0; |
2402 | 0 | } |
2403 | 0 | } |
2404 | | |
2405 | 0 | size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { |
2406 | 0 | llama_io_read_buffer io(src, size); |
2407 | 0 | try { |
2408 | 0 | return state_seq_read_data(io, seq_id, flags); |
2409 | 0 | } catch (const std::exception & err) { |
2410 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2411 | 0 | return 0; |
2412 | 0 | } |
2413 | 0 | } |
2414 | | |
2415 | 0 | bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2416 | 0 | llama_file file(filepath, "rb"); |
2417 | | |
2418 | | // sanity checks |
2419 | 0 | { |
2420 | 0 | const uint32_t magic = file.read_u32(); |
2421 | 0 | const uint32_t version = file.read_u32(); |
2422 | |
|
2423 | 0 | if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { |
2424 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); |
2425 | 0 | return false; |
2426 | 0 | } |
2427 | 0 | } |
2428 | | |
2429 | | // load the prompt |
2430 | 0 | { |
2431 | 0 | const uint32_t n_token_count = file.read_u32(); |
2432 | |
|
2433 | 0 | if (n_token_count > n_token_capacity) { |
2434 | 0 | LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2435 | 0 | return false; |
2436 | 0 | } |
2437 | | |
2438 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2439 | 0 | *n_token_count_out = n_token_count; |
2440 | 0 | } |
2441 | | |
2442 | | // restore the context state |
2443 | 0 | { |
2444 | 0 | const size_t n_state_size_cur = file.size() - file.tell(); |
2445 | |
|
2446 | 0 | llama_io_read_file io( &file); |
2447 | 0 | const size_t n_read = state_read_data(io); |
2448 | |
|
2449 | 0 | if (n_read != n_state_size_cur) { |
2450 | 0 | LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); |
2451 | 0 | return false; |
2452 | 0 | } |
2453 | 0 | } |
2454 | | |
2455 | 0 | return true; |
2456 | 0 | } |
2457 | | |
2458 | 0 | bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2459 | 0 | llama_file file(filepath, "wb"); |
2460 | |
|
2461 | 0 | file.write_u32(LLAMA_SESSION_MAGIC); |
2462 | 0 | file.write_u32(LLAMA_SESSION_VERSION); |
2463 | | |
2464 | | // save the prompt |
2465 | 0 | file.write_u32((uint32_t) n_token_count); |
2466 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2467 | | |
2468 | | // save the context state using stream saving |
2469 | 0 | llama_io_write_file io(&file); |
2470 | 0 | state_write_data(io); |
2471 | |
|
2472 | 0 | return true; |
2473 | 0 | } |
2474 | | |
2475 | 0 | size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2476 | 0 | llama_file file(filepath, "rb"); |
2477 | | |
2478 | | // version checks |
2479 | 0 | { |
2480 | 0 | const uint32_t magic = file.read_u32(); |
2481 | 0 | const uint32_t version = file.read_u32(); |
2482 | |
|
2483 | 0 | if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { |
2484 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); |
2485 | 0 | return 0; |
2486 | 0 | } |
2487 | 0 | } |
2488 | | |
2489 | | // load the prompt |
2490 | 0 | { |
2491 | 0 | const uint32_t n_token_count = file.read_u32(); |
2492 | |
|
2493 | 0 | if (n_token_count > n_token_capacity) { |
2494 | 0 | LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2495 | 0 | return 0; |
2496 | 0 | } |
2497 | | |
2498 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2499 | 0 | *n_token_count_out = n_token_count; |
2500 | 0 | } |
2501 | | |
2502 | | // restore the context state |
2503 | 0 | { |
2504 | 0 | const size_t state_size = file.size() - file.tell(); |
2505 | 0 | llama_io_read_file io(&file); |
2506 | 0 | const size_t nread = state_seq_read_data(io, seq_id, 0); |
2507 | 0 | if (!nread) { |
2508 | 0 | LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); |
2509 | 0 | return 0; |
2510 | 0 | } |
2511 | 0 | GGML_ASSERT(nread <= state_size); |
2512 | 0 | GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); |
2513 | 0 | } |
2514 | | |
2515 | 0 | return file.tell(); |
2516 | 0 | } |
2517 | | |
2518 | 0 | size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2519 | 0 | llama_file file(filepath, "wb"); |
2520 | |
|
2521 | 0 | file.write_u32(LLAMA_STATE_SEQ_MAGIC); |
2522 | 0 | file.write_u32(LLAMA_STATE_SEQ_VERSION); |
2523 | | |
2524 | | // save the prompt |
2525 | 0 | file.write_u32((uint32_t) n_token_count); |
2526 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2527 | | |
2528 | | // save the context state using stream saving |
2529 | 0 | llama_io_write_file io(&file); |
2530 | 0 | state_seq_write_data(io, seq_id, 0); |
2531 | |
|
2532 | 0 | const size_t res = file.tell(); |
2533 | 0 | GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); |
2534 | |
|
2535 | 0 | return res; |
2536 | 0 | } |
2537 | | |
2538 | 0 | size_t llama_context::state_write_data(llama_io_write_i & io) { |
2539 | 0 | LLAMA_LOG_DEBUG("%s: writing state\n", __func__); |
2540 | | |
2541 | | // write model info |
2542 | 0 | { |
2543 | 0 | LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); |
2544 | |
|
2545 | 0 | const std::string arch_str = llm_arch_name(model.arch); |
2546 | 0 | io.write_string(arch_str); |
2547 | | // TODO: add more model-specific info which should prevent loading the session file if not identical |
2548 | 0 | } |
2549 | |
|
2550 | 0 | if (memory != nullptr) { |
2551 | 0 | LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); |
2552 | 0 | memory->state_write(io); |
2553 | 0 | } |
2554 | |
|
2555 | 0 | return io.n_bytes(); |
2556 | 0 | } |
2557 | | |
2558 | 0 | size_t llama_context::state_read_data(llama_io_read_i & io) { |
2559 | 0 | LLAMA_LOG_DEBUG("%s: reading state\n", __func__); |
2560 | | |
2561 | | // read model info |
2562 | 0 | { |
2563 | 0 | LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); |
2564 | |
|
2565 | 0 | const std::string cur_arch_str = llm_arch_name(model.arch); |
2566 | |
|
2567 | 0 | std::string arch_str; |
2568 | 0 | io.read_string(arch_str); |
2569 | 0 | if (cur_arch_str != arch_str) { |
2570 | 0 | throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); |
2571 | 0 | } |
2572 | | // TODO: add more info which needs to be identical but which is not verified otherwise |
2573 | 0 | } |
2574 | | |
2575 | 0 | if (memory) { |
2576 | 0 | LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); |
2577 | |
|
2578 | 0 | memory->state_read(io); |
2579 | 0 | } |
2580 | |
|
2581 | 0 | return io.n_bytes(); |
2582 | 0 | } |
2583 | | |
2584 | 0 | size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2585 | 0 | GGML_UNUSED(seq_id); |
2586 | |
|
2587 | 0 | if (memory) { |
2588 | 0 | memory->state_write(io, seq_id, flags); |
2589 | 0 | } |
2590 | |
|
2591 | 0 | return io.n_bytes(); |
2592 | 0 | } |
2593 | | |
2594 | 0 | size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2595 | 0 | GGML_UNUSED(seq_id); |
2596 | |
|
2597 | 0 | if (memory) { |
2598 | 0 | memory->state_read(io, seq_id, flags); |
2599 | 0 | } |
2600 | |
|
2601 | 0 | return io.n_bytes(); |
2602 | 0 | } |
2603 | | |
2604 | | // |
2605 | | // perf |
2606 | | // |
2607 | | |
2608 | 0 | llama_perf_context_data llama_context::perf_get_data() const { |
2609 | 0 | llama_perf_context_data data = {}; |
2610 | |
|
2611 | 0 | data.t_start_ms = 1e-3 * t_start_us; |
2612 | 0 | data.t_load_ms = 1e-3 * t_load_us; |
2613 | 0 | data.t_p_eval_ms = 1e-3 * t_p_eval_us; |
2614 | 0 | data.t_eval_ms = 1e-3 * t_eval_us; |
2615 | 0 | data.n_p_eval = std::max(1, n_p_eval); |
2616 | 0 | data.n_eval = std::max(1, n_eval); |
2617 | 0 | data.n_reused = std::max(0, n_reused); |
2618 | |
|
2619 | 0 | return data; |
2620 | 0 | } |
2621 | | |
2622 | 0 | void llama_context::perf_reset() { |
2623 | 0 | t_start_us = ggml_time_us(); |
2624 | 0 | t_eval_us = n_eval = 0; |
2625 | 0 | t_p_eval_us = n_p_eval = 0; |
2626 | 0 | n_reused = 0; |
2627 | 0 | } |
2628 | | |
2629 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { |
2630 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; |
2631 | 0 | for (const auto & [buft, size] : model.memory_breakdown()) { |
2632 | 0 | ret[buft].model += size; |
2633 | 0 | } |
2634 | 0 | if (memory) { |
2635 | 0 | for (const auto & [buft, size] : memory->memory_breakdown()) { |
2636 | 0 | ret[buft].context += size; |
2637 | 0 | } |
2638 | 0 | } |
2639 | 0 | if (model.hparams.no_alloc) { |
2640 | 0 | for (size_t i = 0; i < backends.size(); ++i) { |
2641 | 0 | ggml_backend_t backend = backends[i].get(); |
2642 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2643 | 0 | ret[buft].compute += backend_buf_exp_size[i]; |
2644 | 0 | } |
2645 | 0 | } else { |
2646 | 0 | for (const auto & backend_ptr : backends) { |
2647 | 0 | ggml_backend_t backend = backend_ptr.get(); |
2648 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2649 | 0 | ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); |
2650 | 0 | } |
2651 | 0 | } |
2652 | 0 | return ret; |
2653 | 0 | } |
2654 | | |
2655 | | // |
2656 | | // training |
2657 | | // |
2658 | | |
2659 | 0 | static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { |
2660 | 0 | if (!tensor || tensor->type != GGML_TYPE_F32) { |
2661 | 0 | return; |
2662 | 0 | } |
2663 | 0 | if (!param_filter(tensor, userdata)) { |
2664 | 0 | return; |
2665 | 0 | } |
2666 | 0 | if (strcmp(tensor->name, "token_embd.weight") == 0) { |
2667 | 0 | return; // FIXME |
2668 | 0 | } |
2669 | 0 | if (strcmp(tensor->name, "rope_freqs.weight") == 0) { |
2670 | 0 | return; // FIXME |
2671 | 0 | } |
2672 | 0 | ggml_set_param(tensor); |
2673 | 0 | } |
2674 | | |
2675 | 0 | void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { |
2676 | 0 | GGML_ASSERT(!opt_ctx); |
2677 | 0 | model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); |
2678 | 0 | const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); |
2679 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2680 | 0 | GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); |
2681 | 0 | GGML_ASSERT(n_batch % n_ubatch == 0); |
2682 | |
|
2683 | 0 | ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); |
2684 | 0 | opt_params.opt_period = n_batch / n_ubatch; |
2685 | 0 | opt_params.get_opt_pars = lopt_params.get_opt_pars; |
2686 | 0 | opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; |
2687 | 0 | opt_params.optimizer = lopt_params.optimizer_type; |
2688 | 0 | opt_ctx = ggml_opt_init(opt_params); |
2689 | |
|
2690 | 0 | llama_opt_param_filter param_filter = lopt_params.param_filter; |
2691 | 0 | void * param_filter_ud = lopt_params.param_filter_ud; |
2692 | | |
2693 | | //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME |
2694 | 0 | llama_set_param(model->type_embd, param_filter, param_filter_ud); |
2695 | 0 | llama_set_param(model->pos_embd, param_filter, param_filter_ud); |
2696 | 0 | llama_set_param(model->tok_norm, param_filter, param_filter_ud); |
2697 | 0 | llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); |
2698 | 0 | llama_set_param(model->output_norm, param_filter, param_filter_ud); |
2699 | 0 | llama_set_param(model->output_norm_b, param_filter, param_filter_ud); |
2700 | 0 | llama_set_param(model->output, param_filter, param_filter_ud); |
2701 | 0 | llama_set_param(model->output_b, param_filter, param_filter_ud); |
2702 | 0 | llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); |
2703 | 0 | llama_set_param(model->cls, param_filter, param_filter_ud); |
2704 | 0 | llama_set_param(model->cls_b, param_filter, param_filter_ud); |
2705 | 0 | llama_set_param(model->cls_out, param_filter, param_filter_ud); |
2706 | 0 | llama_set_param(model->cls_out_b, param_filter, param_filter_ud); |
2707 | 0 | llama_set_param(model->cls_norm, param_filter, param_filter_ud); |
2708 | |
|
2709 | 0 | for (struct llama_layer & layer : model->layers) { |
2710 | 0 | for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { |
2711 | 0 | llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud); |
2712 | 0 | } |
2713 | 0 | } |
2714 | 0 | } |
2715 | | |
2716 | | void llama_context::opt_epoch_iter( |
2717 | | ggml_opt_dataset_t dataset, |
2718 | | ggml_opt_result_t result, |
2719 | | const std::vector<llama_token> & tokens, |
2720 | | const std::vector<llama_token> & labels_sparse, |
2721 | | llama_batch & batch, |
2722 | | ggml_opt_epoch_callback callback, |
2723 | | bool train, |
2724 | | int64_t idata_in_loop, |
2725 | | int64_t ndata_in_loop, |
2726 | 0 | int64_t t_loop_start) { |
2727 | 0 | GGML_ASSERT(opt_ctx); |
2728 | 0 | const uint32_t n_ctx = llama_model_n_ctx_train(&model); |
2729 | 0 | const uint32_t n_batch = std::min(this->n_batch(), n_ctx); |
2730 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2731 | |
|
2732 | 0 | memory->clear(true); |
2733 | |
|
2734 | 0 | for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { |
2735 | 0 | batch.n_tokens = n_batch; |
2736 | 0 | for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { |
2737 | 0 | batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; |
2738 | 0 | batch.pos [pos_batch] = pos_ctx + pos_batch; |
2739 | 0 | batch.n_seq_id[pos_batch] = 1; |
2740 | 0 | batch.seq_id [pos_batch][0] = 0; |
2741 | 0 | batch.logits [pos_batch] = true; |
2742 | 0 | } |
2743 | |
|
2744 | 0 | if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
2745 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
2746 | 0 | return; |
2747 | 0 | } |
2748 | | |
2749 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
2750 | |
|
2751 | 0 | n_queued_tokens += n_tokens_all; |
2752 | |
|
2753 | 0 | embd_seq.clear(); |
2754 | |
|
2755 | 0 | uint32_t n_outputs_all = n_tokens_all; |
2756 | |
|
2757 | 0 | auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); |
2758 | 0 | if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |
2759 | 0 | LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |
2760 | 0 | break; |
2761 | 0 | } |
2762 | | |
2763 | | // reserve output buffer |
2764 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
2765 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
2766 | 0 | GGML_ABORT("TODO: handle this error"); |
2767 | 0 | }; |
2768 | |
|
2769 | 0 | uint32_t pos_batch = 0; |
2770 | 0 | do { |
2771 | 0 | const auto & ubatch = mctx->get_ubatch(); |
2772 | |
|
2773 | 0 | n_outputs = ubatch.n_tokens; |
2774 | |
|
2775 | 0 | if (!mctx->apply()) { |
2776 | 0 | LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); |
2777 | 0 | break; |
2778 | 0 | } |
2779 | | |
2780 | 0 | auto * res = gf_res_prev.get(); |
2781 | |
|
2782 | 0 | const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); |
2783 | |
|
2784 | 0 | res->reset(); |
2785 | |
|
2786 | 0 | auto * gf = model.build_graph(gparams); |
2787 | |
|
2788 | 0 | struct ggml_context * ctx_compute_opt; |
2789 | 0 | { |
2790 | 0 | const size_t size_gf = ggml_graph_size(gf); |
2791 | 0 | const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); |
2792 | 0 | struct ggml_init_params params = { |
2793 | 0 | /*.mem_size =*/ size_meta, |
2794 | 0 | /*.mem_buffer =*/ nullptr, |
2795 | 0 | /*.no_alloc =*/ true, |
2796 | 0 | }; |
2797 | 0 | ctx_compute_opt = ggml_init(params); |
2798 | 0 | } |
2799 | 0 | ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); |
2800 | 0 | ggml_opt_alloc(opt_ctx, train); |
2801 | |
|
2802 | 0 | res->set_inputs(&ubatch); |
2803 | 0 | { |
2804 | 0 | struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); |
2805 | 0 | GGML_ASSERT(labels->ne[1] == n_ubatch); |
2806 | 0 | ggml_set_zero(labels); |
2807 | 0 | const float onef = 1.0f; |
2808 | 0 | for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { |
2809 | 0 | const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; |
2810 | 0 | GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); |
2811 | 0 | ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); |
2812 | 0 | } |
2813 | 0 | } |
2814 | 0 | ggml_opt_eval(opt_ctx, result); |
2815 | 0 | if (callback) { |
2816 | 0 | callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); |
2817 | 0 | } |
2818 | 0 | ggml_free(ctx_compute_opt); |
2819 | |
|
2820 | 0 | pos_batch += ubatch.n_tokens; |
2821 | 0 | } while (mctx->next()); |
2822 | 0 | } |
2823 | 0 | } |
2824 | | |
2825 | | void llama_context::opt_epoch( |
2826 | | ggml_opt_dataset_t dataset, |
2827 | | ggml_opt_result_t result_train, |
2828 | | ggml_opt_result_t result_eval, |
2829 | | int64_t idata_split, |
2830 | | ggml_opt_epoch_callback callback_train, |
2831 | 0 | ggml_opt_epoch_callback callback_eval) { |
2832 | 0 | const uint32_t n_ctx = this->n_ctx(); |
2833 | 0 | const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); |
2834 | 0 | const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); |
2835 | 0 | const int64_t ndata = ggml_opt_dataset_ndata(dataset); |
2836 | |
|
2837 | 0 | GGML_ASSERT(idata_split >= 0); |
2838 | 0 | GGML_ASSERT(idata_split <= ndata); |
2839 | |
|
2840 | 0 | const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; |
2841 | |
|
2842 | 0 | struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |
2843 | 0 | std::vector<llama_token> tokens(n_ctx); |
2844 | 0 | std::vector<llama_token> labels_sparse(n_ctx); |
2845 | |
|
2846 | 0 | int64_t idata = 0; |
2847 | |
|
2848 | 0 | int64_t t_loop_start = ggml_time_us(); |
2849 | 0 | int64_t ndata_in_loop = idata_split*ubatch_per_ctx; |
2850 | 0 | for (; idata < idata_split; ++idata) { |
2851 | 0 | constexpr bool train = true; |
2852 | 0 | const int64_t idata_in_loop = idata*ubatch_per_ctx; |
2853 | |
|
2854 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2855 | 0 | opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, |
2856 | 0 | callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2857 | 0 | } |
2858 | |
|
2859 | 0 | t_loop_start = ggml_time_us(); |
2860 | 0 | ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; |
2861 | 0 | for (; idata < ndata; ++idata) { |
2862 | 0 | constexpr bool train = false; |
2863 | 0 | const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; |
2864 | |
|
2865 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2866 | 0 | opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, |
2867 | 0 | callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2868 | 0 | } |
2869 | |
|
2870 | 0 | llama_batch_free(batch); |
2871 | 0 | } |
2872 | | |
2873 | | // |
2874 | | // interface implementation |
2875 | | // |
2876 | | |
2877 | 0 | llama_context_params llama_context_default_params() { |
2878 | 0 | llama_context_params result = { |
2879 | 0 | /*.n_ctx =*/ 512, |
2880 | 0 | /*.n_batch =*/ 2048, |
2881 | 0 | /*.n_ubatch =*/ 512, |
2882 | 0 | /*.n_seq_max =*/ 1, |
2883 | 0 | /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default |
2884 | 0 | /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, |
2885 | 0 | /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, |
2886 | 0 | /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, |
2887 | 0 | /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, |
2888 | 0 | /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, |
2889 | 0 | /*.rope_freq_base =*/ 0.0f, |
2890 | 0 | /*.rope_freq_scale =*/ 0.0f, |
2891 | 0 | /*.yarn_ext_factor =*/ -1.0f, |
2892 | 0 | /*.yarn_attn_factor =*/ -1.0f, |
2893 | 0 | /*.yarn_beta_fast =*/ -1.0f, |
2894 | 0 | /*.yarn_beta_slow =*/ -1.0f, |
2895 | 0 | /*.yarn_orig_ctx =*/ 0, |
2896 | 0 | /*.defrag_thold =*/ -1.0f, |
2897 | 0 | /*.cb_eval =*/ nullptr, |
2898 | 0 | /*.cb_eval_user_data =*/ nullptr, |
2899 | 0 | /*.type_k =*/ GGML_TYPE_F16, |
2900 | 0 | /*.type_v =*/ GGML_TYPE_F16, |
2901 | 0 | /*.abort_callback =*/ nullptr, |
2902 | 0 | /*.abort_callback_data =*/ nullptr, |
2903 | 0 | /*.embeddings =*/ false, |
2904 | 0 | /*.offload_kqv =*/ true, |
2905 | 0 | /*.no_perf =*/ true, |
2906 | 0 | /*.op_offload =*/ true, |
2907 | 0 | /*.swa_full =*/ true, |
2908 | 0 | /*.kv_unified =*/ false, |
2909 | 0 | /*.sampler =*/ nullptr, |
2910 | 0 | /*.n_sampler =*/ 0, |
2911 | 0 | }; |
2912 | |
|
2913 | 0 | return result; |
2914 | 0 | } |
2915 | | |
2916 | | llama_context * llama_init_from_model( |
2917 | | llama_model * model, |
2918 | 0 | llama_context_params params) { |
2919 | 0 | if (!model) { |
2920 | 0 | LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); |
2921 | 0 | return nullptr; |
2922 | 0 | } |
2923 | | |
2924 | 0 | if (params.n_batch == 0 && params.n_ubatch == 0) { |
2925 | 0 | LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); |
2926 | 0 | return nullptr; |
2927 | 0 | } |
2928 | | |
2929 | 0 | if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { |
2930 | 0 | LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); |
2931 | 0 | return nullptr; |
2932 | 0 | } |
2933 | | |
2934 | 0 | if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { |
2935 | 0 | LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); |
2936 | 0 | params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
2937 | 0 | } |
2938 | |
|
2939 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { |
2940 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_k); |
2941 | 0 | for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { |
2942 | 0 | if (model->hparams.n_embd_head_k(il) % blck_size != 0) { |
2943 | 0 | LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2944 | 0 | __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); |
2945 | 0 | return nullptr; |
2946 | 0 | } |
2947 | 0 | } |
2948 | 0 | } |
2949 | | |
2950 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { |
2951 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_v); |
2952 | 0 | for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { |
2953 | 0 | if (model->hparams.n_embd_head_v(il) % blck_size != 0) { |
2954 | 0 | LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", |
2955 | 0 | __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); |
2956 | 0 | return nullptr; |
2957 | 0 | } |
2958 | 0 | } |
2959 | 0 | } |
2960 | | |
2961 | 0 | if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { |
2962 | 0 | LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |
2963 | 0 | return nullptr; |
2964 | 0 | } |
2965 | | |
2966 | 0 | if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && |
2967 | 0 | params.pooling_type != model->hparams.pooling_type) { |
2968 | | //user-specified pooling-type is different from the model default |
2969 | 0 | LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, |
2970 | 0 | model->hparams.pooling_type, params.pooling_type); |
2971 | 0 | } |
2972 | |
|
2973 | 0 | try { |
2974 | 0 | auto * ctx = new llama_context(*model, params); |
2975 | 0 | return ctx; |
2976 | 0 | } catch (const std::exception & err) { |
2977 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); |
2978 | 0 | } |
2979 | | |
2980 | 0 | return nullptr; |
2981 | 0 | } |
2982 | | |
2983 | | // deprecated |
2984 | | llama_context * llama_new_context_with_model( |
2985 | | llama_model * model, |
2986 | 0 | llama_context_params params) { |
2987 | 0 | return llama_init_from_model(model, params); |
2988 | 0 | } |
2989 | | |
2990 | 0 | void llama_free(llama_context * ctx) { |
2991 | 0 | delete ctx; |
2992 | 0 | } |
2993 | | |
2994 | 0 | uint32_t llama_n_ctx(const llama_context * ctx) { |
2995 | 0 | return ctx->n_ctx(); |
2996 | 0 | } |
2997 | | |
2998 | 0 | uint32_t llama_n_ctx_seq(const llama_context * ctx) { |
2999 | 0 | return ctx->n_ctx_seq(); |
3000 | 0 | } |
3001 | | |
3002 | 0 | uint32_t llama_n_batch(const llama_context * ctx) { |
3003 | 0 | return ctx->n_batch(); |
3004 | 0 | } |
3005 | | |
3006 | 0 | uint32_t llama_n_ubatch(const llama_context * ctx) { |
3007 | 0 | return ctx->n_ubatch(); |
3008 | 0 | } |
3009 | | |
3010 | 0 | uint32_t llama_n_seq_max(const llama_context * ctx) { |
3011 | 0 | return ctx->n_seq_max(); |
3012 | 0 | } |
3013 | | |
3014 | 0 | const llama_model * llama_get_model(const llama_context * ctx) { |
3015 | 0 | return &ctx->get_model(); |
3016 | 0 | } |
3017 | | |
3018 | 0 | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { |
3019 | 0 | return ctx->pooling_type(); |
3020 | 0 | } |
3021 | | |
3022 | | void llama_attach_threadpool( |
3023 | | llama_context * ctx, |
3024 | | ggml_threadpool_t threadpool, |
3025 | 0 | ggml_threadpool_t threadpool_batch) { |
3026 | 0 | ctx->attach_threadpool(threadpool, threadpool_batch); |
3027 | 0 | } |
3028 | | |
3029 | 0 | void llama_detach_threadpool(llama_context * ctx) { |
3030 | 0 | ctx->detach_threadpool(); |
3031 | 0 | } |
3032 | | |
3033 | 0 | void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
3034 | 0 | ctx->set_n_threads(n_threads, n_threads_batch); |
3035 | 0 | } |
3036 | | |
3037 | 0 | int32_t llama_n_threads(llama_context * ctx) { |
3038 | 0 | return ctx->n_threads(); |
3039 | 0 | } |
3040 | | |
3041 | 0 | int32_t llama_n_threads_batch(llama_context * ctx) { |
3042 | 0 | return ctx->n_threads_batch(); |
3043 | 0 | } |
3044 | | |
3045 | 0 | void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
3046 | 0 | ctx->set_abort_callback(abort_callback, abort_callback_data); |
3047 | 0 | } |
3048 | | |
3049 | 0 | void llama_set_embeddings(llama_context * ctx, bool embeddings) { |
3050 | 0 | ctx->set_embeddings(embeddings); |
3051 | 0 | } |
3052 | | |
3053 | 0 | void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { |
3054 | 0 | ctx->set_causal_attn(causal_attn); |
3055 | 0 | } |
3056 | | |
3057 | 0 | void llama_set_warmup(llama_context * ctx, bool warmup) { |
3058 | 0 | ctx->set_warmup(warmup); |
3059 | 0 | } |
3060 | | |
3061 | 0 | void llama_synchronize(llama_context * ctx) { |
3062 | 0 | ctx->synchronize(); |
3063 | 0 | } |
3064 | | |
3065 | 0 | float * llama_get_logits(llama_context * ctx) { |
3066 | 0 | ctx->synchronize(); |
3067 | |
|
3068 | 0 | return ctx->get_logits(); |
3069 | 0 | } |
3070 | | |
3071 | 0 | float * llama_get_logits_ith(llama_context * ctx, int32_t i) { |
3072 | 0 | ctx->synchronize(); |
3073 | |
|
3074 | 0 | float * res = nullptr; |
3075 | |
|
3076 | 0 | res = ctx->get_sampled_logits_ith(i); |
3077 | |
|
3078 | 0 | if (!res) { |
3079 | 0 | res = ctx->get_logits_ith(i); |
3080 | 0 | } |
3081 | |
|
3082 | 0 | return res; |
3083 | 0 | } |
3084 | | |
3085 | 0 | float * llama_get_embeddings(llama_context * ctx) { |
3086 | 0 | ctx->synchronize(); |
3087 | |
|
3088 | 0 | return ctx->get_embeddings(); |
3089 | 0 | } |
3090 | | |
3091 | 0 | float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { |
3092 | 0 | ctx->synchronize(); |
3093 | |
|
3094 | 0 | return ctx->get_embeddings_ith(i); |
3095 | 0 | } |
3096 | | |
3097 | 0 | float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { |
3098 | 0 | ctx->synchronize(); |
3099 | |
|
3100 | 0 | return ctx->get_embeddings_seq(seq_id); |
3101 | 0 | } |
3102 | | |
3103 | 0 | bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { |
3104 | 0 | return ctx->set_sampler(seq_id, smpl); |
3105 | 0 | } |
3106 | | |
3107 | 0 | llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { |
3108 | 0 | ctx->synchronize(); |
3109 | |
|
3110 | 0 | return ctx->get_sampled_token_ith(i); |
3111 | 0 | } |
3112 | | |
3113 | 0 | float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { |
3114 | 0 | ctx->synchronize(); |
3115 | |
|
3116 | 0 | return ctx->get_sampled_probs_ith(i); |
3117 | 0 | } |
3118 | | |
3119 | 0 | float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { |
3120 | 0 | ctx->synchronize(); |
3121 | |
|
3122 | 0 | return ctx->get_sampled_logits_ith(i); |
3123 | 0 | } |
3124 | | |
3125 | 0 | llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { |
3126 | 0 | ctx->synchronize(); |
3127 | |
|
3128 | 0 | return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i)); |
3129 | 0 | } |
3130 | | |
3131 | 0 | uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { |
3132 | 0 | ctx->synchronize(); |
3133 | |
|
3134 | 0 | return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i)); |
3135 | 0 | } |
3136 | | |
3137 | 0 | uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { |
3138 | 0 | ctx->synchronize(); |
3139 | |
|
3140 | 0 | return static_cast<uint32_t>(ctx->get_sampled_logits_count(i)); |
3141 | 0 | } |
3142 | | |
3143 | 0 | uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { |
3144 | 0 | ctx->synchronize(); |
3145 | |
|
3146 | 0 | return static_cast<uint32_t>(ctx->get_sampled_probs_count(i)); |
3147 | 0 | } |
3148 | | |
3149 | | struct ggml_cgraph * llama_graph_reserve( |
3150 | | struct llama_context * ctx, |
3151 | | uint32_t n_tokens, |
3152 | | uint32_t n_seqs, |
3153 | 0 | uint32_t n_outputs) { |
3154 | 0 | auto * memory = ctx->get_memory(); |
3155 | 0 | llama_memory_context_ptr mctx; |
3156 | 0 | if (memory) { |
3157 | 0 | mctx = memory->init_full(); |
3158 | 0 | } |
3159 | 0 | return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get()); |
3160 | 0 | } |
3161 | | |
3162 | | // llama adapter API |
3163 | | |
3164 | | int32_t llama_set_adapters_lora( |
3165 | | llama_context * ctx, |
3166 | | llama_adapter_lora ** adapters, |
3167 | | size_t n_adapters, |
3168 | 0 | float * scales) { |
3169 | 0 | if (adapters == nullptr || scales == nullptr) { |
3170 | 0 | GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); |
3171 | 0 | } |
3172 | |
|
3173 | 0 | ctx->set_adapters_lora(adapters, n_adapters, scales); |
3174 | |
|
3175 | 0 | return 0; |
3176 | 0 | } |
3177 | | |
3178 | | int32_t llama_set_adapter_cvec( |
3179 | | llama_context * ctx, |
3180 | | const float * data, |
3181 | | size_t len, |
3182 | | int32_t n_embd, |
3183 | | int32_t il_start, |
3184 | 0 | int32_t il_end) { |
3185 | 0 | bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); |
3186 | |
|
3187 | 0 | return res ? 0 : -1; |
3188 | 0 | } |
3189 | | |
3190 | | // |
3191 | | // memory |
3192 | | // |
3193 | | |
3194 | 0 | llama_memory_t llama_get_memory(const struct llama_context * ctx) { |
3195 | 0 | return ctx->get_memory(); |
3196 | 0 | } |
3197 | | |
3198 | 0 | void llama_memory_clear(llama_memory_t mem, bool data) { |
3199 | 0 | if (!mem) { |
3200 | 0 | return; |
3201 | 0 | } |
3202 | | |
3203 | 0 | mem->clear(data); |
3204 | 0 | } |
3205 | | |
3206 | | bool llama_memory_seq_rm( |
3207 | | llama_memory_t mem, |
3208 | | llama_seq_id seq_id, |
3209 | | llama_pos p0, |
3210 | 0 | llama_pos p1) { |
3211 | 0 | if (!mem) { |
3212 | 0 | return true; |
3213 | 0 | } |
3214 | | |
3215 | 0 | return mem->seq_rm(seq_id, p0, p1); |
3216 | 0 | } |
3217 | | |
3218 | | void llama_memory_seq_cp( |
3219 | | llama_memory_t mem, |
3220 | | llama_seq_id seq_id_src, |
3221 | | llama_seq_id seq_id_dst, |
3222 | | llama_pos p0, |
3223 | 0 | llama_pos p1) { |
3224 | 0 | if (!mem) { |
3225 | 0 | return; |
3226 | 0 | } |
3227 | | |
3228 | 0 | mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
3229 | 0 | } |
3230 | | |
3231 | | void llama_memory_seq_keep( |
3232 | | llama_memory_t mem, |
3233 | 0 | llama_seq_id seq_id) { |
3234 | 0 | if (!mem) { |
3235 | 0 | return; |
3236 | 0 | } |
3237 | | |
3238 | 0 | mem->seq_keep(seq_id); |
3239 | 0 | } |
3240 | | |
3241 | | void llama_memory_seq_add( |
3242 | | llama_memory_t mem, |
3243 | | llama_seq_id seq_id, |
3244 | | llama_pos p0, |
3245 | | llama_pos p1, |
3246 | 0 | llama_pos delta) { |
3247 | 0 | if (!mem) { |
3248 | 0 | return; |
3249 | 0 | } |
3250 | | |
3251 | 0 | mem->seq_add(seq_id, p0, p1, delta); |
3252 | 0 | } |
3253 | | |
3254 | | void llama_memory_seq_div( |
3255 | | llama_memory_t mem, |
3256 | | llama_seq_id seq_id, |
3257 | | llama_pos p0, |
3258 | | llama_pos p1, |
3259 | 0 | int d) { |
3260 | 0 | if (!mem) { |
3261 | 0 | return; |
3262 | 0 | } |
3263 | | |
3264 | 0 | mem->seq_div(seq_id, p0, p1, d); |
3265 | 0 | } |
3266 | | |
3267 | | llama_pos llama_memory_seq_pos_min( |
3268 | | llama_memory_t mem, |
3269 | 0 | llama_seq_id seq_id) { |
3270 | 0 | if (!mem) { |
3271 | 0 | return -1; |
3272 | 0 | } |
3273 | | |
3274 | 0 | return mem->seq_pos_min(seq_id); |
3275 | 0 | } |
3276 | | |
3277 | | llama_pos llama_memory_seq_pos_max( |
3278 | | llama_memory_t mem, |
3279 | 0 | llama_seq_id seq_id) { |
3280 | 0 | if (!mem) { |
3281 | 0 | return -1; |
3282 | 0 | } |
3283 | | |
3284 | 0 | return mem->seq_pos_max(seq_id); |
3285 | 0 | } |
3286 | | |
3287 | 0 | bool llama_memory_can_shift(llama_memory_t mem) { |
3288 | 0 | if (!mem) { |
3289 | 0 | return false; |
3290 | 0 | } |
3291 | | |
3292 | 0 | return mem->get_can_shift(); |
3293 | 0 | } |
3294 | | |
3295 | | // llama state API |
3296 | | |
3297 | | // deprecated |
3298 | 0 | size_t llama_get_state_size(llama_context * ctx) { |
3299 | 0 | return llama_state_get_size(ctx); |
3300 | 0 | } |
3301 | | |
3302 | | // deprecated |
3303 | 0 | size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { |
3304 | 0 | return llama_state_get_data(ctx, dst, -1); |
3305 | 0 | } |
3306 | | |
3307 | | // deprecated |
3308 | 0 | size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { |
3309 | 0 | return llama_state_set_data(ctx, src, -1); |
3310 | 0 | } |
3311 | | |
3312 | | // deprecated |
3313 | 0 | bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3314 | 0 | return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); |
3315 | 0 | } |
3316 | | |
3317 | | // deprecated |
3318 | 0 | bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3319 | 0 | return llama_state_save_file(ctx, path_session, tokens, n_token_count); |
3320 | 0 | } |
3321 | | |
3322 | | // Returns the *actual* size of the state. |
3323 | | // Intended to be used when saving to state to a buffer. |
3324 | 0 | size_t llama_state_get_size(llama_context * ctx) { |
3325 | 0 | return ctx->state_get_size(); |
3326 | 0 | } |
3327 | | |
3328 | 0 | size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { |
3329 | 0 | ctx->synchronize(); |
3330 | |
|
3331 | 0 | return ctx->state_get_data(dst, size); |
3332 | 0 | } |
3333 | | |
3334 | | // Sets the state reading from the specified source address |
3335 | 0 | size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { |
3336 | 0 | ctx->synchronize(); |
3337 | |
|
3338 | 0 | return ctx->state_set_data(src, size); |
3339 | 0 | } |
3340 | | |
3341 | 0 | bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3342 | 0 | ctx->synchronize(); |
3343 | |
|
3344 | 0 | try { |
3345 | 0 | return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); |
3346 | 0 | } catch (const std::exception & err) { |
3347 | 0 | LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); |
3348 | 0 | return false; |
3349 | 0 | } |
3350 | 0 | } |
3351 | | |
3352 | 0 | bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3353 | 0 | ctx->synchronize(); |
3354 | |
|
3355 | 0 | try { |
3356 | 0 | return ctx->state_save_file(path_session, tokens, n_token_count); |
3357 | 0 | } catch (const std::exception & err) { |
3358 | 0 | LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); |
3359 | 0 | return false; |
3360 | 0 | } |
3361 | 0 | } |
3362 | | |
3363 | 0 | size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { |
3364 | 0 | return llama_state_seq_get_size_ext(ctx, seq_id, 0); |
3365 | 0 | } |
3366 | | |
3367 | 0 | size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { |
3368 | 0 | return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); |
3369 | 0 | } |
3370 | | |
3371 | 0 | size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { |
3372 | 0 | return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); |
3373 | 0 | } |
3374 | | |
3375 | 0 | size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3376 | 0 | return ctx->state_seq_get_size(seq_id, flags); |
3377 | 0 | } |
3378 | | |
3379 | 0 | size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3380 | 0 | ctx->synchronize(); |
3381 | |
|
3382 | 0 | return ctx->state_seq_get_data(seq_id, dst, size, flags); |
3383 | 0 | } |
3384 | | |
3385 | 0 | size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3386 | 0 | ctx->synchronize(); |
3387 | |
|
3388 | 0 | return ctx->state_seq_set_data(seq_id, src, size, flags); |
3389 | 0 | } |
3390 | | |
3391 | 0 | size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { |
3392 | 0 | ctx->synchronize(); |
3393 | |
|
3394 | 0 | try { |
3395 | 0 | return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); |
3396 | 0 | } catch (const std::exception & err) { |
3397 | 0 | LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); |
3398 | 0 | return 0; |
3399 | 0 | } |
3400 | 0 | } |
3401 | | |
3402 | 0 | size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3403 | 0 | ctx->synchronize(); |
3404 | |
|
3405 | 0 | try { |
3406 | 0 | return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); |
3407 | 0 | } catch (const std::exception & err) { |
3408 | 0 | LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); |
3409 | 0 | return 0; |
3410 | 0 | } |
3411 | 0 | } |
3412 | | |
3413 | | /// |
3414 | | |
3415 | | int32_t llama_encode( |
3416 | | llama_context * ctx, |
3417 | 0 | llama_batch batch) { |
3418 | 0 | const int ret = ctx->encode(batch); |
3419 | 0 | if (ret != 0) { |
3420 | 0 | LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
3421 | 0 | } |
3422 | |
|
3423 | 0 | return ret; |
3424 | 0 | } |
3425 | | |
3426 | | int32_t llama_decode( |
3427 | | llama_context * ctx, |
3428 | 0 | llama_batch batch) { |
3429 | 0 | const int ret = ctx->decode(batch); |
3430 | 0 | if (ret != 0 && ret != 1) { |
3431 | 0 | LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
3432 | 0 | } |
3433 | |
|
3434 | 0 | return ret; |
3435 | 0 | } |
3436 | | |
3437 | | // |
3438 | | // perf |
3439 | | // |
3440 | | |
3441 | 0 | llama_perf_context_data llama_perf_context(const llama_context * ctx) { |
3442 | 0 | llama_perf_context_data data = {}; |
3443 | |
|
3444 | 0 | if (ctx == nullptr) { |
3445 | 0 | return data; |
3446 | 0 | } |
3447 | | |
3448 | 0 | data = ctx->perf_get_data(); |
3449 | |
|
3450 | 0 | return data; |
3451 | 0 | } |
3452 | | |
3453 | 0 | void llama_perf_context_print(const llama_context * ctx) { |
3454 | 0 | const auto data = llama_perf_context(ctx); |
3455 | |
|
3456 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
3457 | |
|
3458 | 0 | LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
3459 | 0 | LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
3460 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
3461 | 0 | LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
3462 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
3463 | 0 | LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
3464 | 0 | LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); |
3465 | 0 | } |
3466 | | |
3467 | 0 | void llama_perf_context_reset(llama_context * ctx) { |
3468 | 0 | ctx->perf_reset(); |
3469 | 0 | } |
3470 | | |
3471 | 0 | void llama_memory_breakdown_print(const struct llama_context * ctx) { |
3472 | 0 | const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; |
3473 | |
|
3474 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); |
3475 | |
|
3476 | 0 | std::vector<std::array<std::string, 9>> table_data; |
3477 | 0 | table_data.reserve(devices.size()); |
3478 | 0 | const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; |
3479 | 0 | const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; |
3480 | 0 | const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; |
3481 | |
|
3482 | 0 | table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); |
3483 | |
|
3484 | 0 | constexpr size_t MiB = 1024 * 1024; |
3485 | 0 | const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; |
3486 | | |
3487 | | // track seen buffer types to avoid double counting: |
3488 | 0 | std::set<ggml_backend_buffer_type_t> seen_buffer_types; |
3489 | | |
3490 | | // accumulative memory breakdown for each device and for host: |
3491 | 0 | std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); |
3492 | 0 | llama_memory_breakdown_data mb_host; |
3493 | |
|
3494 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3495 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3496 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3497 | 0 | if (ggml_backend_buft_is_host(buft)) { |
3498 | 0 | mb_host.model += mb.model; |
3499 | 0 | mb_host.context += mb.context; |
3500 | 0 | mb_host.compute += mb.compute; |
3501 | 0 | seen_buffer_types.insert(buft); |
3502 | 0 | continue; |
3503 | 0 | } |
3504 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); |
3505 | 0 | if (dev) { |
3506 | 0 | int i_dev = -1; |
3507 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3508 | 0 | if (devices[i] == dev) { |
3509 | 0 | i_dev = i; |
3510 | 0 | break; |
3511 | 0 | } |
3512 | 0 | } |
3513 | 0 | if (i_dev != -1) { |
3514 | 0 | mb_dev[i_dev].model += mb.model; |
3515 | 0 | mb_dev[i_dev].context += mb.context; |
3516 | 0 | mb_dev[i_dev].compute += mb.compute; |
3517 | 0 | seen_buffer_types.insert(buft); |
3518 | 0 | continue; |
3519 | 0 | } |
3520 | 0 | } |
3521 | 0 | } |
3522 | | |
3523 | | // print memory breakdown for each device: |
3524 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3525 | 0 | ggml_backend_dev_t dev = devices[i]; |
3526 | 0 | llama_memory_breakdown_data mb = mb_dev[i]; |
3527 | |
|
3528 | 0 | const std::string name = ggml_backend_dev_name(dev); |
3529 | 0 | std::string desc = ggml_backend_dev_description(dev); |
3530 | 0 | for (const std::string & prefix : desc_prefixes_strip) { |
3531 | 0 | if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { |
3532 | 0 | desc = desc.substr(prefix.length()); |
3533 | 0 | } |
3534 | 0 | } |
3535 | |
|
3536 | 0 | size_t free, total; |
3537 | 0 | ggml_backend_dev_memory(dev, &free, &total); |
3538 | |
|
3539 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3540 | 0 | const size_t unaccounted = total - self - free; |
3541 | |
|
3542 | 0 | table_data.push_back({ |
3543 | 0 | template_gpu, |
3544 | 0 | " - " + name + " (" + desc + ")", |
3545 | 0 | std::to_string(total / MiB), |
3546 | 0 | std::to_string(free / MiB), |
3547 | 0 | std::to_string(self / MiB), |
3548 | 0 | std::to_string(mb.model / MiB), |
3549 | 0 | std::to_string(mb.context / MiB), |
3550 | 0 | std::to_string(mb.compute / MiB), |
3551 | 0 | std::to_string(unaccounted / MiB)}); |
3552 | 0 | } |
3553 | | |
3554 | | // print memory breakdown for host: |
3555 | 0 | { |
3556 | 0 | const size_t self = mb_host.model + mb_host.context + mb_host.compute; |
3557 | 0 | table_data.push_back({ |
3558 | 0 | template_other, |
3559 | 0 | " - Host", |
3560 | 0 | "", // total |
3561 | 0 | "", // free |
3562 | 0 | std::to_string(self / MiB), |
3563 | 0 | std::to_string(mb_host.model / MiB), |
3564 | 0 | std::to_string(mb_host.context / MiB), |
3565 | 0 | std::to_string(mb_host.compute / MiB), |
3566 | 0 | ""}); // unaccounted |
3567 | 0 | } |
3568 | | |
3569 | | // print memory breakdown for all remaining buffer types: |
3570 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3571 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3572 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3573 | 0 | if (seen_buffer_types.count(buft) == 1) { |
3574 | 0 | continue; |
3575 | 0 | } |
3576 | 0 | const std::string name = ggml_backend_buft_name(buft); |
3577 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3578 | 0 | table_data.push_back({ |
3579 | 0 | template_other, |
3580 | 0 | " - " + name, |
3581 | 0 | "", // total |
3582 | 0 | "", // free |
3583 | 0 | std::to_string(self / MiB), |
3584 | 0 | std::to_string(mb.model / MiB), |
3585 | 0 | std::to_string(mb.context / MiB), |
3586 | 0 | std::to_string(mb.compute / MiB), |
3587 | 0 | ""}); // unaccounted |
3588 | 0 | seen_buffer_types.insert(buft); |
3589 | 0 | } |
3590 | |
|
3591 | 0 | for (size_t j = 1; j < table_data[0].size(); j++) { |
3592 | 0 | size_t max_len = 0; |
3593 | 0 | for (const auto & td : table_data) { |
3594 | 0 | max_len = std::max(max_len, td[j].length()); |
3595 | 0 | } |
3596 | 0 | for (auto & td : table_data) { |
3597 | 0 | td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); |
3598 | 0 | } |
3599 | 0 | } |
3600 | 0 | for (const auto & td : table_data) { |
3601 | 0 | LLAMA_LOG_INFO(td[0].c_str(), |
3602 | 0 | __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), |
3603 | 0 | td[6].c_str(), td[7].c_str(), td[8].c_str()); |
3604 | 0 | } |
3605 | 0 | } |
3606 | | |
3607 | | // |
3608 | | // training |
3609 | | // |
3610 | | |
3611 | 0 | bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { |
3612 | 0 | GGML_UNUSED(tensor); |
3613 | 0 | GGML_UNUSED(userdata); |
3614 | 0 | return true; |
3615 | 0 | } |
3616 | | |
3617 | 0 | void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { |
3618 | 0 | ctx->opt_init(model, lopt_params); |
3619 | 0 | } |
3620 | | |
3621 | | void llama_opt_epoch( |
3622 | | struct llama_context * ctx, |
3623 | | ggml_opt_dataset_t dataset, |
3624 | | ggml_opt_result_t result_train, |
3625 | | ggml_opt_result_t result_eval, |
3626 | | int64_t idata_split, |
3627 | | ggml_opt_epoch_callback callback_train, |
3628 | 0 | ggml_opt_epoch_callback callback_eval) { |
3629 | 0 | ctx->opt_epoch( |
3630 | 0 | dataset, |
3631 | 0 | result_train, |
3632 | 0 | result_eval, |
3633 | 0 | idata_split, |
3634 | 0 | callback_train, |
3635 | 0 | callback_eval); |
3636 | 0 | } |