/src/llama.cpp/src/llama-context.cpp
Line | Count | Source |
1 | | #include "llama-context.h" |
2 | | |
3 | | #include "llama-arch.h" |
4 | | #include "llama-impl.h" |
5 | | #include "llama-batch.h" |
6 | | #include "llama-io.h" |
7 | | #include "llama-memory.h" |
8 | | #include "llama-mmap.h" |
9 | | #include "llama-model.h" |
10 | | |
11 | | #include <cinttypes> |
12 | | #include <cmath> |
13 | | #include <cstring> |
14 | | #include <limits> |
15 | | #include <stdexcept> |
16 | | |
17 | | // |
18 | | // llama_context |
19 | | // |
20 | | |
21 | | llama_context::llama_context( |
22 | | const llama_model & model, |
23 | | llama_context_params params) : |
24 | 0 | model(model), |
25 | 0 | cvec(std::make_unique<llama_adapter_cvec>()), |
26 | 0 | loras(std::make_unique<llama_adapter_loras>()), |
27 | 0 | balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { |
28 | | // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, |
29 | | // may need to be backend-dependent |
30 | 0 | LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |
31 | |
|
32 | 0 | t_start_us = model.t_start_us; |
33 | 0 | t_load_us = model.t_load_us; |
34 | |
|
35 | 0 | const auto & hparams = model.hparams; |
36 | |
|
37 | 0 | cparams.n_seq_max = std::max(1u, params.n_seq_max); |
38 | 0 | if (cparams.n_seq_max > LLAMA_MAX_SEQ) { |
39 | 0 | throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); |
40 | 0 | } |
41 | | |
42 | 0 | cparams.n_threads = params.n_threads; |
43 | 0 | cparams.n_threads_batch = params.n_threads_batch; |
44 | 0 | cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; |
45 | 0 | cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; |
46 | 0 | cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; |
47 | 0 | cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; |
48 | 0 | cparams.embeddings = params.embeddings; |
49 | 0 | cparams.offload_kqv = params.offload_kqv; |
50 | 0 | cparams.no_perf = params.no_perf; |
51 | 0 | cparams.pooling_type = params.pooling_type; |
52 | 0 | cparams.warmup = false; |
53 | |
|
54 | 0 | cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; |
55 | 0 | cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; |
56 | 0 | cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; |
57 | |
|
58 | 0 | cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : |
59 | 0 | hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : |
60 | 0 | hparams.n_ctx_train; |
61 | |
|
62 | 0 | cparams.cb_eval = params.cb_eval; |
63 | 0 | cparams.cb_eval_user_data = params.cb_eval_user_data; |
64 | | |
65 | | // Initialize backend samplers here so they are part of the sampling graph |
66 | | // before the reserve passes run later in this function. This avoids a later |
67 | | // re-reserve when graph nodes change. |
68 | 0 | if (params.samplers != nullptr && params.n_samplers > 0) { |
69 | 0 | for (size_t i = 0; i < params.n_samplers; ++i) { |
70 | 0 | const auto & config = params.samplers[i]; |
71 | |
|
72 | 0 | if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { |
73 | 0 | throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); |
74 | 0 | } |
75 | | |
76 | 0 | if (set_sampler(config.seq_id, config.sampler)) { |
77 | 0 | const int n_samplers = llama_sampler_chain_n(config.sampler); |
78 | |
|
79 | 0 | LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); |
80 | 0 | } |
81 | 0 | } |
82 | 0 | } |
83 | | |
84 | 0 | auto rope_scaling_type = params.rope_scaling_type; |
85 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { |
86 | 0 | rope_scaling_type = hparams.rope_scaling_type_train; |
87 | 0 | } |
88 | |
|
89 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { |
90 | 0 | cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none |
91 | 0 | } |
92 | |
|
93 | 0 | if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' |
94 | 0 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
95 | 0 | } |
96 | |
|
97 | 0 | if (cparams.yarn_ext_factor != 0) { |
98 | 0 | static auto get_mscale = [](float scale, float mscale) { |
99 | 0 | return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); |
100 | 0 | }; |
101 | |
|
102 | 0 | const float factor = 1.0f / cparams.rope_freq_scale; |
103 | | |
104 | | // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 |
105 | 0 | if (hparams.rope_yarn_log_mul != 0.0f) { |
106 | | // note: here we assume `mscale == 1.0f` |
107 | | // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f |
108 | 0 | float mscale = 1.0f; |
109 | 0 | const float mscale_all_dims = hparams.rope_yarn_log_mul; |
110 | | |
111 | | // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] |
112 | | // special-case DEEPSEEK v2: |
113 | | // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 |
114 | 0 | if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { |
115 | 0 | mscale = mscale_all_dims; |
116 | 0 | } |
117 | |
|
118 | 0 | cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
119 | |
|
120 | 0 | LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", |
121 | 0 | __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); |
122 | 0 | } else { |
123 | 0 | cparams.yarn_attn_factor = get_mscale(factor, 1.0f); |
124 | 0 | } |
125 | | |
126 | | // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: |
127 | | // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 |
128 | | // |
129 | | // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 |
130 | | // https://github.com/ggml-org/llama.cpp/pull/17945 |
131 | 0 | cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); |
132 | 0 | } |
133 | |
|
134 | 0 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
135 | |
|
136 | 0 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
137 | 0 | if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
138 | 0 | cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; |
139 | 0 | } else { |
140 | 0 | cparams.pooling_type = hparams.pooling_type; |
141 | 0 | } |
142 | 0 | } |
143 | |
|
144 | 0 | if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { |
145 | 0 | cparams.causal_attn = hparams.causal_attn; |
146 | 0 | } else { |
147 | 0 | cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; |
148 | 0 | } |
149 | |
|
150 | 0 | cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; |
151 | 0 | cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; |
152 | | |
153 | | // with causal attention, the batch size is limited by the context size |
154 | 0 | cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; |
155 | |
|
156 | 0 | cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); |
157 | |
|
158 | 0 | cparams.op_offload = params.op_offload; |
159 | 0 | cparams.kv_unified = params.kv_unified; |
160 | | |
161 | | // intialized later |
162 | 0 | cparams.pipeline_parallel = false; |
163 | |
|
164 | 0 | { |
165 | 0 | const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); |
166 | 0 | graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; |
167 | |
|
168 | 0 | if (graph_reuse_disable) { |
169 | 0 | LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); |
170 | 0 | } |
171 | 0 | } |
172 | | |
173 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 |
174 | 0 | cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); |
175 | |
|
176 | 0 | if (cparams.kv_unified) { |
177 | 0 | cparams.n_ctx_seq = cparams.n_ctx; |
178 | 0 | } else { |
179 | 0 | cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; |
180 | 0 | cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); |
181 | |
|
182 | 0 | if (cparams.n_ctx_seq == 0) { |
183 | 0 | throw std::runtime_error("n_ctx_seq == 0"); |
184 | 0 | } |
185 | | |
186 | 0 | if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { |
187 | 0 | cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; |
188 | 0 | LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); |
189 | 0 | } |
190 | 0 | } |
191 | | |
192 | 0 | LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); |
193 | 0 | LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); |
194 | 0 | LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); |
195 | 0 | LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); |
196 | 0 | LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |
197 | 0 | LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); |
198 | 0 | LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); |
199 | 0 | LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); |
200 | 0 | LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); |
201 | 0 | LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |
202 | |
|
203 | 0 | if (cparams.n_ctx_seq < hparams.n_ctx_train) { |
204 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", |
205 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
206 | 0 | } |
207 | |
|
208 | 0 | if (cparams.n_ctx_seq > hparams.n_ctx_train) { |
209 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", |
210 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
211 | 0 | } |
212 | |
|
213 | 0 | if (!hparams.vocab_only) { |
214 | | // GPU backends |
215 | 0 | for (auto * dev : model.devices) { |
216 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
217 | 0 | if (backend == nullptr) { |
218 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
219 | 0 | } |
220 | 0 | backends.emplace_back(backend); |
221 | 0 | } |
222 | | |
223 | | // add ACCEL backends (such as BLAS) |
224 | 0 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
225 | 0 | ggml_backend_dev_t dev = ggml_backend_dev_get(i); |
226 | 0 | if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { |
227 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
228 | 0 | if (backend == nullptr) { |
229 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
230 | 0 | } |
231 | 0 | backends.emplace_back(backend); |
232 | 0 | } |
233 | 0 | } |
234 | | |
235 | | // add CPU backend |
236 | 0 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
237 | 0 | if (backend_cpu == nullptr) { |
238 | 0 | throw std::runtime_error("failed to initialize CPU backend"); |
239 | 0 | } |
240 | 0 | backends.emplace_back(backend_cpu); |
241 | | |
242 | | // create a list of the set_n_threads functions in the backends |
243 | 0 | for (auto & backend : backends) { |
244 | 0 | ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); |
245 | 0 | ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; |
246 | 0 | if (reg) { |
247 | 0 | auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); |
248 | 0 | if (ggml_backend_set_n_threads_fn) { |
249 | 0 | set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); |
250 | 0 | } |
251 | 0 | } |
252 | 0 | } |
253 | |
|
254 | 0 | llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); |
255 | | |
256 | | // graph outputs buffer |
257 | 0 | { |
258 | 0 | if (output_reserve(params.n_seq_max) < params.n_seq_max) { |
259 | 0 | throw std::runtime_error("failed to reserve initial output buffer"); |
260 | 0 | } |
261 | | |
262 | 0 | LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, |
263 | 0 | ggml_backend_buffer_name (buf_output.get()), |
264 | 0 | ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); |
265 | 0 | } |
266 | 0 | } |
267 | | |
268 | | // init the memory module |
269 | 0 | if (!hparams.vocab_only) { |
270 | 0 | llama_memory_params params_mem = { |
271 | 0 | /*.type_k =*/ params.type_k, |
272 | 0 | /*.type_v =*/ params.type_v, |
273 | 0 | /*.swa_full =*/ params.swa_full, |
274 | 0 | }; |
275 | |
|
276 | 0 | memory.reset(model.create_memory(params_mem, cparams)); |
277 | 0 | } |
278 | | |
279 | | // init backends |
280 | 0 | if (!hparams.vocab_only) { |
281 | 0 | LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); |
282 | |
|
283 | 0 | backend_buft.clear(); |
284 | 0 | backend_ptrs.clear(); |
285 | 0 | backend_buf_exp_size.clear(); |
286 | |
|
287 | 0 | for (auto & backend : backends) { |
288 | 0 | auto * buft = ggml_backend_get_default_buffer_type(backend.get()); |
289 | 0 | auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
290 | |
|
291 | 0 | if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { |
292 | | // use the host buffer of the first device CPU for faster transfer of the intermediate state |
293 | 0 | auto * dev = model.devices[0]; |
294 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(dev); |
295 | 0 | if (host_buft) { |
296 | 0 | buft = host_buft; |
297 | 0 | } |
298 | 0 | } |
299 | |
|
300 | 0 | backend_buft.push_back(buft); |
301 | 0 | backend_ptrs.push_back(backend.get()); |
302 | 0 | backend_buf_exp_size.push_back(0); |
303 | 0 | } |
304 | |
|
305 | 0 | LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); |
306 | | |
307 | | // TODO: move these checks to ggml_backend_sched |
308 | | // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary |
309 | 0 | bool pipeline_parallel = |
310 | 0 | model.n_devices() > 1 && |
311 | 0 | model.n_gpu_layers() > model.hparams.n_layer && |
312 | 0 | model.split_mode() == LLAMA_SPLIT_MODE_LAYER && |
313 | 0 | cparams.offload_kqv && |
314 | 0 | !model.has_tensor_overrides(); |
315 | | |
316 | | // pipeline parallelism requires support for async compute and events in all devices |
317 | 0 | if (pipeline_parallel) { |
318 | 0 | for (auto & backend : backends) { |
319 | 0 | auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
320 | 0 | if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { |
321 | | // ignore CPU backend |
322 | | // TODO: should we ignore ACCEL types too? |
323 | 0 | continue; |
324 | 0 | } |
325 | 0 | auto * dev = ggml_backend_get_device(backend.get()); |
326 | 0 | ggml_backend_dev_props props; |
327 | 0 | ggml_backend_dev_get_props(dev, &props); |
328 | 0 | if (!props.caps.async || !props.caps.events) { |
329 | | // device does not support async compute or events |
330 | 0 | pipeline_parallel = false; |
331 | 0 | break; |
332 | 0 | } |
333 | 0 | } |
334 | 0 | } |
335 | |
|
336 | 0 | cparams.pipeline_parallel = pipeline_parallel; |
337 | |
|
338 | 0 | if (cparams.pipeline_parallel) { |
339 | 0 | LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); |
340 | 0 | } |
341 | |
|
342 | 0 | sched_reserve(); |
343 | |
|
344 | 0 | if (!cparams.flash_attn) { |
345 | 0 | if (ggml_is_quantized(params.type_v)) { |
346 | 0 | throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |
347 | 0 | } |
348 | 0 | } |
349 | 0 | } |
350 | | |
351 | | // Initialize the full vocabulary token ids for backend samplers. |
352 | 0 | { |
353 | 0 | const int n_vocab = model.vocab.n_tokens(); |
354 | |
|
355 | 0 | sampling.token_ids_full_vocab.resize(n_vocab); |
356 | 0 | for (int i = 0; i < n_vocab; ++i) { |
357 | 0 | sampling.token_ids_full_vocab[i] = i; |
358 | 0 | } |
359 | 0 | } |
360 | 0 | } |
361 | | |
362 | 0 | llama_context::~llama_context() { |
363 | 0 | if (!model.hparams.no_alloc) { |
364 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
365 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
366 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
367 | |
|
368 | 0 | const size_t size_exp = backend_buf_exp_size[i]; |
369 | 0 | const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
370 | 0 | if (size_exp == size_act) { |
371 | 0 | LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", |
372 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
373 | 0 | } else { |
374 | 0 | LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", |
375 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
376 | 0 | } |
377 | 0 | } |
378 | 0 | } |
379 | 0 | ggml_opt_free(opt_ctx); |
380 | 0 | } |
381 | | |
382 | 0 | void llama_context::sched_reserve() { |
383 | 0 | if (!sched_need_reserve) { |
384 | 0 | return; |
385 | 0 | } |
386 | | |
387 | 0 | sched_need_reserve = false; |
388 | |
|
389 | 0 | LLAMA_LOG_INFO("%s: reserving ...\n", __func__); |
390 | |
|
391 | 0 | synchronize(); |
392 | |
|
393 | 0 | const int64_t t_start_us = ggml_time_us(); |
394 | |
|
395 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
396 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
397 | |
|
398 | 0 | const size_t max_nodes = this->graph_max_nodes(n_tokens); |
399 | |
|
400 | 0 | LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); |
401 | |
|
402 | 0 | gf_res_prev.reset(new llm_graph_result(max_nodes)); |
403 | 0 | gf_res_reserve.reset(new llm_graph_result(max_nodes)); |
404 | |
|
405 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); |
406 | |
|
407 | 0 | llama_memory_context_ptr mctx; |
408 | 0 | if (memory) { |
409 | 0 | LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |
410 | 0 | mctx = memory->init_full(); |
411 | 0 | if (!mctx) { |
412 | 0 | throw std::runtime_error("failed to initialize memory module"); |
413 | 0 | } |
414 | 0 | } |
415 | | |
416 | | // avoid reserving graphs with zero outputs - assume one output per sequence |
417 | 0 | const int n_outputs = n_seqs; |
418 | |
|
419 | 0 | LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |
420 | | |
421 | | // resolve automatic Flash Attention use |
422 | 0 | if (cparams.auto_fa) { |
423 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
424 | 0 | if (!gf) { |
425 | 0 | throw std::runtime_error("failed to split graph for Flash Attention check"); |
426 | 0 | } |
427 | | |
428 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |
429 | 0 | bool fa_device_mismatch = false; |
430 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
431 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
432 | 0 | if (n->op != GGML_OP_FLASH_ATTN_EXT) { |
433 | 0 | continue; |
434 | 0 | } |
435 | 0 | ggml_backend_dev_t device_fa = ggml_backend_get_device( |
436 | 0 | ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
437 | | |
438 | | // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer |
439 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); |
440 | 0 | const int il = std::stoi(n->name + prefix_len); |
441 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
442 | 0 | if (device_fa != device_kv) { |
443 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " |
444 | 0 | "is assigned to device %s (usually due to missing support)\n", |
445 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); |
446 | | // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways |
447 | 0 | fa_device_mismatch = true; |
448 | 0 | break; |
449 | 0 | } |
450 | 0 | } |
451 | 0 | if (fa_device_mismatch) { |
452 | 0 | cparams.flash_attn = false; |
453 | 0 | LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); |
454 | 0 | } else { |
455 | 0 | cparams.flash_attn = true; |
456 | 0 | LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |
457 | 0 | } |
458 | |
|
459 | 0 | cparams.auto_fa = false; |
460 | 0 | } |
461 | | |
462 | | // reserve worst-case graph |
463 | 0 | int n_splits_pp = -1; |
464 | 0 | int n_nodes_pp = -1; |
465 | |
|
466 | 0 | int n_splits_tg = -1; |
467 | 0 | int n_nodes_tg = -1; |
468 | | |
469 | | // reserve pp (prompt processing) graph first so that buffers are only allocated once |
470 | 0 | { |
471 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), |
472 | 0 | model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); |
473 | 0 | if (!gf) { |
474 | 0 | if (cparams.pipeline_parallel) { |
475 | 0 | LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); |
476 | 0 | cparams.pipeline_parallel = false; |
477 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); |
478 | 0 | gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
479 | 0 | } |
480 | 0 | if (!gf) { |
481 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
482 | 0 | } |
483 | 0 | } |
484 | | |
485 | 0 | n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |
486 | 0 | n_nodes_pp = ggml_graph_n_nodes(gf); |
487 | 0 | } |
488 | | |
489 | | // reserve with tg (token generation) graph to get the number of splits and nodes |
490 | 0 | { |
491 | 0 | auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); |
492 | 0 | if (!gf) { |
493 | 0 | throw std::runtime_error("failed to allocate compute tg buffers"); |
494 | 0 | } |
495 | | |
496 | 0 | n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); |
497 | 0 | n_nodes_tg = ggml_graph_n_nodes(gf); |
498 | 0 | } |
499 | | |
500 | | // reserve again with pp graph to avoid ggml-alloc reallocations during inference |
501 | 0 | { |
502 | | // TODO: not sure if the following graph would be worster case for multi-stream KV caches: |
503 | | // |
504 | | // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); |
505 | | // |
506 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); |
507 | 0 | if (!gf) { |
508 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
509 | 0 | } |
510 | 0 | } |
511 | | |
512 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
513 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
514 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
515 | 0 | if (!model.hparams.no_alloc) { |
516 | 0 | backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
517 | 0 | } |
518 | 0 | if (backend_buf_exp_size[i] > 1) { |
519 | 0 | LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, |
520 | 0 | ggml_backend_buft_name(buft), |
521 | 0 | backend_buf_exp_size[i] / 1024.0 / 1024.0); |
522 | 0 | } |
523 | 0 | } |
524 | |
|
525 | 0 | if (n_nodes_pp == n_nodes_tg) { |
526 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); |
527 | 0 | } else { |
528 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); |
529 | 0 | } |
530 | |
|
531 | 0 | if (n_splits_pp == n_splits_tg) { |
532 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); |
533 | 0 | } else { |
534 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); |
535 | 0 | } |
536 | |
|
537 | 0 | const int64_t t_end_us = ggml_time_us(); |
538 | |
|
539 | 0 | LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", |
540 | 0 | __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); |
541 | 0 | } |
542 | | |
543 | 0 | void llama_context::synchronize() { |
544 | 0 | if (!sched) { |
545 | 0 | return; |
546 | 0 | } |
547 | | |
548 | 0 | ggml_backend_sched_synchronize(sched.get()); |
549 | | |
550 | | // FIXME: if multiple single tokens are evaluated without a synchronization, |
551 | | // the stats will be added to the prompt evaluation stats |
552 | | // this should only happen when using batch size 1 to evaluate a batch |
553 | | |
554 | | // add the evaluation to the stats |
555 | 0 | if (n_queued_tokens == 1) { |
556 | 0 | if (!cparams.no_perf) { |
557 | 0 | t_eval_us += ggml_time_us() - t_compute_start_us; |
558 | 0 | } |
559 | 0 | n_eval++; |
560 | 0 | } else if (n_queued_tokens > 1) { |
561 | 0 | if (!cparams.no_perf) { |
562 | 0 | t_p_eval_us += ggml_time_us() - t_compute_start_us; |
563 | 0 | } |
564 | 0 | n_p_eval += n_queued_tokens; |
565 | 0 | } |
566 | | |
567 | | // get a more accurate load time, upon first eval |
568 | 0 | if (n_queued_tokens > 0 && !has_evaluated_once) { |
569 | 0 | t_load_us = ggml_time_us() - t_start_us; |
570 | 0 | has_evaluated_once = true; |
571 | 0 | } |
572 | |
|
573 | 0 | n_queued_tokens = 0; |
574 | 0 | t_compute_start_us = 0; |
575 | 0 | } |
576 | | |
577 | 0 | const llama_model & llama_context::get_model() const { |
578 | 0 | return model; |
579 | 0 | } |
580 | | |
581 | 0 | const llama_cparams & llama_context::get_cparams() const { |
582 | 0 | return cparams; |
583 | 0 | } |
584 | | |
585 | 0 | ggml_backend_sched_t llama_context::get_sched() const { |
586 | 0 | return sched.get(); |
587 | 0 | } |
588 | | |
589 | 0 | uint32_t llama_context::n_ctx() const { |
590 | 0 | return cparams.n_ctx; |
591 | 0 | } |
592 | | |
593 | 0 | uint32_t llama_context::n_ctx_seq() const { |
594 | 0 | return cparams.n_ctx_seq; |
595 | 0 | } |
596 | | |
597 | 0 | uint32_t llama_context::n_batch() const { |
598 | 0 | return cparams.n_batch; |
599 | 0 | } |
600 | | |
601 | 0 | uint32_t llama_context::n_ubatch() const { |
602 | 0 | return cparams.n_ubatch; |
603 | 0 | } |
604 | | |
605 | 0 | uint32_t llama_context::n_seq_max() const { |
606 | 0 | return cparams.n_seq_max; |
607 | 0 | } |
608 | | |
609 | 0 | uint32_t llama_context::n_threads() const { |
610 | 0 | return cparams.n_threads; |
611 | 0 | } |
612 | | |
613 | 0 | uint32_t llama_context::n_threads_batch() const { |
614 | 0 | return cparams.n_threads_batch; |
615 | 0 | } |
616 | | |
617 | 0 | llama_memory_t llama_context::get_memory() const { |
618 | 0 | return memory.get(); |
619 | 0 | } |
620 | | |
621 | 0 | bool llama_context::memory_update(bool optimize) { |
622 | 0 | if (!memory) { |
623 | 0 | return false; |
624 | 0 | } |
625 | | |
626 | 0 | { |
627 | 0 | const auto mctx = memory->init_update(this, optimize); |
628 | 0 | switch (mctx->get_status()) { |
629 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
630 | 0 | { |
631 | | // noop |
632 | 0 | } break; |
633 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
634 | 0 | { |
635 | | // no updates need to be performed |
636 | 0 | return false; |
637 | 0 | } |
638 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
639 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
640 | 0 | { |
641 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); |
642 | 0 | return false; |
643 | 0 | } |
644 | 0 | } |
645 | | |
646 | | // reset the previous graph result to make sure that it won't be reused |
647 | | // TODO: change the mctx->apply() to return information if a graph reserve is needed |
648 | | // reset the graph result only if the memory module did reset the scheduler |
649 | 0 | gf_res_prev->reset(); |
650 | |
|
651 | 0 | if (!mctx->apply()) { |
652 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); |
653 | 0 | } |
654 | 0 | } |
655 | | |
656 | | // if the memory module did any computation, we have to reserve a new worst-case graph |
657 | 0 | { |
658 | 0 | const auto mctx = memory->init_full(); |
659 | 0 | if (!mctx) { |
660 | 0 | throw std::runtime_error("failed to initialize memory context"); |
661 | 0 | } |
662 | | |
663 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
664 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
665 | |
|
666 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
667 | 0 | if (!gf) { |
668 | 0 | LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); |
669 | 0 | } |
670 | 0 | } |
671 | | |
672 | 0 | return true; |
673 | 0 | } |
674 | | |
675 | 0 | enum llama_pooling_type llama_context::pooling_type() const { |
676 | 0 | return cparams.pooling_type; |
677 | 0 | } |
678 | | |
679 | 0 | float * llama_context::get_logits() { |
680 | 0 | output_reorder(); |
681 | |
|
682 | 0 | return logits.data; |
683 | 0 | } |
684 | | |
685 | 0 | int64_t llama_context::output_resolve_row(int32_t i) const { |
686 | 0 | int64_t j = -1; |
687 | | |
688 | | // support negative indices (last output row) |
689 | 0 | if (i < 0) { |
690 | 0 | j = n_outputs + i; |
691 | 0 | if (j < 0) { |
692 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
693 | 0 | } |
694 | 0 | } else if ((size_t) i >= output_ids.size()) { |
695 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
696 | 0 | } else { |
697 | | // use output_ids to translate the batch token index into a row number |
698 | | // that holds this token's data. |
699 | 0 | j = output_ids[i]; |
700 | 0 | } |
701 | | |
702 | 0 | if (j < 0) { |
703 | | // the batch token was not configured to output anything |
704 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
705 | 0 | } |
706 | | |
707 | 0 | if (j >= n_outputs) { |
708 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
709 | 0 | } |
710 | | |
711 | 0 | return j; |
712 | 0 | } |
713 | | |
714 | 0 | float * llama_context::get_logits_ith(int32_t i) { |
715 | 0 | output_reorder(); |
716 | |
|
717 | 0 | try { |
718 | 0 | if (logits.data == nullptr) { |
719 | 0 | throw std::runtime_error("no logits"); |
720 | 0 | } |
721 | | |
722 | 0 | const int64_t j = output_resolve_row(i); |
723 | 0 | return logits.data + j*model.vocab.n_tokens(); |
724 | 0 | } catch (const std::exception & err) { |
725 | 0 | LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
726 | | #ifndef NDEBUG |
727 | | GGML_ABORT("fatal error"); |
728 | | #else |
729 | 0 | return nullptr; |
730 | 0 | #endif |
731 | 0 | } |
732 | 0 | } |
733 | | |
734 | 0 | float * llama_context::get_embeddings() { |
735 | 0 | output_reorder(); |
736 | |
|
737 | 0 | return embd.data; |
738 | 0 | } |
739 | | |
740 | 0 | llama_token * llama_context::get_sampled_tokens() const{ |
741 | 0 | return sampling.sampled.data; |
742 | 0 | } |
743 | | |
744 | 0 | float * llama_context::get_embeddings_ith(int32_t i) { |
745 | 0 | output_reorder(); |
746 | |
|
747 | 0 | try { |
748 | 0 | if (embd.data == nullptr) { |
749 | 0 | throw std::runtime_error("no embeddings"); |
750 | 0 | } |
751 | | |
752 | 0 | const int64_t j = output_resolve_row(i); |
753 | 0 | const uint32_t n_embd_out = model.hparams.n_embd_out(); |
754 | 0 | return embd.data + j*n_embd_out; |
755 | 0 | } catch (const std::exception & err) { |
756 | 0 | LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
757 | | #ifndef NDEBUG |
758 | | GGML_ABORT("fatal error"); |
759 | | #else |
760 | 0 | return nullptr; |
761 | 0 | #endif |
762 | 0 | } |
763 | 0 | } |
764 | | |
765 | 0 | float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { |
766 | 0 | auto it = embd_seq.find(seq_id); |
767 | 0 | if (it == embd_seq.end()) { |
768 | 0 | return nullptr; |
769 | 0 | } |
770 | | |
771 | 0 | return it->second.data(); |
772 | 0 | } |
773 | | |
774 | 0 | llama_token llama_context::get_sampled_token_ith(int32_t idx) { |
775 | 0 | output_reorder(); |
776 | |
|
777 | 0 | if (!sampling.sampled.has_data()) { |
778 | 0 | return LLAMA_TOKEN_NULL; |
779 | 0 | } |
780 | | |
781 | 0 | try { |
782 | 0 | const int64_t row = output_resolve_row(idx); |
783 | 0 | GGML_ASSERT(row < (int64_t) sampling.sampled.size); |
784 | 0 | return sampling.sampled.data[row]; |
785 | 0 | } catch (const std::exception & err) { |
786 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); |
787 | 0 | return LLAMA_TOKEN_NULL; |
788 | 0 | } |
789 | 0 | } |
790 | | |
791 | 0 | float * llama_context::get_sampled_probs_ith(int32_t idx) { |
792 | 0 | output_reorder(); |
793 | |
|
794 | 0 | if (!sampling.probs.has_data()) { |
795 | 0 | return nullptr; |
796 | 0 | } |
797 | | |
798 | 0 | try { |
799 | 0 | const int64_t row = output_resolve_row(idx); |
800 | 0 | if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { |
801 | 0 | return nullptr; |
802 | 0 | } |
803 | 0 | return sampling.probs.data + row*model.vocab.n_tokens(); |
804 | 0 | } catch (const std::exception & err) { |
805 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); |
806 | 0 | return nullptr; |
807 | 0 | } |
808 | 0 | } |
809 | | |
810 | 0 | float * llama_context::get_sampled_logits_ith(int32_t idx) { |
811 | 0 | output_reorder(); |
812 | |
|
813 | 0 | if (!sampling.logits.has_data()) { |
814 | 0 | return nullptr; |
815 | 0 | } |
816 | | |
817 | 0 | try { |
818 | 0 | const int64_t row = output_resolve_row(idx); |
819 | 0 | if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { |
820 | 0 | return nullptr; |
821 | 0 | } |
822 | 0 | return sampling.logits.data + row*model.vocab.n_tokens(); |
823 | 0 | } catch (const std::exception & err) { |
824 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); |
825 | 0 | return nullptr; |
826 | 0 | } |
827 | 0 | } |
828 | | |
829 | 0 | const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { |
830 | 0 | output_reorder(); |
831 | |
|
832 | 0 | try { |
833 | 0 | const int64_t row = output_resolve_row(idx); |
834 | 0 | if (sampling.candidates.has_data() && |
835 | 0 | (size_t) row < sampling.candidates_count.size() && |
836 | 0 | sampling.candidates_count[row] > 0) { |
837 | 0 | return sampling.candidates.data + row*model.vocab.n_tokens(); |
838 | 0 | } |
839 | 0 | } catch (const std::exception & err) { |
840 | | // fallback to full vocab list |
841 | 0 | GGML_UNUSED(err); |
842 | 0 | } |
843 | | |
844 | 0 | return sampling.token_ids_full_vocab.data(); |
845 | 0 | } |
846 | | |
847 | 0 | size_t llama_context::get_sampled_candidates_count(int32_t idx) { |
848 | 0 | output_reorder(); |
849 | |
|
850 | 0 | if (!sampling.candidates.has_data()) { |
851 | 0 | return 0; |
852 | 0 | } |
853 | | |
854 | 0 | try { |
855 | 0 | const int64_t row = output_resolve_row(idx); |
856 | 0 | if ((size_t) row >= sampling.candidates_count.size()) { |
857 | 0 | return 0; |
858 | 0 | } |
859 | 0 | return sampling.candidates_count[row]; |
860 | 0 | } catch (const std::exception & err) { |
861 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); |
862 | 0 | return 0; |
863 | 0 | } |
864 | 0 | } |
865 | | |
866 | 0 | size_t llama_context::get_sampled_logits_count(int32_t idx) { |
867 | 0 | output_reorder(); |
868 | |
|
869 | 0 | if (!sampling.logits.has_data()) { |
870 | 0 | return model.vocab.n_tokens(); |
871 | 0 | } |
872 | | |
873 | 0 | try { |
874 | 0 | const int64_t row = output_resolve_row(idx); |
875 | 0 | if ((size_t) row >= sampling.logits_count.size()) { |
876 | 0 | return 0; |
877 | 0 | } |
878 | 0 | return sampling.logits_count[row]; |
879 | 0 | } catch (const std::exception & err) { |
880 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); |
881 | 0 | return 0; |
882 | 0 | } |
883 | 0 | } |
884 | | |
885 | 0 | size_t llama_context::get_sampled_probs_count(int32_t idx) { |
886 | 0 | output_reorder(); |
887 | |
|
888 | 0 | if (!sampling.probs.has_data()) { |
889 | 0 | return 0; |
890 | 0 | } |
891 | | |
892 | 0 | try { |
893 | 0 | const int64_t row = output_resolve_row(idx); |
894 | 0 | if ((size_t) row >= sampling.probs_count.size()) { |
895 | 0 | return 0; |
896 | 0 | } |
897 | 0 | return sampling.probs_count[row]; |
898 | 0 | } catch (const std::exception & err) { |
899 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); |
900 | 0 | return 0; |
901 | 0 | } |
902 | 0 | } |
903 | | |
904 | | |
905 | | void llama_context::attach_threadpool( |
906 | | ggml_threadpool_t threadpool, |
907 | 0 | ggml_threadpool_t threadpool_batch) { |
908 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
909 | |
|
910 | 0 | this->threadpool = threadpool; |
911 | 0 | this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
912 | 0 | } |
913 | | |
914 | 0 | void llama_context::detach_threadpool() { |
915 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
916 | |
|
917 | 0 | this->threadpool = nullptr; |
918 | 0 | this->threadpool_batch = nullptr; |
919 | 0 | } |
920 | | |
921 | 0 | void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { |
922 | 0 | LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); |
923 | |
|
924 | 0 | cparams.n_threads = n_threads; |
925 | 0 | cparams.n_threads_batch = n_threads_batch; |
926 | 0 | } |
927 | | |
928 | 0 | void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { |
929 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
930 | |
|
931 | 0 | this->abort_callback = abort_callback; |
932 | 0 | this->abort_callback_data = abort_callback_data; |
933 | |
|
934 | 0 | for (auto & backend : backends) { |
935 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
936 | 0 | auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
937 | 0 | if (set_abort_callback_fn) { |
938 | 0 | set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); |
939 | 0 | } |
940 | 0 | } |
941 | 0 | } |
942 | | |
943 | 0 | void llama_context::set_embeddings(bool value) { |
944 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
945 | |
|
946 | 0 | cparams.embeddings = value; |
947 | | |
948 | | // TODO: not sure yet if we want to reserve here |
949 | | //sched_need_reserve = true; |
950 | 0 | } |
951 | | |
952 | 0 | void llama_context::set_causal_attn(bool value) { |
953 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
954 | |
|
955 | 0 | if (cparams.causal_attn == value) { |
956 | 0 | return; |
957 | 0 | } |
958 | | |
959 | 0 | cparams.causal_attn = value; |
960 | |
|
961 | 0 | sched_need_reserve = true; |
962 | 0 | } |
963 | | |
964 | 0 | void llama_context::set_warmup(bool value) { |
965 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
966 | |
|
967 | 0 | if (cparams.warmup == value) { |
968 | 0 | return; |
969 | 0 | } |
970 | | |
971 | 0 | cparams.warmup = value; |
972 | | |
973 | | // warmups are usually with small batches, so no need to reserve |
974 | | //sched_need_reserve = true; |
975 | 0 | } |
976 | | |
977 | 0 | bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { |
978 | 0 | if (!sampler && sampling.samplers.count(seq_id) == 0) { |
979 | 0 | return true; |
980 | 0 | } |
981 | | |
982 | 0 | LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); |
983 | |
|
984 | 0 | const bool can_offload = |
985 | 0 | sampler && |
986 | 0 | sampler->iface->backend_init && |
987 | 0 | sampler->iface->backend_apply && |
988 | 0 | llama_sampler_chain_n(sampler) > 0; |
989 | |
|
990 | 0 | if (sampler && can_offload) { |
991 | 0 | auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); |
992 | |
|
993 | 0 | sampler->iface->backend_init(sampler, buft); |
994 | |
|
995 | 0 | sampling.samplers[seq_id] = sampler; |
996 | |
|
997 | 0 | sched_need_reserve = true; |
998 | |
|
999 | 0 | return true; |
1000 | 0 | } |
1001 | | |
1002 | 0 | if (sampler && !can_offload) { |
1003 | 0 | LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); |
1004 | |
|
1005 | 0 | if (sampling.samplers.count(seq_id) > 0) { |
1006 | 0 | sched_need_reserve = true; |
1007 | 0 | } |
1008 | |
|
1009 | 0 | sampling.samplers.erase(seq_id); |
1010 | |
|
1011 | 0 | return false; |
1012 | 0 | } |
1013 | | |
1014 | 0 | sampling.samplers.erase(seq_id); |
1015 | |
|
1016 | 0 | sched_need_reserve = true; |
1017 | |
|
1018 | 0 | return true; |
1019 | 0 | } |
1020 | | |
1021 | 0 | void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { |
1022 | 0 | LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); |
1023 | |
|
1024 | 0 | if (adapters_lora_are_same(adapters, n_adapters, scales)) { |
1025 | 0 | return; |
1026 | 0 | } |
1027 | | |
1028 | 0 | loras.reset(new llama_adapter_loras()); |
1029 | |
|
1030 | 0 | for (size_t i = 0; i < n_adapters; i ++) { |
1031 | 0 | if (scales[i] != 0.0f) { |
1032 | 0 | loras->insert({adapters[i], scales[i]}); |
1033 | 0 | } |
1034 | 0 | } |
1035 | |
|
1036 | 0 | sched_need_reserve = true; |
1037 | 0 | } |
1038 | | |
1039 | 0 | bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { |
1040 | 0 | LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); |
1041 | |
|
1042 | 0 | if (n_adapters != loras->size()) { |
1043 | 0 | return false; |
1044 | 0 | } |
1045 | | |
1046 | 0 | for (size_t i = 0; i < n_adapters; i ++) { |
1047 | 0 | auto it = loras->find(adapters[i]); |
1048 | |
|
1049 | 0 | if (it == loras->end() || it->second != scales[i]) { |
1050 | 0 | return false; |
1051 | 0 | } |
1052 | 0 | } |
1053 | | |
1054 | 0 | return true; |
1055 | 0 | } |
1056 | | |
1057 | | bool llama_context::set_adapter_cvec( |
1058 | | const float * data, |
1059 | | size_t len, |
1060 | | int32_t n_embd, |
1061 | | int32_t il_start, |
1062 | 0 | int32_t il_end) { |
1063 | 0 | LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); |
1064 | | |
1065 | | // TODO: should we reserve? |
1066 | |
|
1067 | 0 | return cvec->apply(model, data, len, n_embd, il_start, il_end); |
1068 | 0 | } |
1069 | | |
1070 | 0 | llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { |
1071 | 0 | if (mctx && !mctx->apply()) { |
1072 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); |
1073 | 0 | ret = GGML_STATUS_FAILED; |
1074 | 0 | return nullptr; |
1075 | 0 | } |
1076 | | |
1077 | 0 | auto * res = gf_res_prev.get(); |
1078 | 0 | auto * gf = res->get_gf(); |
1079 | | |
1080 | | // the new graph parameters |
1081 | | // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters |
1082 | 0 | const auto gparams = graph_params(res, ubatch, mctx, gtype); |
1083 | |
|
1084 | 0 | if (!graph_reuse_disable && res->can_reuse(gparams)) { |
1085 | | //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); |
1086 | |
|
1087 | 0 | n_reused++; |
1088 | 0 | } else { |
1089 | 0 | res->reset(); |
1090 | |
|
1091 | 0 | ggml_backend_sched_reset(sched.get()); |
1092 | 0 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
1093 | | |
1094 | | //const auto t_start_us = ggml_time_us(); |
1095 | |
|
1096 | 0 | gf = model.build_graph(gparams); |
1097 | | |
1098 | | //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1099 | |
|
1100 | 0 | if (!gf) { |
1101 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
1102 | 0 | ret = GGML_STATUS_FAILED; |
1103 | 0 | return nullptr; |
1104 | 0 | } |
1105 | | |
1106 | 0 | if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
1107 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
1108 | 0 | ret = GGML_STATUS_ALLOC_FAILED; |
1109 | 0 | return nullptr; |
1110 | 0 | } |
1111 | 0 | } |
1112 | | |
1113 | | // set the input data for the input tensors |
1114 | 0 | { |
1115 | | //const auto t_start_us = ggml_time_us(); |
1116 | |
|
1117 | 0 | res->set_inputs(&ubatch); |
1118 | | |
1119 | | //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1120 | 0 | } |
1121 | |
|
1122 | 0 | const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); |
1123 | 0 | if (status != GGML_STATUS_SUCCESS) { |
1124 | 0 | LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
1125 | 0 | ret = status; |
1126 | 0 | return nullptr; |
1127 | 0 | } |
1128 | | |
1129 | 0 | ret = GGML_STATUS_SUCCESS; |
1130 | |
|
1131 | 0 | return res; |
1132 | 0 | } |
1133 | | |
1134 | 0 | int llama_context::encode(const llama_batch & batch_inp) { |
1135 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1136 | |
|
1137 | 0 | if (batch_inp.n_tokens == 0) { |
1138 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1139 | 0 | return -1; |
1140 | 0 | } |
1141 | | |
1142 | 0 | const auto & hparams = model.hparams; |
1143 | |
|
1144 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1145 | 0 | const int64_t n_vocab = model.vocab.n_tokens(); |
1146 | | |
1147 | | // note: during encode, we always pass the full sequence starting from pos = 0 |
1148 | 0 | if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
1149 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1150 | 0 | return -1; |
1151 | 0 | } |
1152 | | |
1153 | 0 | const uint32_t n_tokens = balloc->get_n_tokens(); |
1154 | | |
1155 | | // [TAG_NO_CACHE_PAD] |
1156 | | // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true |
1157 | 0 | const llama_ubatch ubatch = balloc->split_simple(n_tokens); |
1158 | | |
1159 | | // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |
1160 | 0 | GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); |
1161 | |
|
1162 | 0 | if (t_compute_start_us == 0) { |
1163 | 0 | t_compute_start_us = ggml_time_us(); |
1164 | 0 | } |
1165 | | |
1166 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1167 | 0 | embd_seq.clear(); |
1168 | |
|
1169 | 0 | sched_reserve(); |
1170 | |
|
1171 | 0 | n_queued_tokens += n_tokens; |
1172 | | |
1173 | | // reserve output buffer |
1174 | 0 | if (output_reserve(n_tokens) < n_tokens) { |
1175 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); |
1176 | 0 | return -2; |
1177 | 0 | }; |
1178 | |
|
1179 | 0 | for (uint32_t i = 0; i < n_tokens; ++i) { |
1180 | 0 | output_ids[i] = i; |
1181 | 0 | } |
1182 | |
|
1183 | 0 | n_outputs = n_tokens; |
1184 | |
|
1185 | 0 | const auto causal_attn_org = cparams.causal_attn; |
1186 | | |
1187 | | // always use non-causal attention for encoder graphs |
1188 | | // TODO: this is a tmp solution until we have a proper way to support enc-dec models |
1189 | | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
1190 | 0 | cparams.causal_attn = false; |
1191 | |
|
1192 | 0 | ggml_status status; |
1193 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); |
1194 | |
|
1195 | 0 | cparams.causal_attn = causal_attn_org; |
1196 | |
|
1197 | 0 | if (!res) { |
1198 | 0 | switch (status) { |
1199 | 0 | case GGML_STATUS_ABORTED: return 2; |
1200 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1201 | 0 | case GGML_STATUS_FAILED: return -3; |
1202 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1203 | 0 | } |
1204 | 0 | } |
1205 | | |
1206 | 0 | auto * t_logits = res->get_logits(); |
1207 | 0 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
1208 | | |
1209 | | // extract logits |
1210 | 0 | if (logits.data && t_logits) { |
1211 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1212 | 0 | GGML_ASSERT(backend_res != nullptr); |
1213 | 0 | GGML_ASSERT(logits.data != nullptr); |
1214 | |
|
1215 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); |
1216 | 0 | } |
1217 | | |
1218 | | // extract embeddings |
1219 | 0 | if (embd.data && t_embd) { |
1220 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1221 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1222 | |
|
1223 | 0 | switch (cparams.pooling_type) { |
1224 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1225 | 0 | { |
1226 | | // extract token embeddings |
1227 | 0 | GGML_ASSERT(embd.data != nullptr); |
1228 | 0 | const uint32_t n_embd_out = hparams.n_embd_out(); |
1229 | |
|
1230 | 0 | GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); |
1231 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); |
1232 | 0 | } break; |
1233 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1234 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1235 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1236 | 0 | { |
1237 | | // extract sequence embeddings |
1238 | 0 | auto & embd_seq_out = embd_seq; |
1239 | |
|
1240 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1241 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1242 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1243 | |
|
1244 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1245 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1246 | 0 | } |
1247 | 0 | } break; |
1248 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1249 | 0 | { |
1250 | | // extract the rerank score - n_cls_out floats per sequence |
1251 | 0 | auto & embd_seq_out = embd_seq; |
1252 | |
|
1253 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1254 | |
|
1255 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1256 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1257 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1258 | |
|
1259 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1260 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1261 | 0 | } |
1262 | 0 | } break; |
1263 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1264 | 0 | { |
1265 | 0 | GGML_ABORT("unknown pooling type"); |
1266 | 0 | } |
1267 | 0 | } |
1268 | 0 | } |
1269 | | |
1270 | | // TODO: hacky solution |
1271 | 0 | if (model.arch == LLM_ARCH_T5 && t_embd) { |
1272 | | //cross.t_embd = t_embd; |
1273 | |
|
1274 | 0 | synchronize(); |
1275 | |
|
1276 | 0 | cross.n_embd = t_embd->ne[0]; |
1277 | 0 | cross.n_enc = t_embd->ne[1]; |
1278 | 0 | cross.v_embd.resize(cross.n_embd*cross.n_enc); |
1279 | 0 | memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); |
1280 | |
|
1281 | 0 | const auto & batch = balloc->get_batch(); |
1282 | | |
1283 | | // remember the sequence ids used during the encoding - needed for cross attention later |
1284 | 0 | cross.seq_ids_enc.resize(n_tokens); |
1285 | 0 | for (uint32_t i = 0; i < n_tokens; i++) { |
1286 | 0 | cross.seq_ids_enc[i].clear(); |
1287 | |
|
1288 | 0 | for (int s = 0; s < batch.n_seq_id[i]; s++) { |
1289 | 0 | const llama_seq_id seq_id = batch.seq_id[i][s]; |
1290 | |
|
1291 | 0 | cross.seq_ids_enc[i].insert(seq_id); |
1292 | 0 | } |
1293 | 0 | } |
1294 | 0 | } |
1295 | |
|
1296 | 0 | return 0; |
1297 | 0 | } |
1298 | | |
1299 | 0 | static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { |
1300 | 0 | std::map<llama_seq_id, uint32_t> seq_to_row; |
1301 | | // how many output tokens we have seen so far for this ubatch. |
1302 | 0 | uint32_t local = 0; |
1303 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1304 | | // skip tokens that are not output. |
1305 | 0 | if (!ubatch.output[i]) { |
1306 | 0 | continue; |
1307 | 0 | } |
1308 | | |
1309 | 0 | const llama_seq_id seq_id = ubatch.seq_id[i][0]; |
1310 | | // row_offset is the number of output tokens before this ubatch. |
1311 | 0 | seq_to_row[seq_id] = row_offset + local; |
1312 | 0 | ++local; |
1313 | 0 | } |
1314 | 0 | return seq_to_row; |
1315 | 0 | } |
1316 | | |
1317 | | static void copy_tensor_async_ints( |
1318 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1319 | | const buffer_view<llama_token> & sampled, |
1320 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1321 | 0 | ggml_backend_sched_t sched) { |
1322 | 0 | if (!sampled.has_data()) { |
1323 | 0 | return; |
1324 | 0 | } |
1325 | | |
1326 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1327 | 0 | auto it = seq_to_row.find(seq_id); |
1328 | 0 | if (it == seq_to_row.end()) { |
1329 | 0 | continue; |
1330 | 0 | } |
1331 | | |
1332 | 0 | const uint32_t row = it->second; |
1333 | 0 | GGML_ASSERT(row < sampled.size); |
1334 | |
|
1335 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); |
1336 | |
|
1337 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1338 | 0 | ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); |
1339 | 0 | } |
1340 | 0 | } |
1341 | | |
1342 | | static void copy_tensor_async_floats( |
1343 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1344 | | const buffer_view<float> & dst, |
1345 | | size_t stride, |
1346 | | std::vector<uint32_t> & counts, |
1347 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1348 | 0 | ggml_backend_sched_t sched) { |
1349 | 0 | if (!dst.has_data()) { |
1350 | 0 | return; |
1351 | 0 | } |
1352 | | |
1353 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1354 | 0 | auto it = seq_to_row.find(seq_id); |
1355 | 0 | if (it == seq_to_row.end()) { |
1356 | 0 | continue; |
1357 | 0 | } |
1358 | | |
1359 | 0 | const uint32_t row = it->second; |
1360 | 0 | GGML_ASSERT(row < counts.size()); |
1361 | |
|
1362 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); |
1363 | |
|
1364 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1365 | 0 | float * row_ptr = dst.data + (size_t) row * stride; |
1366 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1367 | | |
1368 | | // Update the actual number of logits/probabilities that were written for this row. |
1369 | 0 | counts[row] = ggml_nelements(tensor); |
1370 | 0 | } |
1371 | 0 | } |
1372 | | |
1373 | | static void copy_tensor_async_candidates( |
1374 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1375 | | const buffer_view<llama_token> & dst, |
1376 | | size_t stride, |
1377 | | std::vector<uint32_t> & counts, |
1378 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1379 | 0 | ggml_backend_sched_t sched) { |
1380 | 0 | if (!dst.has_data()) { |
1381 | 0 | return; |
1382 | 0 | } |
1383 | | |
1384 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1385 | 0 | auto it = seq_to_row.find(seq_id); |
1386 | 0 | if (it == seq_to_row.end()) { |
1387 | 0 | continue; |
1388 | 0 | } |
1389 | | |
1390 | 0 | const uint32_t row = it->second; |
1391 | 0 | GGML_ASSERT(row < counts.size()); |
1392 | |
|
1393 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); |
1394 | |
|
1395 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1396 | 0 | llama_token * row_ptr = dst.data + (size_t) row * stride; |
1397 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1398 | | |
1399 | | // Update the actual number of candidates that were written. |
1400 | 0 | counts[row] = ggml_nelements(tensor); |
1401 | 0 | } |
1402 | 0 | } |
1403 | | |
1404 | 0 | static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) { |
1405 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1406 | 0 | if (!ubatch.output[i]) { |
1407 | 0 | continue; |
1408 | 0 | } |
1409 | | |
1410 | | // Check if the output token has at least one sequence without a backend sampler. |
1411 | 0 | for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { |
1412 | 0 | llama_seq_id seq_id = ubatch.seq_id[i][j]; |
1413 | 0 | if (samplers.find(seq_id) == samplers.end()) { |
1414 | 0 | return true; |
1415 | 0 | } |
1416 | 0 | } |
1417 | 0 | } |
1418 | 0 | return false; // all sequences use backend sampling |
1419 | 0 | } |
1420 | | |
1421 | 0 | int llama_context::decode(const llama_batch & batch_inp) { |
1422 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1423 | |
|
1424 | 0 | if (!memory) { |
1425 | 0 | LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |
1426 | 0 | return encode(batch_inp); |
1427 | 0 | } |
1428 | | |
1429 | 0 | if (batch_inp.n_tokens == 0) { |
1430 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1431 | 0 | return -1; |
1432 | 0 | } |
1433 | | |
1434 | 0 | const auto & vocab = model.vocab; |
1435 | 0 | const auto & hparams = model.hparams; |
1436 | |
|
1437 | 0 | const int64_t n_vocab = vocab.n_tokens(); |
1438 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1439 | | |
1440 | | // when computing embeddings, all tokens are output |
1441 | 0 | const bool output_all = cparams.embeddings; |
1442 | 0 | const bool has_samplers = !sampling.samplers.empty(); |
1443 | |
|
1444 | 0 | const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; |
1445 | | |
1446 | | // TODO: avoid this workaround in the future |
1447 | 0 | if (has_samplers && batch_inp.logits) { |
1448 | 0 | std::vector<int32_t> seq_output_count(n_seq_max, 0); |
1449 | |
|
1450 | 0 | for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { |
1451 | 0 | if (batch_inp.logits[i] == 0) { |
1452 | 0 | continue; |
1453 | 0 | } |
1454 | | |
1455 | 0 | const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; |
1456 | |
|
1457 | 0 | for (int32_t s = 0; s < ns; ++s) { |
1458 | 0 | const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; |
1459 | |
|
1460 | 0 | seq_output_count[seq_id]++; |
1461 | 0 | if (seq_output_count[seq_id] > 1) { |
1462 | 0 | LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", |
1463 | 0 | __func__, seq_id, seq_output_count[seq_id]); |
1464 | 0 | return -1; |
1465 | 0 | } |
1466 | 0 | } |
1467 | 0 | } |
1468 | 0 | } |
1469 | | |
1470 | 0 | if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { |
1471 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1472 | 0 | return -1; |
1473 | 0 | } |
1474 | | |
1475 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
1476 | 0 | const uint32_t n_outputs_all = balloc->get_n_outputs(); |
1477 | |
|
1478 | 0 | if (output_all) { |
1479 | | // require that all tokens are output |
1480 | 0 | if (n_outputs_all != n_tokens_all) { |
1481 | 0 | LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", |
1482 | 0 | __func__, n_outputs_all, n_tokens_all); |
1483 | 0 | return -1; |
1484 | 0 | } |
1485 | 0 | } |
1486 | | |
1487 | 0 | GGML_ASSERT(n_tokens_all <= cparams.n_batch); |
1488 | |
|
1489 | 0 | GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |
1490 | |
|
1491 | 0 | if (t_compute_start_us == 0) { |
1492 | 0 | t_compute_start_us = ggml_time_us(); |
1493 | 0 | } |
1494 | 0 | n_queued_tokens += n_tokens_all; |
1495 | | |
1496 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1497 | 0 | embd_seq.clear(); |
1498 | 0 | output_swaps.clear(); |
1499 | |
|
1500 | 0 | sched_reserve(); |
1501 | |
|
1502 | 0 | bool did_optimize = false; |
1503 | | |
1504 | | // handle any pending shifts/copies |
1505 | 0 | memory_update(false); |
1506 | |
|
1507 | 0 | llama_memory_context_ptr mctx; |
1508 | |
|
1509 | 0 | while (true) { |
1510 | 0 | mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); |
1511 | 0 | if (!mctx) { |
1512 | 0 | return -2; |
1513 | 0 | } |
1514 | | |
1515 | 0 | switch (mctx->get_status()) { |
1516 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
1517 | 0 | { |
1518 | 0 | } break; |
1519 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
1520 | 0 | { |
1521 | 0 | LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); |
1522 | |
|
1523 | 0 | return -2; |
1524 | 0 | } |
1525 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
1526 | 0 | { |
1527 | 0 | if (!did_optimize) { |
1528 | 0 | did_optimize = true; |
1529 | |
|
1530 | 0 | if (memory_update(true)) { |
1531 | 0 | LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); |
1532 | |
|
1533 | 0 | continue; |
1534 | 0 | } |
1535 | 0 | } |
1536 | | |
1537 | 0 | LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); |
1538 | |
|
1539 | 0 | return 1; |
1540 | 0 | } |
1541 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
1542 | 0 | { |
1543 | 0 | LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); |
1544 | |
|
1545 | 0 | return -2; |
1546 | 0 | } |
1547 | 0 | } |
1548 | | |
1549 | 0 | break; |
1550 | 0 | } |
1551 | | |
1552 | | // reserve output buffer |
1553 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
1554 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
1555 | 0 | return -2; |
1556 | 0 | }; |
1557 | |
|
1558 | 0 | int64_t n_outputs_prev = 0; |
1559 | |
|
1560 | 0 | do { |
1561 | 0 | const auto & ubatch = mctx->get_ubatch(); |
1562 | | |
1563 | | // count the outputs in this ubatch |
1564 | 0 | { |
1565 | 0 | int32_t n_outputs_new = 0; |
1566 | |
|
1567 | 0 | if (n_outputs_all == n_tokens_all) { |
1568 | 0 | n_outputs_new = ubatch.n_tokens; |
1569 | 0 | } else { |
1570 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1571 | 0 | n_outputs_new += (int32_t) (ubatch.output[i] != 0); |
1572 | 0 | } |
1573 | 0 | } |
1574 | | |
1575 | | // needs to happen before the graph is built |
1576 | 0 | n_outputs = n_outputs_new; |
1577 | 0 | } |
1578 | |
|
1579 | 0 | ggml_status status; |
1580 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); |
1581 | |
|
1582 | 0 | if (!res) { |
1583 | | // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module |
1584 | 0 | llama_pos pos_min[LLAMA_MAX_SEQ]; |
1585 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1586 | 0 | pos_min[s] = std::numeric_limits<llama_pos>::max(); |
1587 | 0 | } |
1588 | |
|
1589 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1590 | 0 | const auto & seq_id = ubatch.seq_id[i][0]; |
1591 | |
|
1592 | 0 | pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
1593 | 0 | } |
1594 | |
|
1595 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1596 | 0 | if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
1597 | 0 | continue; |
1598 | 0 | } |
1599 | | |
1600 | 0 | LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
1601 | |
|
1602 | 0 | memory->seq_rm(s, pos_min[s], -1); |
1603 | 0 | } |
1604 | |
|
1605 | 0 | switch (status) { |
1606 | 0 | case GGML_STATUS_ABORTED: return 2; |
1607 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1608 | 0 | case GGML_STATUS_FAILED: return -3; |
1609 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1610 | 0 | } |
1611 | 0 | } |
1612 | | |
1613 | | // plot the computation graph in dot format (for debugging purposes) |
1614 | | //if (n_past%100 == 0) { |
1615 | | // ggml_graph_dump_dot(gf, NULL, "llama.dot"); |
1616 | | //} |
1617 | | |
1618 | 0 | auto * t_logits = res->get_logits(); |
1619 | 0 | auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; |
1620 | |
|
1621 | 0 | if (t_embd && res->get_embd_pooled()) { |
1622 | 0 | t_embd = res->get_embd_pooled(); |
1623 | 0 | } |
1624 | | |
1625 | | // extract logits |
1626 | 0 | if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { |
1627 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1628 | 0 | GGML_ASSERT(backend_res != nullptr); |
1629 | 0 | GGML_ASSERT(logits.data != nullptr); |
1630 | |
|
1631 | 0 | float * logits_out = logits.data + n_outputs_prev*n_vocab; |
1632 | |
|
1633 | 0 | if (n_outputs) { |
1634 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1635 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); |
1636 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); |
1637 | 0 | } |
1638 | 0 | } |
1639 | | |
1640 | | // extract embeddings |
1641 | 0 | if (embd.data && t_embd && n_outputs > 0) { |
1642 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1643 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1644 | |
|
1645 | 0 | switch (cparams.pooling_type) { |
1646 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1647 | 0 | { |
1648 | | // extract token embeddings |
1649 | 0 | GGML_ASSERT(embd.data != nullptr); |
1650 | 0 | const uint32_t n_embd_out = hparams.n_embd_out(); |
1651 | 0 | float * embd_out = embd.data + n_outputs_prev*n_embd_out; |
1652 | |
|
1653 | 0 | if (n_outputs) { |
1654 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1655 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); |
1656 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); |
1657 | 0 | } |
1658 | 0 | } break; |
1659 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1660 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1661 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1662 | 0 | { |
1663 | | // extract sequence embeddings (cleared before processing each batch) |
1664 | 0 | auto & embd_seq_out = embd_seq; |
1665 | |
|
1666 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1667 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1668 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1669 | |
|
1670 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1671 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1672 | 0 | } |
1673 | 0 | } break; |
1674 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1675 | 0 | { |
1676 | | // extract the rerank score - n_cls_out floats per sequence |
1677 | 0 | auto & embd_seq_out = embd_seq; |
1678 | |
|
1679 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1680 | |
|
1681 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1682 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1683 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1684 | |
|
1685 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1686 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1687 | 0 | } |
1688 | 0 | } break; |
1689 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1690 | 0 | { |
1691 | 0 | GGML_ABORT("unknown pooling type"); |
1692 | 0 | } |
1693 | 0 | } |
1694 | 0 | } |
1695 | | |
1696 | | // Copy backend sampling output if this ubatch produced any sampling tensors. |
1697 | 0 | if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { |
1698 | 0 | const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); |
1699 | 0 | const auto stride = n_vocab; |
1700 | | |
1701 | | // async copy the sampling data from the backend to the host |
1702 | 0 | copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); |
1703 | |
|
1704 | 0 | copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); |
1705 | 0 | copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); |
1706 | 0 | copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); |
1707 | 0 | } |
1708 | |
|
1709 | 0 | n_outputs_prev += n_outputs; |
1710 | 0 | } while (mctx->next()); |
1711 | | |
1712 | | // set to total number of outputs in the batch, for use in llama_get_logits_ith |
1713 | 0 | n_outputs = n_outputs_all; |
1714 | | |
1715 | | // set output mappings |
1716 | 0 | if (n_outputs > 0) { |
1717 | 0 | bool sorted_output = true; |
1718 | |
|
1719 | 0 | auto & out_ids = balloc->get_out_ids(); |
1720 | |
|
1721 | 0 | GGML_ASSERT(out_ids.size() == (size_t) n_outputs); |
1722 | |
|
1723 | 0 | for (int64_t i = 0; i < n_outputs; ++i) { |
1724 | 0 | int64_t out_id = out_ids[i]; |
1725 | 0 | output_ids[out_id] = i; |
1726 | 0 | if (out_id != i) { |
1727 | 0 | sorted_output = false; |
1728 | 0 | } |
1729 | 0 | } |
1730 | | |
1731 | | // make the outputs have the same order they had in the user-provided batch |
1732 | | // note: this is mostly relevant for recurrent models atm |
1733 | 0 | if (!sorted_output && n_outputs > 1) { |
1734 | 0 | GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
1735 | | |
1736 | | // TODO: is there something more efficient which also minimizes swaps? |
1737 | | // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
1738 | 0 | for (uint32_t i = 0; i < n_outputs - 1; ++i) { |
1739 | 0 | uint32_t j_min = i; |
1740 | 0 | for (uint32_t j = i + 1; j < n_outputs; ++j) { |
1741 | 0 | if (out_ids[j] < out_ids[j_min]) { |
1742 | 0 | j_min = j; |
1743 | 0 | } |
1744 | 0 | } |
1745 | 0 | if (j_min == i) { |
1746 | 0 | continue; |
1747 | 0 | } |
1748 | 0 | std::swap(out_ids[i], out_ids[j_min]); |
1749 | | |
1750 | | // remember the swaps and apply them lazily upon logits/embeddings access |
1751 | 0 | output_swaps.push_back({ i, j_min }); |
1752 | 0 | } |
1753 | |
|
1754 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1755 | |
|
1756 | 0 | for (uint32_t i = 0; i < n_outputs; ++i) { |
1757 | 0 | output_ids[out_ids[i]] = i; |
1758 | 0 | } |
1759 | 0 | } |
1760 | 0 | } |
1761 | | |
1762 | | // wait for the computation to finish (automatically done when obtaining the model output) |
1763 | | //synchronize(); |
1764 | |
|
1765 | 0 | return 0; |
1766 | 0 | } |
1767 | | |
1768 | | // |
1769 | | // output |
1770 | | // |
1771 | | |
1772 | 0 | uint32_t llama_context::output_reserve(int32_t n_outputs) { |
1773 | 0 | const auto & hparams = model.hparams; |
1774 | 0 | const auto & vocab = model.vocab; |
1775 | |
|
1776 | 0 | const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); |
1777 | |
|
1778 | 0 | const auto n_batch = cparams.n_batch; |
1779 | 0 | const auto n_vocab = vocab.n_tokens(); |
1780 | 0 | const auto n_embd_out = hparams.n_embd_out(); |
1781 | |
|
1782 | 0 | bool has_logits = true; |
1783 | 0 | bool has_embd = cparams.embeddings; |
1784 | | |
1785 | | // TODO: hacky enc-dec support |
1786 | 0 | if (model.arch == LLM_ARCH_T5) { |
1787 | 0 | has_logits = true; |
1788 | 0 | has_embd = true; |
1789 | 0 | } |
1790 | | |
1791 | |
|
1792 | 0 | size_t backend_float_count = 0; |
1793 | 0 | size_t backend_token_count = 0; |
1794 | |
|
1795 | 0 | logits.size = has_logits ? n_vocab*n_outputs_max : 0; |
1796 | 0 | embd.size = has_embd ? n_embd_out*n_outputs_max : 0; |
1797 | | |
1798 | | // Allocate backend sampling output buffers if there are backend samplers configured. |
1799 | 0 | const bool has_sampling = !sampling.samplers.empty(); |
1800 | 0 | if (has_sampling) { |
1801 | 0 | backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs |
1802 | 0 | backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates |
1803 | 0 | } |
1804 | |
|
1805 | 0 | if (output_ids.empty()) { |
1806 | | // init, never resized afterwards |
1807 | 0 | output_ids.resize(n_batch); |
1808 | 0 | } |
1809 | |
|
1810 | 0 | const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; |
1811 | 0 | const size_t new_size = |
1812 | 0 | (logits.size + embd.size + backend_float_count) * sizeof(float) + |
1813 | 0 | ( backend_token_count) * sizeof(llama_token); |
1814 | | |
1815 | | // alloc only when more than the current capacity is required |
1816 | | // TODO: also consider shrinking the buffer |
1817 | 0 | if (!buf_output || prev_size < new_size) { |
1818 | 0 | if (buf_output) { |
1819 | | #ifndef NDEBUG |
1820 | | // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
1821 | | LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
1822 | | #endif |
1823 | 0 | synchronize(); |
1824 | | |
1825 | | // TODO: not needed? |
1826 | 0 | buf_output = nullptr; |
1827 | 0 | logits.data = nullptr; |
1828 | 0 | embd.data = nullptr; |
1829 | 0 | } |
1830 | |
|
1831 | 0 | auto * buft = ggml_backend_cpu_buffer_type(); |
1832 | | // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
1833 | 0 | auto * output_dev = model.dev_output(); |
1834 | 0 | auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
1835 | 0 | if (output_dev_host_buft) { |
1836 | 0 | buft = output_dev_host_buft; |
1837 | 0 | } |
1838 | 0 | buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
1839 | 0 | if (buf_output == nullptr) { |
1840 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
1841 | 0 | return 0; |
1842 | 0 | } |
1843 | 0 | } |
1844 | | |
1845 | 0 | float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); |
1846 | |
|
1847 | 0 | size_t offset = 0; |
1848 | 0 | uint8_t * base = (uint8_t *) output_base; |
1849 | |
|
1850 | 0 | logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0}; |
1851 | 0 | offset += logits.size * sizeof(float); |
1852 | |
|
1853 | 0 | embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; |
1854 | 0 | offset += embd.size * sizeof(float); |
1855 | |
|
1856 | 0 | if (has_sampling) { |
1857 | 0 | sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1858 | 0 | offset += sampling.logits.size * sizeof(float); |
1859 | |
|
1860 | 0 | sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1861 | 0 | offset += sampling.probs.size * sizeof(float); |
1862 | |
|
1863 | 0 | sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; |
1864 | 0 | offset += sampling.sampled.size * sizeof(llama_token); |
1865 | |
|
1866 | 0 | sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; |
1867 | 0 | offset += sampling.candidates.size * sizeof(llama_token); |
1868 | | |
1869 | | // The count vectors keep track of the actual number of logits/probs/candidates |
1870 | | // copied from the backend for each output row. |
1871 | |
|
1872 | 0 | sampling.logits_count.resize(n_outputs_max); |
1873 | 0 | sampling.probs_count.resize(n_outputs_max); |
1874 | 0 | sampling.candidates_count.resize(n_outputs_max); |
1875 | |
|
1876 | 0 | std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); |
1877 | 0 | std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); |
1878 | 0 | std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); |
1879 | |
|
1880 | 0 | std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); |
1881 | 0 | } else { |
1882 | 0 | sampling.logits = {nullptr, 0}; |
1883 | 0 | sampling.probs = {nullptr, 0}; |
1884 | 0 | sampling.sampled = {nullptr, 0}; |
1885 | 0 | sampling.candidates = {nullptr, 0}; |
1886 | |
|
1887 | 0 | sampling.logits_count.clear(); |
1888 | 0 | sampling.probs_count.clear(); |
1889 | 0 | sampling.candidates_count.clear(); |
1890 | 0 | } |
1891 | | |
1892 | | // set all ids as invalid (negative) |
1893 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1894 | |
|
1895 | 0 | this->n_outputs = 0; |
1896 | |
|
1897 | 0 | return n_outputs_max; |
1898 | 0 | } |
1899 | | |
1900 | 0 | void llama_context::output_reorder() { |
1901 | 0 | const uint64_t n_vocab = model.vocab.n_tokens(); |
1902 | 0 | const uint64_t n_embd = model.hparams.n_embd; |
1903 | |
|
1904 | 0 | for (size_t s = 0; s < output_swaps.size(); ++s) { |
1905 | 0 | const uint64_t i0 = output_swaps[s].i0; |
1906 | 0 | const uint64_t i1 = output_swaps[s].i1; |
1907 | |
|
1908 | 0 | if (logits.size > 0) { |
1909 | 0 | for (uint64_t k = 0; k < n_vocab; k++) { |
1910 | 0 | std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); |
1911 | 0 | } |
1912 | 0 | } |
1913 | |
|
1914 | 0 | if (embd.size > 0) { |
1915 | 0 | for (uint64_t k = 0; k < n_embd; k++) { |
1916 | 0 | std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); |
1917 | 0 | } |
1918 | 0 | } |
1919 | |
|
1920 | 0 | if (!sampling.samplers.empty()) { |
1921 | 0 | assert(sampling.logits.size > 0); |
1922 | 0 | assert(sampling.probs.size > 0); |
1923 | 0 | assert(sampling.candidates.size > 0); |
1924 | 0 | assert(sampling.sampled.size > 0); |
1925 | 0 | assert(sampling.logits_count.size() > 0); |
1926 | 0 | assert(sampling.probs_count.size() > 0); |
1927 | 0 | assert(sampling.candidates_count.size() > 0); |
1928 | |
|
1929 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1930 | 0 | std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); |
1931 | 0 | } |
1932 | |
|
1933 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1934 | 0 | std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); |
1935 | 0 | } |
1936 | |
|
1937 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1938 | 0 | std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); |
1939 | 0 | } |
1940 | |
|
1941 | 0 | std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); |
1942 | 0 | std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); |
1943 | 0 | std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); |
1944 | 0 | std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); |
1945 | 0 | } |
1946 | 0 | } |
1947 | |
|
1948 | 0 | output_swaps.clear(); |
1949 | 0 | } |
1950 | | |
1951 | | // |
1952 | | // graph |
1953 | | // |
1954 | | |
1955 | 0 | uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { |
1956 | 0 | if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { |
1957 | 0 | return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors()); |
1958 | 0 | } |
1959 | 0 | uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors()); |
1960 | 0 | for (const auto & lora : model.loras) { |
1961 | 0 | res += lora->get_n_nodes(); |
1962 | 0 | } |
1963 | 0 | return res; |
1964 | 0 | } |
1965 | | |
1966 | 0 | llm_graph_result * llama_context::get_gf_res_reserve() const { |
1967 | 0 | return static_cast<llm_graph_result *>(gf_res_reserve.get()); |
1968 | 0 | } |
1969 | | |
1970 | | ggml_cgraph * llama_context::graph_reserve( |
1971 | 0 | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) { |
1972 | 0 | LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |
1973 | 0 | GGML_ASSERT(n_outputs >= 1); |
1974 | |
|
1975 | 0 | if (n_tokens % n_seqs != 0) { |
1976 | 0 | n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs |
1977 | 0 | n_outputs = std::max(n_outputs, n_tokens); |
1978 | |
|
1979 | 0 | LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); |
1980 | 0 | } |
1981 | |
|
1982 | 0 | ggml_backend_sched_reset(sched.get()); |
1983 | | |
1984 | | // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that |
1985 | 0 | gf_res_prev->reset(); |
1986 | | |
1987 | | // store the n_outputs as it is, and restore it afterwards |
1988 | | // TODO: not sure if needed, might simplify in the future by removing this |
1989 | 0 | const auto save_n_outputs = this->n_outputs; |
1990 | |
|
1991 | 0 | this->n_outputs = n_outputs; |
1992 | |
|
1993 | 0 | llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); |
1994 | 0 | llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); |
1995 | | |
1996 | | // set one output token per sequence in order to activate all backend samplers |
1997 | 0 | std::vector<llama_seq_id> seq_ids(n_seqs); |
1998 | 0 | for (uint32_t i = 0; i < n_seqs; ++i) { |
1999 | 0 | seq_ids[i] = i; |
2000 | 0 | ubatch.n_seq_id[i] = 1; |
2001 | 0 | ubatch.seq_id[i] = &seq_ids[i]; |
2002 | 0 | ubatch.output[i] = true; |
2003 | 0 | } |
2004 | |
|
2005 | 0 | auto * res = gf_res_reserve.get(); |
2006 | |
|
2007 | 0 | const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); |
2008 | |
|
2009 | 0 | res->reset(); |
2010 | |
|
2011 | 0 | auto * gf = model.build_graph(gparams); |
2012 | |
|
2013 | 0 | this->n_outputs = save_n_outputs; |
2014 | | |
2015 | | // initialize scheduler with the specified graph |
2016 | 0 | if (split_only) { |
2017 | 0 | if (sizes) { |
2018 | 0 | ggml_backend_sched_reserve_size(sched.get(), gf, sizes); |
2019 | 0 | } else { |
2020 | 0 | ggml_backend_sched_split_graph(sched.get(), gf); |
2021 | 0 | } |
2022 | 0 | } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { |
2023 | 0 | GGML_ASSERT(!sizes); |
2024 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |
2025 | 0 | return nullptr; |
2026 | 0 | } |
2027 | | |
2028 | 0 | return gf; |
2029 | 0 | } |
2030 | | |
2031 | | llm_graph_params llama_context::graph_params( |
2032 | | llm_graph_result * res, |
2033 | | const llama_ubatch & ubatch, |
2034 | | const llama_memory_context_i * mctx, |
2035 | 0 | llm_graph_type gtype) const { |
2036 | 0 | return { |
2037 | 0 | /*.arch =*/ model.arch, |
2038 | 0 | /*.hparams =*/ model.hparams, |
2039 | 0 | /*.cparams =*/ cparams, |
2040 | 0 | /*.ubatch =*/ ubatch, |
2041 | 0 | /*.gtype =*/ gtype, |
2042 | 0 | /*.sched =*/ sched.get(), |
2043 | 0 | /*.backend_cpu =*/ backend_cpu, |
2044 | 0 | /*.cvec =*/ cvec.get(), |
2045 | 0 | /*.loras =*/ loras.get(), |
2046 | 0 | /*.mctx =*/ mctx, |
2047 | 0 | /*.cross =*/ &cross, |
2048 | 0 | /*.samplers =*/ sampling.samplers, |
2049 | 0 | /*.n_outputs =*/ n_outputs, |
2050 | 0 | /*.cb =*/ graph_get_cb(), |
2051 | 0 | /*.res =*/ res, |
2052 | 0 | }; |
2053 | 0 | } |
2054 | | |
2055 | | ggml_status llama_context::graph_compute( |
2056 | | ggml_cgraph * gf, |
2057 | 0 | bool batched) { |
2058 | 0 | int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; |
2059 | 0 | ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; |
2060 | |
|
2061 | 0 | if (backend_cpu != nullptr) { |
2062 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); |
2063 | 0 | auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); |
2064 | 0 | if (set_threadpool_fn) { |
2065 | 0 | set_threadpool_fn(backend_cpu, tp); |
2066 | 0 | } |
2067 | 0 | } |
2068 | | |
2069 | | // set the number of threads for all the backends |
2070 | 0 | for (const auto & set_n_threads_fn : set_n_threads_fns) { |
2071 | 0 | set_n_threads_fn.second(set_n_threads_fn.first, n_threads); |
2072 | 0 | } |
2073 | |
|
2074 | 0 | auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); |
2075 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2076 | 0 | LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); |
2077 | 0 | } |
2078 | | |
2079 | | // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); |
2080 | |
|
2081 | 0 | return status; |
2082 | 0 | } |
2083 | | |
2084 | 0 | llm_graph_cb llama_context::graph_get_cb() const { |
2085 | 0 | return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { |
2086 | 0 | if (il >= 0) { |
2087 | 0 | ggml_format_name(cur, "%s-%d", name, il); |
2088 | 0 | } else { |
2089 | 0 | ggml_set_name(cur, name); |
2090 | 0 | } |
2091 | | |
2092 | | // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends |
2093 | | // FIXME: fix in ggml_backend_sched |
2094 | 0 | const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; |
2095 | 0 | if (ubatch.n_tokens < 32 || full_offload) { |
2096 | 0 | if (il != -1 && strcmp(name, "norm") == 0) { |
2097 | 0 | const auto & dev_layer = model.dev_layer(il); |
2098 | 0 | for (const auto & backend : backends) { |
2099 | 0 | if (ggml_backend_get_device(backend.get()) == dev_layer) { |
2100 | 0 | if (ggml_backend_supports_op(backend.get(), cur)) { |
2101 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); |
2102 | 0 | } |
2103 | 0 | } |
2104 | 0 | } |
2105 | 0 | } |
2106 | 0 | } |
2107 | 0 | }; |
2108 | 0 | } |
2109 | | |
2110 | | // |
2111 | | // state save/load |
2112 | | // |
2113 | | |
2114 | | class llama_io_write_dummy : public llama_io_write_i { |
2115 | | public: |
2116 | 0 | llama_io_write_dummy() = default; |
2117 | | |
2118 | 0 | void write(const void * /* src */, size_t size) override { |
2119 | 0 | size_written += size; |
2120 | 0 | } |
2121 | | |
2122 | 0 | void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { |
2123 | 0 | size_written += size; |
2124 | 0 | } |
2125 | | |
2126 | 0 | size_t n_bytes() override { |
2127 | 0 | return size_written; |
2128 | 0 | } |
2129 | | |
2130 | | private: |
2131 | | size_t size_written = 0; |
2132 | | }; |
2133 | | |
2134 | | class llama_io_write_buffer : public llama_io_write_i { |
2135 | | public: |
2136 | | llama_io_write_buffer( |
2137 | 0 | uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2138 | | |
2139 | 0 | void write(const void * src, size_t size) override { |
2140 | 0 | if (size > buf_size) { |
2141 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2142 | 0 | } |
2143 | 0 | memcpy(ptr, src, size); |
2144 | 0 | ptr += size; |
2145 | 0 | size_written += size; |
2146 | 0 | buf_size -= size; |
2147 | 0 | } |
2148 | | |
2149 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2150 | 0 | if (size > buf_size) { |
2151 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2152 | 0 | } |
2153 | 0 | ggml_backend_tensor_get(tensor, ptr, offset, size); |
2154 | 0 | ptr += size; |
2155 | 0 | size_written += size; |
2156 | 0 | buf_size -= size; |
2157 | 0 | } |
2158 | | |
2159 | 0 | size_t n_bytes() override { |
2160 | 0 | return size_written; |
2161 | 0 | } |
2162 | | |
2163 | | private: |
2164 | | uint8_t * ptr; |
2165 | | size_t buf_size = 0; |
2166 | | size_t size_written = 0; |
2167 | | }; |
2168 | | |
2169 | | class llama_io_read_buffer : public llama_io_read_i { |
2170 | | public: |
2171 | 0 | llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2172 | | |
2173 | 0 | const uint8_t * read(size_t size) override { |
2174 | 0 | const uint8_t * base_ptr = ptr; |
2175 | 0 | if (size > buf_size) { |
2176 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2177 | 0 | } |
2178 | 0 | ptr += size; |
2179 | 0 | size_read += size; |
2180 | 0 | buf_size -= size; |
2181 | 0 | return base_ptr; |
2182 | 0 | } |
2183 | | |
2184 | 0 | void read_to(void * dst, size_t size) override { |
2185 | 0 | memcpy(dst, read(size), size); |
2186 | 0 | } |
2187 | | |
2188 | 0 | size_t n_bytes() override { |
2189 | 0 | return size_read; |
2190 | 0 | } |
2191 | | |
2192 | | private: |
2193 | | const uint8_t * ptr; |
2194 | | size_t buf_size = 0; |
2195 | | size_t size_read = 0; |
2196 | | }; |
2197 | | |
2198 | | class llama_io_write_file : public llama_io_write_i { |
2199 | | public: |
2200 | 0 | llama_io_write_file(llama_file * f) : file(f) {} |
2201 | | |
2202 | 0 | void write(const void * src, size_t size) override { |
2203 | 0 | file->write_raw(src, size); |
2204 | 0 | size_written += size; |
2205 | 0 | } |
2206 | | |
2207 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2208 | 0 | temp_buffer.resize(size); |
2209 | 0 | ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); |
2210 | 0 | write(temp_buffer.data(), temp_buffer.size()); |
2211 | 0 | } |
2212 | | |
2213 | 0 | size_t n_bytes() override { |
2214 | 0 | return size_written; |
2215 | 0 | } |
2216 | | |
2217 | | private: |
2218 | | llama_file * file; |
2219 | | size_t size_written = 0; |
2220 | | std::vector<uint8_t> temp_buffer; |
2221 | | }; |
2222 | | |
2223 | | class llama_io_read_file : public llama_io_read_i { |
2224 | | public: |
2225 | 0 | llama_io_read_file(llama_file * f) : file(f) {} |
2226 | | |
2227 | 0 | void read_to(void * dst, size_t size) override { |
2228 | 0 | file->read_raw(dst, size); |
2229 | 0 | size_read += size; |
2230 | 0 | } |
2231 | | |
2232 | 0 | const uint8_t * read(size_t size) override { |
2233 | 0 | temp_buffer.resize(size); |
2234 | 0 | read_to(temp_buffer.data(), size); |
2235 | 0 | return temp_buffer.data(); |
2236 | 0 | } |
2237 | | |
2238 | 0 | size_t n_bytes() override { |
2239 | 0 | return size_read; |
2240 | 0 | } |
2241 | | |
2242 | | private: |
2243 | | llama_file * file; |
2244 | | size_t size_read = 0; |
2245 | | std::vector<uint8_t> temp_buffer; |
2246 | | }; |
2247 | | |
2248 | 0 | size_t llama_context::state_get_size() { |
2249 | 0 | llama_io_write_dummy io; |
2250 | 0 | try { |
2251 | 0 | return state_write_data(io); |
2252 | 0 | } catch (const std::exception & err) { |
2253 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2254 | 0 | return 0; |
2255 | 0 | } |
2256 | 0 | } |
2257 | | |
2258 | 0 | size_t llama_context::state_get_data(uint8_t * dst, size_t size) { |
2259 | 0 | llama_io_write_buffer io(dst, size); |
2260 | 0 | try { |
2261 | 0 | return state_write_data(io); |
2262 | 0 | } catch (const std::exception & err) { |
2263 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2264 | 0 | return 0; |
2265 | 0 | } |
2266 | 0 | } |
2267 | | |
2268 | 0 | size_t llama_context::state_set_data(const uint8_t * src, size_t size) { |
2269 | 0 | llama_io_read_buffer io(src, size); |
2270 | 0 | try { |
2271 | 0 | return state_read_data(io); |
2272 | 0 | } catch (const std::exception & err) { |
2273 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2274 | 0 | return 0; |
2275 | 0 | } |
2276 | 0 | } |
2277 | | |
2278 | 0 | size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { |
2279 | 0 | llama_io_write_dummy io; |
2280 | 0 | try { |
2281 | 0 | return state_seq_write_data(io, seq_id, flags); |
2282 | 0 | } catch (const std::exception & err) { |
2283 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2284 | 0 | return 0; |
2285 | 0 | } |
2286 | 0 | } |
2287 | | |
2288 | 0 | size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { |
2289 | 0 | llama_io_write_buffer io(dst, size); |
2290 | 0 | try { |
2291 | 0 | return state_seq_write_data(io, seq_id, flags); |
2292 | 0 | } catch (const std::exception & err) { |
2293 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2294 | 0 | return 0; |
2295 | 0 | } |
2296 | 0 | } |
2297 | | |
2298 | 0 | size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { |
2299 | 0 | llama_io_read_buffer io(src, size); |
2300 | 0 | try { |
2301 | 0 | return state_seq_read_data(io, seq_id, flags); |
2302 | 0 | } catch (const std::exception & err) { |
2303 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2304 | 0 | return 0; |
2305 | 0 | } |
2306 | 0 | } |
2307 | | |
2308 | 0 | bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2309 | 0 | llama_file file(filepath, "rb"); |
2310 | | |
2311 | | // sanity checks |
2312 | 0 | { |
2313 | 0 | const uint32_t magic = file.read_u32(); |
2314 | 0 | const uint32_t version = file.read_u32(); |
2315 | |
|
2316 | 0 | if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { |
2317 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); |
2318 | 0 | return false; |
2319 | 0 | } |
2320 | 0 | } |
2321 | | |
2322 | | // load the prompt |
2323 | 0 | { |
2324 | 0 | const uint32_t n_token_count = file.read_u32(); |
2325 | |
|
2326 | 0 | if (n_token_count > n_token_capacity) { |
2327 | 0 | LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2328 | 0 | return false; |
2329 | 0 | } |
2330 | | |
2331 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2332 | 0 | *n_token_count_out = n_token_count; |
2333 | 0 | } |
2334 | | |
2335 | | // restore the context state |
2336 | 0 | { |
2337 | 0 | const size_t n_state_size_cur = file.size() - file.tell(); |
2338 | |
|
2339 | 0 | llama_io_read_file io( &file); |
2340 | 0 | const size_t n_read = state_read_data(io); |
2341 | |
|
2342 | 0 | if (n_read != n_state_size_cur) { |
2343 | 0 | LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); |
2344 | 0 | return false; |
2345 | 0 | } |
2346 | 0 | } |
2347 | | |
2348 | 0 | return true; |
2349 | 0 | } |
2350 | | |
2351 | 0 | bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2352 | 0 | llama_file file(filepath, "wb"); |
2353 | |
|
2354 | 0 | file.write_u32(LLAMA_SESSION_MAGIC); |
2355 | 0 | file.write_u32(LLAMA_SESSION_VERSION); |
2356 | | |
2357 | | // save the prompt |
2358 | 0 | file.write_u32((uint32_t) n_token_count); |
2359 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2360 | | |
2361 | | // save the context state using stream saving |
2362 | 0 | llama_io_write_file io(&file); |
2363 | 0 | state_write_data(io); |
2364 | |
|
2365 | 0 | return true; |
2366 | 0 | } |
2367 | | |
2368 | 0 | size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2369 | 0 | llama_file file(filepath, "rb"); |
2370 | | |
2371 | | // version checks |
2372 | 0 | { |
2373 | 0 | const uint32_t magic = file.read_u32(); |
2374 | 0 | const uint32_t version = file.read_u32(); |
2375 | |
|
2376 | 0 | if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { |
2377 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); |
2378 | 0 | return 0; |
2379 | 0 | } |
2380 | 0 | } |
2381 | | |
2382 | | // load the prompt |
2383 | 0 | { |
2384 | 0 | const uint32_t n_token_count = file.read_u32(); |
2385 | |
|
2386 | 0 | if (n_token_count > n_token_capacity) { |
2387 | 0 | LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2388 | 0 | return 0; |
2389 | 0 | } |
2390 | | |
2391 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2392 | 0 | *n_token_count_out = n_token_count; |
2393 | 0 | } |
2394 | | |
2395 | | // restore the context state |
2396 | 0 | { |
2397 | 0 | const size_t state_size = file.size() - file.tell(); |
2398 | 0 | llama_io_read_file io(&file); |
2399 | 0 | const size_t nread = state_seq_read_data(io, seq_id, 0); |
2400 | 0 | if (!nread) { |
2401 | 0 | LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); |
2402 | 0 | return 0; |
2403 | 0 | } |
2404 | 0 | GGML_ASSERT(nread <= state_size); |
2405 | 0 | GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); |
2406 | 0 | } |
2407 | | |
2408 | 0 | return file.tell(); |
2409 | 0 | } |
2410 | | |
2411 | 0 | size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2412 | 0 | llama_file file(filepath, "wb"); |
2413 | |
|
2414 | 0 | file.write_u32(LLAMA_STATE_SEQ_MAGIC); |
2415 | 0 | file.write_u32(LLAMA_STATE_SEQ_VERSION); |
2416 | | |
2417 | | // save the prompt |
2418 | 0 | file.write_u32((uint32_t) n_token_count); |
2419 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2420 | | |
2421 | | // save the context state using stream saving |
2422 | 0 | llama_io_write_file io(&file); |
2423 | 0 | state_seq_write_data(io, seq_id, 0); |
2424 | |
|
2425 | 0 | const size_t res = file.tell(); |
2426 | 0 | GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); |
2427 | |
|
2428 | 0 | return res; |
2429 | 0 | } |
2430 | | |
2431 | 0 | size_t llama_context::state_write_data(llama_io_write_i & io) { |
2432 | 0 | LLAMA_LOG_DEBUG("%s: writing state\n", __func__); |
2433 | | |
2434 | | // write model info |
2435 | 0 | { |
2436 | 0 | LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); |
2437 | |
|
2438 | 0 | const std::string arch_str = llm_arch_name(model.arch); |
2439 | 0 | io.write_string(arch_str); |
2440 | | // TODO: add more model-specific info which should prevent loading the session file if not identical |
2441 | 0 | } |
2442 | |
|
2443 | 0 | if (memory != nullptr) { |
2444 | 0 | LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); |
2445 | 0 | memory->state_write(io); |
2446 | 0 | } |
2447 | |
|
2448 | 0 | return io.n_bytes(); |
2449 | 0 | } |
2450 | | |
2451 | 0 | size_t llama_context::state_read_data(llama_io_read_i & io) { |
2452 | 0 | LLAMA_LOG_DEBUG("%s: reading state\n", __func__); |
2453 | | |
2454 | | // read model info |
2455 | 0 | { |
2456 | 0 | LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); |
2457 | |
|
2458 | 0 | const std::string cur_arch_str = llm_arch_name(model.arch); |
2459 | |
|
2460 | 0 | std::string arch_str; |
2461 | 0 | io.read_string(arch_str); |
2462 | 0 | if (cur_arch_str != arch_str) { |
2463 | 0 | throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); |
2464 | 0 | } |
2465 | | // TODO: add more info which needs to be identical but which is not verified otherwise |
2466 | 0 | } |
2467 | | |
2468 | 0 | if (memory) { |
2469 | 0 | LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); |
2470 | |
|
2471 | 0 | memory->state_read(io); |
2472 | 0 | } |
2473 | |
|
2474 | 0 | return io.n_bytes(); |
2475 | 0 | } |
2476 | | |
2477 | 0 | size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2478 | 0 | GGML_UNUSED(seq_id); |
2479 | |
|
2480 | 0 | if (memory) { |
2481 | 0 | memory->state_write(io, seq_id, flags); |
2482 | 0 | } |
2483 | |
|
2484 | 0 | return io.n_bytes(); |
2485 | 0 | } |
2486 | | |
2487 | 0 | size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2488 | 0 | GGML_UNUSED(seq_id); |
2489 | |
|
2490 | 0 | if (memory) { |
2491 | 0 | memory->state_read(io, seq_id, flags); |
2492 | 0 | } |
2493 | |
|
2494 | 0 | return io.n_bytes(); |
2495 | 0 | } |
2496 | | |
2497 | | // |
2498 | | // perf |
2499 | | // |
2500 | | |
2501 | 0 | llama_perf_context_data llama_context::perf_get_data() const { |
2502 | 0 | llama_perf_context_data data = {}; |
2503 | |
|
2504 | 0 | data.t_start_ms = 1e-3 * t_start_us; |
2505 | 0 | data.t_load_ms = 1e-3 * t_load_us; |
2506 | 0 | data.t_p_eval_ms = 1e-3 * t_p_eval_us; |
2507 | 0 | data.t_eval_ms = 1e-3 * t_eval_us; |
2508 | 0 | data.n_p_eval = std::max(1, n_p_eval); |
2509 | 0 | data.n_eval = std::max(1, n_eval); |
2510 | 0 | data.n_reused = std::max(0, n_reused); |
2511 | |
|
2512 | 0 | return data; |
2513 | 0 | } |
2514 | | |
2515 | 0 | void llama_context::perf_reset() { |
2516 | 0 | t_start_us = ggml_time_us(); |
2517 | 0 | t_eval_us = n_eval = 0; |
2518 | 0 | t_p_eval_us = n_p_eval = 0; |
2519 | 0 | n_reused = 0; |
2520 | 0 | } |
2521 | | |
2522 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { |
2523 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; |
2524 | 0 | for (const auto & [buft, size] : model.memory_breakdown()) { |
2525 | 0 | ret[buft].model += size; |
2526 | 0 | } |
2527 | 0 | if (memory) { |
2528 | 0 | for (const auto & [buft, size] : memory->memory_breakdown()) { |
2529 | 0 | ret[buft].context += size; |
2530 | 0 | } |
2531 | 0 | } |
2532 | 0 | if (model.hparams.no_alloc) { |
2533 | 0 | for (size_t i = 0; i < backends.size(); ++i) { |
2534 | 0 | ggml_backend_t backend = backends[i].get(); |
2535 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2536 | 0 | ret[buft].compute += backend_buf_exp_size[i]; |
2537 | 0 | } |
2538 | 0 | } else { |
2539 | 0 | for (const auto & backend_ptr : backends) { |
2540 | 0 | ggml_backend_t backend = backend_ptr.get(); |
2541 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2542 | 0 | ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); |
2543 | 0 | } |
2544 | 0 | } |
2545 | 0 | return ret; |
2546 | 0 | } |
2547 | | |
2548 | | // |
2549 | | // training |
2550 | | // |
2551 | | |
2552 | 0 | static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { |
2553 | 0 | if (!tensor || tensor->type != GGML_TYPE_F32) { |
2554 | 0 | return; |
2555 | 0 | } |
2556 | 0 | if (!param_filter(tensor, userdata)) { |
2557 | 0 | return; |
2558 | 0 | } |
2559 | 0 | if (strcmp(tensor->name, "token_embd.weight") == 0) { |
2560 | 0 | return; // FIXME |
2561 | 0 | } |
2562 | 0 | if (strcmp(tensor->name, "rope_freqs.weight") == 0) { |
2563 | 0 | return; // FIXME |
2564 | 0 | } |
2565 | 0 | ggml_set_param(tensor); |
2566 | 0 | } |
2567 | | |
2568 | 0 | void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { |
2569 | 0 | GGML_ASSERT(!opt_ctx); |
2570 | 0 | model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); |
2571 | 0 | const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); |
2572 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2573 | 0 | GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); |
2574 | 0 | GGML_ASSERT(n_batch % n_ubatch == 0); |
2575 | |
|
2576 | 0 | ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); |
2577 | 0 | opt_params.opt_period = n_batch / n_ubatch; |
2578 | 0 | opt_params.get_opt_pars = lopt_params.get_opt_pars; |
2579 | 0 | opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; |
2580 | 0 | opt_params.optimizer = lopt_params.optimizer_type; |
2581 | 0 | opt_ctx = ggml_opt_init(opt_params); |
2582 | |
|
2583 | 0 | llama_opt_param_filter param_filter = lopt_params.param_filter; |
2584 | 0 | void * param_filter_ud = lopt_params.param_filter_ud; |
2585 | | |
2586 | | //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME |
2587 | 0 | llama_set_param(model->type_embd, param_filter, param_filter_ud); |
2588 | 0 | llama_set_param(model->pos_embd, param_filter, param_filter_ud); |
2589 | 0 | llama_set_param(model->tok_norm, param_filter, param_filter_ud); |
2590 | 0 | llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); |
2591 | 0 | llama_set_param(model->output_norm, param_filter, param_filter_ud); |
2592 | 0 | llama_set_param(model->output_norm_b, param_filter, param_filter_ud); |
2593 | 0 | llama_set_param(model->output, param_filter, param_filter_ud); |
2594 | 0 | llama_set_param(model->output_b, param_filter, param_filter_ud); |
2595 | 0 | llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); |
2596 | 0 | llama_set_param(model->cls, param_filter, param_filter_ud); |
2597 | 0 | llama_set_param(model->cls_b, param_filter, param_filter_ud); |
2598 | 0 | llama_set_param(model->cls_out, param_filter, param_filter_ud); |
2599 | 0 | llama_set_param(model->cls_out_b, param_filter, param_filter_ud); |
2600 | 0 | llama_set_param(model->cls_norm, param_filter, param_filter_ud); |
2601 | |
|
2602 | 0 | for (struct llama_layer & layer : model->layers) { |
2603 | 0 | for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { |
2604 | 0 | llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud); |
2605 | 0 | } |
2606 | 0 | } |
2607 | 0 | } |
2608 | | |
2609 | | void llama_context::opt_epoch_iter( |
2610 | | ggml_opt_dataset_t dataset, |
2611 | | ggml_opt_result_t result, |
2612 | | const std::vector<llama_token> & tokens, |
2613 | | const std::vector<llama_token> & labels_sparse, |
2614 | | llama_batch & batch, |
2615 | | ggml_opt_epoch_callback callback, |
2616 | | bool train, |
2617 | | int64_t idata_in_loop, |
2618 | | int64_t ndata_in_loop, |
2619 | 0 | int64_t t_loop_start) { |
2620 | 0 | GGML_ASSERT(opt_ctx); |
2621 | 0 | const uint32_t n_ctx = llama_model_n_ctx_train(&model); |
2622 | 0 | const uint32_t n_batch = std::min(this->n_batch(), n_ctx); |
2623 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2624 | |
|
2625 | 0 | memory->clear(true); |
2626 | |
|
2627 | 0 | for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { |
2628 | 0 | batch.n_tokens = n_batch; |
2629 | 0 | for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { |
2630 | 0 | batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; |
2631 | 0 | batch.pos [pos_batch] = pos_ctx + pos_batch; |
2632 | 0 | batch.n_seq_id[pos_batch] = 1; |
2633 | 0 | batch.seq_id [pos_batch][0] = 0; |
2634 | 0 | batch.logits [pos_batch] = true; |
2635 | 0 | } |
2636 | |
|
2637 | 0 | if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
2638 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
2639 | 0 | return; |
2640 | 0 | } |
2641 | | |
2642 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
2643 | |
|
2644 | 0 | n_queued_tokens += n_tokens_all; |
2645 | |
|
2646 | 0 | embd_seq.clear(); |
2647 | |
|
2648 | 0 | uint32_t n_outputs_all = n_tokens_all; |
2649 | |
|
2650 | 0 | auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); |
2651 | 0 | if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |
2652 | 0 | LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |
2653 | 0 | break; |
2654 | 0 | } |
2655 | | |
2656 | | // reserve output buffer |
2657 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
2658 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
2659 | 0 | GGML_ABORT("TODO: handle this error"); |
2660 | 0 | }; |
2661 | |
|
2662 | 0 | uint32_t pos_batch = 0; |
2663 | 0 | do { |
2664 | 0 | const auto & ubatch = mctx->get_ubatch(); |
2665 | |
|
2666 | 0 | n_outputs = ubatch.n_tokens; |
2667 | |
|
2668 | 0 | if (!mctx->apply()) { |
2669 | 0 | LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); |
2670 | 0 | break; |
2671 | 0 | } |
2672 | | |
2673 | 0 | auto * res = gf_res_prev.get(); |
2674 | |
|
2675 | 0 | const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); |
2676 | |
|
2677 | 0 | res->reset(); |
2678 | |
|
2679 | 0 | auto * gf = model.build_graph(gparams); |
2680 | |
|
2681 | 0 | struct ggml_context * ctx_compute_opt; |
2682 | 0 | { |
2683 | 0 | const size_t size_gf = ggml_graph_size(gf); |
2684 | 0 | const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); |
2685 | 0 | struct ggml_init_params params = { |
2686 | 0 | /*.mem_size =*/ size_meta, |
2687 | 0 | /*.mem_buffer =*/ nullptr, |
2688 | 0 | /*.no_alloc =*/ true, |
2689 | 0 | }; |
2690 | 0 | ctx_compute_opt = ggml_init(params); |
2691 | 0 | } |
2692 | 0 | ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); |
2693 | 0 | ggml_opt_alloc(opt_ctx, train); |
2694 | |
|
2695 | 0 | res->set_inputs(&ubatch); |
2696 | 0 | { |
2697 | 0 | struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); |
2698 | 0 | GGML_ASSERT(labels->ne[1] == n_ubatch); |
2699 | 0 | ggml_set_zero(labels); |
2700 | 0 | const float onef = 1.0f; |
2701 | 0 | for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { |
2702 | 0 | const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; |
2703 | 0 | GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); |
2704 | 0 | ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); |
2705 | 0 | } |
2706 | 0 | } |
2707 | 0 | ggml_opt_eval(opt_ctx, result); |
2708 | 0 | if (callback) { |
2709 | 0 | callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); |
2710 | 0 | } |
2711 | 0 | ggml_free(ctx_compute_opt); |
2712 | |
|
2713 | 0 | pos_batch += ubatch.n_tokens; |
2714 | 0 | } while (mctx->next()); |
2715 | 0 | } |
2716 | 0 | } |
2717 | | |
2718 | | void llama_context::opt_epoch( |
2719 | | ggml_opt_dataset_t dataset, |
2720 | | ggml_opt_result_t result_train, |
2721 | | ggml_opt_result_t result_eval, |
2722 | | int64_t idata_split, |
2723 | | ggml_opt_epoch_callback callback_train, |
2724 | 0 | ggml_opt_epoch_callback callback_eval) { |
2725 | 0 | const uint32_t n_ctx = this->n_ctx(); |
2726 | 0 | const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); |
2727 | 0 | const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); |
2728 | 0 | const int64_t ndata = ggml_opt_dataset_ndata(dataset); |
2729 | |
|
2730 | 0 | GGML_ASSERT(idata_split >= 0); |
2731 | 0 | GGML_ASSERT(idata_split <= ndata); |
2732 | |
|
2733 | 0 | const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; |
2734 | |
|
2735 | 0 | struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |
2736 | 0 | std::vector<llama_token> tokens(n_ctx); |
2737 | 0 | std::vector<llama_token> labels_sparse(n_ctx); |
2738 | |
|
2739 | 0 | int64_t idata = 0; |
2740 | |
|
2741 | 0 | int64_t t_loop_start = ggml_time_us(); |
2742 | 0 | int64_t ndata_in_loop = idata_split*ubatch_per_ctx; |
2743 | 0 | for (; idata < idata_split; ++idata) { |
2744 | 0 | constexpr bool train = true; |
2745 | 0 | const int64_t idata_in_loop = idata*ubatch_per_ctx; |
2746 | |
|
2747 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2748 | 0 | opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, |
2749 | 0 | callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2750 | 0 | } |
2751 | |
|
2752 | 0 | t_loop_start = ggml_time_us(); |
2753 | 0 | ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; |
2754 | 0 | for (; idata < ndata; ++idata) { |
2755 | 0 | constexpr bool train = false; |
2756 | 0 | const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; |
2757 | |
|
2758 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2759 | 0 | opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, |
2760 | 0 | callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2761 | 0 | } |
2762 | |
|
2763 | 0 | llama_batch_free(batch); |
2764 | 0 | } |
2765 | | |
2766 | | // |
2767 | | // interface implementation |
2768 | | // |
2769 | | |
2770 | 0 | llama_context_params llama_context_default_params() { |
2771 | 0 | llama_context_params result = { |
2772 | 0 | /*.n_ctx =*/ 512, |
2773 | 0 | /*.n_batch =*/ 2048, |
2774 | 0 | /*.n_ubatch =*/ 512, |
2775 | 0 | /*.n_seq_max =*/ 1, |
2776 | 0 | /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default |
2777 | 0 | /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, |
2778 | 0 | /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, |
2779 | 0 | /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, |
2780 | 0 | /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, |
2781 | 0 | /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, |
2782 | 0 | /*.rope_freq_base =*/ 0.0f, |
2783 | 0 | /*.rope_freq_scale =*/ 0.0f, |
2784 | 0 | /*.yarn_ext_factor =*/ -1.0f, |
2785 | 0 | /*.yarn_attn_factor =*/ -1.0f, |
2786 | 0 | /*.yarn_beta_fast =*/ -1.0f, |
2787 | 0 | /*.yarn_beta_slow =*/ -1.0f, |
2788 | 0 | /*.yarn_orig_ctx =*/ 0, |
2789 | 0 | /*.defrag_thold =*/ -1.0f, |
2790 | 0 | /*.cb_eval =*/ nullptr, |
2791 | 0 | /*.cb_eval_user_data =*/ nullptr, |
2792 | 0 | /*.type_k =*/ GGML_TYPE_F16, |
2793 | 0 | /*.type_v =*/ GGML_TYPE_F16, |
2794 | 0 | /*.abort_callback =*/ nullptr, |
2795 | 0 | /*.abort_callback_data =*/ nullptr, |
2796 | 0 | /*.embeddings =*/ false, |
2797 | 0 | /*.offload_kqv =*/ true, |
2798 | 0 | /*.no_perf =*/ true, |
2799 | 0 | /*.op_offload =*/ true, |
2800 | 0 | /*.swa_full =*/ true, |
2801 | 0 | /*.kv_unified =*/ false, |
2802 | 0 | /*.sampler =*/ nullptr, |
2803 | 0 | /*.n_sampler =*/ 0, |
2804 | 0 | }; |
2805 | |
|
2806 | 0 | return result; |
2807 | 0 | } |
2808 | | |
2809 | | llama_context * llama_init_from_model( |
2810 | | llama_model * model, |
2811 | 0 | llama_context_params params) { |
2812 | 0 | if (!model) { |
2813 | 0 | LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); |
2814 | 0 | return nullptr; |
2815 | 0 | } |
2816 | | |
2817 | 0 | if (params.n_batch == 0 && params.n_ubatch == 0) { |
2818 | 0 | LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); |
2819 | 0 | return nullptr; |
2820 | 0 | } |
2821 | | |
2822 | 0 | if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { |
2823 | 0 | LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); |
2824 | 0 | return nullptr; |
2825 | 0 | } |
2826 | | |
2827 | 0 | if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { |
2828 | 0 | LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); |
2829 | 0 | params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
2830 | 0 | } |
2831 | |
|
2832 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { |
2833 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_k); |
2834 | 0 | if (model->hparams.n_embd_head_k % blck_size != 0) { |
2835 | 0 | LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2836 | 0 | __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); |
2837 | 0 | return nullptr; |
2838 | 0 | } |
2839 | 0 | } |
2840 | | |
2841 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { |
2842 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_v); |
2843 | 0 | if (model->hparams.n_embd_head_v % blck_size != 0) { |
2844 | 0 | LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2845 | 0 | __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); |
2846 | 0 | return nullptr; |
2847 | 0 | } |
2848 | 0 | } |
2849 | | |
2850 | 0 | if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { |
2851 | 0 | LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |
2852 | 0 | return nullptr; |
2853 | 0 | } |
2854 | | |
2855 | 0 | if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && |
2856 | 0 | params.pooling_type != model->hparams.pooling_type) { |
2857 | | //user-specified pooling-type is different from the model default |
2858 | 0 | LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, |
2859 | 0 | model->hparams.pooling_type, params.pooling_type); |
2860 | 0 | } |
2861 | |
|
2862 | 0 | try { |
2863 | 0 | auto * ctx = new llama_context(*model, params); |
2864 | 0 | return ctx; |
2865 | 0 | } catch (const std::exception & err) { |
2866 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); |
2867 | 0 | } |
2868 | | |
2869 | 0 | return nullptr; |
2870 | 0 | } |
2871 | | |
2872 | | // deprecated |
2873 | | llama_context * llama_new_context_with_model( |
2874 | | llama_model * model, |
2875 | 0 | llama_context_params params) { |
2876 | 0 | return llama_init_from_model(model, params); |
2877 | 0 | } |
2878 | | |
2879 | 0 | void llama_free(llama_context * ctx) { |
2880 | 0 | delete ctx; |
2881 | 0 | } |
2882 | | |
2883 | 0 | uint32_t llama_n_ctx(const llama_context * ctx) { |
2884 | 0 | return ctx->n_ctx(); |
2885 | 0 | } |
2886 | | |
2887 | 0 | uint32_t llama_n_ctx_seq(const llama_context * ctx) { |
2888 | 0 | return ctx->n_ctx_seq(); |
2889 | 0 | } |
2890 | | |
2891 | 0 | uint32_t llama_n_batch(const llama_context * ctx) { |
2892 | 0 | return ctx->n_batch(); |
2893 | 0 | } |
2894 | | |
2895 | 0 | uint32_t llama_n_ubatch(const llama_context * ctx) { |
2896 | 0 | return ctx->n_ubatch(); |
2897 | 0 | } |
2898 | | |
2899 | 0 | uint32_t llama_n_seq_max(const llama_context * ctx) { |
2900 | 0 | return ctx->n_seq_max(); |
2901 | 0 | } |
2902 | | |
2903 | 0 | const llama_model * llama_get_model(const llama_context * ctx) { |
2904 | 0 | return &ctx->get_model(); |
2905 | 0 | } |
2906 | | |
2907 | 0 | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { |
2908 | 0 | return ctx->pooling_type(); |
2909 | 0 | } |
2910 | | |
2911 | | void llama_attach_threadpool( |
2912 | | llama_context * ctx, |
2913 | | ggml_threadpool_t threadpool, |
2914 | 0 | ggml_threadpool_t threadpool_batch) { |
2915 | 0 | ctx->attach_threadpool(threadpool, threadpool_batch); |
2916 | 0 | } |
2917 | | |
2918 | 0 | void llama_detach_threadpool(llama_context * ctx) { |
2919 | 0 | ctx->detach_threadpool(); |
2920 | 0 | } |
2921 | | |
2922 | 0 | void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
2923 | 0 | ctx->set_n_threads(n_threads, n_threads_batch); |
2924 | 0 | } |
2925 | | |
2926 | 0 | int32_t llama_n_threads(llama_context * ctx) { |
2927 | 0 | return ctx->n_threads(); |
2928 | 0 | } |
2929 | | |
2930 | 0 | int32_t llama_n_threads_batch(llama_context * ctx) { |
2931 | 0 | return ctx->n_threads_batch(); |
2932 | 0 | } |
2933 | | |
2934 | 0 | void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
2935 | 0 | ctx->set_abort_callback(abort_callback, abort_callback_data); |
2936 | 0 | } |
2937 | | |
2938 | 0 | void llama_set_embeddings(llama_context * ctx, bool embeddings) { |
2939 | 0 | ctx->set_embeddings(embeddings); |
2940 | 0 | } |
2941 | | |
2942 | 0 | void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { |
2943 | 0 | ctx->set_causal_attn(causal_attn); |
2944 | 0 | } |
2945 | | |
2946 | 0 | void llama_set_warmup(llama_context * ctx, bool warmup) { |
2947 | 0 | ctx->set_warmup(warmup); |
2948 | 0 | } |
2949 | | |
2950 | 0 | void llama_synchronize(llama_context * ctx) { |
2951 | 0 | ctx->synchronize(); |
2952 | 0 | } |
2953 | | |
2954 | 0 | float * llama_get_logits(llama_context * ctx) { |
2955 | 0 | ctx->synchronize(); |
2956 | |
|
2957 | 0 | return ctx->get_logits(); |
2958 | 0 | } |
2959 | | |
2960 | 0 | float * llama_get_logits_ith(llama_context * ctx, int32_t i) { |
2961 | 0 | ctx->synchronize(); |
2962 | |
|
2963 | 0 | float * res = nullptr; |
2964 | |
|
2965 | 0 | res = ctx->get_sampled_logits_ith(i); |
2966 | |
|
2967 | 0 | if (!res) { |
2968 | 0 | res = ctx->get_logits_ith(i); |
2969 | 0 | } |
2970 | |
|
2971 | 0 | return res; |
2972 | 0 | } |
2973 | | |
2974 | 0 | float * llama_get_embeddings(llama_context * ctx) { |
2975 | 0 | ctx->synchronize(); |
2976 | |
|
2977 | 0 | return ctx->get_embeddings(); |
2978 | 0 | } |
2979 | | |
2980 | 0 | float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { |
2981 | 0 | ctx->synchronize(); |
2982 | |
|
2983 | 0 | return ctx->get_embeddings_ith(i); |
2984 | 0 | } |
2985 | | |
2986 | 0 | float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { |
2987 | 0 | ctx->synchronize(); |
2988 | |
|
2989 | 0 | return ctx->get_embeddings_seq(seq_id); |
2990 | 0 | } |
2991 | | |
2992 | 0 | bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { |
2993 | 0 | return ctx->set_sampler(seq_id, smpl); |
2994 | 0 | } |
2995 | | |
2996 | 0 | llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { |
2997 | 0 | ctx->synchronize(); |
2998 | |
|
2999 | 0 | return ctx->get_sampled_token_ith(i); |
3000 | 0 | } |
3001 | | |
3002 | 0 | float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { |
3003 | 0 | ctx->synchronize(); |
3004 | |
|
3005 | 0 | return ctx->get_sampled_probs_ith(i); |
3006 | 0 | } |
3007 | | |
3008 | 0 | float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { |
3009 | 0 | ctx->synchronize(); |
3010 | |
|
3011 | 0 | return ctx->get_sampled_logits_ith(i); |
3012 | 0 | } |
3013 | | |
3014 | 0 | llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { |
3015 | 0 | ctx->synchronize(); |
3016 | |
|
3017 | 0 | return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i)); |
3018 | 0 | } |
3019 | | |
3020 | 0 | uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { |
3021 | 0 | ctx->synchronize(); |
3022 | |
|
3023 | 0 | return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i)); |
3024 | 0 | } |
3025 | | |
3026 | 0 | uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { |
3027 | 0 | ctx->synchronize(); |
3028 | |
|
3029 | 0 | return static_cast<uint32_t>(ctx->get_sampled_logits_count(i)); |
3030 | 0 | } |
3031 | | |
3032 | 0 | uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { |
3033 | 0 | ctx->synchronize(); |
3034 | |
|
3035 | 0 | return static_cast<uint32_t>(ctx->get_sampled_probs_count(i)); |
3036 | 0 | } |
3037 | | |
3038 | | // llama adapter API |
3039 | | |
3040 | | int32_t llama_set_adapters_lora( |
3041 | | llama_context * ctx, |
3042 | | llama_adapter_lora ** adapters, |
3043 | | size_t n_adapters, |
3044 | 0 | float * scales) { |
3045 | 0 | if (adapters == nullptr || scales == nullptr) { |
3046 | 0 | GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); |
3047 | 0 | } |
3048 | |
|
3049 | 0 | ctx->set_adapters_lora(adapters, n_adapters, scales); |
3050 | |
|
3051 | 0 | return 0; |
3052 | 0 | } |
3053 | | |
3054 | | int32_t llama_set_adapter_cvec( |
3055 | | llama_context * ctx, |
3056 | | const float * data, |
3057 | | size_t len, |
3058 | | int32_t n_embd, |
3059 | | int32_t il_start, |
3060 | 0 | int32_t il_end) { |
3061 | 0 | bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); |
3062 | |
|
3063 | 0 | return res ? 0 : -1; |
3064 | 0 | } |
3065 | | |
3066 | | // |
3067 | | // memory |
3068 | | // |
3069 | | |
3070 | 0 | llama_memory_t llama_get_memory(const struct llama_context * ctx) { |
3071 | 0 | return ctx->get_memory(); |
3072 | 0 | } |
3073 | | |
3074 | 0 | void llama_memory_clear(llama_memory_t mem, bool data) { |
3075 | 0 | if (!mem) { |
3076 | 0 | return; |
3077 | 0 | } |
3078 | | |
3079 | 0 | mem->clear(data); |
3080 | 0 | } |
3081 | | |
3082 | | bool llama_memory_seq_rm( |
3083 | | llama_memory_t mem, |
3084 | | llama_seq_id seq_id, |
3085 | | llama_pos p0, |
3086 | 0 | llama_pos p1) { |
3087 | 0 | if (!mem) { |
3088 | 0 | return true; |
3089 | 0 | } |
3090 | | |
3091 | 0 | return mem->seq_rm(seq_id, p0, p1); |
3092 | 0 | } |
3093 | | |
3094 | | void llama_memory_seq_cp( |
3095 | | llama_memory_t mem, |
3096 | | llama_seq_id seq_id_src, |
3097 | | llama_seq_id seq_id_dst, |
3098 | | llama_pos p0, |
3099 | 0 | llama_pos p1) { |
3100 | 0 | if (!mem) { |
3101 | 0 | return; |
3102 | 0 | } |
3103 | | |
3104 | 0 | mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
3105 | 0 | } |
3106 | | |
3107 | | void llama_memory_seq_keep( |
3108 | | llama_memory_t mem, |
3109 | 0 | llama_seq_id seq_id) { |
3110 | 0 | if (!mem) { |
3111 | 0 | return; |
3112 | 0 | } |
3113 | | |
3114 | 0 | mem->seq_keep(seq_id); |
3115 | 0 | } |
3116 | | |
3117 | | void llama_memory_seq_add( |
3118 | | llama_memory_t mem, |
3119 | | llama_seq_id seq_id, |
3120 | | llama_pos p0, |
3121 | | llama_pos p1, |
3122 | 0 | llama_pos delta) { |
3123 | 0 | if (!mem) { |
3124 | 0 | return; |
3125 | 0 | } |
3126 | | |
3127 | 0 | mem->seq_add(seq_id, p0, p1, delta); |
3128 | 0 | } |
3129 | | |
3130 | | void llama_memory_seq_div( |
3131 | | llama_memory_t mem, |
3132 | | llama_seq_id seq_id, |
3133 | | llama_pos p0, |
3134 | | llama_pos p1, |
3135 | 0 | int d) { |
3136 | 0 | if (!mem) { |
3137 | 0 | return; |
3138 | 0 | } |
3139 | | |
3140 | 0 | mem->seq_div(seq_id, p0, p1, d); |
3141 | 0 | } |
3142 | | |
3143 | | llama_pos llama_memory_seq_pos_min( |
3144 | | llama_memory_t mem, |
3145 | 0 | llama_seq_id seq_id) { |
3146 | 0 | if (!mem) { |
3147 | 0 | return -1; |
3148 | 0 | } |
3149 | | |
3150 | 0 | return mem->seq_pos_min(seq_id); |
3151 | 0 | } |
3152 | | |
3153 | | llama_pos llama_memory_seq_pos_max( |
3154 | | llama_memory_t mem, |
3155 | 0 | llama_seq_id seq_id) { |
3156 | 0 | if (!mem) { |
3157 | 0 | return -1; |
3158 | 0 | } |
3159 | | |
3160 | 0 | return mem->seq_pos_max(seq_id); |
3161 | 0 | } |
3162 | | |
3163 | 0 | bool llama_memory_can_shift(llama_memory_t mem) { |
3164 | 0 | if (!mem) { |
3165 | 0 | return false; |
3166 | 0 | } |
3167 | | |
3168 | 0 | return mem->get_can_shift(); |
3169 | 0 | } |
3170 | | |
3171 | | // llama state API |
3172 | | |
3173 | | // deprecated |
3174 | 0 | size_t llama_get_state_size(llama_context * ctx) { |
3175 | 0 | return llama_state_get_size(ctx); |
3176 | 0 | } |
3177 | | |
3178 | | // deprecated |
3179 | 0 | size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { |
3180 | 0 | return llama_state_get_data(ctx, dst, -1); |
3181 | 0 | } |
3182 | | |
3183 | | // deprecated |
3184 | 0 | size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { |
3185 | 0 | return llama_state_set_data(ctx, src, -1); |
3186 | 0 | } |
3187 | | |
3188 | | // deprecated |
3189 | 0 | bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3190 | 0 | return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); |
3191 | 0 | } |
3192 | | |
3193 | | // deprecated |
3194 | 0 | bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3195 | 0 | return llama_state_save_file(ctx, path_session, tokens, n_token_count); |
3196 | 0 | } |
3197 | | |
3198 | | // Returns the *actual* size of the state. |
3199 | | // Intended to be used when saving to state to a buffer. |
3200 | 0 | size_t llama_state_get_size(llama_context * ctx) { |
3201 | 0 | return ctx->state_get_size(); |
3202 | 0 | } |
3203 | | |
3204 | 0 | size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { |
3205 | 0 | ctx->synchronize(); |
3206 | |
|
3207 | 0 | return ctx->state_get_data(dst, size); |
3208 | 0 | } |
3209 | | |
3210 | | // Sets the state reading from the specified source address |
3211 | 0 | size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { |
3212 | 0 | ctx->synchronize(); |
3213 | |
|
3214 | 0 | return ctx->state_set_data(src, size); |
3215 | 0 | } |
3216 | | |
3217 | 0 | bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3218 | 0 | ctx->synchronize(); |
3219 | |
|
3220 | 0 | try { |
3221 | 0 | return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); |
3222 | 0 | } catch (const std::exception & err) { |
3223 | 0 | LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); |
3224 | 0 | return false; |
3225 | 0 | } |
3226 | 0 | } |
3227 | | |
3228 | 0 | bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3229 | 0 | ctx->synchronize(); |
3230 | |
|
3231 | 0 | try { |
3232 | 0 | return ctx->state_save_file(path_session, tokens, n_token_count); |
3233 | 0 | } catch (const std::exception & err) { |
3234 | 0 | LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); |
3235 | 0 | return false; |
3236 | 0 | } |
3237 | 0 | } |
3238 | | |
3239 | 0 | size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { |
3240 | 0 | return llama_state_seq_get_size_ext(ctx, seq_id, 0); |
3241 | 0 | } |
3242 | | |
3243 | 0 | size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { |
3244 | 0 | return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); |
3245 | 0 | } |
3246 | | |
3247 | 0 | size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { |
3248 | 0 | return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); |
3249 | 0 | } |
3250 | | |
3251 | 0 | size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3252 | 0 | return ctx->state_seq_get_size(seq_id, flags); |
3253 | 0 | } |
3254 | | |
3255 | 0 | size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3256 | 0 | ctx->synchronize(); |
3257 | |
|
3258 | 0 | return ctx->state_seq_get_data(seq_id, dst, size, flags); |
3259 | 0 | } |
3260 | | |
3261 | 0 | size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3262 | 0 | ctx->synchronize(); |
3263 | |
|
3264 | 0 | return ctx->state_seq_set_data(seq_id, src, size, flags); |
3265 | 0 | } |
3266 | | |
3267 | 0 | size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { |
3268 | 0 | ctx->synchronize(); |
3269 | |
|
3270 | 0 | try { |
3271 | 0 | return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); |
3272 | 0 | } catch (const std::exception & err) { |
3273 | 0 | LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); |
3274 | 0 | return 0; |
3275 | 0 | } |
3276 | 0 | } |
3277 | | |
3278 | 0 | size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3279 | 0 | ctx->synchronize(); |
3280 | |
|
3281 | 0 | try { |
3282 | 0 | return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); |
3283 | 0 | } catch (const std::exception & err) { |
3284 | 0 | LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); |
3285 | 0 | return 0; |
3286 | 0 | } |
3287 | 0 | } |
3288 | | |
3289 | | /// |
3290 | | |
3291 | | int32_t llama_encode( |
3292 | | llama_context * ctx, |
3293 | 0 | llama_batch batch) { |
3294 | 0 | const int ret = ctx->encode(batch); |
3295 | 0 | if (ret != 0) { |
3296 | 0 | LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
3297 | 0 | } |
3298 | |
|
3299 | 0 | return ret; |
3300 | 0 | } |
3301 | | |
3302 | | int32_t llama_decode( |
3303 | | llama_context * ctx, |
3304 | 0 | llama_batch batch) { |
3305 | 0 | const int ret = ctx->decode(batch); |
3306 | 0 | if (ret != 0 && ret != 1) { |
3307 | 0 | LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
3308 | 0 | } |
3309 | |
|
3310 | 0 | return ret; |
3311 | 0 | } |
3312 | | |
3313 | | // |
3314 | | // perf |
3315 | | // |
3316 | | |
3317 | 0 | llama_perf_context_data llama_perf_context(const llama_context * ctx) { |
3318 | 0 | llama_perf_context_data data = {}; |
3319 | |
|
3320 | 0 | if (ctx == nullptr) { |
3321 | 0 | return data; |
3322 | 0 | } |
3323 | | |
3324 | 0 | data = ctx->perf_get_data(); |
3325 | |
|
3326 | 0 | return data; |
3327 | 0 | } |
3328 | | |
3329 | 0 | void llama_perf_context_print(const llama_context * ctx) { |
3330 | 0 | const auto data = llama_perf_context(ctx); |
3331 | |
|
3332 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
3333 | |
|
3334 | 0 | LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
3335 | 0 | LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
3336 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
3337 | 0 | LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
3338 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
3339 | 0 | LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
3340 | 0 | LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); |
3341 | 0 | } |
3342 | | |
3343 | 0 | void llama_perf_context_reset(llama_context * ctx) { |
3344 | 0 | ctx->perf_reset(); |
3345 | 0 | } |
3346 | | |
3347 | 0 | void llama_memory_breakdown_print(const struct llama_context * ctx) { |
3348 | 0 | const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; |
3349 | |
|
3350 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); |
3351 | |
|
3352 | 0 | std::vector<std::array<std::string, 9>> table_data; |
3353 | 0 | table_data.reserve(devices.size()); |
3354 | 0 | const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; |
3355 | 0 | const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; |
3356 | 0 | const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; |
3357 | |
|
3358 | 0 | table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); |
3359 | |
|
3360 | 0 | constexpr size_t MiB = 1024 * 1024; |
3361 | 0 | const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; |
3362 | | |
3363 | | // track seen buffer types to avoid double counting: |
3364 | 0 | std::set<ggml_backend_buffer_type_t> seen_buffer_types; |
3365 | | |
3366 | | // accumulative memory breakdown for each device and for host: |
3367 | 0 | std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); |
3368 | 0 | llama_memory_breakdown_data mb_host; |
3369 | |
|
3370 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3371 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3372 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3373 | 0 | if (ggml_backend_buft_is_host(buft)) { |
3374 | 0 | mb_host.model += mb.model; |
3375 | 0 | mb_host.context += mb.context; |
3376 | 0 | mb_host.compute += mb.compute; |
3377 | 0 | seen_buffer_types.insert(buft); |
3378 | 0 | continue; |
3379 | 0 | } |
3380 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); |
3381 | 0 | if (dev) { |
3382 | 0 | int i_dev = -1; |
3383 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3384 | 0 | if (devices[i] == dev) { |
3385 | 0 | i_dev = i; |
3386 | 0 | break; |
3387 | 0 | } |
3388 | 0 | } |
3389 | 0 | if (i_dev != -1) { |
3390 | 0 | mb_dev[i_dev].model += mb.model; |
3391 | 0 | mb_dev[i_dev].context += mb.context; |
3392 | 0 | mb_dev[i_dev].compute += mb.compute; |
3393 | 0 | seen_buffer_types.insert(buft); |
3394 | 0 | continue; |
3395 | 0 | } |
3396 | 0 | } |
3397 | 0 | } |
3398 | | |
3399 | | // print memory breakdown for each device: |
3400 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3401 | 0 | ggml_backend_dev_t dev = devices[i]; |
3402 | 0 | llama_memory_breakdown_data mb = mb_dev[i]; |
3403 | |
|
3404 | 0 | const std::string name = ggml_backend_dev_name(dev); |
3405 | 0 | std::string desc = ggml_backend_dev_description(dev); |
3406 | 0 | for (const std::string & prefix : desc_prefixes_strip) { |
3407 | 0 | if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { |
3408 | 0 | desc = desc.substr(prefix.length()); |
3409 | 0 | } |
3410 | 0 | } |
3411 | |
|
3412 | 0 | size_t free, total; |
3413 | 0 | ggml_backend_dev_memory(dev, &free, &total); |
3414 | |
|
3415 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3416 | 0 | const size_t unaccounted = total - self - free; |
3417 | |
|
3418 | 0 | table_data.push_back({ |
3419 | 0 | template_gpu, |
3420 | 0 | " - " + name + " (" + desc + ")", |
3421 | 0 | std::to_string(total / MiB), |
3422 | 0 | std::to_string(free / MiB), |
3423 | 0 | std::to_string(self / MiB), |
3424 | 0 | std::to_string(mb.model / MiB), |
3425 | 0 | std::to_string(mb.context / MiB), |
3426 | 0 | std::to_string(mb.compute / MiB), |
3427 | 0 | std::to_string(unaccounted / MiB)}); |
3428 | 0 | } |
3429 | | |
3430 | | // print memory breakdown for host: |
3431 | 0 | { |
3432 | 0 | const size_t self = mb_host.model + mb_host.context + mb_host.compute; |
3433 | 0 | table_data.push_back({ |
3434 | 0 | template_other, |
3435 | 0 | " - Host", |
3436 | 0 | "", // total |
3437 | 0 | "", // free |
3438 | 0 | std::to_string(self / MiB), |
3439 | 0 | std::to_string(mb_host.model / MiB), |
3440 | 0 | std::to_string(mb_host.context / MiB), |
3441 | 0 | std::to_string(mb_host.compute / MiB), |
3442 | 0 | ""}); // unaccounted |
3443 | 0 | } |
3444 | | |
3445 | | // print memory breakdown for all remaining buffer types: |
3446 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3447 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3448 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3449 | 0 | if (seen_buffer_types.count(buft) == 1) { |
3450 | 0 | continue; |
3451 | 0 | } |
3452 | 0 | const std::string name = ggml_backend_buft_name(buft); |
3453 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3454 | 0 | table_data.push_back({ |
3455 | 0 | template_other, |
3456 | 0 | " - " + name, |
3457 | 0 | "", // total |
3458 | 0 | "", // free |
3459 | 0 | std::to_string(self / MiB), |
3460 | 0 | std::to_string(mb.model / MiB), |
3461 | 0 | std::to_string(mb.context / MiB), |
3462 | 0 | std::to_string(mb.compute / MiB), |
3463 | 0 | ""}); // unaccounted |
3464 | 0 | seen_buffer_types.insert(buft); |
3465 | 0 | } |
3466 | |
|
3467 | 0 | for (size_t j = 1; j < table_data[0].size(); j++) { |
3468 | 0 | size_t max_len = 0; |
3469 | 0 | for (const auto & td : table_data) { |
3470 | 0 | max_len = std::max(max_len, td[j].length()); |
3471 | 0 | } |
3472 | 0 | for (auto & td : table_data) { |
3473 | 0 | td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); |
3474 | 0 | } |
3475 | 0 | } |
3476 | 0 | for (const auto & td : table_data) { |
3477 | 0 | LLAMA_LOG_INFO(td[0].c_str(), |
3478 | 0 | __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), |
3479 | 0 | td[6].c_str(), td[7].c_str(), td[8].c_str()); |
3480 | 0 | } |
3481 | 0 | } |
3482 | | |
3483 | | // |
3484 | | // training |
3485 | | // |
3486 | | |
3487 | 0 | bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { |
3488 | 0 | GGML_UNUSED(tensor); |
3489 | 0 | GGML_UNUSED(userdata); |
3490 | 0 | return true; |
3491 | 0 | } |
3492 | | |
3493 | 0 | void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { |
3494 | 0 | ctx->opt_init(model, lopt_params); |
3495 | 0 | } |
3496 | | |
3497 | | void llama_opt_epoch( |
3498 | | struct llama_context * ctx, |
3499 | | ggml_opt_dataset_t dataset, |
3500 | | ggml_opt_result_t result_train, |
3501 | | ggml_opt_result_t result_eval, |
3502 | | int64_t idata_split, |
3503 | | ggml_opt_epoch_callback callback_train, |
3504 | 0 | ggml_opt_epoch_callback callback_eval) { |
3505 | 0 | ctx->opt_epoch( |
3506 | 0 | dataset, |
3507 | 0 | result_train, |
3508 | 0 | result_eval, |
3509 | 0 | idata_split, |
3510 | 0 | callback_train, |
3511 | 0 | callback_eval); |
3512 | 0 | } |